Compare commits

...

141 Commits

Author SHA1 Message Date
Daniel Hiltgen
c8e0878814 enable flash attention for gemma4 (#15296) 2026-04-03 12:46:18 -07:00
Jesse Gross
bb0c58e134 ggml: skip cublasGemmBatchedEx during graph reservation
cublasGemmBatchedEx fails during graph capture when pool allocations
return fake pointers. This is triggered when NUM_PARALLEL is greater
than 1 for models like gemma4 that use batched matmuls. Skip it
during reservation since the memory tracking is already handled by
the pool allocations.

Fixes #15249
2026-04-03 12:41:09 -07:00
Devon Rifkin
036ed1b9b5 model/parsers: fix gemma4 arg parsing when quoted strings contain " (#15254)
* model/parsers: fix gemma4 arg parsing when quoted strings contain "

Fixes: #15241

* add more tests, be careful about what we escape

We want Windows-style paths to not get misinterpreted

* fix backslash-quote case, it really should be a literal backslash

h/t to @chathaway-codes for pointing this out!

Co-Authored-By: Charles H <2773397+chathaway-codes@users.noreply.github.com>

---------

Co-authored-by: Charles H <2773397+chathaway-codes@users.noreply.github.com>
2026-04-02 22:52:51 -07:00
Daniel Hiltgen
3536ef58f6 bench: add prompt calibration, context size flag, and NumCtx reporting (#15158)
Add --num-ctx flag to set context size, and report NumCtx in model info
header. Calibrate tokens-per-word ratio during warmup using actual
tokenization metrics from the model, replacing the fixed 1.3 heuristic.
This produces more accurate prompt token counts for --prompt-tokens.

Also add fetchContextLength() to query running model context via /api/ps.
2026-04-02 14:23:53 -07:00
Daniel Hiltgen
de9673ac3f tokenizer: add byte fallback for SentencePiece BPE encoding (#15232)
* tokenizer: add byte fallback for SentencePiece BPE encoding

When BPE merging produces tokens not in the vocabulary, fall back to
encoding each UTF-8 byte as <0xHH> byte tokens instead of silently
dropping the character. Also teach Decode to convert <0xHH> tokens
back to raw bytes.

Fixes #15229, fixes #15231

* tokenizer fixes
2026-04-02 13:04:45 -07:00
Daniel Hiltgen
96b202d34b Add support for gemma4 (#15214)
* bench: add prompt calibration, context size flag, and NumCtx reporting

Add --num-ctx flag to set context size, and report NumCtx in model info
header. Calibrate tokens-per-word ratio during warmup using actual
tokenization metrics from the model, replacing the fixed 1.3 heuristic.
This produces more accurate prompt token counts for --prompt-tokens.

Also add fetchContextLength() to query running model context via /api/ps.

* integration: improve vision test robustness and add thinking tests

Add skipIfNoVisionOverride() to skip vision tests when OLLAMA_TEST_MODEL
is set to a non-vision model. Add Think:false to context exhaustion test
to prevent thinking models from using all context before the test can
measure it. Add third test image (ollama homepage) and replace OCR test
with ImageDescription test using it. Relax match strings for broader
model compatibility. Add TestThinkingEnabled and TestThinkingSuppressed
to verify thinking output and channel tag handling.

* gemma4: add Gemma 4 GGML model support

Add full Gemma 4 model family support (E2B, E4B, 26B MoE, 31B Dense)
for the GGML backend including text, vision, converter, parser, and
renderer.

Text model features:
- Sliding window + full attention with per-layer patterns
- KV sharing across layers with donor map
- Per-layer embeddings (PLE) with learned projections
- MoE routing with RMSNorm + learned scale
- Proportional RoPE with freq_factors for global attention
- Final logit softcapping

Vision model features:
- SigLIP vision encoder with 2D RoPE
- ClippableLinear with input/output clamping via packed v.clamp_data
- Adaptive average pooling with nMerge kernel
- Multi-modal projection with unweighted RMSNorm

Converter:
- Safetensors to GGUF with vision tensor renaming
- Fused MoE gate_up_proj splitting
- Vision patch embedding reshape (HF to Conv2D layout)
- Packed clamp data tensor for ClippableLinear bounds
- Proportional RoPE freq_factors generation

Also includes:
- BackendGet() on ml.Tensor for reading weight tensor data
- Q6_K CUDA get_rows kernel support
- MoE-aware ffn_down quantization layer counting
- Gemma4 parser with tool calling and thinking support
- Gemma4 renderer with structured tool format
- Architecture-based auto-detection of renderer/parser/stop tokens
- Integration test gemma4 model list additions

* gemma4: add audio support with USM conformer encoder

Add audio encoding for Gemma 4 using the USM conformer architecture:
- Converter: audio tensor mapping, SSCP/conformer/embedder name replacements,
  softplus repacker for per_dim_scale, F32 enforcement for conv weights
- GGML backend: Conv1DDW and PadExt tensor ops
- Audio encoder: SSCP Conv2D, 12 conformer blocks (FFW + block-local
  attention with relative position embeddings + LightConv1d + FFW),
  output projection, audio-to-text embedding projector
- Audio preprocessing: WAV decode, mel spectrogram, FFT (pure Go)
- Model wiring: WAV detection, audio token handling, unified PostTokenize

Correctly transcribes "why is the sky blue" from test audio.

* integration: add gemma4 audio tests including OpenAI API coverage

Test audio transcription and response via the Ollama native API, plus
two new tests exercising the OpenAI-compatible endpoints:
- /v1/audio/transcriptions (multipart form upload)
- /v1/chat/completions with input_audio content type

All tests use capability checks and skip models without audio support.

* gemma4: add OpenAI audio API support and capability detection

- Add CapabilityAudio and detect from audio.block_count in GGUF
- Add /v1/audio/transcriptions endpoint with TranscriptionMiddleware
- Add input_audio content type support in /v1/chat/completions
- Add TranscriptionRequest/Response types in openai package

* gemma4: add audio input support for run command

- /audio toggle in interactive mode for voice chat
- Platform-specific microphone recording (AVFoundation on macOS,
  PulseAudio/ALSA on Linux, WASAPI on Windows)
- Space to start/stop recording, automatic chunking for long audio

* gemma4: add transcribe command (ollama transcribe MODEL)

- Interactive mode with readline prompt and slash commands
- Non-interactive mode for piped audio or record-until-Ctrl+C
- Chunked streaming transcription for long recordings
- Word-wrapped output matching run command style

* gemma4: add parser, renderer, and integration test plumbing

* gemma4: fix renderer to emit BOS token

* gemma4: add OpenAI audio transcription API and input_audio support

* gemma4: update converter for new weight drop naming

* gemma4: add per_expert_scale to MoE router and fix moe_intermediate_size config

* gemma4: rewrite renderer to match HF Jinja2 template exactly

Fix 8 bugs found by building 55 reference tests verified against the
HF Jinja2 chat template (VERIFY_JINJA2=1 shells out to Python):

- Tool responses use separate <|turn>tool turns (not inline tags)
- Tool calls emitted before content in assistant messages
- Thinking content stripped from assistant history (strip_thinking)
- User, tool, and system content trimmed (template does | trim)
- Empty system message still emits system turn (check role, not content)
- Nested object properties rendered recursively with required field
- Array items specification rendered for array-type properties
- OBJECT/ARRAY type-specific rendering comma logic matches template

Also adds Required field to api.ToolProperty for nested object schemas,
replaces old gemma4_test.go with comprehensive gemma4_reference_test.go,
and commits the Jinja2 template as testdata for verification.

* gemma4: fix MoE fused gate_up split and multiline tool-call arg parsing

- Text MoE: split `ffn_gate_up_exps` into contiguous `[gate|up]` halves instead of stride-2 slices.
- Parser: escape control characters in `<|"|>...<|"|>` string literals when converting tool-call args to JSON.
- Fixes warnings like `invalid character '\n' in string literal` for multiline tool arguments.
- Add Gemma4 parser regressions for multiline tool-call args and `gemma4ArgsToJSON`.

* cmd: simplify audio input to dropped file attachments

* gemma4: use full SWA memory for better cache reuse

* gemma4: initialize clamps after backend load

* convert: align gemma4 audio tensor renames with llama.cpp

* Remove redundant comments in gemma4 vision model

* Format Gemma4 MoE block field alignment

* use 4096 kvcache.NewSWAMemCache

* convert: support new Gemma4 audio_tower tensor naming (#15221)

Co-authored-by: jmorganca <jmorganca@gmail.com>

* fix integration test defaults for audio

* review comments and lint fixes

* remove unused audio/video files

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
2026-04-02 11:33:33 -07:00
Devon Rifkin
79865e6c5a app: use the same client for inference and other requests (#15204)
Previously we were accidentally using different clients/UAs depending on
whether it was an inference call or a different call. This change makes
them consistent, other than the timeout being different.
2026-04-02 11:07:50 -07:00
Parth Sareen
5ab10d347a app: add launch page for a simple way to launch integrations (#15182) 2026-04-02 10:31:19 -07:00
Eva H
a8292dd85f launch: replace deprecated OPENAI_BASE_URL with config.toml profile for codex (#15041) 2026-04-01 11:43:23 -04:00
Daniel Hiltgen
cb0033598e tokenizer: add SentencePiece-style BPE support (#15162)
* tokenizer: add SentencePiece-style BPE support

Add WithSentencePieceNormalizer option to BytePairEncoding for models
that use BPE with SentencePiece-style space markers (space to/from
U+2581).

NewBytePairEncoding is unchanged; the new NewBytePairEncodingWithOptions
constructor accepts BPEOption functions. Decoding handles the reverse
mapping of U+2581 back to spaces.

* review comments
2026-03-31 17:00:36 -07:00
Daniel Hiltgen
4d14b0ff92 mlx: respect tokenizer add_bos_token setting in pipeline (#15185)
Replace hardcoded Encode(prompt, true) with
Encode(prompt, r.Tokenizer.AddBOS()) so the pipeline respects each
model's tokenizer configuration.

Models with add_bos_token=true (gemma3, llama): unchanged, tokenizer
still prepends BOS.

Models with bos_token=null (qwen3, qwen3.5): unchanged, the BOS
guard (vocab.BOS >= 0) already prevented prepending regardless of
the flag.

This aligns the pipeline with the /v1/tokenize endpoint which already
uses Tokenizer.AddBOS().
2026-03-31 16:46:30 -07:00
Parth Sareen
d9cb70c270 docs: update pi docs (#15152) 2026-03-31 16:37:55 -07:00
Jeffrey Morgan
31f968fe1f cmd: set OpenCode default model in config (#15127) 2026-03-29 12:11:36 -07:00
Jeffrey Morgan
b7bda92d52 model: add qwen3-next compatibility for legacy ssm_in projections (#15133) 2026-03-29 11:50:47 -07:00
Parth Sareen
8e54823fd3 revert context length warnings change (#15121) 2026-03-28 16:43:59 -07:00
Parth Sareen
7c8da5679e launch: improve multi-select for already added models (#15113) 2026-03-28 13:44:40 -07:00
Parth Sareen
6214103e66 launch: auto-install pi and manage web-search lifecycle (#15118) 2026-03-28 13:06:20 -07:00
Patrick Devine
9e7cb9697e mlx: fix vision capability + min version (#15106) 2026-03-27 17:09:28 -07:00
Bruce MacDonald
3824e380a8 server: preserve raw manifest bytes during pull (#15104)
pullModelManifest unmarshals the registry response into a Go struct
then re-marshals with json.Marshal before writing to disk. When the
registry's JSON formatting or field ordering differs from Go's
output, the local SHA256 won't match the registry's
Ollama-Content-Digest header, causing false "out of date" warnings.

Preserve the raw bytes from the registry response and write them
directly to disk so the local manifest is byte-for-byte identical
to what the registry serves.
2026-03-27 15:42:31 -07:00
Devon Rifkin
c9b2dcfc52 anthropic: fix empty inputs in content blocks (#15105)
* anthropic: fix empty inputs in content blocks

When we switched to `api.ToolCallFunctionArguments`, `omitempty` stopped
doing what we were relying on it for before. This would cause non-tool
content blocks to have an `"input": {}` field, which doesn't match our
old behavior.

* use omitzero instead
2026-03-27 15:41:27 -07:00
Parth Sareen
b00bd1dfd4 launch: skip context length warning for MLX models and show model name (#15102) 2026-03-27 15:01:33 -07:00
Jesse Gross
ac83ac20c4 anthropic: fix KV cache reuse degraded by tool call argument reordering
Use typed structs for tool call arguments instead of map[string]any to
preserve JSON key order, which Go maps do not guarantee.
2026-03-27 14:30:16 -07:00
Bruce MacDonald
e7ccc129ea app: fix false "out of date" model warnings (#15101)
The staleness check compared the local manifest digest (SHA256 of the
file on disk) against the registry's Ollama-Content-Digest header.
These never matched because PullModel re-serializes the manifest JSON
before writing, producing different bytes than the registry's original.

The fallback comparison (local modified_at vs upstream push time) was
also broken: the generated TypeScript Time class discards the actual
timestamp value, so Date parsing always produced NaN.

Fix by moving the staleness comparison server-side where we have
reliable access to both the local manifest file mtime and the upstream
push time. The /api/v1/model/upstream endpoint now returns a simple
`stale` boolean instead of raw digests for the frontend to compare.

Also adds User-Agent to the CORS allowed headers for dev mode.
2026-03-27 14:15:10 -07:00
Jeffrey Morgan
69ed0c2729 parsers: qwen3.5 streaming tool-call parsing and add regression test (#15098) 2026-03-27 14:04:14 -07:00
Alfredo Matas
1cefa749aa model/parsers: close think block if tool block starts in Qwen3.5 (#15022) 2026-03-27 11:28:34 -07:00
Daniel Hiltgen
aec2fef95d ci: harden cuda include path handling (#15093)
On windows we can get multiple include dirs, so find where the headers are then
copy from that location.
2026-03-27 07:57:07 -07:00
Eva H
366625a831 launch: warn when server context length is below 64k for local models (#15044)
A stop-gap for now to guide users better. We'll add more in-depth recommendations per integration as well.

---------

Co-authored-by: Parth Sareen <parth.sareen@ollama.com>
2026-03-27 00:15:53 -07:00
Daniel Hiltgen
516ebd8548 ci: include mlx jit headers on linux (#15083)
* ci: include mlx jit headers on linux

* handle CUDA JIT headers
2026-03-26 23:10:07 -07:00
Parth Sareen
f567abc63f tui: update chat title (#15082) 2026-03-26 18:06:53 -07:00
Eva H
1adfc27f04 launch/vscode: prefer known vs code paths over code on PATH (#15073) 2026-03-26 18:06:28 -04:00
Parth Sareen
4a2b9f9dbc launch: hide cline integration (#15080) 2026-03-26 14:33:43 -07:00
Parth Sareen
e46b67a6cc launch: hide vs code (#15076) 2026-03-26 13:52:50 -07:00
Eva H
c000afe76c doc: update vscode doc (#15064)
---------

Co-authored-by: ParthSareen <parth.sareen@ollama.com>
2026-03-26 13:45:48 -07:00
Jesse Gross
9d7b18f81e mlxrunner: combine setStateRaw and setStateDetached into setState 2026-03-26 13:32:11 -07:00
Jesse Gross
4f5999fd3f mlxrunner: schedule periodic snapshots during prefill
Add periodic snapshots every 8k tokens and near the end of the prompt
so that long prompts can be partially restored and thinking/generation
can be retried without full reprocessing.
2026-03-26 13:32:11 -07:00
Jesse Gross
ac5f0dbb6a mlxrunner: improve eviction and LRU tracking
Update LRU last used time just on the nodes that actually used
during processing rather than all snapshots along the path. This
allows eviction to remove nodes more accurately so we can avoid
other heuristics to auto-merge nodes.
2026-03-26 13:32:11 -07:00
Jesse Gross
d1151e18a1 mlx: fix KV cache snapshot memory leak
mlx.Copy shares the backing buffer with its source (via
copy_shared_buffer) rather than allocating independent storage.
When used to snapshot a slice of the KV cache, the snapshot array
holds the entire original cache buffer alive through the shared
data pointer — even after eval detaches the computation graph.

Replace Copy with Contiguous in Snapshot and Split. Contiguous
allocates a compact buffer when the source buffer is significantly
larger than the logical slice (Contiguous::eval checks
buffer_size > nbytes + 16384), which is always the case for KV
cache slices.
2026-03-25 17:26:34 -07:00
rick
ebbce136c7 ggml: force flash attention off for grok 2026-03-25 16:15:49 -07:00
Devon Rifkin
26b9f53f8e api/show: overwrite basename for copilot chat (#15062)
Copilot Chat prefers to use `general.basename` in the built-in Ollama
integration, but this name isn't usually shown directly to users (and
there may be many models that share this name). Instead we pass back
`req.Model`, which for this extension is the value that we return from
`/api/tags`
2026-03-25 14:02:22 -07:00
Eva H
7575438366 cmd: ollama launch vscode (#15060)
Co-authored-by: Parth Sareen <parth.sareen@ollama.com>
2026-03-25 16:37:02 -04:00
Eva H
7d7c90d702 tui: add left arrow back navigation in model selector (#14940) 2026-03-25 11:53:48 -07:00
Daniel Hiltgen
4fda69809a ci: fix windows cgo compiler error (#15046) 2026-03-24 16:45:36 -07:00
Daniel Hiltgen
c9b5da6b0c integration: improve ability to test individual models (#14948)
* integration: improve ability to test individual models

Add OLLAMA_TEST_MODEL env var to run integration tests against a
single model.

Enhance vision tests: multi-turn chat with cached image tokens, object
counting, spatial reasoning, detail recognition, scene understanding, OCR, and
multi-image comparison.

Add tool calling stress tests with complex agent-style prompts, large
system messages, and multi-turn tool response handling.

* review comments
2026-03-24 14:28:23 -07:00
Patrick Devine
de5cb7311f mlx: add mxfp4/mxfp8/nvfp4 importing (#15015)
This change allows importing bf16 and converting to mxfp4/mxfp8/nvfp4
and also importing fp8 and converting directly to mxfp8.
2026-03-24 13:45:44 -07:00
Jesse Gross
95ee7fbd29 mlxrunner: panic on double unpin 2026-03-23 17:44:19 -07:00
Jesse Gross
ec55536734 mlxrunner: show time since last used in cache dump tree 2026-03-23 17:44:19 -07:00
Jesse Gross
77491439c2 mlxrunner: support partial match on pure transformer caches
Previously, a partial match within a node's edge would truncate the path
to the parent snapshot - effectively making all cache types behave as
recurrent caches. Caches with only transformer layers can rewind to
arbitrary boundary so this restores this capability to improve cache
hits
2026-03-23 17:44:19 -07:00
Parth Sareen
b166b36cd2 docs: update Claude Code with Telegram guide (#15026) 2026-03-23 16:31:21 -07:00
Daniel Hiltgen
c2b0bb7a52 mlx: update as of 3/23 (#14789)
* mlx: update to HEAD on 3/23

Also fixes a few misc vendoring bugs uncovered with this first update.
This also renames the version files to make them clearer.

* CUDA Fast Gated Delta kernel

* mlx: detect eval errors and panic

On model errors or missing kernels, don't mask the error, bubble it up.
2026-03-23 11:28:44 -07:00
Bruce MacDonald
22c2bdbd8a docs: nemoclaw integration (#14962)
---------

Co-authored-by: ParthSareen <parth.sareen@ollama.com>
2026-03-20 15:27:37 -07:00
Bruce MacDonald
6df6d097d9 launch: skip openclaw gateway health check when no daemon install (#14984) 2026-03-20 15:20:14 -07:00
Jesse Gross
d7c176ab91 llm, mlxrunner: fix done channel value consumed by first receiver
Receiving from a buffered chan error consumes the value, so only the
first caller (WaitUntilRunning, HasExited, or Close) sees the signal.
Subsequent receivers block or take the wrong branch. Replace with a
closed chan struct{} which can be received from any number of times,
and store the error in a separate field.
2026-03-19 17:44:28 -07:00
Jesse Gross
0ff7d724ff mlx: fix subprocess log deadlock
The stderr reader used bufio.Scanner which has a 64KB max line size.
If the subprocess wrote a line exceeding this limit, the scanner would
stop reading, the OS pipe buffer would fill, and the subprocess would
deadlock.

Replace the scanner with a statusWriter that wraps io.Copy. The writer
forwards all stderr to os.Stderr while capturing the last short line
(≤256 bytes) for error reporting, avoiding both the deadlock and the
need to buffer arbitrarily long lines.
2026-03-19 17:44:28 -07:00
Devon Rifkin
46cb7795e1 add ability to turn on debug request logging (#14106)
If `OLLAMA_DEBUG_LOG_REQUESTS` is set, then on server startup a temp
folder will be created. Upon any inference request, the body will be
logged to a file in this folder, as well as a small shell script to
"replay" the request using cURL.

This is just intended for debugging scenarios, not as something to turn
on normally.
2026-03-19 17:08:17 -07:00
Bruce MacDonald
126d8db7f3 parsers: robust xml tool repair (#14961)
Previous xml repair for glm was a good start, but we need to go further and repair any incorrect open or closing tags

Co-authored-by: Dongluo Chen <dongluo.chen@gmail.com>
2026-03-19 11:24:48 -07:00
Eva H
3f3a24b418 app: fix desktop app stuck loading when OLLAMA_HOST is an unspecified bind address (#14885) 2026-03-19 12:57:57 -04:00
Jesse Gross
96e36c0d90 mlxrunner: share KV cache across conversations with common prefixes
Enable multiple conversations to reuse cached computations when they
share token prefixes (e.g. the same system prompt). A prefix trie
tracks shared regions so switching between conversations only
recomputes tokens that diverge. Inactive conversation state is paged
from active GPU memory to other memory and restored on demand, with LRU
eviction to keep memory usage bounded.
2026-03-18 16:06:33 -07:00
Jesse Gross
6f8ddbb26b mlxrunner: fix Slice(0, 0) returning full dimension instead of empty
Slice used cmp.Or to resolve a zero stop value to the dimension size,
intended to support open-ended slices like a[i:]. This made Slice(0, 0)
indistinguishable from Slice(), so any slice with a zero stop would
silently include the entire dimension instead of being empty.

Replace cmp.Or with an explicit End sentinel and resolve negative
indices against the dimension size, matching Python/PyTorch semantics.
2026-03-18 16:06:33 -07:00
Eva H
b5e7888414 cmd/launch: skip redundant config writes when model unchanged (#14941) 2026-03-18 17:36:52 -04:00
Parth Sareen
eab4d22269 docs: update claude code and openclaw for web search (#14922) 2026-03-18 14:18:49 -07:00
Bruce MacDonald
5759c2d2d2 launch: fix openclaw not picking up newly selected model (#14943)
Sessions with a stale model field were not updated when the primary
changed, so the old model continued to be used.
2026-03-18 13:20:10 -07:00
Bruce MacDonald
42b1c2642b docs: update minimax-m2.5 references to m2.7 (#14942) 2026-03-18 12:59:28 -07:00
Bruce MacDonald
727d69ddf3 tui: fix signin on headless Linux systems (#14627)
Defensively handle environments without a display server to ensure signin remains usable on headless VMs and SSH sessions.

- Skip calling xdg-open when neither DISPLAY nor WAYLAND_DISPLAY is set, preventing silent failures or unexpected browser handlers
- Render the signin URL as plain text instead of wrapping it in OSC 8 hyperlink escape sequences, which can be garbled or hidden by terminals that don't support them
2026-03-18 11:11:17 -07:00
Jesse Gross
f622b0c5fc launch: disable claude attribution header to preserve KV cache
Claude Code sends an x-anthropic-billing-header that changes on every
request. This is embedded in the system prompt and consequently
breaks the KV cache for every request. Given the size of the prompts
that Claude Code usees, this has significant performance impact.
2026-03-17 20:48:03 -07:00
Bruce MacDonald
5d0000634c cmd/launch: check for both npm and git before installing OpenClaw (#14888)
The OpenClaw installer requires git in addition to npm. Update the
dependency check to detect both and provide specific install guidance
for whichever dependencies are missing.
2026-03-17 18:20:05 -07:00
Parth Sareen
676d9845ba launch: register websearch for openclaw (#14914) 2026-03-17 15:03:15 -07:00
Devon Rifkin
e37a9b4c01 cloud_proxy: for the web_search legacy path, flush on newlines (#14897)
`WebSearchAnthropicWriter` expects a single object per write. The new
transparent proxy will instead send it whatever bytes it sees. This
cloud-model + local-orchestration + cloud-search is a temporary code
path, so instead of making the web search code more robust to this, I
put an adapter in the middle that will flush line-by-line to preserve
the old behavior.
2026-03-17 13:30:17 -07:00
Patrick Devine
d727aacd04 mlx: quantized embeddings, fast SwiGLU, and runtime fixes (#14884)
Add QuantizedEmbedding and EmbeddingLayer interface so models can
use quantized embedding weights and expose tied output projections.
This change updates gemma3, glm4_moe_lite, llama, qwen3, and qwen3_5
to use the new interface.
2026-03-17 11:21:38 -07:00
Patrick Devine
fa69b833cd mlx: add prequantized tensor packing + changes for qwen35 (#14878)
This change adds a tensorImportTransform interface for model-specific
tensor transformations during safetensors import. This allows importing
and modifying the standard HF based weights as well as the mlx-community
derived pre-quantized safetensors repos to be directly
imported into `ollama create`. Right now this only works with Qwen3.5
importing which does tensor renaming, norm weight shifting (it
adds +1 to each value of the norm vectors), conv1d transposition,
and casts to BF16s for F32 based vectors.
2026-03-17 11:21:18 -07:00
Jesse Gross
bbbad97686 sched: Model eviction for MLX
MLX runners (image generation and LLM) previously bypassed the
scheduler's standard load path via a separate loadMLX method. This meant
they skipped VRAM fitting checks and couldn't participate in model
eviction.

Now all model types flow through the same load function. Model eviction
for MLX is based on weights as KV cache and compute graph are dynamic.
This means that eviction does not take into account the worst case
memory and models can still compete for memory but it is a significant
improvement.
2026-03-16 17:40:29 -07:00
Parth Sareen
bcf6d55b54 launch: fix web search, add web fetch, and enable both for local (#14886) 2026-03-16 16:26:19 -07:00
easonysliu
810d4f9c22 runner: fix swallowed error in allocModel graph reservation
In allocModel(), the first call to reserveWorstCaseGraph(true) had its
error silently discarded — `return nil` was used instead of `return err`.

This meant that if the prompt-sized graph reservation failed (e.g. due
to insufficient memory), the error was swallowed, allocModel reported
success, and the model appeared to load correctly. Subsequent inference
would then fail in unexpected ways because the worst-case graph was
never properly reserved.

Fix: return the actual error so the caller can handle the failure
(retry with reduced parallelism, report OOM, etc.).

Co-Authored-By: Claude (claude-opus-4-6) <noreply@anthropic.com>
2026-03-16 15:48:45 -07:00
Bruce MacDonald
856c047a6c cmd/launch: skip --install-daemon when systemd is unavailable (#14883)
In container environments without systemd, `openclaw onboard
--install-daemon` exits non-zero because it cannot create a systemd
user service. This causes `ollama launch openclaw` to abort even
though the gateway can be started as a foreground child process.

Only pass --install-daemon when systemd user services are reachable
(Linux with /run/systemd/system present and XDG_RUNTIME_DIR set).
On all other platforms the flag is still included by default.
2026-03-16 13:50:04 -07:00
Daniel Hiltgen
79c1e93c00 bench: improve benchmarking tool (#14240)
New features:
- Warmup phase to eliminate cold-start outliers
- time-to-first-token measured in each epoch
- VRAM/memory tracking to identify CPU spillover
- Controlled prompt length
- Defaults to 6 epochs and 200 tokens max

Benchstat fixes:
- ns/request instead of ns/op — non-standard unit created a separate group instead of grouping with timing metrics
- Token count as the N field — benchstat interprets N as iteration count for statistical weighting, not as a token count
2026-03-15 11:47:31 -07:00
Parth Sareen
f8b657c967 cmd/launch: add guards for headless mode (#14837) 2026-03-14 00:10:02 -07:00
Bruce MacDonald
10fefe0d57 config: use native OpenClaw Ollama onboarding (#14829)
OpenClaw now accepts the Ollama onboarding flags directly upstream, so rely on its wizard state instead of the legacy integration onboarding flag.

Update first-run setup to pass the Ollama auth and model flags during onboarding, perform a best-effort update before onboarding when needed, and drop the stale test that asserted persistence of the old onboarding flag.
2026-03-13 16:28:40 -07:00
Daniel Hiltgen
2f9a68f9e9 rocm: doc driver constraints (#14833) 2026-03-13 15:53:35 -07:00
Bruce MacDonald
3980c0217d server: decompress zstd request bodies in cloud passthrough middleware (#14827)
When a zstd-compressed request (e.g. from Codex CLI) hits /v1/responses
with a cloud model the request failed.

Fix by decompressing zstd bodies before
model extraction, so cloud models are detected and proxied directly
without the writer being wrapped.
2026-03-13 15:06:47 -07:00
Parth Sareen
870599f5da launch: remove warning for default policy (#14830) 2026-03-13 15:01:38 -07:00
Bruce MacDonald
abf8e8e9c8 middleware: handle non-JSON error responses gracefully (#14828)
writeError in both OpenAI and Anthropic middleware writers would return
a raw json.SyntaxError when the error payload wasn't valid JSON (e.g.
"invalid character 'e' looking for beginning of value"). Fall back to
using the raw bytes as the error message instead.

Also use the actual HTTP status code rather than hardcoding 500, so
error types map correctly
2026-03-13 14:50:49 -07:00
Shivam Tiwari
f3f31a8192 anthropic: close thinking block before tool_use when no text in between (#14825)
Root cause: StreamConverter.Process() only incremented contentIndex when
closing a thinking block if text content was present. When a model emitted
thinking followed directly by a tool_use block (no text in between),
thinkingDone was never set and contentIndex was not incremented, causing the
tool_use content_block_start to reuse index 0. Clients expecting sequential
indices would then fail to find the tool content block.

Fix: In the tool call loop, close any open thinking block (thinkingStarted &&
!thinkingDone) and increment contentIndex before opening the tool_use block,
mirroring the existing logic that closes an open text block.

Fixes #14816
2026-03-13 13:12:05 -07:00
Devon Rifkin
9e7ba835da cmd: still populate ollama ls when using ollama run <model:cloud> (#14824)
This is temporary until `api/tags` supports cloud natively
2026-03-13 12:24:45 -07:00
Parth Sareen
347f17b8d1 launch: add compact window for claude code (#14823) 2026-03-13 12:09:23 -07:00
Devon Rifkin
081b9eb423 api/create: always propagate :cloud source for cloud models (#14822)
Otherwise, using `/save` would try to run the local model instead
2026-03-13 11:58:00 -07:00
Parth Sareen
bb867c6fdb launch: fix headless --yes integration flow and policy scoping (#14815) 2026-03-13 11:45:36 -07:00
Cadu
81f4506a61 docs: document reasoning_effort support in OpenAI-compatible API (#14821)
Add reasoning_effort and reasoning to the supported features and
request fields for /v1/chat/completions. These fields control
thinking on thinking-capable models but were previously undocumented.

Closes #14820
2026-03-13 10:57:14 -07:00
Parth Sareen
76925f1284 cmd: TUI model ordering (#14814) 2026-03-13 10:19:22 -07:00
Devon Rifkin
f676231de9 server: remove experimental aliases support (#14810) 2026-03-12 20:27:24 -07:00
Parth Sareen
af5f7c0a9e cmd: refactor tui and launch (#14609) 2026-03-12 18:39:06 -07:00
Daniel Hiltgen
a6b27d776b ci: fix missing windows zip file (#14807)
Use 7z compression (better compression rate) if found in path.  That
alone isn't sufficient to get us under 2G, so MLX is now split out as a
discrete download.  Fix CI so it will fail if artifacts fail to upload.
2026-03-12 16:14:00 -07:00
Daniel Hiltgen
539741199e mlx: perf improvements (#14768)
* mlx: perf improvements

Fix nn.go to call mlx_fast_layer_norm instead of manually implementing (mean,
subtract, variance, rsqrt, multiply, add — 6 ops)

Fix llama.go, gemma3.go to remove RepeatKV to tile K/V tensors to match the Q
head count, since scaled_dot_product_attention natively handles GQA (it just
requires n_q_heads % n_kv_heads == 0)

* review comments
2026-03-12 12:01:28 -07:00
Eva H
8f45236d09 middleware: enable local tool model for web search (#14787) 2026-03-11 17:51:39 -04:00
Parth Sareen
97013a190c openai: split mixed thinking stream chunks via ToChunks (#14648) 2026-03-11 14:21:29 -07:00
Daniel Hiltgen
c222735c02 mlx: only log load errors when MLX is needed (#14764)
This suppresses irrelevant/noisy errors in the GGML runner.
2026-03-11 10:31:31 -07:00
Daniel Hiltgen
87d21c7fc0 MLX: harden for init failures (#14777)
The CLI now links to the lazy-load MLX code, but that still happens in
init functions.  On internal MLX errors, the CLI exits before it has a
chance to start.  This change re-wires the MLX error handling so it
doesn't exit by default.  The MLX based runners currently expect exits
on failure, so they re-initialize the default error handling.  We can
refine error handling for better go stack traces in the future.
2026-03-10 22:52:23 -07:00
Jeffrey Morgan
54e05172a0 Revert "runner: add token history sampling parameters to ollama runner (#14537)" (#14776)
This reverts commit 86513cb697.
2026-03-10 21:07:52 -07:00
Parth Sareen
464186e995 config: qwen3.5 recommendations (#14758) 2026-03-10 18:04:57 -07:00
Devon Rifkin
8c4d5d6c2f cloud_proxy: send ollama client version (#14769)
This was previously included in the user agent, and we've made use of it
in the past to hotpatch bugs server-side for particular Ollama versions.
2026-03-10 15:53:25 -07:00
Parth Sareen
bc72b14016 docs: update claude code docs (#14770) 2026-03-10 15:52:41 -07:00
Parth Sareen
61086083eb server: add experimental web search and web fetch routes (#14753) 2026-03-09 21:52:12 -07:00
Daniel Hiltgen
62d1f01ab4 ci: Fix windows build (#14754)
Instead of relying on sh for wildcard, do it in Go for better windows
compatibility.
2026-03-09 19:27:59 -07:00
Daniel Hiltgen
10e51c5177 MLX: add header vendoring and remove go build tag (#14642)
* prefer rocm v6 on windows

Avoid building with v7 - more changes are needed

* MLX: add header vendoring and remove go build tag

This switches to using a vendoring approach for the mlx-c headers so that Go
can build without requiring a cmake first.  This enables building the new MLX
based code by default.  Every time cmake runs, the headers are refreshed, so we
can easily keep them in sync when we bump mlx versions.  Basic Windows
and Linux support are verified.

* ci: harden for flaky choco repo servers

CI sometimes fails due to choco not actually installing cache.  Since it just speeds up the build, we can proceed without.

* review comments
2026-03-09 17:24:45 -07:00
Patrick Devine
3e06bde643 mlx: get parameters from modelfile during model creation (#14747) 2026-03-09 15:33:24 -07:00
Eva H
6be2de8214 app: auto update should be enabled when reset to defaults (#14741) 2026-03-09 15:02:36 -04:00
Daniel Hiltgen
ebb1b9ec14 rocm: update linux to v7.2 (#14391)
* rocm: update linux to v7.2

* review comments
2026-03-09 08:26:55 -07:00
Patrick Devine
d126467d5d x/mlxrunner: replace sampler interface chain with single stateful Sampler (#14652)
- Collapse MLX sampling state into a single sample.Sampler struct (options + history).
- Replace interface-based sampler chain (TopP, TopK, penalty, etc.) with function-based transforms.
- Update request/pipeline wiring to use *sample.Sampler, seed history from prompt tokens, and append generated tokens each step.
- Implement top_p, min_p, repeat_penalty, and frequency_penalty
2026-03-07 17:50:57 -08:00
Devon Rifkin
afb4c62fbf cloud_proxy: handle stream disconnects gracefully (#14685)
Previously we were printing out bad errors for expected cases like
clients disconnecting. Now we only debug log when that happens (which
still might help in cases where we're figuring out why an integration
isn't working). For other errors, we print out a proper warning now
2026-03-06 19:18:52 -08:00
Patrick Devine
e790dc435b mlx: int4 groupsize 64 (#14682)
Change affine 4bit integers to use groupsize 64
2026-03-06 16:39:47 -08:00
Daniel Hiltgen
288077c3a3 build: smarter docker parallelism (#14653)
Our Dockerfile leverages parallel stages for more efficient builds.  However,
our old parallel settings were naive and lead to under/over utilization
depending on the capabilities of your build system.

This change switches to using Ninja for all our docker cmake builds to leverage
its smarter parallel logic.  We tell Ninja to target a load of nproc so each of
the build stages will share the load on the system aiming for full CPU use
without oversaturation.

The GPU parallelism settings are also adjusted to 4 to avoid a long-tail for
the last few GPU targets as they work through the long list of GPU
architectures.

This also fixes the Dockerfile to move Vulkan install to just the stage that
needs it instead of blocking most other GPU installs.  This should speed up CI
which always has a clean build cache.
2026-03-06 16:36:22 -08:00
Daniel Hiltgen
4425c54eda create: fix localhost handling (#14681) 2026-03-06 16:35:58 -08:00
Michael Yang
778899a5d2 docs: format compat docs (#14678) 2026-03-06 14:53:17 -08:00
Jeffrey Morgan
4eab60c1e2 Reapply "don't require pulling stubs for cloud models" again (#14608)
* Revert "Revert "Reapply "don't require pulling stubs for cloud models"" (#14606)"

This reverts commit 39982a954e.

* fix test + do cloud lookup only when seeing cloud models

---------

Co-authored-by: ParthSareen <parth.sareen@ollama.com>
2026-03-06 14:27:47 -08:00
Bruce MacDonald
1af850e6e3 parsers: repair unclosed arg_value tags in GLM tool calls (#14656)
GLM models sometimes omits </arg_value> closing tags in tool call XML, causing xml.Unmarshal to fail with "element <arg_value> closed by </tool_call>".

This is a known issue across the GLM family.

Sanitize the input to fix closing arg_key values so encoding/xml can handle it.
2026-03-06 14:08:34 -08:00
Parth Sareen
9b0c7cc7b9 cmd: override stale entries for context window pi (#14655) 2026-03-05 16:30:24 -08:00
Daniel Hiltgen
6928630601 mlx: prevent remote creation mismatch (#14651)
If the user is pointing at a remote OLLAMA_HOST, fail experimental safetensor
based create operations as we only support local creation currently.
2026-03-05 14:59:00 -08:00
Parth Sareen
9896e3627f cmd/config: fix cloud model limit lookups in integrations (#14650) 2026-03-05 13:57:28 -08:00
Bruce MacDonald
15732f0ea7 cmd: use native Ollama API endpoint for OpenClaw (#14649)
Remove the /v1 suffix from the OpenClaw provider baseUrl so it uses
the native Ollama API instead of the OpenAI-compatible endpoint. The
/v1 endpoint my break tool calling in OpenClaw.
2026-03-05 13:29:17 -08:00
Parth Sareen
562c76d7cc cmd: add qwen3.5 context length for launch (#14626) 2026-03-04 14:10:52 -08:00
Parth Sareen
122c68c151 server: loosen thinking level constraint (#14625) 2026-03-04 13:42:18 -08:00
Jeffrey Morgan
82848a7806 model: fix renderer and parser for qwen3.5 (#14605) 2026-03-03 20:58:29 -08:00
Jeffrey Morgan
39982a954e Revert "Reapply "don't require pulling stubs for cloud models"" (#14606)
This reverts commit 799e51d419.
2026-03-03 20:56:10 -08:00
Patrick Devine
e9f6ea232f Add qwen3.5-next-moe support to MLX runner and models (#14417)
This change adds support for qwen3.5-next-moe models (qwen3-next/qwen3.5-next/qwen3-coder) to the MLX runner. It also:

* introduces recurrent cache support and related MLX ops
* updates pipeline/runner integration and adds tests
* properly quantizes stacked expert tensors
* a Gated Delta Metal kernel for fast SSM inference
* adds new MLX calls for Conv1d, DepthwideConv1d, Contiguous, Exp, Log, SoftmaxAxis
2026-03-03 16:39:22 -08:00
Patrick Devine
110eff01a9 chore: remove old imagegen LLMs models (#14597)
These models are implemented in the x/mlxrunner instead.
2026-03-03 13:23:40 -08:00
Jeffrey Morgan
799e51d419 Reapply "don't require pulling stubs for cloud models"
This reverts commit 97d2f05a6d.
2026-03-03 13:17:10 -08:00
Victor-Quqi
e8fcb29586 model/renderers: fix glm-ocr image tags in renderer prompts (#14584) 2026-03-03 12:51:34 -08:00
Jeffrey Morgan
97d2f05a6d Revert "don't require pulling stubs for cloud models (#14574)" (#14596)
This reverts commit 8207e55ec7.
2026-03-03 12:51:23 -08:00
Devon Rifkin
8207e55ec7 don't require pulling stubs for cloud models (#14574)
* don't require pulling stubs for cloud models

This is a first in a series of PRs that will better integrate Ollama's
cloud into the API and CLI. Previously we used to have a layer of
indirection where you'd first have to pull a "stub" model that contains
a reference to a cloud model. With this change, you don't have to pull
first, you can just use a cloud model in various routes like `/api/chat`
and `/api/show`. This change respects
<https://github.com/ollama/ollama/pull/14221>, so if cloud is disabled,
these models won't be accessible.

There's also a new, simpler pass-through proxy that doesn't convert the
requests ahead of hitting the cloud models, which they themselves
already support various formats (e.g., `v1/chat/completions` or Open
Responses, etc.). This will help prevent issues caused by double
converting (e.g., `v1/chat/completions` converted to `api/chat` on the
client, then calling cloud and converting back to a
`v1/chat/completions` response instead of the cloud model handling the
original `v1/chat/completions` request first).

There's now a notion of "source tags", which can be mixed with existing
tags. So instead of having different formats like`gpt-oss:20b-cloud` vs.
`kimi-k2.5:cloud` (`-cloud` suffix vs. `:cloud`), you can now specify
cloud by simply appending `:cloud`. This PR doesn't change model
resolution yet, but sets us up to allow for things like omitting the
non-source tag, which would make something like `ollama run
gpt-oss:cloud` work the same way that `ollama run gpt-oss` already works
today.

More detailed changes:

- Added a shared model selector parser in `types/modelselector`:
  - supports `:cloud` and `:local`
  - accepts source tags in any position
  - supports legacy `:<tag>-cloud`
  - rejects conflicting source tags
- Integrated selector handling across server inference/show routes:
  - `GenerateHandler`, `ChatHandler`, `EmbedHandler`,
    `EmbeddingsHandler`, `ShowHandler`
- Added explicit-cloud passthrough proxy for ollama.com:
  - same-endpoint forwarding for `/api/*`, `/v1/*`, and `/v1/messages`
  - normalizes `model` (and `name` for `/api/show`) before forwarding
  - forwards request headers except hop-by-hop/proxy-managed headers
  - uses bounded response-header timeout
  - handles auth failures in a friendly way
- Preserved cloud-disable behavior (`OLLAMA_NO_CLOUD`)
- Updated create flow to support `FROM ...:cloud` model sources (though
  this flow uses the legacy proxy still, supporting Modelfile overrides
  is more complicated with the direct proxy approach)
- Updated CLI/TUI/config cloud detection to use shared selector logic
- Updated CLI preflight behavior so explicit cloud requests do not
  auto-pull local stubs

What's next?

- Cloud discovery/listing and cache-backed `ollama ls` / `/api/tags`
- Modelfile overlay support for virtual cloud models on OpenAI/Anthropic
  request families
- Recommender/default-selection behavior for ambiguous model families
- Fully remove the legacy flow

Fixes: https://github.com/ollama/ollama/issues/13801

* consolidate pull logic into confirmAndPull helper

pullIfNeeded and ShowOrPull shared identical confirm-and-pull logic.
Extract confirmAndPull to eliminate the duplication.

* skip local existence checks for cloud models

ModelExists and the TUI's modelExists both check the local model list,
which causes cloud models to appear missing. Return true early for
explicit cloud models so the TUI displays them beside the integration
name and skips re-prompting the model picker on relaunch.

* support optionally pulling stubs for newly-style names

We now normalize names like `<family>:<size>:cloud` into legacy-style
names like `<family>:<size>-cloud` for pulling and deleting (this also
supports stripping `:local`). Support for pulling cloud models is
temporary, once we integrate properly into `/api/tags` we won't need
this anymore.

* Fix server alias syncing

* Update cmd/cmd.go

Co-authored-by: Parth Sareen <parth.sareen@ollama.com>

* address comments

* improve some naming

---------

Co-authored-by: ParthSareen <parth.sareen@ollama.com>
2026-03-03 10:46:33 -08:00
Jesse Gross
ad16bffc7d mlx: Remove peak memory from the API
This is still in flux so it is better to just log it for now.
2026-03-02 15:56:18 -08:00
Jesse Gross
c1e3ef4bcc mlxrunner: Refcount pinned tensors
Otherwise, it is error prone to manage multiple components working
with the same tensor.
2026-03-02 15:56:06 -08:00
Parth Sareen
a3093cd5e5 cmd/opencode: rename provider from "Ollama (local)" to "Ollama" (#14566)
The "(local)" qualifier is unnecessary since there's only one Ollama
provider. Existing configs with the old name are migrated automatically;
custom names are left unchanged.
2026-03-02 14:17:18 -08:00
Bruce MacDonald
23d4cad1a2 server: verify digest is not empty on create (#14555)
An empty digest is not a valid digest for an incoming create request. Reject empty digests at the api level.
2026-03-02 13:43:35 -08:00
Jeffrey Morgan
86513cb697 runner: add token history sampling parameters to ollama runner (#14537) 2026-03-01 19:16:07 -08:00
Jeffrey Morgan
3490e9590b model/qwen3next: avoid crash in in DeltaNet when offloading (#14541)
Co-authored-by: Yossi Ovadia <jabadia@gmail.com>
2026-03-01 18:44:04 -08:00
Jeffrey Morgan
8da09b1e7e qwen3next: add compatibility with imported GGUF models (#14517) 2026-02-28 14:21:42 -08:00
Jesse Gross
a60b9adcce mlxrunner: Fix prompt eval timing and count metrics
Only the last token's processing time is included in prompt processing,
giving an artificially high rate. In addition, the number of tokens
only included the tokens that miss the cache, instead of our historic
total tokens.
2026-02-27 17:29:47 -08:00
Jesse Gross
a16f96658b mlxrunner: Enforce model context limit
Currently, context length is unbounded - the cache will keep
growing forever independent of the model's trained context
length. This caps it and enforces semantics similar to most
cloud services:
 - Long prompts will result in an error, not truncation.
 - Generation that exceeds the context will be stopped
2026-02-27 17:29:47 -08:00
Jesse Gross
18ab09b431 mlxrunner: Propagate pipeline errors to client via api.StatusError
Errors that occur during pipeline processing are currently only
logged but not sent back to the client. Rather than using HTTP
status codes as we have historically done, this serializes errors
as messages to allow sending them at any time during the stream.
2026-02-27 17:29:47 -08:00
Jesse Gross
638faeac54 mlxrunner: Report actual memory usage from runner
The MLX runner previously reported a static VRAM estimate that was
computed at load time and consisted only of the weights. This is
strictly less than the actual memory usage, as it does not include
the KV cache or compute graph.
2026-02-27 17:29:47 -08:00
Jesse Gross
dd5eb6337d mlxrunner: Fix panic on full KV cache hit
When the entire prompt was already cached (e.g. repeated prompt),
findRemaining returned an empty slice, causing FromValues to panic
on an index-out-of-range accessing a zero-length byte slice.

Fix by always keeping at least one token to re-evaluate so the
pipeline can seed token generation. Also reject empty prompts
early rather than panicking.
2026-02-27 11:07:03 -08:00
Patrick Devine
79917cf80b show peak memory usage (#14485) 2026-02-26 18:38:27 -08:00
Parth Sareen
cc90a035a0 model/parsers: add stable tool call indexing for glm47 and qwen3 parsers (#14484) 2026-02-26 18:14:29 -08:00
395 changed files with 44936 additions and 11765 deletions

View File

@@ -117,6 +117,25 @@ jobs:
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
flags: '' flags: ''
runner_dir: 'vulkan' runner_dir: 'vulkan'
- os: windows
arch: amd64
preset: 'MLX CUDA 13'
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip
cuda-components:
- '"cudart"'
- '"nvcc"'
- '"cublas"'
- '"cublas_dev"'
- '"cufft"'
- '"cufft_dev"'
- '"nvrtc"'
- '"nvrtc_dev"'
- '"crt"'
- '"nvvm"'
- '"nvptxcompiler"'
cuda-version: '13.0'
flags: ''
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }} runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release environment: release
env: env:
@@ -125,8 +144,10 @@ jobs:
- name: Install system dependencies - name: Install system dependencies
run: | run: |
choco install -y --no-progress ccache ninja choco install -y --no-progress ccache ninja
ccache -o cache_dir=${{ github.workspace }}\.ccache if (Get-Command ccache -ErrorAction SilentlyContinue) {
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan') ccache -o cache_dir=${{ github.workspace }}\.ccache
}
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan') || startsWith(matrix.preset, 'MLX ')
id: cache-install id: cache-install
uses: actions/cache/restore@v4 uses: actions/cache/restore@v4
with: with:
@@ -134,8 +155,9 @@ jobs:
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm C:\Program Files\AMD\ROCm
C:\VulkanSDK C:\VulkanSDK
key: ${{ matrix.install }} C:\Program Files\NVIDIA\CUDNN
- if: startsWith(matrix.preset, 'CUDA ') key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'MLX ')
name: Install CUDA ${{ matrix.cuda-version }} name: Install CUDA ${{ matrix.cuda-version }}
run: | run: |
$ErrorActionPreference = "Stop" $ErrorActionPreference = "Stop"
@@ -179,6 +201,23 @@ jobs:
run: | run: |
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: startsWith(matrix.preset, 'MLX ')
name: Install cuDNN for MLX
run: |
$ErrorActionPreference = "Stop"
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip"
Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted
$cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName
New-Item -ItemType Directory -Force -Path $cudnnRoot
Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse
}
echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }} - if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
uses: actions/cache/save@v4 uses: actions/cache/save@v4
with: with:
@@ -186,7 +225,8 @@ jobs:
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm C:\Program Files\AMD\ROCm
C:\VulkanSDK C:\VulkanSDK
key: ${{ matrix.install }} C:\Program Files\NVIDIA\CUDNN
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/cache@v4 - uses: actions/cache@v4
with: with:
@@ -198,7 +238,7 @@ jobs:
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo' Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}" cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}"
cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}" cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}"
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip cmake --install build --component "${{ startsWith(matrix.preset, 'MLX ') && 'MLX' || startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
env: env:
CMAKE_GENERATOR: Ninja CMAKE_GENERATOR: Ninja
@@ -384,6 +424,7 @@ jobs:
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/include*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;; lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;; lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;; lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
@@ -543,11 +584,19 @@ jobs:
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg dist/*.ps1 dist/*.sh ; do for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg dist/*.ps1 dist/*.sh ; do
echo "Uploading $payload" echo "Uploading $payload"
gh release upload ${GITHUB_REF_NAME} $payload --clobber & gh release upload ${GITHUB_REF_NAME} $payload --clobber &
pids[$!]=$! pids+=($!)
sleep 1 sleep 1
done done
echo "Waiting for uploads to complete" echo "Waiting for uploads to complete"
for pid in "${pids[*]}"; do failed=0
wait $pid for pid in "${pids[@]}"; do
if ! wait $pid; then
echo "::error::Upload failed (pid $pid)"
failed=1
fi
done done
if [ $failed -ne 0 ]; then
echo "One or more uploads failed"
exit 1
fi
echo "done" echo "done"

View File

@@ -37,7 +37,7 @@ jobs:
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))" | xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
} }
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*' '.github/**/*') | tee -a $GITHUB_OUTPUT
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
linux: linux:
@@ -51,7 +51,7 @@ jobs:
container: nvidia/cuda:13.0.0-devel-ubuntu22.04 container: nvidia/cuda:13.0.0-devel-ubuntu22.04
flags: '-DCMAKE_CUDA_ARCHITECTURES=87' flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
- preset: ROCm - preset: ROCm
container: rocm/dev-ubuntu-22.04:6.1.2 container: rocm/dev-ubuntu-22.04:7.2
extra-packages: rocm-libs extra-packages: rocm-libs
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm' flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm'
- preset: Vulkan - preset: Vulkan
@@ -60,6 +60,11 @@ jobs:
mesa-vulkan-drivers vulkan-tools mesa-vulkan-drivers vulkan-tools
libvulkan1 libvulkan-dev libvulkan1 libvulkan-dev
vulkan-sdk cmake ccache g++ make vulkan-sdk cmake ccache g++ make
- preset: 'MLX CUDA 13'
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
extra-packages: libcudnn9-dev-cuda-13 libopenblas-dev liblapack-dev liblapacke-dev git curl
flags: '-DCMAKE_CUDA_ARCHITECTURES=87 -DBLAS_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu -DLAPACK_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu'
install-go: true
runs-on: linux runs-on: linux
container: ${{ matrix.container }} container: ${{ matrix.container }}
steps: steps:
@@ -76,19 +81,29 @@ jobs:
$sudo apt-get update $sudo apt-get update
fi fi
$sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }} $sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }}
# MLX requires CMake 3.25+, install from official releases
if [ "${{ matrix.preset }}" = "MLX CUDA 13" ]; then
curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.31.2/cmake-3.31.2-linux-$(uname -m).tar.gz | $sudo tar xz -C /usr/local --strip-components 1
fi
# Export VULKAN_SDK if provided by LunarG package (defensive) # Export VULKAN_SDK if provided by LunarG package (defensive)
if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then
echo "VULKAN_SDK=/usr" >> $GITHUB_ENV echo "VULKAN_SDK=/usr" >> $GITHUB_ENV
fi fi
env: env:
DEBIAN_FRONTEND: noninteractive DEBIAN_FRONTEND: noninteractive
- if: matrix.install-go
name: Install Go
run: |
GO_VERSION=$(awk '/^go / { print $2 }' go.mod)
curl -fsSL "https://golang.org/dl/go${GO_VERSION}.linux-$(dpkg --print-architecture).tar.gz" | tar xz -C /usr/local
echo "/usr/local/go/bin" >> $GITHUB_PATH
- uses: actions/cache@v4 - uses: actions/cache@v4
with: with:
path: /github/home/.cache/ccache path: /github/home/.cache/ccache
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }} key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
- run: | - run: |
cmake --preset ${{ matrix.preset }} ${{ matrix.flags }} cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
cmake --build --preset ${{ matrix.preset }} --parallel cmake --build --preset "${{ matrix.preset }}" --parallel
windows: windows:
needs: [changes] needs: [changes]
@@ -114,12 +129,31 @@ jobs:
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
- preset: Vulkan - preset: Vulkan
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
- preset: 'MLX CUDA 13'
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
cuda-components:
- '"cudart"'
- '"nvcc"'
- '"cublas"'
- '"cublas_dev"'
- '"cufft"'
- '"cufft_dev"'
- '"nvrtc"'
- '"nvrtc_dev"'
- '"crt"'
- '"nvvm"'
- '"nvptxcompiler"'
cuda-version: '13.0'
runs-on: windows runs-on: windows
steps: steps:
- run: | - run: |
choco install -y --no-progress ccache ninja choco install -y --no-progress ccache ninja
ccache -o cache_dir=${{ github.workspace }}\.ccache if (Get-Command ccache -ErrorAction SilentlyContinue) {
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan' ccache -o cache_dir=${{ github.workspace }}\.ccache
}
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan' || matrix.preset == 'MLX CUDA 13'
id: cache-install id: cache-install
uses: actions/cache/restore@v4 uses: actions/cache/restore@v4
with: with:
@@ -127,8 +161,9 @@ jobs:
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm C:\Program Files\AMD\ROCm
C:\VulkanSDK C:\VulkanSDK
key: ${{ matrix.install }} C:\Program Files\NVIDIA\CUDNN
- if: matrix.preset == 'CUDA' key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
- if: matrix.preset == 'CUDA' || matrix.preset == 'MLX CUDA 13'
name: Install CUDA ${{ matrix.cuda-version }} name: Install CUDA ${{ matrix.cuda-version }}
run: | run: |
$ErrorActionPreference = "Stop" $ErrorActionPreference = "Stop"
@@ -164,10 +199,27 @@ jobs:
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe" Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
Start-Process -FilePath .\install.exe -ArgumentList "-c","--am","--al","in" -NoNewWindow -Wait Start-Process -FilePath .\install.exe -ArgumentList "-c","--am","--al","in" -NoNewWindow -Wait
} }
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path $vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
- if: matrix.preset == 'MLX CUDA 13'
name: Install cuDNN for MLX
run: |
$ErrorActionPreference = "Stop"
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip"
Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted
$cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName
New-Item -ItemType Directory -Force -Path $cudnnRoot
Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse
}
echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }} - if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
uses: actions/cache/save@v4 uses: actions/cache/save@v4
with: with:
@@ -175,7 +227,8 @@ jobs:
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm C:\Program Files\AMD\ROCm
C:\VulkanSDK C:\VulkanSDK
key: ${{ matrix.install }} C:\Program Files\NVIDIA\CUDNN
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/cache@v4 - uses: actions/cache@v4
with: with:

View File

@@ -64,10 +64,15 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR})
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR}) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${OLLAMA_BUILD_DIR}) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${OLLAMA_BUILD_DIR})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src) # Store ggml include paths for use with target_include_directories later.
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include) # We avoid global include_directories() to prevent polluting the include path
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu) # for other projects like MLX (whose openblas dependency has its own common.h).
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx) set(GGML_INCLUDE_DIRS
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx
)
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0) add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
@@ -87,6 +92,14 @@ if(NOT CPU_VARIANTS)
set(CPU_VARIANTS "ggml-cpu") set(CPU_VARIANTS "ggml-cpu")
endif() endif()
# Apply ggml include directories to ggml targets only (not globally)
target_include_directories(ggml-base PRIVATE ${GGML_INCLUDE_DIRS})
foreach(variant ${CPU_VARIANTS})
if(TARGET ${variant})
target_include_directories(${variant} PRIVATE ${GGML_INCLUDE_DIRS})
endif()
endforeach()
install(TARGETS ggml-base ${CPU_VARIANTS} install(TARGETS ggml-base ${CPU_VARIANTS}
RUNTIME_DEPENDENCIES RUNTIME_DEPENDENCIES
PRE_EXCLUDE_REGEXES ".*" PRE_EXCLUDE_REGEXES ".*"
@@ -103,6 +116,7 @@ if(CMAKE_CUDA_COMPILER)
find_package(CUDAToolkit) find_package(CUDAToolkit)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
target_include_directories(ggml-cuda PRIVATE ${GGML_INCLUDE_DIRS})
install(TARGETS ggml-cuda install(TARGETS ggml-cuda
RUNTIME_DEPENDENCIES RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR} DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
@@ -134,6 +148,7 @@ if(CMAKE_HIP_COMPILER)
if(AMDGPU_TARGETS) if(AMDGPU_TARGETS)
find_package(hip REQUIRED) find_package(hip REQUIRED)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
target_include_directories(ggml-hip PRIVATE ${GGML_INCLUDE_DIRS})
if (WIN32) if (WIN32)
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY) target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
@@ -148,7 +163,7 @@ if(CMAKE_HIP_COMPILER)
) )
install(RUNTIME_DEPENDENCY_SET rocm install(RUNTIME_DEPENDENCY_SET rocm
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR} DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register roctx64 rocroller drm drm_amdgpu numa elf
PRE_EXCLUDE_REGEXES ".*" PRE_EXCLUDE_REGEXES ".*"
POST_EXCLUDE_REGEXES "system32" POST_EXCLUDE_REGEXES "system32"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
@@ -168,6 +183,7 @@ if(NOT APPLE)
find_package(Vulkan) find_package(Vulkan)
if(Vulkan_FOUND) if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
target_include_directories(ggml-vulkan PRIVATE ${GGML_INCLUDE_DIRS})
install(TARGETS ggml-vulkan install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES RUNTIME_DEPENDENCIES
PRE_INCLUDE_REGEXES vulkan PRE_INCLUDE_REGEXES vulkan
@@ -179,7 +195,6 @@ if(NOT APPLE)
endif() endif()
option(MLX_ENGINE "Enable MLX backend" OFF) option(MLX_ENGINE "Enable MLX backend" OFF)
if(MLX_ENGINE) if(MLX_ENGINE)
message(STATUS "Setting up MLX (this takes a while...)") message(STATUS "Setting up MLX (this takes a while...)")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx)
@@ -187,10 +202,36 @@ if(MLX_ENGINE)
# Find CUDA toolkit if MLX is built with CUDA support # Find CUDA toolkit if MLX is built with CUDA support
find_package(CUDAToolkit) find_package(CUDAToolkit)
# Build list of directories for runtime dependency resolution
set(MLX_RUNTIME_DIRS ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR})
# Add cuDNN bin paths for DLLs (Windows MLX CUDA builds)
# CUDNN_ROOT_DIR is the standard CMake variable for cuDNN location
if(DEFINED ENV{CUDNN_ROOT_DIR})
# cuDNN 9.x has versioned subdirectories under bin/ (e.g., bin/13.0/)
file(GLOB CUDNN_BIN_SUBDIRS "$ENV{CUDNN_ROOT_DIR}/bin/*")
list(APPEND MLX_RUNTIME_DIRS ${CUDNN_BIN_SUBDIRS})
endif()
# Add build output directory and MLX dependency build directories
list(APPEND MLX_RUNTIME_DIRS ${OLLAMA_BUILD_DIR})
# OpenBLAS DLL location (pre-built zip extracts into openblas-src/bin/)
list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/openblas-src/bin)
# NCCL: on Linux, if real NCCL is found, cmake bundles libnccl.so via the
# regex below. If NCCL is not found, MLX links a static stub (OBJECT lib)
# so there is no runtime dependency. This path covers the stub build dir
# for windows so we include the DLL in our dependencies.
list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/distributed/nccl/nccl_stub-prefix/src/nccl_stub-build/Release)
# Base regexes for runtime dependencies (cross-platform)
set(MLX_INCLUDE_REGEXES cublas cublasLt cudart cufft nvrtc nvrtc-builtins cudnn nccl openblas gfortran)
# On Windows, also include dl.dll (dlfcn-win32 POSIX emulation layer)
if(WIN32)
list(APPEND MLX_INCLUDE_REGEXES "^dl\\.dll$")
endif()
install(TARGETS mlx mlxc install(TARGETS mlx mlxc
RUNTIME_DEPENDENCIES RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR} DIRECTORIES ${MLX_RUNTIME_DIRS}
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran PRE_INCLUDE_REGEXES ${MLX_INCLUDE_REGEXES}
PRE_EXCLUDE_REGEXES ".*" PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
@@ -205,13 +246,117 @@ if(MLX_ENGINE)
COMPONENT MLX) COMPONENT MLX)
endif() endif()
# Manually install cudart and cublas since they might not be picked up as direct dependencies # Install headers for NVRTC JIT compilation at runtime.
# MLX's own install rules use the default component so they get skipped by
# --component MLX. Headers are installed alongside libmlx in OLLAMA_INSTALL_DIR.
#
# Layout:
# ${OLLAMA_INSTALL_DIR}/include/cccl/{cuda,nv}/ — CCCL headers
# ${OLLAMA_INSTALL_DIR}/include/*.h — CUDA toolkit headers
#
# MLX's jit_module.cpp resolves CCCL via
# current_binary_dir()[.parent_path()] / "include" / "cccl"
# On Linux, MLX's jit_module.cpp resolves CCCL via
# current_binary_dir().parent_path() / "include" / "cccl", so we create a
# symlink from lib/ollama/include -> ${OLLAMA_RUNNER_DIR}/include
# This will need refinement if we add multiple CUDA versions for MLX in the future.
# CUDA runtime headers are found via CUDA_PATH env var (set by mlxrunner).
if(EXISTS ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda)
install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda
DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl
COMPONENT MLX)
install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/nv
DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl
COMPONENT MLX)
if(NOT WIN32 AND NOT APPLE)
install(CODE "
set(_link \"${CMAKE_INSTALL_PREFIX}/lib/ollama/include\")
set(_target \"${OLLAMA_RUNNER_DIR}/include\")
if(NOT EXISTS \${_link})
execute_process(COMMAND \${CMAKE_COMMAND} -E create_symlink \${_target} \${_link})
endif()
" COMPONENT MLX)
endif()
endif()
# Install minimal CUDA toolkit headers needed by MLX JIT kernels.
# These are the transitive closure of includes from mlx/backend/cuda/device/*.cuh.
# The Go mlxrunner sets CUDA_PATH to OLLAMA_INSTALL_DIR so MLX finds them at
# $CUDA_PATH/include/*.h via NVRTC --include-path.
if(CUDAToolkit_FOUND) if(CUDAToolkit_FOUND)
file(GLOB CUDART_LIBS # CUDAToolkit_INCLUDE_DIRS may be a semicolon-separated list
# (e.g. ".../include;.../include/cccl"). Find the entry that
# contains the CUDA runtime headers we need.
set(_cuda_inc "")
foreach(_dir ${CUDAToolkit_INCLUDE_DIRS})
if(EXISTS "${_dir}/cuda_runtime_api.h")
set(_cuda_inc "${_dir}")
break()
endif()
endforeach()
if(NOT _cuda_inc)
message(WARNING "Could not find cuda_runtime_api.h in CUDAToolkit_INCLUDE_DIRS: ${CUDAToolkit_INCLUDE_DIRS}")
else()
set(_dst "${OLLAMA_INSTALL_DIR}/include")
set(_MLX_JIT_CUDA_HEADERS
builtin_types.h
cooperative_groups.h
cuda_bf16.h
cuda_bf16.hpp
cuda_device_runtime_api.h
cuda_fp16.h
cuda_fp16.hpp
cuda_fp8.h
cuda_fp8.hpp
cuda_runtime_api.h
device_types.h
driver_types.h
math_constants.h
surface_types.h
texture_types.h
vector_functions.h
vector_functions.hpp
vector_types.h
)
foreach(_hdr ${_MLX_JIT_CUDA_HEADERS})
install(FILES "${_cuda_inc}/${_hdr}"
DESTINATION ${_dst}
COMPONENT MLX)
endforeach()
# Subdirectory headers
install(DIRECTORY "${_cuda_inc}/cooperative_groups"
DESTINATION ${_dst}
COMPONENT MLX
FILES_MATCHING PATTERN "*.h")
install(FILES "${_cuda_inc}/crt/host_defines.h"
DESTINATION "${_dst}/crt"
COMPONENT MLX)
endif()
endif()
# On Windows, explicitly install dl.dll (dlfcn-win32 POSIX dlopen emulation)
# RUNTIME_DEPENDENCIES auto-excludes it via POST_EXCLUDE_FILES_STRICT because
# dlfcn-win32 is a known CMake target with its own install rules (which install
# to the wrong destination). We must install it explicitly here.
if(WIN32)
install(FILES ${OLLAMA_BUILD_DIR}/dl.dll
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()
# Manually install CUDA runtime libraries that MLX loads via dlopen
# (not detected by RUNTIME_DEPENDENCIES since they aren't link-time deps)
if(CUDAToolkit_FOUND)
file(GLOB MLX_CUDA_LIBS
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*" "${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*") "${CUDAToolkit_LIBRARY_DIR}/libcublas.so*"
if(CUDART_LIBS) "${CUDAToolkit_LIBRARY_DIR}/libcublasLt.so*"
install(FILES ${CUDART_LIBS} "${CUDAToolkit_LIBRARY_DIR}/libnvrtc.so*"
"${CUDAToolkit_LIBRARY_DIR}/libnvrtc-builtins.so*"
"${CUDAToolkit_LIBRARY_DIR}/libcufft.so*"
"${CUDAToolkit_LIBRARY_DIR}/libcudnn.so*")
if(MLX_CUDA_LIBS)
install(FILES ${MLX_CUDA_LIBS}
DESTINATION ${OLLAMA_INSTALL_DIR} DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX) COMPONENT MLX)
endif() endif()

View File

@@ -77,6 +77,15 @@
"OLLAMA_RUNNER_DIR": "rocm" "OLLAMA_RUNNER_DIR": "rocm"
} }
}, },
{
"name": "ROCm 7",
"inherits": [ "ROCm" ],
"cacheVariables": {
"CMAKE_HIP_FLAGS": "-parallel-jobs=4",
"AMDGPU_TARGETS": "gfx942;gfx950;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1151;gfx1200;gfx1201;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-",
"OLLAMA_RUNNER_DIR": "rocm"
}
},
{ {
"name": "Vulkan", "name": "Vulkan",
"inherits": [ "Default" ], "inherits": [ "Default" ],
@@ -103,6 +112,7 @@
"name": "MLX CUDA 13", "name": "MLX CUDA 13",
"inherits": [ "MLX", "CUDA 13" ], "inherits": [ "MLX", "CUDA 13" ],
"cacheVariables": { "cacheVariables": {
"MLX_CUDA_ARCHITECTURES": "86;89;90;90a;100;103;75-virtual;80-virtual;110-virtual;120-virtual;121-virtual",
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13" "OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
} }
} }
@@ -158,6 +168,11 @@
"inherits": [ "ROCm" ], "inherits": [ "ROCm" ],
"configurePreset": "ROCm 6" "configurePreset": "ROCm 6"
}, },
{
"name": "ROCm 7",
"inherits": [ "ROCm" ],
"configurePreset": "ROCm 7"
},
{ {
"name": "Vulkan", "name": "Vulkan",
"targets": [ "ggml-vulkan" ], "targets": [ "ggml-vulkan" ],

View File

@@ -1,28 +1,23 @@
# vim: filetype=dockerfile # vim: filetype=dockerfile
ARG FLAVOR=${TARGETARCH} ARG FLAVOR=${TARGETARCH}
ARG PARALLEL=8
ARG ROCMVERSION=6.3.3 ARG ROCMVERSION=7.2
ARG JETPACK5VERSION=r35.4.1 ARG JETPACK5VERSION=r35.4.1
ARG JETPACK6VERSION=r36.4.0 ARG JETPACK6VERSION=r36.4.0
ARG CMAKEVERSION=3.31.2 ARG CMAKEVERSION=3.31.2
ARG NINJAVERSION=1.12.1
ARG VULKANVERSION=1.4.321.1 ARG VULKANVERSION=1.4.321.1
# Default empty stages for local MLX source overrides.
# Override with: docker build --build-context local-mlx=../mlx --build-context local-mlx-c=../mlx-c
FROM scratch AS local-mlx
FROM scratch AS local-mlx-c
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64 FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
RUN dnf install -y yum-utils ccache gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ gcc-toolset-11-binutils \ RUN dnf install -y yum-utils ccache gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ gcc-toolset-11-binutils \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
ARG VULKANVERSION
RUN wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
&& tar xvf /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
&& dnf -y install ninja-build \
&& ln -s /usr/bin/python3 /usr/bin/python \
&& /${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \
&& /${VULKANVERSION}/vulkansdk -j 8 shaderc
RUN cp -r /${VULKANVERSION}/x86_64/include/* /usr/local/include/ \
&& cp -r /${VULKANVERSION}/x86_64/lib/* /usr/local/lib
ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH
FROM --platform=linux/arm64 almalinux:8 AS base-arm64 FROM --platform=linux/arm64 almalinux:8 AS base-arm64
# install epel-release for ccache # install epel-release for ccache
@@ -33,100 +28,119 @@ ENV CC=clang CXX=clang++
FROM base-${TARGETARCH} AS base FROM base-${TARGETARCH} AS base
ARG CMAKEVERSION ARG CMAKEVERSION
ARG NINJAVERSION
RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
RUN dnf install -y unzip \
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux$([ "$(uname -m)" = "aarch64" ] && echo "-aarch64").zip \
&& unzip /tmp/ninja.zip -d /usr/local/bin \
&& rm /tmp/ninja.zip
ENV CMAKE_GENERATOR=Ninja
ENV LDFLAGS=-s ENV LDFLAGS=-s
FROM base AS cpu FROM base AS cpu
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json . COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CPU' \ cmake --preset 'CPU' \
&& cmake --build --parallel ${PARALLEL} --preset 'CPU' \ && cmake --build --preset 'CPU' -- -l $(nproc) \
&& cmake --install build --component CPU --strip --parallel ${PARALLEL} && cmake --install build --component CPU --strip
FROM base AS cuda-11 FROM base AS cuda-11
ARG CUDA11VERSION=11.8 ARG CUDA11VERSION=11.8
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-} RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
ENV PATH=/usr/local/cuda-11/bin:$PATH ENV PATH=/usr/local/cuda-11/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json . COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 11' \ cmake --preset 'CUDA 11' \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \ && cmake --build --preset 'CUDA 11' -- -l $(nproc) \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL} && cmake --install build --component CUDA --strip
FROM base AS cuda-12 FROM base AS cuda-12
ARG CUDA12VERSION=12.8 ARG CUDA12VERSION=12.8
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-} RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
ENV PATH=/usr/local/cuda-12/bin:$PATH ENV PATH=/usr/local/cuda-12/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json . COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 12' \ cmake --preset 'CUDA 12' \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \ && cmake --build --preset 'CUDA 12' -- -l $(nproc) \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL} && cmake --install build --component CUDA --strip
FROM base AS cuda-13 FROM base AS cuda-13
ARG CUDA13VERSION=13.0 ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
ENV PATH=/usr/local/cuda-13/bin:$PATH ENV PATH=/usr/local/cuda-13/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json . COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 13' \ cmake --preset 'CUDA 13' \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \ && cmake --build --preset 'CUDA 13' -- -l $(nproc) \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL} && cmake --install build --component CUDA --strip
FROM base AS rocm-6 FROM base AS rocm-7
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json . COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'ROCm 6' \ cmake --preset 'ROCm 7' \
&& cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \ && cmake --build --preset 'ROCm 7' -- -l $(nproc) \
&& cmake --install build --component HIP --strip --parallel ${PARALLEL} && cmake --install build --component HIP --strip
RUN rm -f dist/lib/ollama/rocm/rocblas/library/*gfx90[06]* RUN rm -f dist/lib/ollama/rocm/rocblas/library/*gfx90[06]*
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5 FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
ARG CMAKEVERSION ARG CMAKEVERSION
RUN apt-get update && apt-get install -y curl ccache \ ARG NINJAVERSION
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 RUN apt-get update && apt-get install -y curl ccache unzip \
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 \
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux-aarch64.zip \
&& unzip /tmp/ninja.zip -d /usr/local/bin \
&& rm /tmp/ninja.zip
ENV CMAKE_GENERATOR=Ninja
COPY CMakeLists.txt CMakePresets.json . COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'JetPack 5' \ cmake --preset 'JetPack 5' \
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 5' \ && cmake --build --preset 'JetPack 5' -- -l $(nproc) \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL} && cmake --install build --component CUDA --strip
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6 FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
ARG CMAKEVERSION ARG CMAKEVERSION
RUN apt-get update && apt-get install -y curl ccache \ ARG NINJAVERSION
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 RUN apt-get update && apt-get install -y curl ccache unzip \
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 \
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux-aarch64.zip \
&& unzip /tmp/ninja.zip -d /usr/local/bin \
&& rm /tmp/ninja.zip
ENV CMAKE_GENERATOR=Ninja
COPY CMakeLists.txt CMakePresets.json . COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'JetPack 6' \ cmake --preset 'JetPack 6' \
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \ && cmake --build --preset 'JetPack 6' -- -l $(nproc) \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL} && cmake --install build --component CUDA --strip
FROM base AS vulkan FROM base AS vulkan
ARG VULKANVERSION
RUN ln -s /usr/bin/python3 /usr/bin/python \
&& wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk.tar.xz \
&& tar xvf /tmp/vulkansdk.tar.xz -C /tmp \
&& /tmp/${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \
&& /tmp/${VULKANVERSION}/vulkansdk -j 8 shaderc \
&& cp -r /tmp/${VULKANVERSION}/x86_64/include/* /usr/local/include/ \
&& cp -r /tmp/${VULKANVERSION}/x86_64/lib/* /usr/local/lib \
&& cp -r /tmp/${VULKANVERSION}/x86_64/bin/* /usr/local/bin/ \
&& rm -rf /tmp/${VULKANVERSION} /tmp/vulkansdk.tar.xz
COPY CMakeLists.txt CMakePresets.json . COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'Vulkan' \ cmake --preset 'Vulkan' \
&& cmake --build --parallel --preset 'Vulkan' \ && cmake --build --preset 'Vulkan' -- -l $(nproc) \
&& cmake --install build --component Vulkan --strip --parallel 8 && cmake --install build --component Vulkan --strip
FROM base AS mlx FROM base AS mlx
ARG CUDA13VERSION=13.0 ARG CUDA13VERSION=13.0
@@ -138,20 +152,27 @@ ENV PATH=/usr/local/cuda-13/bin:$PATH
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs" ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
ARG PARALLEL
WORKDIR /go/src/github.com/ollama/ollama WORKDIR /go/src/github.com/ollama/ollama
COPY CMakeLists.txt CMakePresets.json . COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/imagegen/mlx x/imagegen/mlx COPY x/imagegen/mlx x/imagegen/mlx
COPY go.mod go.sum . COPY go.mod go.sum .
COPY MLX_VERSION . COPY MLX_VERSION MLX_C_VERSION .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download RUN go mod download
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \ --mount=type=bind,from=local-mlx,target=/tmp/local-mlx \
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \ --mount=type=bind,from=local-mlx-c,target=/tmp/local-mlx-c \
&& cmake --install build --component MLX --strip --parallel ${PARALLEL} if [ -f /tmp/local-mlx/CMakeLists.txt ]; then \
export OLLAMA_MLX_SOURCE=/tmp/local-mlx; \
fi \
&& if [ -f /tmp/local-mlx-c/CMakeLists.txt ]; then \
export OLLAMA_MLX_C_SOURCE=/tmp/local-mlx-c; \
fi \
&& cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
&& cmake --build --preset 'MLX CUDA 13' -- -l $(nproc) \
&& cmake --install build --component MLX --strip
FROM base AS build FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama WORKDIR /go/src/github.com/ollama/ollama
@@ -160,16 +181,14 @@ RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-
ENV PATH=/usr/local/go/bin:$PATH ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download RUN go mod download
COPY . . COPY . .
# Clone mlx-c headers for CGO (version from MLX_VERSION file)
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
ARG GOFLAGS="'-ldflags=-w -s'" ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1 ENV CGO_ENABLED=1
ARG CGO_CFLAGS ARG CGO_CFLAGS
ARG CGO_CXXFLAGS ARG CGO_CXXFLAGS
ENV CGO_CFLAGS="${CGO_CFLAGS} -I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src" ENV CGO_CFLAGS="${CGO_CFLAGS}"
ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}" ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}"
RUN --mount=type=cache,target=/root/.cache/go-build \ RUN --mount=type=cache,target=/root/.cache/go-build \
go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama . go build -trimpath -buildmode=pie -o /bin/ollama .
FROM --platform=linux/amd64 scratch AS amd64 FROM --platform=linux/amd64 scratch AS amd64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/ # COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
@@ -186,10 +205,9 @@ COPY --from=jetpack-5 dist/lib/ollama/ /lib/ollama/
COPY --from=jetpack-6 dist/lib/ollama/ /lib/ollama/ COPY --from=jetpack-6 dist/lib/ollama/ /lib/ollama/
FROM scratch AS rocm FROM scratch AS rocm
COPY --from=rocm-6 dist/lib/ollama /lib/ollama COPY --from=rocm-7 dist/lib/ollama /lib/ollama
FROM ${FLAVOR} AS archive FROM ${FLAVOR} AS archive
ARG VULKANVERSION
COPY --from=cpu dist/lib/ollama /lib/ollama COPY --from=cpu dist/lib/ollama /lib/ollama
COPY --from=build /bin/ollama /bin/ollama COPY --from=build /bin/ollama /bin/ollama

1
MLX_C_VERSION Normal file
View File

@@ -0,0 +1 @@
0726ca922fc902c4c61ef9c27d94132be418e945

View File

@@ -1 +1 @@
v0.5.0 38ad257088fb2193ad47e527cf6534a689f30943

View File

@@ -68,7 +68,7 @@ type MessagesRequest struct {
Model string `json:"model"` Model string `json:"model"`
MaxTokens int `json:"max_tokens"` MaxTokens int `json:"max_tokens"`
Messages []MessageParam `json:"messages"` Messages []MessageParam `json:"messages"`
System any `json:"system,omitempty"` // string or []ContentBlock System any `json:"system,omitempty"` // string or []map[string]any (JSON-decoded ContentBlock)
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"` TopP *float64 `json:"top_p,omitempty"`
@@ -82,8 +82,27 @@ type MessagesRequest struct {
// MessageParam represents a message in the request // MessageParam represents a message in the request
type MessageParam struct { type MessageParam struct {
Role string `json:"role"` // "user" or "assistant" Role string `json:"role"` // "user" or "assistant"
Content any `json:"content"` // string or []ContentBlock Content []ContentBlock `json:"content"` // always []ContentBlock; plain strings are normalized on unmarshal
}
func (m *MessageParam) UnmarshalJSON(data []byte) error {
var raw struct {
Role string `json:"role"`
Content json.RawMessage `json:"content"`
}
if err := json.Unmarshal(data, &raw); err != nil {
return err
}
m.Role = raw.Role
var s string
if err := json.Unmarshal(raw.Content, &s); err == nil {
m.Content = []ContentBlock{{Type: "text", Text: &s}}
return nil
}
return json.Unmarshal(raw.Content, &m.Content)
} }
// ContentBlock represents a content block in a message. // ContentBlock represents a content block in a message.
@@ -102,9 +121,9 @@ type ContentBlock struct {
Source *ImageSource `json:"source,omitempty"` Source *ImageSource `json:"source,omitempty"`
// For tool_use and server_tool_use blocks // For tool_use and server_tool_use blocks
ID string `json:"id,omitempty"` ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"` Input api.ToolCallFunctionArguments `json:"input,omitzero"`
// For tool_result and web_search_tool_result blocks // For tool_result and web_search_tool_result blocks
ToolUseID string `json:"tool_use_id,omitempty"` ToolUseID string `json:"tool_use_id,omitempty"`
@@ -377,178 +396,145 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
var messages []api.Message var messages []api.Message
role := strings.ToLower(msg.Role) role := strings.ToLower(msg.Role)
switch content := msg.Content.(type) { var textContent strings.Builder
case string: var images []api.ImageData
messages = append(messages, api.Message{Role: role, Content: content}) var toolCalls []api.ToolCall
var thinking string
var toolResults []api.Message
textBlocks := 0
imageBlocks := 0
toolUseBlocks := 0
toolResultBlocks := 0
serverToolUseBlocks := 0
webSearchToolResultBlocks := 0
thinkingBlocks := 0
unknownBlocks := 0
case []any: for _, block := range msg.Content {
var textContent strings.Builder switch block.Type {
var images []api.ImageData case "text":
var toolCalls []api.ToolCall textBlocks++
var thinking string if block.Text != nil {
var toolResults []api.Message textContent.WriteString(*block.Text)
textBlocks := 0
imageBlocks := 0
toolUseBlocks := 0
toolResultBlocks := 0
serverToolUseBlocks := 0
webSearchToolResultBlocks := 0
thinkingBlocks := 0
unknownBlocks := 0
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
logutil.Trace("anthropic: invalid content block format", "role", role)
return nil, errors.New("invalid content block format")
} }
blockType, _ := blockMap["type"].(string) case "image":
imageBlocks++
if block.Source == nil {
logutil.Trace("anthropic: invalid image source", "role", role)
return nil, errors.New("invalid image source")
}
switch blockType { if block.Source.Type == "base64" {
case "text": decoded, err := base64.StdEncoding.DecodeString(block.Source.Data)
textBlocks++ if err != nil {
if text, ok := blockMap["text"].(string); ok { logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err)
textContent.WriteString(text) return nil, fmt.Errorf("invalid base64 image data: %w", err)
} }
images = append(images, decoded)
} else {
logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", block.Source.Type)
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", block.Source.Type)
}
case "image": case "tool_use":
imageBlocks++ toolUseBlocks++
source, ok := blockMap["source"].(map[string]any) if block.ID == "" {
if !ok { logutil.Trace("anthropic: tool_use block missing id", "role", role)
logutil.Trace("anthropic: invalid image source", "role", role) return nil, errors.New("tool_use block missing required 'id' field")
return nil, errors.New("invalid image source") }
} if block.Name == "" {
logutil.Trace("anthropic: tool_use block missing name", "role", role)
return nil, errors.New("tool_use block missing required 'name' field")
}
toolCalls = append(toolCalls, api.ToolCall{
ID: block.ID,
Function: api.ToolCallFunction{
Name: block.Name,
Arguments: block.Input,
},
})
sourceType, _ := source["type"].(string) case "tool_result":
if sourceType == "base64" { toolResultBlocks++
data, _ := source["data"].(string) var resultContent string
decoded, err := base64.StdEncoding.DecodeString(data)
if err != nil {
logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err)
return nil, fmt.Errorf("invalid base64 image data: %w", err)
}
images = append(images, decoded)
} else {
logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", sourceType)
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
}
// URL images would need to be fetched - skip for now
case "tool_use": switch c := block.Content.(type) {
toolUseBlocks++ case string:
id, ok := blockMap["id"].(string) resultContent = c
if !ok { case []any:
logutil.Trace("anthropic: tool_use block missing id", "role", role) for _, cb := range c {
return nil, errors.New("tool_use block missing required 'id' field") if cbMap, ok := cb.(map[string]any); ok {
} if cbMap["type"] == "text" {
name, ok := blockMap["name"].(string) if text, ok := cbMap["text"].(string); ok {
if !ok { resultContent += text
logutil.Trace("anthropic: tool_use block missing name", "role", role)
return nil, errors.New("tool_use block missing required 'name' field")
}
tc := api.ToolCall{
ID: id,
Function: api.ToolCallFunction{
Name: name,
},
}
if input, ok := blockMap["input"].(map[string]any); ok {
tc.Function.Arguments = mapToArgs(input)
}
toolCalls = append(toolCalls, tc)
case "tool_result":
toolResultBlocks++
toolUseID, _ := blockMap["tool_use_id"].(string)
var resultContent string
switch c := blockMap["content"].(type) {
case string:
resultContent = c
case []any:
for _, cb := range c {
if cbMap, ok := cb.(map[string]any); ok {
if cbMap["type"] == "text" {
if text, ok := cbMap["text"].(string); ok {
resultContent += text
}
} }
} }
} }
} }
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: resultContent,
ToolCallID: toolUseID,
})
case "thinking":
thinkingBlocks++
if t, ok := blockMap["thinking"].(string); ok {
thinking = t
}
case "server_tool_use":
serverToolUseBlocks++
id, _ := blockMap["id"].(string)
name, _ := blockMap["name"].(string)
tc := api.ToolCall{
ID: id,
Function: api.ToolCallFunction{
Name: name,
},
}
if input, ok := blockMap["input"].(map[string]any); ok {
tc.Function.Arguments = mapToArgs(input)
}
toolCalls = append(toolCalls, tc)
case "web_search_tool_result":
webSearchToolResultBlocks++
toolUseID, _ := blockMap["tool_use_id"].(string)
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: formatWebSearchToolResultContent(blockMap["content"]),
ToolCallID: toolUseID,
})
default:
unknownBlocks++
} }
}
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" { toolResults = append(toolResults, api.Message{
m := api.Message{ Role: "tool",
Role: role, Content: resultContent,
Content: textContent.String(), ToolCallID: block.ToolUseID,
Images: images, })
ToolCalls: toolCalls,
Thinking: thinking, case "thinking":
thinkingBlocks++
if block.Thinking != nil {
thinking = *block.Thinking
} }
messages = append(messages, m)
case "server_tool_use":
serverToolUseBlocks++
toolCalls = append(toolCalls, api.ToolCall{
ID: block.ID,
Function: api.ToolCallFunction{
Name: block.Name,
Arguments: block.Input,
},
})
case "web_search_tool_result":
webSearchToolResultBlocks++
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: formatWebSearchToolResultContent(block.Content),
ToolCallID: block.ToolUseID,
})
default:
unknownBlocks++
} }
// Add tool results as separate messages
messages = append(messages, toolResults...)
logutil.Trace("anthropic: converted block message",
"role", role,
"blocks", len(content),
"text", textBlocks,
"image", imageBlocks,
"tool_use", toolUseBlocks,
"tool_result", toolResultBlocks,
"server_tool_use", serverToolUseBlocks,
"web_search_result", webSearchToolResultBlocks,
"thinking", thinkingBlocks,
"unknown", unknownBlocks,
"messages", TraceAPIMessages(messages),
)
default:
return nil, fmt.Errorf("invalid message content type: %T", content)
} }
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
m := api.Message{
Role: role,
Content: textContent.String(),
Images: images,
ToolCalls: toolCalls,
Thinking: thinking,
}
messages = append(messages, m)
}
// Add tool results as separate messages
messages = append(messages, toolResults...)
logutil.Trace("anthropic: converted block message",
"role", role,
"blocks", len(msg.Content),
"text", textBlocks,
"image", imageBlocks,
"tool_use", toolUseBlocks,
"tool_result", toolResultBlocks,
"server_tool_use", serverToolUseBlocks,
"web_search_result", webSearchToolResultBlocks,
"thinking", thinkingBlocks,
"unknown", unknownBlocks,
"messages", TraceAPIMessages(messages),
)
return messages, nil return messages, nil
} }
@@ -852,6 +838,19 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
continue continue
} }
// Close thinking block if still open (thinking → tool_use without text in between)
if c.thinkingStarted && !c.thinkingDone {
c.thinkingDone = true
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
}
if c.textStarted { if c.textStarted {
events = append(events, StreamEvent{ events = append(events, StreamEvent{
Event: "content_block_stop", Event: "content_block_stop",
@@ -869,7 +868,6 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID) slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID)
continue continue
} }
events = append(events, StreamEvent{ events = append(events, StreamEvent{
Event: "content_block_start", Event: "content_block_start",
Data: ContentBlockStartEvent{ Data: ContentBlockStartEvent{
@@ -879,7 +877,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
Type: "tool_use", Type: "tool_use",
ID: tc.ID, ID: tc.ID,
Name: tc.Function.Name, Name: tc.Function.Name,
Input: map[string]any{}, Input: api.NewToolCallFunctionArguments(),
}, },
}, },
}) })
@@ -976,15 +974,6 @@ func ptr(s string) *string {
return &s return &s
} }
// mapToArgs converts a map to ToolCallFunctionArguments
func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
// CountTokensRequest represents an Anthropic count_tokens request // CountTokensRequest represents an Anthropic count_tokens request
type CountTokensRequest struct { type CountTokensRequest struct {
Model string `json:"model"` Model string `json:"model"`
@@ -1017,17 +1006,13 @@ func estimateTokens(req CountTokensRequest) int {
var totalLen int var totalLen int
// Count system prompt // Count system prompt
if req.System != nil { totalLen += countAnyContent(req.System)
totalLen += countAnyContent(req.System)
}
// Count messages
for _, msg := range req.Messages { for _, msg := range req.Messages {
// Count role (always present) // Count role (always present)
totalLen += len(msg.Role) totalLen += len(msg.Role)
// Count content // Count content
contentLen := countAnyContent(msg.Content) totalLen += countAnyContent(msg.Content)
totalLen += contentLen
} }
for _, tool := range req.Tools { for _, tool := range req.Tools {
@@ -1050,12 +1035,25 @@ func countAnyContent(content any) int {
switch c := content.(type) { switch c := content.(type) {
case string: case string:
return len(c) return len(c)
case []any: case []ContentBlock:
total := 0 total := 0
for _, block := range c { for _, block := range c {
total += countContentBlock(block) total += countContentBlock(block)
} }
return total return total
case []any:
total := 0
for _, item := range c {
data, err := json.Marshal(item)
if err != nil {
continue
}
var block ContentBlock
if err := json.Unmarshal(data, &block); err == nil {
total += countContentBlock(block)
}
}
return total
default: default:
if data, err := json.Marshal(content); err == nil { if data, err := json.Marshal(content); err == nil {
return len(data) return len(data)
@@ -1064,38 +1062,19 @@ func countAnyContent(content any) int {
} }
} }
func countContentBlock(block any) int { func countContentBlock(block ContentBlock) int {
blockMap, ok := block.(map[string]any)
if !ok {
if s, ok := block.(string); ok {
return len(s)
}
return 0
}
total := 0 total := 0
blockType, _ := blockMap["type"].(string) if block.Text != nil {
total += len(*block.Text)
if text, ok := blockMap["text"].(string); ok {
total += len(text)
} }
if block.Thinking != nil {
if thinking, ok := blockMap["thinking"].(string); ok { total += len(*block.Thinking)
total += len(thinking)
} }
if block.Type == "tool_use" || block.Type == "tool_result" {
if blockType == "tool_use" { if data, err := json.Marshal(block); err == nil {
if data, err := json.Marshal(blockMap); err == nil {
total += len(data) total += len(data)
} }
} }
if blockType == "tool_result" {
if data, err := json.Marshal(blockMap); err == nil {
total += len(data)
}
}
return total return total
} }

View File

@@ -15,11 +15,16 @@ const (
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
) )
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests) // textContent is a convenience for constructing []ContentBlock with a single text block in tests.
func testArgs(m map[string]any) api.ToolCallFunctionArguments { func textContent(s string) []ContentBlock {
return []ContentBlock{{Type: "text", Text: &s}}
}
// makeArgs creates ToolCallFunctionArguments from key-value pairs (convenience function for tests)
func makeArgs(kvs ...any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments() args := api.NewToolCallFunctionArguments()
for k, v := range m { for i := 0; i < len(kvs)-1; i += 2 {
args.Set(k, v) args.Set(kvs[i].(string), kvs[i+1])
} }
return args return args
} }
@@ -29,7 +34,7 @@ func TestFromMessagesRequest_Basic(t *testing.T) {
Model: "test-model", Model: "test-model",
MaxTokens: 1024, MaxTokens: 1024,
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: "Hello"}, {Role: "user", Content: textContent("Hello")},
}, },
} }
@@ -61,7 +66,7 @@ func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
MaxTokens: 1024, MaxTokens: 1024,
System: "You are a helpful assistant.", System: "You are a helpful assistant.",
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: "Hello"}, {Role: "user", Content: textContent("Hello")},
}, },
} }
@@ -88,7 +93,7 @@ func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
map[string]any{"type": "text", "text": " Be concise."}, map[string]any{"type": "text", "text": " Be concise."},
}, },
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: "Hello"}, {Role: "user", Content: textContent("Hello")},
}, },
} }
@@ -113,7 +118,7 @@ func TestFromMessagesRequest_WithOptions(t *testing.T) {
req := MessagesRequest{ req := MessagesRequest{
Model: "test-model", Model: "test-model",
MaxTokens: 2048, MaxTokens: 2048,
Messages: []MessageParam{{Role: "user", Content: "Hello"}}, Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
Temperature: &temp, Temperature: &temp,
TopP: &topP, TopP: &topP,
TopK: &topK, TopK: &topK,
@@ -148,14 +153,14 @@ func TestFromMessagesRequest_WithImage(t *testing.T) {
Messages: []MessageParam{ Messages: []MessageParam{
{ {
Role: "user", Role: "user",
Content: []any{ Content: []ContentBlock{
map[string]any{"type": "text", "text": "What's in this image?"}, {Type: "text", Text: ptr("What's in this image?")},
map[string]any{ {
"type": "image", Type: "image",
"source": map[string]any{ Source: &ImageSource{
"type": "base64", Type: "base64",
"media_type": "image/png", MediaType: "image/png",
"data": testImage, Data: testImage,
}, },
}, },
}, },
@@ -190,15 +195,15 @@ func TestFromMessagesRequest_WithToolUse(t *testing.T) {
Model: "test-model", Model: "test-model",
MaxTokens: 1024, MaxTokens: 1024,
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: "What's the weather in Paris?"}, {Role: "user", Content: textContent("What's the weather in Paris?")},
{ {
Role: "assistant", Role: "assistant",
Content: []any{ Content: []ContentBlock{
map[string]any{ {
"type": "tool_use", Type: "tool_use",
"id": "call_123", ID: "call_123",
"name": "get_weather", Name: "get_weather",
"input": map[string]any{"location": "Paris"}, Input: makeArgs("location", "Paris"),
}, },
}, },
}, },
@@ -234,11 +239,11 @@ func TestFromMessagesRequest_WithToolResult(t *testing.T) {
Messages: []MessageParam{ Messages: []MessageParam{
{ {
Role: "user", Role: "user",
Content: []any{ Content: []ContentBlock{
map[string]any{ {
"type": "tool_result", Type: "tool_result",
"tool_use_id": "call_123", ToolUseID: "call_123",
"content": "The weather in Paris is sunny, 22°C", Content: "The weather in Paris is sunny, 22°C",
}, },
}, },
}, },
@@ -270,7 +275,7 @@ func TestFromMessagesRequest_WithTools(t *testing.T) {
req := MessagesRequest{ req := MessagesRequest{
Model: "test-model", Model: "test-model",
MaxTokens: 1024, MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}}, Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
Tools: []Tool{ Tools: []Tool{
{ {
Name: "get_weather", Name: "get_weather",
@@ -305,7 +310,7 @@ func TestFromMessagesRequest_DropsCustomWebSearchWhenBuiltinPresent(t *testing.T
req := MessagesRequest{ req := MessagesRequest{
Model: "test-model", Model: "test-model",
MaxTokens: 1024, MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}}, Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
Tools: []Tool{ Tools: []Tool{
{ {
Type: "web_search_20250305", Type: "web_search_20250305",
@@ -346,7 +351,7 @@ func TestFromMessagesRequest_KeepsCustomWebSearchWhenBuiltinAbsent(t *testing.T)
req := MessagesRequest{ req := MessagesRequest{
Model: "test-model", Model: "test-model",
MaxTokens: 1024, MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}}, Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
Tools: []Tool{ Tools: []Tool{
{ {
Type: "custom", Type: "custom",
@@ -377,7 +382,7 @@ func TestFromMessagesRequest_WithThinking(t *testing.T) {
req := MessagesRequest{ req := MessagesRequest{
Model: "test-model", Model: "test-model",
MaxTokens: 1024, MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}}, Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000}, Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
} }
@@ -399,13 +404,13 @@ func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
Model: "test-model", Model: "test-model",
MaxTokens: 1024, MaxTokens: 1024,
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: "Hello"}, {Role: "user", Content: textContent("Hello")},
{ {
Role: "assistant", Role: "assistant",
Content: []any{ Content: []ContentBlock{
map[string]any{ {
"type": "thinking", Type: "thinking",
"thinking": "Let me think about this...", Thinking: ptr("Let me think about this..."),
}, },
}, },
}, },
@@ -434,10 +439,10 @@ func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
Messages: []MessageParam{ Messages: []MessageParam{
{ {
Role: "assistant", Role: "assistant",
Content: []any{ Content: []ContentBlock{
map[string]any{ {
"type": "tool_use", Type: "tool_use",
"name": "get_weather", Name: "get_weather",
}, },
}, },
}, },
@@ -460,10 +465,10 @@ func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
Messages: []MessageParam{ Messages: []MessageParam{
{ {
Role: "assistant", Role: "assistant",
Content: []any{ Content: []ContentBlock{
map[string]any{ {
"type": "tool_use", Type: "tool_use",
"id": "call_123", ID: "call_123",
}, },
}, },
}, },
@@ -483,7 +488,7 @@ func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
req := MessagesRequest{ req := MessagesRequest{
Model: "test-model", Model: "test-model",
MaxTokens: 1024, MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}}, Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
Tools: []Tool{ Tools: []Tool{
{ {
Name: "bad_tool", Name: "bad_tool",
@@ -548,7 +553,7 @@ func TestToMessagesResponse_WithToolCalls(t *testing.T) {
ID: "call_123", ID: "call_123",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}), Arguments: makeArgs("location", "Paris"),
}, },
}, },
}, },
@@ -760,7 +765,7 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
ID: "call_123", ID: "call_123",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}), Arguments: makeArgs("location", "Paris"),
}, },
}, },
}, },
@@ -799,6 +804,107 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
} }
} }
// TestStreamConverter_ThinkingDirectlyFollowedByToolCall verifies that when a
// model emits a thinking block followed directly by a tool_use block (with no
// text block in between), the streaming converter correctly closes the thinking
// block and increments the content index before opening the tool_use block.
// Previously, the converter reused contentIndex=0 for the tool_use block,
// which caused "Content block not found" errors in clients. See #14816.
func TestStreamConverter_ThinkingDirectlyFollowedByToolCall(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model", 0)
// First chunk: thinking content (no text)
resp1 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Thinking: "I should call the tool.",
},
}
events1 := conv.Process(resp1)
// Should have: message_start, content_block_start(thinking), content_block_delta(thinking)
if len(events1) < 3 {
t.Fatalf("expected at least 3 events for thinking chunk, got %d", len(events1))
}
if events1[0].Event != "message_start" {
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
}
thinkingStart, ok := events1[1].Data.(ContentBlockStartEvent)
if !ok || thinkingStart.ContentBlock.Type != "thinking" {
t.Errorf("expected content_block_start(thinking) as second event, got %+v", events1[1])
}
if thinkingStart.Index != 0 {
t.Errorf("expected thinking block at index 0, got %d", thinkingStart.Index)
}
// Second chunk: tool call (no text between thinking and tool)
resp2 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_abc",
Function: api.ToolCallFunction{
Name: "ask_user",
Arguments: makeArgs("question", "cats or dogs?"),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
}
events2 := conv.Process(resp2)
// Expect: content_block_stop(index=0), content_block_start(tool_use, index=1),
// content_block_delta(input_json_delta, index=1), content_block_stop(index=1),
// message_delta, message_stop
var thinkingStop, toolStart, toolDelta, toolStop *StreamEvent
for i := range events2 {
e := &events2[i]
switch e.Event {
case "content_block_stop":
if stop, ok := e.Data.(ContentBlockStopEvent); ok {
if stop.Index == 0 && thinkingStop == nil {
thinkingStop = e
} else if stop.Index == 1 {
toolStop = e
}
}
case "content_block_start":
if start, ok := e.Data.(ContentBlockStartEvent); ok && start.ContentBlock.Type == "tool_use" {
toolStart = e
}
case "content_block_delta":
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok && delta.Delta.Type == "input_json_delta" {
toolDelta = e
}
}
}
if thinkingStop == nil {
t.Error("expected content_block_stop for thinking block (index 0)")
}
if toolStart == nil {
t.Fatal("expected content_block_start for tool_use block")
}
if start, ok := toolStart.Data.(ContentBlockStartEvent); !ok || start.Index != 1 {
t.Errorf("expected tool_use block at index 1, got %+v", toolStart.Data)
}
if toolDelta == nil {
t.Fatal("expected input_json_delta event for tool call")
}
if delta, ok := toolDelta.Data.(ContentBlockDeltaEvent); !ok || delta.Index != 1 {
t.Errorf("expected tool delta at index 1, got %+v", toolDelta.Data)
}
if toolStop == nil {
t.Error("expected content_block_stop for tool_use block (index 1)")
}
}
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) { func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
// Test that unmarshalable arguments (like channels) are handled gracefully // Test that unmarshalable arguments (like channels) are handled gracefully
// and don't cause a panic or corrupt stream // and don't cause a panic or corrupt stream
@@ -864,7 +970,7 @@ func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
ID: "call_good", ID: "call_good",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "good_function", Name: "good_function",
Arguments: testArgs(map[string]any{"location": "Paris"}), Arguments: makeArgs("location", "Paris"),
}, },
}, },
{ {
@@ -966,6 +1072,57 @@ func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
} }
} }
func TestContentBlockJSON_NonToolBlocksDoNotIncludeInput(t *testing.T) {
tests := []struct {
name string
block ContentBlock
}{
{
name: "text block",
block: ContentBlock{
Type: "text",
Text: ptr("hello"),
},
},
{
name: "thinking block",
block: ContentBlock{
Type: "thinking",
Thinking: ptr("let me think"),
},
},
{
name: "image block",
block: ContentBlock{
Type: "image",
Source: &ImageSource{
Type: "base64",
MediaType: "image/png",
Data: testImage,
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.block)
if err != nil {
t.Fatalf("failed to marshal: %v", err)
}
var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if _, ok := result["input"]; ok {
t.Fatalf("unexpected input field in non-tool block JSON: %s", string(data))
}
})
}
}
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) { func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
t.Run("text block start includes empty text", func(t *testing.T) { t.Run("text block start includes empty text", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model", 0) conv := NewStreamConverter("msg_123", "test-model", 0)
@@ -986,7 +1143,9 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
// Marshal and verify the text field is present // Marshal and verify the text field is present
data, _ := json.Marshal(start) data, _ := json.Marshal(start)
var result map[string]any var result map[string]any
json.Unmarshal(data, &result) if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("failed to unmarshal content_block_start JSON: %v", err)
}
cb := result["content_block"].(map[string]any) cb := result["content_block"].(map[string]any)
if _, ok := cb["text"]; !ok { if _, ok := cb["text"]; !ok {
t.Error("content_block_start for text should include 'text' field") t.Error("content_block_start for text should include 'text' field")
@@ -1033,13 +1192,71 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
t.Error("expected thinking content_block_start event") t.Error("expected thinking content_block_start event")
} }
}) })
t.Run("tool_use block start includes empty input object", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model", 0)
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: makeArgs("location", "Paris"),
},
},
},
},
}
events := conv.Process(resp)
var foundToolStart bool
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
foundToolStart = true
if start.ContentBlock.Input.Len() != 0 {
t.Errorf("expected empty input object, got len=%d", start.ContentBlock.Input.Len())
}
data, _ := json.Marshal(start)
var result map[string]any
json.Unmarshal(data, &result)
cb := result["content_block"].(map[string]any)
input, ok := cb["input"]
if !ok {
t.Error("content_block_start for tool_use should include 'input' field")
continue
}
inputMap, ok := input.(map[string]any)
if !ok {
t.Errorf("input field should be an object, got %T", input)
continue
}
if len(inputMap) != 0 {
t.Errorf("expected empty input object in content_block_start, got %v", inputMap)
}
}
}
}
}
if !foundToolStart {
t.Error("expected tool_use content_block_start event")
}
})
} }
func TestEstimateTokens_SimpleMessage(t *testing.T) { func TestEstimateTokens_SimpleMessage(t *testing.T) {
req := CountTokensRequest{ req := CountTokensRequest{
Model: "test-model", Model: "test-model",
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: "Hello, world!"}, {Role: "user", Content: textContent("Hello, world!")},
}, },
} }
@@ -1060,7 +1277,7 @@ func TestEstimateTokens_WithSystemPrompt(t *testing.T) {
Model: "test-model", Model: "test-model",
System: "You are a helpful assistant.", System: "You are a helpful assistant.",
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: "Hello"}, {Role: "user", Content: textContent("Hello")},
}, },
} }
@@ -1076,7 +1293,7 @@ func TestEstimateTokens_WithTools(t *testing.T) {
req := CountTokensRequest{ req := CountTokensRequest{
Model: "test-model", Model: "test-model",
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: "What's the weather?"}, {Role: "user", Content: textContent("What's the weather?")},
}, },
Tools: []Tool{ Tools: []Tool{
{ {
@@ -1099,17 +1316,17 @@ func TestEstimateTokens_WithThinking(t *testing.T) {
req := CountTokensRequest{ req := CountTokensRequest{
Model: "test-model", Model: "test-model",
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: "Hello"}, {Role: "user", Content: textContent("Hello")},
{ {
Role: "assistant", Role: "assistant",
Content: []any{ Content: []ContentBlock{
map[string]any{ {
"type": "thinking", Type: "thinking",
"thinking": "Let me think about this carefully...", Thinking: ptr("Let me think about this carefully..."),
}, },
map[string]any{ {
"type": "text", Type: "text",
"text": "Here is my response.", Text: ptr("Here is my response."),
}, },
}, },
}, },
@@ -1207,12 +1424,12 @@ func TestConvertTool_RegularTool(t *testing.T) {
func TestConvertMessage_ServerToolUse(t *testing.T) { func TestConvertMessage_ServerToolUse(t *testing.T) {
msg := MessageParam{ msg := MessageParam{
Role: "assistant", Role: "assistant",
Content: []any{ Content: []ContentBlock{
map[string]any{ {
"type": "server_tool_use", Type: "server_tool_use",
"id": "srvtoolu_123", ID: "srvtoolu_123",
"name": "web_search", Name: "web_search",
"input": map[string]any{"query": "test query"}, Input: makeArgs("query", "test query"),
}, },
}, },
} }
@@ -1243,11 +1460,11 @@ func TestConvertMessage_ServerToolUse(t *testing.T) {
func TestConvertMessage_WebSearchToolResult(t *testing.T) { func TestConvertMessage_WebSearchToolResult(t *testing.T) {
msg := MessageParam{ msg := MessageParam{
Role: "user", Role: "user",
Content: []any{ Content: []ContentBlock{
map[string]any{ {
"type": "web_search_tool_result", Type: "web_search_tool_result",
"tool_use_id": "srvtoolu_123", ToolUseID: "srvtoolu_123",
"content": []any{ Content: []any{
map[string]any{ map[string]any{
"type": "web_search_result", "type": "web_search_result",
"title": "Test Result", "title": "Test Result",
@@ -1284,11 +1501,11 @@ func TestConvertMessage_WebSearchToolResult(t *testing.T) {
func TestConvertMessage_WebSearchToolResultEmptyStillCreatesToolMessage(t *testing.T) { func TestConvertMessage_WebSearchToolResultEmptyStillCreatesToolMessage(t *testing.T) {
msg := MessageParam{ msg := MessageParam{
Role: "user", Role: "user",
Content: []any{ Content: []ContentBlock{
map[string]any{ {
"type": "web_search_tool_result", Type: "web_search_tool_result",
"tool_use_id": "srvtoolu_empty", ToolUseID: "srvtoolu_empty",
"content": []any{}, Content: []any{},
}, },
}, },
} }
@@ -1315,11 +1532,11 @@ func TestConvertMessage_WebSearchToolResultEmptyStillCreatesToolMessage(t *testi
func TestConvertMessage_WebSearchToolResultErrorStillCreatesToolMessage(t *testing.T) { func TestConvertMessage_WebSearchToolResultErrorStillCreatesToolMessage(t *testing.T) {
msg := MessageParam{ msg := MessageParam{
Role: "user", Role: "user",
Content: []any{ Content: []ContentBlock{
map[string]any{ {
"type": "web_search_tool_result", Type: "web_search_tool_result",
"tool_use_id": "srvtoolu_error", ToolUseID: "srvtoolu_error",
"content": map[string]any{ Content: map[string]any{
"type": "web_search_tool_result_error", "type": "web_search_tool_result_error",
"error_code": "max_uses_exceeded", "error_code": "max_uses_exceeded",
}, },

View File

@@ -476,25 +476,3 @@ func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
} }
return &resp, nil return &resp, nil
} }
// AliasRequest is the request body for creating or updating a model alias.
type AliasRequest struct {
Alias string `json:"alias"`
Target string `json:"target"`
PrefixMatching bool `json:"prefix_matching,omitempty"`
}
// SetAliasExperimental creates or updates a model alias via the experimental aliases API.
func (c *Client) SetAliasExperimental(ctx context.Context, req *AliasRequest) error {
return c.do(ctx, http.MethodPost, "/api/experimental/aliases", req, nil)
}
// AliasDeleteRequest is the request body for deleting a model alias.
type AliasDeleteRequest struct {
Alias string `json:"alias"`
}
// DeleteAliasExperimental deletes a model alias via the experimental aliases API.
func (c *Client) DeleteAliasExperimental(ctx context.Context, req *AliasDeleteRequest) error {
return c.do(ctx, http.MethodDelete, "/api/experimental/aliases", req, nil)
}

View File

@@ -436,6 +436,7 @@ type ToolProperty struct {
Description string `json:"description,omitempty"` Description string `json:"description,omitempty"`
Enum []any `json:"enum,omitempty"` Enum []any `json:"enum,omitempty"`
Properties *ToolPropertiesMap `json:"properties,omitempty"` Properties *ToolPropertiesMap `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
} }
// ToTypeScriptType converts a ToolProperty to a TypeScript type string // ToTypeScriptType converts a ToolProperty to a TypeScript type string

View File

@@ -14,7 +14,7 @@ import (
// currentSchemaVersion defines the current database schema version. // currentSchemaVersion defines the current database schema version.
// Increment this when making schema changes that require migrations. // Increment this when making schema changes that require migrations.
const currentSchemaVersion = 15 const currentSchemaVersion = 16
// database wraps the SQLite connection. // database wraps the SQLite connection.
// SQLite handles its own locking for concurrent access: // SQLite handles its own locking for concurrent access:
@@ -82,6 +82,7 @@ func (db *database) init() error {
websearch_enabled BOOLEAN NOT NULL DEFAULT 0, websearch_enabled BOOLEAN NOT NULL DEFAULT 0,
selected_model TEXT NOT NULL DEFAULT '', selected_model TEXT NOT NULL DEFAULT '',
sidebar_open BOOLEAN NOT NULL DEFAULT 0, sidebar_open BOOLEAN NOT NULL DEFAULT 0,
last_home_view TEXT NOT NULL DEFAULT 'launch',
think_enabled BOOLEAN NOT NULL DEFAULT 0, think_enabled BOOLEAN NOT NULL DEFAULT 0,
think_level TEXT NOT NULL DEFAULT '', think_level TEXT NOT NULL DEFAULT '',
cloud_setting_migrated BOOLEAN NOT NULL DEFAULT 0, cloud_setting_migrated BOOLEAN NOT NULL DEFAULT 0,
@@ -264,6 +265,12 @@ func (db *database) migrate() error {
return fmt.Errorf("migrate v14 to v15: %w", err) return fmt.Errorf("migrate v14 to v15: %w", err)
} }
version = 15 version = 15
case 15:
// add last_home_view column to settings table
if err := db.migrateV15ToV16(); err != nil {
return fmt.Errorf("migrate v15 to v16: %w", err)
}
version = 16
default: default:
// If we have a version we don't recognize, just set it to current // If we have a version we don't recognize, just set it to current
// This might happen during development // This might happen during development
@@ -518,6 +525,21 @@ func (db *database) migrateV14ToV15() error {
return nil return nil
} }
// migrateV15ToV16 adds the last_home_view column to the settings table
func (db *database) migrateV15ToV16() error {
_, err := db.conn.Exec(`ALTER TABLE settings ADD COLUMN last_home_view TEXT NOT NULL DEFAULT 'launch'`)
if err != nil && !duplicateColumnError(err) {
return fmt.Errorf("add last_home_view column: %w", err)
}
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 16`)
if err != nil {
return fmt.Errorf("update schema version: %w", err)
}
return nil
}
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug // cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
func (db *database) cleanupOrphanedData() error { func (db *database) cleanupOrphanedData() error {
_, err := db.conn.Exec(` _, err := db.conn.Exec(`
@@ -1166,9 +1188,9 @@ func (db *database) getSettings() (Settings, error) {
var s Settings var s Settings
err := db.conn.QueryRow(` err := db.conn.QueryRow(`
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level, auto_update_enabled SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, turbo_enabled, websearch_enabled, selected_model, sidebar_open, last_home_view, think_enabled, think_level, auto_update_enabled
FROM settings FROM settings
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel, &s.AutoUpdateEnabled) `).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.LastHomeView, &s.ThinkEnabled, &s.ThinkLevel, &s.AutoUpdateEnabled)
if err != nil { if err != nil {
return Settings{}, fmt.Errorf("get settings: %w", err) return Settings{}, fmt.Errorf("get settings: %w", err)
} }
@@ -1177,10 +1199,26 @@ func (db *database) getSettings() (Settings, error) {
} }
func (db *database) setSettings(s Settings) error { func (db *database) setSettings(s Settings) error {
lastHomeView := strings.ToLower(strings.TrimSpace(s.LastHomeView))
validLaunchView := map[string]struct{}{
"launch": {},
"openclaw": {},
"claude": {},
"codex": {},
"opencode": {},
"droid": {},
"pi": {},
}
if lastHomeView != "chat" {
if _, ok := validLaunchView[lastHomeView]; !ok {
lastHomeView = "chat"
}
}
_, err := db.conn.Exec(` _, err := db.conn.Exec(`
UPDATE settings UPDATE settings
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?, auto_update_enabled = ? SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, last_home_view = ?, think_enabled = ?, think_level = ?, auto_update_enabled = ?
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel, s.AutoUpdateEnabled) `, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, lastHomeView, s.ThinkEnabled, s.ThinkLevel, s.AutoUpdateEnabled)
if err != nil { if err != nil {
return fmt.Errorf("set settings: %w", err) return fmt.Errorf("set settings: %w", err)
} }

View File

@@ -167,6 +167,9 @@ type Settings struct {
// SidebarOpen indicates if the chat sidebar is open // SidebarOpen indicates if the chat sidebar is open
SidebarOpen bool SidebarOpen bool
// LastHomeView stores the preferred home route target ("chat" or integration name)
LastHomeView string
// AutoUpdateEnabled indicates if automatic updates should be downloaded // AutoUpdateEnabled indicates if automatic updates should be downloaded
AutoUpdateEnabled bool AutoUpdateEnabled bool
} }
@@ -389,6 +392,10 @@ func (s *Store) Settings() (Settings, error) {
} }
} }
if settings.LastHomeView == "" {
settings.LastHomeView = "launch"
}
return settings, nil return settings, nil
} }

View File

@@ -414,6 +414,7 @@ export class Settings {
ThinkLevel: string; ThinkLevel: string;
SelectedModel: string; SelectedModel: string;
SidebarOpen: boolean; SidebarOpen: boolean;
LastHomeView: string;
AutoUpdateEnabled: boolean; AutoUpdateEnabled: boolean;
constructor(source: any = {}) { constructor(source: any = {}) {
@@ -432,6 +433,7 @@ export class Settings {
this.ThinkLevel = source["ThinkLevel"]; this.ThinkLevel = source["ThinkLevel"];
this.SelectedModel = source["SelectedModel"]; this.SelectedModel = source["SelectedModel"];
this.SidebarOpen = source["SidebarOpen"]; this.SidebarOpen = source["SidebarOpen"];
this.LastHomeView = source["LastHomeView"];
this.AutoUpdateEnabled = source["AutoUpdateEnabled"]; this.AutoUpdateEnabled = source["AutoUpdateEnabled"];
} }
} }
@@ -550,14 +552,12 @@ export class Error {
} }
} }
export class ModelUpstreamResponse { export class ModelUpstreamResponse {
digest?: string; stale: boolean;
pushTime: number;
error?: string; error?: string;
constructor(source: any = {}) { constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source); if ('string' === typeof source) source = JSON.parse(source);
this.digest = source["digest"]; this.stale = source["stale"];
this.pushTime = source["pushTime"];
this.error = source["error"]; this.error = source["error"];
} }
} }

View File

@@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<!-- Generated by Pixelmator Pro 3.6.17 -->
<svg width="1200" height="1200" viewBox="0 0 1200 1200" xmlns="http://www.w3.org/2000/svg">
<g id="g314">
<path id="path147" fill="#d97757" stroke="none" d="M 233.959793 800.214905 L 468.644287 668.536987 L 472.590637 657.100647 L 468.644287 650.738403 L 457.208069 650.738403 L 417.986633 648.322144 L 283.892639 644.69812 L 167.597321 639.865845 L 54.926208 633.825623 L 26.577238 627.785339 L 3.3e-05 592.751709 L 2.73832 575.27533 L 26.577238 559.248352 L 60.724873 562.228149 L 136.187973 567.382629 L 249.422867 575.194763 L 331.570496 580.026978 L 453.261841 592.671082 L 472.590637 592.671082 L 475.328857 584.859009 L 468.724915 580.026978 L 463.570557 575.194763 L 346.389313 495.785217 L 219.543671 411.865906 L 153.100723 363.543762 L 117.181267 339.060425 L 99.060455 316.107361 L 91.248367 266.01355 L 123.865784 230.093994 L 167.677887 233.073853 L 178.872513 236.053772 L 223.248367 270.201477 L 318.040283 343.570496 L 441.825592 434.738342 L 459.946411 449.798706 L 467.194672 444.64447 L 468.080597 441.020203 L 459.946411 427.409485 L 392.617493 305.718323 L 320.778564 181.932983 L 288.80542 130.630859 L 280.348999 99.865845 C 277.369171 87.221436 275.194641 76.590698 275.194641 63.624268 L 312.322174 13.20813 L 332.8591 6.604126 L 382.389313 13.20813 L 403.248352 31.328979 L 434.013519 101.71814 L 483.865753 212.537048 L 561.181274 363.221497 L 583.812134 407.919434 L 595.892639 449.315491 L 600.40271 461.959839 L 608.214783 461.959839 L 608.214783 454.711609 L 614.577271 369.825623 L 626.335632 265.61084 L 637.771851 131.516846 L 641.718201 93.745117 L 660.402832 48.483276 L 697.530334 24.000122 L 726.52356 37.852417 L 750.362549 72 L 747.060486 94.067139 L 732.886047 186.201416 L 705.100708 330.52356 L 686.979919 427.167847 L 697.530334 427.167847 L 709.61084 415.087341 L 758.496704 350.174561 L 840.644348 247.490051 L 876.885925 206.738342 L 919.167847 161.71814 L 946.308838 140.29541 L 997.61084 140.29541 L 1035.38269 196.429626 L 1018.469849 254.416199 L 965.637634 321.422852 L 921.825562 378.201538 L 859.006714 462.765259 L 819.785278 530.41626 L 823.409424 535.812073 L 832.75177 534.92627 L 974.657776 504.724915 L 1051.328979 490.872559 L 1142.818848 475.167786 L 1184.214844 494.496582 L 1188.724854 514.147644 L 1172.456421 554.335693 L 1074.604126 578.496765 L 959.838989 601.449829 L 788.939636 641.879272 L 786.845764 643.409485 L 789.261841 646.389343 L 866.255127 653.637634 L 899.194702 655.409424 L 979.812134 655.409424 L 1129.932861 666.604187 L 1169.154419 692.537109 L 1192.671265 724.268677 L 1188.724854 748.429688 L 1128.322144 779.194641 L 1046.818848 759.865845 L 856.590759 714.604126 L 791.355774 698.335754 L 782.335693 698.335754 L 782.335693 703.731567 L 836.69812 756.885986 L 936.322205 846.845581 L 1061.073975 962.81897 L 1067.436279 991.490112 L 1051.409424 1014.120911 L 1034.496704 1011.704712 L 924.885986 929.234924 L 882.604126 892.107544 L 786.845764 811.48999 L 780.483276 811.48999 L 780.483276 819.946289 L 802.550415 852.241699 L 919.087341 1027.409424 L 925.127625 1081.127686 L 916.671204 1098.604126 L 886.469849 1109.154419 L 853.288696 1103.114136 L 785.073914 1007.355835 L 714.684631 899.516785 L 657.906067 802.872498 L 650.979858 806.81897 L 617.476624 1167.704834 L 601.771851 1186.147705 L 565.530212 1200 L 535.328857 1177.046997 L 519.302124 1139.919556 L 535.328857 1066.550537 L 554.657776 970.792053 L 570.362488 894.68457 L 584.536926 800.134277 L 592.993347 768.724976 L 592.429626 766.630859 L 585.503479 767.516968 L 514.22821 865.369263 L 405.825531 1011.865906 L 320.053711 1103.677979 L 299.516815 1111.812256 L 263.919525 1093.369263 L 267.221497 1060.429688 L 287.114136 1031.114136 L 405.825531 880.107361 L 477.422913 786.52356 L 523.651062 732.483276 L 523.328918 724.671265 L 520.590698 724.671265 L 205.288605 929.395935 L 149.154434 936.644409 L 124.993355 914.01355 L 127.973183 876.885986 L 139.409409 864.80542 L 234.201385 799.570435 L 233.879227 799.8927 Z"/>
</g>
</svg>

After

Width:  |  Height:  |  Size: 4.0 KiB

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 320 320"><path fill="#fff" d="m297.06 130.97c7.26-21.79 4.76-45.66-6.85-65.48-17.46-30.4-52.56-46.04-86.84-38.68-15.25-17.18-37.16-26.95-60.13-26.81-35.04-.08-66.13 22.48-76.91 55.82-22.51 4.61-41.94 18.7-53.31 38.67-17.59 30.32-13.58 68.54 9.92 94.54-7.26 21.79-4.76 45.66 6.85 65.48 17.46 30.4 52.56 46.04 86.84 38.68 15.24 17.18 37.16 26.95 60.13 26.8 35.06.09 66.16-22.49 76.94-55.86 22.51-4.61 41.94-18.7 53.31-38.67 17.57-30.32 13.55-68.51-9.94-94.51zm-120.28 168.11c-14.03.02-27.62-4.89-38.39-13.88.49-.26 1.34-.73 1.89-1.07l63.72-36.8c3.26-1.85 5.26-5.32 5.24-9.07v-89.83l26.93 15.55c.29.14.48.42.52.74v74.39c-.04 33.08-26.83 59.9-59.91 59.97zm-128.84-55.03c-7.03-12.14-9.56-26.37-7.15-40.18.47.28 1.3.79 1.89 1.13l63.72 36.8c3.23 1.89 7.23 1.89 10.47 0l77.79-44.92v31.1c.02.32-.13.63-.38.83l-64.41 37.19c-28.69 16.52-65.33 6.7-81.92-21.95zm-16.77-139.09c7-12.16 18.05-21.46 31.21-26.29 0 .55-.03 1.52-.03 2.2v73.61c-.02 3.74 1.98 7.21 5.23 9.06l77.79 44.91-26.93 15.55c-.27.18-.61.21-.91.08l-64.42-37.22c-28.63-16.58-38.45-53.21-21.95-81.89zm221.26 51.49-77.79-44.92 26.93-15.54c.27-.18.61-.21.91-.08l64.42 37.19c28.68 16.57 38.51 53.26 21.94 81.94-7.01 12.14-18.05 21.44-31.2 26.28v-75.81c.03-3.74-1.96-7.2-5.2-9.06zm26.8-40.34c-.47-.29-1.3-.79-1.89-1.13l-63.72-36.8c-3.23-1.89-7.23-1.89-10.47 0l-77.79 44.92v-31.1c-.02-.32.13-.63.38-.83l64.41-37.16c28.69-16.55 65.37-6.7 81.91 22 6.99 12.12 9.52 26.31 7.15 40.1zm-168.51 55.43-26.94-15.55c-.29-.14-.48-.42-.52-.74v-74.39c.02-33.12 26.89-59.96 60.01-59.94 14.01 0 27.57 4.92 38.34 13.88-.49.26-1.33.73-1.89 1.07l-63.72 36.8c-3.26 1.85-5.26 5.31-5.24 9.06l-.04 89.79zm14.63-31.54 34.65-20.01 34.65 20v40.01l-34.65 20-34.65-20z"/></svg>

After

Width:  |  Height:  |  Size: 1.7 KiB

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 320 320"><path d="m297.06 130.97c7.26-21.79 4.76-45.66-6.85-65.48-17.46-30.4-52.56-46.04-86.84-38.68-15.25-17.18-37.16-26.95-60.13-26.81-35.04-.08-66.13 22.48-76.91 55.82-22.51 4.61-41.94 18.7-53.31 38.67-17.59 30.32-13.58 68.54 9.92 94.54-7.26 21.79-4.76 45.66 6.85 65.48 17.46 30.4 52.56 46.04 86.84 38.68 15.24 17.18 37.16 26.95 60.13 26.8 35.06.09 66.16-22.49 76.94-55.86 22.51-4.61 41.94-18.7 53.31-38.67 17.57-30.32 13.55-68.51-9.94-94.51zm-120.28 168.11c-14.03.02-27.62-4.89-38.39-13.88.49-.26 1.34-.73 1.89-1.07l63.72-36.8c3.26-1.85 5.26-5.32 5.24-9.07v-89.83l26.93 15.55c.29.14.48.42.52.74v74.39c-.04 33.08-26.83 59.9-59.91 59.97zm-128.84-55.03c-7.03-12.14-9.56-26.37-7.15-40.18.47.28 1.3.79 1.89 1.13l63.72 36.8c3.23 1.89 7.23 1.89 10.47 0l77.79-44.92v31.1c.02.32-.13.63-.38.83l-64.41 37.19c-28.69 16.52-65.33 6.7-81.92-21.95zm-16.77-139.09c7-12.16 18.05-21.46 31.21-26.29 0 .55-.03 1.52-.03 2.2v73.61c-.02 3.74 1.98 7.21 5.23 9.06l77.79 44.91-26.93 15.55c-.27.18-.61.21-.91.08l-64.42-37.22c-28.63-16.58-38.45-53.21-21.95-81.89zm221.26 51.49-77.79-44.92 26.93-15.54c.27-.18.61-.21.91-.08l64.42 37.19c28.68 16.57 38.51 53.26 21.94 81.94-7.01 12.14-18.05 21.44-31.2 26.28v-75.81c.03-3.74-1.96-7.2-5.2-9.06zm26.8-40.34c-.47-.29-1.3-.79-1.89-1.13l-63.72-36.8c-3.23-1.89-7.23-1.89-10.47 0l-77.79 44.92v-31.1c-.02-.32.13-.63.38-.83l64.41-37.16c28.69-16.55 65.37-6.7 81.91 22 6.99 12.12 9.52 26.31 7.15 40.1zm-168.51 55.43-26.94-15.55c-.29-.14-.48-.42-.52-.74v-74.39c.02-33.12 26.89-59.96 60.01-59.94 14.01 0 27.57 4.92 38.34 13.88-.49.26-1.33.73-1.89 1.07l-63.72 36.8c-3.26 1.85-5.26 5.31-5.24 9.06l-.04 89.79zm14.63-31.54 34.65-20.01 34.65 20v40.01l-34.65 20-34.65-20z"/></svg>

After

Width:  |  Height:  |  Size: 1.7 KiB

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 6.2 KiB

View File

@@ -0,0 +1,242 @@
<svg version="1.2" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 500 500" width="500" height="500">
<style>
.s0 { fill: #f6f4f4 }
.s1 { fill: #0b0303 }
.s2 { fill: #ef0011 }
.s3 { fill: #f3e2e2 }
.s4 { fill: #f00212 }
.s5 { fill: #ba000d }
.s6 { fill: #faf1f1 }
.s7 { fill: #0b0100 }
.s8 { fill: #fbedee }
.s9 { fill: #faeaea }
.s10 { fill: #ab797d }
.s11 { fill: #f8eaea }
.s12 { fill: #902021 }
.s13 { fill: #f9eeee }
.s14 { fill: #f6ecec }
.s15 { fill: #080201 }
.s16 { fill: #150100 }
.s17 { fill: #f2e7e7 }
.s18 { fill: #fbe7e8 }
.s19 { fill: #060101 }
.s20 { fill: #f5e7e7 }
.s21 { fill: #fa999e }
.s22 { fill: #c46064 }
.s23 { fill: #180300 }
.s24 { fill: #f6dcdd }
.s25 { fill: #f2e6e6 }
.s26 { fill: #110200 }
.s27 { fill: #eb0011 }
.s28 { fill: #e20010 }
.s29 { fill: #ea0011 }
.s30 { fill: #760007 }
.s31 { fill: #f00514 }
.s32 { fill: #fcebeb }
.s33 { fill: #ecd6d6 }
.s34 { fill: #f5e3e3 }
.s35 { fill: #f5e4e4 }
.s36 { fill: #faf6f6 }
.s37 { fill: #e50010 }
.s38 { fill: #d5000f }
.s39 { fill: #f2e2e3 }
.s40 { fill: #ef1018 }
.s41 { fill: #f4e8e9 }
.s42 { fill: #ef0513 }
.s43 { fill: #f5e5e5 }
.s44 { fill: #f00413 }
.s45 { fill: #f4e9ea }
.s46 { fill: #ed0011 }
.s47 { fill: #e80011 }
.s48 { fill: #e60613 }
.s49 { fill: #f0d6d6 }
.s50 { fill: #fca9ac }
.s51 { fill: #9c000c }
.s52 { fill: #73393b }
</style>
<g>
<path fill-rule="evenodd" class="s0" d="m166.5 52.5q3.5 0 7 0 2.75 2.99 1.5 7-21.27 45.61-20.5 96 39.99 2.76 72 26.5 7.87 6.86 13.5 15.5 42.88-56.39 103.5-92.5 47.35-25.46 101-25 14.52 0.38 23.5 11.5 3.19 7.74 2 16-1.81 7.18-4.5 14-1 0-1 1-5.04 6.05-9 13-1 0-1 1 0 0.5 0 1-12.42 12.15-28.5 19-6.02 36.27-41.5 45-0.83 2.75 0 5 19.02-12.85 41.5-9 10.85-8.09 23.5-13 15.01-6.37 31-2.5 14.09 7.43 14 23.5-2.83 23.25-15.5 43-6.42 9.92-14 19-10.04 8.8-19.5 18-72.02 48.88-156.5 27-19.63 9.6-41.5 10.5-4.59 1.27-9 3 2 1 4 2 20.09-1.11 35 12 25.46 6.95 37.5 30.5 1.26 5.69-1 11-3.38 3.79-7.5 6.5 5.74 10.07 1.5 20.5-7.55 7.47-17.5 3.5-11.01-5.34-22.5-9.5-18.26 10-38.5 13-15.5 0-31 0-26.62-4.54-51-17-4.17 1.33-8 3.5-7.23 5.87-15 11-8.62 2.58-13.5-4.5-1.82 2.32-4.5 3.5-6.06 2.24-12 3.5-7.5 0-15 0-27.42-2.56-50-18.5-18-17.25-23-41.5 0-11.5 0-23 4.12-22.7 25-33 6.95-16.67 22-26.5-20.39-20.8-14.5-49.5 7.01-26.98 28.5-44.5 7.56-5.27 15-10.5-13.09-30.88-7.5-64 3.16-15.57 14.5-26.5 6.85-2.48 8 4.5-6.59 39.53 11 75.5 7.99-0.49 16-2 2.42-34.57 14.5-67.5 8.51-22.23 27.5-36z"/>
</g>
<g>
<path fill-rule="evenodd" class="s1" d="m113.5 401.5q0.48-5.1-1-10-0.91 0.19-1 1-2.46 1.74-5 3.5 5.65 9.54-5 13-32.21 5.55-61-10-32.89-23.11-29.5-63.5 2.96-22.67 23.5-32 7.99-19.75 27-29.5-27.65-23.7-15.5-58.5 7.33-16.82 20.5-29.5 10.79-8.14 22-15.5-16.49-37.08-5.5-76 3.19-6.13 7.5-11.5 1.48-0.89 2 1-5.69 41.09 12.5 78.5 1 1 2 2 9.97-3.24 20.5-4 2 0 4 0 0-7.5 0-15 0.99-42.22 24.5-77 6.12-7.12 14-12-4.65 13.43-10 27-11.93 37.6-9.5 77 49.38 0.7 83.5 36 2.75 4.5 5.5 9 38.99-52.24 93-88.5 45.84-29.03 100-32.5 15.69-1.56 29 6.5 5.68 7.29 3.5 16.5-10.38 33.62-43.5 45-4.39 37.33-41 45-0.79 8.63-6 15.5 1.91 1.83 4.5 2.5 22.27-17.25 50.5-14.5 12.93-9.41 28-15 36.22-8.28 31.5 28.5-15.19 51.69-62.5 77.5-65.92 35.87-138 15.5-19.67 10.42-42 10.5-8.39 2.88-17 5 3.58 6.08 10 9 20.92-1.14 36 13 22.67 5.23 34.5 25.5 3.33 7.13-3.5 11.5-3.88 1.8-8 3 7.36 8.45 6.5 19.5-4.43 5.66-11.5 3.5-12.84-5.67-26-10.5-39.4 21.02-83 10.5-18.85-5.78-36.5-14.5-13.65 4.14-23.5 14.5-9.51 3.74-11-6.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s2" d="m153.5 173.5q24.62 1.46 46 13.5 12.11 8.1 17.5 21.5 0.74 2.45 0.5 5 0.09 0.81 1 1 1.48-4.9 1-10 5.04 10.48 1.5 22-9.81 27.86-35.5 42.5-26.17 14.97-56 19.5-2.77-0.4-2 1 2.86 1.27 6 1 25.64 1.53 48.5-10 0.34 10.08 2 20 1.08 5.76 5 10 1 1.5 0 3-31.11 20.84-68.5 17.5-23.7-5.7-32.5-28.5-4.39-9.18-3.5-19 15.41 6.23 32 4.5-20.68-6.39-39-18-34.81-27.22-12.5-65.5 11.84-14.83 29-23 4.21 7.66 11.5 12.5 3 1 6 0-26.04-34.62-29-78-0.13-8.46 2-16.5 1 6.5 2 13 3.43 39.53 24.5 73 2.03 2.28 4.5 4 0.5-1.25 1-2.5-1.27-6.54-5-12 0.5-0.75 1-1.5 9.72-3.43 20-4 0.55 10.34 8 17.5 1.94 0.74 4 0.5-17.8-64.6 16.5-122 0.98-1.79 1.5 0-28.21 56.64-13.5 118 1.08 1.43 2.5 0.5 2.21-4.98 2-10.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s3" d="m454.5 97.5q-18.37-2.97-37-1.5-16.14 2.08-32 5.5 32.38-14.09 67-7.5 1.98 1.22 2 3.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s4" d="m454.5 97.5q-1.33 11.18-8.5 20-21.81 26.28-55.5 32-1.11-0.2-2 0.5 2.31 2.82 5.5 4.5 1 2 0 4-9.56 11.3-19.5 20 19.71-8.72 31-27 2.68-0.43 5 1-14.24 30.97-48 36.5-9.93 1.71-20 1.5-6.8-0.48-13 1 5.81 6.92 14 11-10.78 16.03-27 26.5 27.16-7.4 38-33.5 4.34 1.35 9 1-9.08 23.84-33 33.5-18.45 6.41-38 7 22.59 8.92 45-1 12.05-5.52 24-11 9.01-1.79 17 2.5 5.28-4.38 11-8 12.8-6.07 27-5 0 0.5 0 1-19.34 2.69-34 15.5 0.5 0.25 1 0.5 17.79-8.09 36-15 2.71-0.79 5-2 2.5-1 5-2 5.53-4.04 11-8 11.7-4.18 24-6.5 7.78-1.36 15 1.5-2.97 18.45-13.5 34-34.92 49.37-94.5 62.5-59.27 12.45-108-23-15.53-12.52-21.5-31.5-2.47-14.26 4-27-3.15 24.41 14 42-4.92-10.28-7-22-1.97-17.63 7-33 47.28-69.5 125.5-100 15.86-3.42 32-5.5 18.63-1.47 37 1.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s5" d="m86.5 112.5q-1-6.5-2-13 0.7-5.34 3.5-10-1.8 11.32-1.5 23z"/>
</g>
<g>
<path fill-rule="evenodd" class="s6" d="m433.5 97.5q2.22-0.39 4 1-10 13.75-27 14-0.24-2.06 0.5-4 10.3-7.78 22.5-11z"/>
</g>
<g>
<path fill-rule="evenodd" class="s7" d="m407.5 101.5q2.55-0.24 5 0.5-52.87 18.31-84.5 64.5-6.94 7.95-17 11-9.38-2.38-5-11 40.38-48.62 101.5-65z"/>
</g>
<g>
<path fill-rule="evenodd" class="s8" d="m402.5 112.5q3 0 6 0-2.56 8.8-12 7-0.22-1.58 0.5-3 2.72-2.22 5.5-4z"/>
</g>
<g>
</g>
<g>
</g>
<g>
</g>
<g>
</g>
<g>
</g>
<g>
</g>
<g>
</g>
<g>
<path fill-rule="evenodd" class="s9" d="m390.5 149.5q7.77 0.52 15 2-11.29 18.28-31 27 9.94-8.7 19.5-20 1-2 0-4-3.19-1.68-5.5-4.5 0.89-0.7 2-0.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s10" d="m131.5 145.5q0 7.5 0 15-2 0-4 0 1.06-1.36 3-1-0.48-7.29 1-14z"/>
</g>
<g>
<path fill-rule="evenodd" class="s11" d="m219.5 204.5q-1 4.5-2 9 0.24-2.55-0.5-5-5.39-13.4-17.5-21.5-21.38-12.04-46-13.5 0-2 0-4 36.7-0.86 61.5 26 3.06 4.11 4.5 9z"/>
</g>
<g>
<path fill-rule="evenodd" class="s12" d="m329.5 191.5q6.2-1.48 13-1-3.5 1-7 2-2.9-0.97-6-1z"/>
</g>
<g>
<path fill-rule="evenodd" class="s13" d="m329.5 191.5q3.1 0.03 6 1 9.55 1.31 19 3-10.84 26.1-38 33.5 16.22-10.47 27-26.5-8.19-4.08-14-11z"/>
</g>
<g>
<path fill-rule="evenodd" class="s14" d="m479.5 199.5q-7.22-2.86-15-1.5-12.3 2.32-24 6.5 15.6-13.11 36-11.5 3.63 2.26 3 6.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s15" d="m193.5 216.5q-12.01 1.52-22 8-2.83 1.29-5.5 3-4.79-4.57-6.5-11-5.04 2.2-9.5-1-3.47-6.4 3.5-3 4.4 0.05 8-2.5 9.22-9.73 21-16 6.3-3.24 12 1-2.9 1.22-6 1.5 2.61 5.74 4.5 12 0.75 3.97 0.5 8z"/>
</g>
<g>
<path fill-rule="evenodd" class="s16" d="m458.5 200.5q3.04-0.24 6 0.5-18.02 7.05-33 19-1 1-2 0 11.53-14.3 29-19.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s17" d="m178.5 202.5q6.85-0.63 4.5 6-7.6 5.09-6-4 1.08-0.82 1.5-2z"/>
</g>
<g>
<path fill-rule="evenodd" class="s18" d="m469.5 201.5q-2.26 13.65-14.5 22-0.47-2.11 1-4 7.08-8.82 13.5-18z"/>
</g>
<g>
<path fill-rule="evenodd" class="s19" d="m74.5 208.5q8.22-0.2 16 2.5 11.8 4.26 23.5 8.5 5.65-0.63 8-6 2.41 11.83-9.5 13 0.55 3.61 2 7-0.5 1-1 2-4.67-0.94-9.5-1-9.96 0.44-19.5 2.5-5.05-3.55-6.5-9.5-0.75-7.48-0.5-15-6.47 0.15-3-4z"/>
</g>
<g>
<path fill-rule="evenodd" class="s20" d="m429.5 212.5q-2.5 1-5 2-4 0-8 0-14.2-1.07-27 5 15.27-12.44 35-9.5 2.72 1.14 5 2.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s21" d="m219.5 204.5q0.48 5.1-1 10-0.91-0.19-1-1 1-4.5 2-9z"/>
</g>
<g>
<path fill-rule="evenodd" class="s22" d="m416.5 215.5q0-0.5 0-1 4 0 8 0-2.29 1.21-5 2-1.06-1.36-3-1z"/>
</g>
<g>
<path fill-rule="evenodd" class="s23" d="m416.5 215.5q1.94-0.36 3 1-18.21 6.91-36 15-0.5-0.25-1-0.5 14.66-12.81 34-15.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s24" d="m193.5 216.5q4.39 1.3 9 3-0.79 1.04-2 1.5-14.77-0.13-29 3.5 9.99-6.48 22-8z"/>
</g>
<g>
<path fill-rule="evenodd" class="s25" d="m98.5 219.5q6.09-0.98 6 5-3.04 0.24-6-0.5-1.84-2.24 0-4.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s26" d="m176.5 229.5q8.85-1.14 16 4-4.98 1.75-10 0-13.56 14.3-33 19.5-28.06 8.2-55 1 3.32-6.4 10-5.5-0.71 1.47-2 2.5 36.58 4.24 69-14 4.68-2.13 1-5 2.35-0.91 4-2.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s27" d="m231.5 238.5q1.31-0.2 2 1-3.13 28.62 15 51-16.25 6.75-27-7.5-1-1-2 0 14.73 29.34 46 18.5 1.79 0.52 0 1.5-37.63 16.82-50.5-22.5-5.1-26.48 16.5-42z"/>
</g>
<g>
<path fill-rule="evenodd" class="s28" d="m243.5 259.5q5.88 3.62 10.5 9 12.96 18.46 32.5 29.5-31.51-7.75-43-38.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s29" d="m203.5 266.5q1.31-0.2 2 1-2.48 22.08 12 39-6.99 1.35-14 0.5 4.59 4.08 10 7-8.71 0.28-14.5-6.5-16.98-22.76 4.5-41z"/>
</g>
<g>
<path fill-rule="evenodd" class="s27" d="m58.5 284.5q9.6-2.17 14.5 6 5.15 14.18-1 28-11.05-13.14-27.5-17.5 5.15-9.9 14-16.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s30" d="m129.5 288.5q2 1 4 2-3.14 0.27-6-1-0.77-1.4 2-1z"/>
</g>
<g>
<path fill-rule="evenodd" class="s31" d="m56.5 313.5q3.43 5.43 8 10-4.88 0.44-8 4-1.11-0.2-2 0.5 28.91 1.65 38 28.5 0.45 3.16-1 6-11.02-7.01-23-12.5-4.75-3.75-9.5-7.5 1.47 7.42 7 13 8.34 27.18 32 43 0.99 2.41-1.5 3.5-40.25 5.58-66.5-25.5-15.67-22.01-8-48 10.46-23.87 34.5-15z"/>
</g>
<g>
<path fill-rule="evenodd" class="s32" d="m45.5 317.5q4.03-0.25 8 0.5 2.46 4.16-2 6-6.04 2.01-9-3.5 1.26-1.85 3-3z"/>
</g>
<g>
<path fill-rule="evenodd" class="s33" d="m56.5 313.5q4.91 3.14 9.5 7 0.88 2.25-1.5 3-4.57-4.57-8-10z"/>
</g>
<g>
<path fill-rule="evenodd" class="s34" d="m198.5 319.5q-11.1 11.56-27 15.5-15.75 4.88-32 2.5 28.81-3.69 54-18.5 2.65-0.96 5 0.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s4" d="m198.5 319.5q1.44 0.68 2.5 2 2.41 8.23 6 16 1.2 2.64-0.5 5-30.65 21.41-68 18.5-25.16-6.17-32.5-30.5 6.96 4.99 15.5 6.5 8.99 0.75 18 0.5 16.25 2.38 32-2.5 15.9-3.94 27-15.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s35" d="m92.5 356.5q-9.09-26.85-38-28.5 0.89-0.7 2-0.5 25.47-4.89 35.5 19 0.75 4.98 0.5 10z"/>
</g>
<g>
<path fill-rule="evenodd" class="s36" d="m72.5 335.5q3.62-0.38 5 3-4.22 1.83-5-3z"/>
</g>
<g>
<path fill-rule="evenodd" class="s37" d="m223.5 336.5q5.59-0.48 11 1-4.04 4.16-8.5 8-5.99-3.8-2.5-9z"/>
</g>
<g>
<path fill-rule="evenodd" class="s38" d="m90.5 334.5q0.59-1.54 2-0.5 3.94 5.45 9 10 7 6 14 12-6.91-1.7-13-6-6.21-7.72-12-15.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s39" d="m261.5 346.5q-3.54-2.44-8-3.5-6.98-0.75-14-0.5 0.63-1.08 2-1.5 13.82-2.52 26 4-2.63 1.98-6 1.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s40" d="m239.5 342.5q7.02-0.25 14 0.5 4.46 1.06 8 3.5-5.2 2.35-10 5.5-3.88 4.65-9 7.5-9.89-3.09-9.5-13 2.36-3.63 6.5-4z"/>
</g>
<g>
<path fill-rule="evenodd" class="s41" d="m214.5 349.5q-21.43 15.48-48 16 22.82-5.9 43-18.5 3.64-1.12 5 2.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s42" d="m214.5 349.5q5.96 7.2 13.5 13 1 1 0 2-28.58 23.34-65.5 20.5-18.15-4.24-27.5-19.5 1.13 0.94 2.5 1.5 14.7 1.42 29-1.5 26.57-0.52 48-16z"/>
</g>
<g>
<path fill-rule="evenodd" class="s43" d="m302.5 373.5q-14.74-16.73-37-19-4.55 0.25-9 1 25.3-10.24 43.5 11 2.85 2.91 2.5 7z"/>
</g>
<g>
<path fill-rule="evenodd" class="s44" d="m302.5 373.5q0.21 2.44-2 3.5-28.69 7.6-50.5-12.5-0.06-6.71 6.5-9 4.45-0.75 9-1 22.26 2.27 37 19z"/>
</g>
<g>
<path fill-rule="evenodd" class="s45" d="m100.5 356.5q5.42 2.71 11 5.5-13.04 7.54-18.5 21.5-7.57-7.14-10.5-17 5.58 1.54 10 5.5 4.2 0.84 5.5-3.5 1.41-5.99 2.5-12z"/>
</g>
<g>
<path fill-rule="evenodd" class="s8" d="m83.5 394.5q-18.9-10.15-29.5-29-1.54-3.52-2-7 5.79 2.39 10 7 7.82 16.63 21.5 29z"/>
</g>
<g>
<path fill-rule="evenodd" class="s46" d="m232.5 365.5q17.6 6.19 10.5 23-10.6 10.42-25.5 11.5-25.94 3.21-49-9 36.75-1.65 64-25.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s47" d="m113.5 367.5q7.7-0.01 9.5 7-9.69 7.19-18.5 15.5-7.23 5.76-5.5-3.5 3.12-12.84 14.5-19z"/>
</g>
<g>
<path fill-rule="evenodd" class="s29" d="m126.5 380.5q7.88-0.4 12 6.5-8.5 7.25-17 14.5-5.62-12.55 5-21z"/>
</g>
<g>
<path fill-rule="evenodd" class="s48" d="m283.5 385.5q3.22 2.95 7 5.5 2.8 4.03 6 7.5 0.42 2.77-2 4-15.5-9.75-31-19.5-1.79-0.98 0-1.5 9.96 2.49 20 4z"/>
</g>
<g>
<path fill-rule="evenodd" class="s49" d="m283.5 385.5q8.71-1.27 11.5 7 1.22 2.9 1.5 6-3.2-3.47-6-7.5-3.78-2.55-7-5.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s50" d="m83.5 394.5q1.88-0.06 3 1.5-2.25 0.88-3-1.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s51" d="m258.5 392.5q3.51 0.41 0 2.5-2.33 1.93-5 2 2.61-2.28 5-4.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s52" d="m111.5 392.5q0.09-0.81 1-1 1.48 4.9 1 10-1-4.5-2-9z"/>
</g>
</svg>

After

Width:  |  Height:  |  Size: 13 KiB

View File

@@ -0,0 +1,7 @@
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" xmlns:xlink="http://www.w3.org/1999/xlink" width="512" height="512"><svg width="512" height="512" viewBox="0 0 512 512" fill="none" xmlns="http://www.w3.org/2000/svg">
<rect width="512" height="512" fill="#131010"></rect>
<path d="M320 224V352H192V224H320Z" fill="#5A5858"></path>
<path fill-rule="evenodd" clip-rule="evenodd" d="M384 416H128V96H384V416ZM320 160H192V352H320V160Z" fill="white"></path>
</svg><style>@media (prefers-color-scheme: light) { :root { filter: none; } }
@media (prefers-color-scheme: dark) { :root { filter: none; } }
</style></svg>

After

Width:  |  Height:  |  Size: 612 B

View File

@@ -0,0 +1,9 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 800 800">
<rect width="800" height="800" rx="160" fill="#fff"/>
<path fill="#000" fill-rule="evenodd" d="
M165.29 165.29 H517.36 V400 H400 V517.36 H282.65 V634.72 H165.29 Z
M282.65 282.65 V400 H400 V282.65 Z
"/>
<path fill="#000" d="M517.36 400 H634.72 V634.72 H517.36 Z"/>
</svg>

After

Width:  |  Height:  |  Size: 389 B

View File

@@ -0,0 +1,9 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 800 800">
<rect width="800" height="800" rx="160" fill="#000"/>
<path fill="#fff" fill-rule="evenodd" d="
M165.29 165.29 H517.36 V400 H400 V517.36 H282.65 V634.72 H165.29 Z
M282.65 282.65 V400 H400 V282.65 Z
"/>
<path fill="#fff" d="M517.36 400 H634.72 V634.72 H517.36 Z"/>
</svg>

After

Width:  |  Height:  |  Size: 389 B

View File

@@ -161,7 +161,7 @@ export async function getModels(query?: string): Promise<Model[]> {
// Add query if it's in the registry and not already in the list // Add query if it's in the registry and not already in the list
if (!exactMatch) { if (!exactMatch) {
const result = await getModelUpstreamInfo(new Model({ model: query })); const result = await getModelUpstreamInfo(new Model({ model: query }));
const existsUpstream = !!result.digest && !result.error; const existsUpstream = result.exists;
if (existsUpstream) { if (existsUpstream) {
filteredModels.push(new Model({ model: query })); filteredModels.push(new Model({ model: query }));
} }
@@ -339,7 +339,7 @@ export async function deleteChat(chatId: string): Promise<void> {
// Get upstream information for model staleness checking // Get upstream information for model staleness checking
export async function getModelUpstreamInfo( export async function getModelUpstreamInfo(
model: Model, model: Model,
): Promise<{ digest?: string; pushTime: number; error?: string }> { ): Promise<{ stale: boolean; exists: boolean; error?: string }> {
try { try {
const response = await fetch(`${API_BASE}/api/v1/model/upstream`, { const response = await fetch(`${API_BASE}/api/v1/model/upstream`, {
method: "POST", method: "POST",
@@ -353,22 +353,22 @@ export async function getModelUpstreamInfo(
if (!response.ok) { if (!response.ok) {
console.warn( console.warn(
`Failed to check upstream digest for ${model.model}: ${response.status}`, `Failed to check upstream for ${model.model}: ${response.status}`,
); );
return { pushTime: 0 }; return { stale: false, exists: false };
} }
const data = await response.json(); const data = await response.json();
if (data.error) { if (data.error) {
console.warn(`Upstream digest check: ${data.error}`); console.warn(`Upstream check: ${data.error}`);
return { error: data.error, pushTime: 0 }; return { stale: false, exists: false, error: data.error };
} }
return { digest: data.digest, pushTime: data.pushTime || 0 }; return { stale: !!data.stale, exists: true };
} catch (error) { } catch (error) {
console.warn(`Error checking model staleness:`, error); console.warn(`Error checking model staleness:`, error);
return { pushTime: 0 }; return { stale: false, exists: false };
} }
} }

View File

@@ -6,7 +6,7 @@ import { getChat } from "@/api";
import { Link } from "@/components/ui/link"; import { Link } from "@/components/ui/link";
import { useState, useRef, useEffect, useCallback, useMemo } from "react"; import { useState, useRef, useEffect, useCallback, useMemo } from "react";
import { ChatsResponse } from "@/gotypes"; import { ChatsResponse } from "@/gotypes";
import { CogIcon } from "@heroicons/react/24/outline"; import { CogIcon, RocketLaunchIcon } from "@heroicons/react/24/outline";
// there's a hidden debug feature to copy a chat's data to the clipboard by // there's a hidden debug feature to copy a chat's data to the clipboard by
// holding shift and clicking this many times within this many seconds // holding shift and clicking this many times within this many seconds
@@ -267,9 +267,8 @@ export function ChatSidebar({ currentChatId }: ChatSidebarProps) {
<Link <Link
href="/c/new" href="/c/new"
mask={{ to: "/" }} mask={{ to: "/" }}
className={`flex w-full items-center gap-3 rounded-lg px-2 py-2 text-left text-sm text-neutral-700 hover:bg-neutral-100 dark:hover:bg-neutral-800 dark:text-neutral-100 ${ className={`flex w-full items-center gap-3 rounded-lg px-2 py-2 text-left text-sm text-neutral-700 hover:bg-neutral-100 dark:hover:bg-neutral-800 dark:text-neutral-100 ${currentChatId === "new" ? "bg-neutral-100 dark:bg-neutral-800" : ""
currentChatId === "new" ? "bg-neutral-100 dark:bg-neutral-800" : "" }`}
}`}
draggable={false} draggable={false}
> >
<svg <svg
@@ -283,6 +282,18 @@ export function ChatSidebar({ currentChatId }: ChatSidebarProps) {
</svg> </svg>
<span className="truncate">New Chat</span> <span className="truncate">New Chat</span>
</Link> </Link>
<Link
to="/c/$chatId"
params={{ chatId: "launch" }}
className={`flex w-full items-center gap-3 rounded-lg px-2 py-2 text-left text-sm text-neutral-700 hover:bg-neutral-100 dark:hover:bg-neutral-800 dark:text-neutral-100 cursor-pointer ${currentChatId === "launch"
? "bg-neutral-100 dark:bg-neutral-800"
: ""
}`}
draggable={false}
>
<RocketLaunchIcon className="h-5 w-5 stroke-current" />
<span className="truncate">Launch</span>
</Link>
{isWindows && ( {isWindows && (
<Link <Link
href="/settings" href="/settings"
@@ -304,19 +315,18 @@ export function ChatSidebar({ currentChatId }: ChatSidebarProps) {
{group.chats.map((chat) => ( {group.chats.map((chat) => (
<div <div
key={chat.id} key={chat.id}
className={`allow-context-menu flex items-center relative text-sm text-neutral-800 dark:text-neutral-400 rounded-lg hover:bg-neutral-100 dark:hover:bg-neutral-800 ${ className={`allow-context-menu flex items-center relative text-sm text-neutral-800 dark:text-neutral-400 rounded-lg hover:bg-neutral-100 dark:hover:bg-neutral-800 ${chat.id === currentChatId
chat.id === currentChatId ? "bg-neutral-100 text-black dark:bg-neutral-800"
? "bg-neutral-100 text-black dark:bg-neutral-800" : ""
: "" }`}
}`}
onMouseEnter={() => handleMouseEnter(chat.id)} onMouseEnter={() => handleMouseEnter(chat.id)}
onContextMenu={(e) => onContextMenu={(e) =>
handleContextMenu( handleContextMenu(
e, e,
chat.id, chat.id,
chat.title || chat.title ||
chat.userExcerpt || chat.userExcerpt ||
chat.createdAt.toLocaleString(), chat.createdAt.toLocaleString(),
) )
} }
> >

View File

@@ -10,6 +10,7 @@ interface CopyButtonProps {
showLabels?: boolean; showLabels?: boolean;
className?: string; className?: string;
title?: string; title?: string;
onCopy?: () => void;
} }
const CopyButton: React.FC<CopyButtonProps> = ({ const CopyButton: React.FC<CopyButtonProps> = ({
@@ -20,6 +21,7 @@ const CopyButton: React.FC<CopyButtonProps> = ({
showLabels = false, showLabels = false,
className = "", className = "",
title = "", title = "",
onCopy,
}) => { }) => {
const [isCopied, setIsCopied] = useState(false); const [isCopied, setIsCopied] = useState(false);
@@ -48,12 +50,14 @@ const CopyButton: React.FC<CopyButtonProps> = ({
} }
setIsCopied(true); setIsCopied(true);
onCopy?.();
setTimeout(() => setIsCopied(false), 2000); setTimeout(() => setIsCopied(false), 2000);
} catch (error) { } catch (error) {
console.error("Clipboard API failed, falling back to plain text", error); console.error("Clipboard API failed, falling back to plain text", error);
try { try {
await navigator.clipboard.writeText(content); await navigator.clipboard.writeText(content);
setIsCopied(true); setIsCopied(true);
onCopy?.();
setTimeout(() => setIsCopied(false), 2000); setTimeout(() => setIsCopied(false), 2000);
} catch (fallbackError) { } catch (fallbackError) {
console.error("Fallback copy also failed:", fallbackError); console.error("Fallback copy also failed:", fallbackError);

View File

@@ -0,0 +1,133 @@
import { useSettings } from "@/hooks/useSettings";
import CopyButton from "@/components/CopyButton";
interface LaunchCommand {
id: string;
name: string;
command: string;
description: string;
icon: string;
darkIcon?: string;
iconClassName?: string;
borderless?: boolean;
}
const LAUNCH_COMMANDS: LaunchCommand[] = [
{
id: "openclaw",
name: "OpenClaw",
command: "ollama launch openclaw",
description: "Personal AI with 100+ skills",
icon: "/launch-icons/openclaw.svg",
},
{
id: "claude",
name: "Claude",
command: "ollama launch claude",
description: "Anthropic's coding tool with subagents",
icon: "/launch-icons/claude.svg",
iconClassName: "h-7 w-7",
},
{
id: "codex",
name: "Codex",
command: "ollama launch codex",
description: "OpenAI's open-source coding agent",
icon: "/launch-icons/codex.svg",
darkIcon: "/launch-icons/codex-dark.svg",
iconClassName: "h-7 w-7",
},
{
id: "opencode",
name: "OpenCode",
command: "ollama launch opencode",
description: "Anomaly's open-source coding agent",
icon: "/launch-icons/opencode.svg",
iconClassName: "h-7 w-7 rounded",
},
{
id: "droid",
name: "Droid",
command: "ollama launch droid",
description: "Factory's coding agent across terminal and IDEs",
icon: "/launch-icons/droid.svg",
},
{
id: "pi",
name: "Pi",
command: "ollama launch pi",
description: "Minimal AI agent toolkit with plugin support",
icon: "/launch-icons/pi.svg",
darkIcon: "/launch-icons/pi-dark.svg",
iconClassName: "h-7 w-7",
},
];
export default function LaunchCommands() {
const isWindows = navigator.platform.toLowerCase().includes("win");
const { setSettings } = useSettings();
const renderCommandCard = (item: LaunchCommand) => (
<div key={item.command} className="w-full text-left">
<div className="flex items-start gap-4 sm:gap-5">
<div
aria-hidden="true"
className={`flex h-10 w-10 shrink-0 items-center justify-center rounded-lg overflow-hidden ${item.borderless ? "" : "border border-neutral-200 bg-white dark:border-neutral-700 dark:bg-neutral-900"}`}
>
{item.darkIcon ? (
<picture>
<source srcSet={item.darkIcon} media="(prefers-color-scheme: dark)" />
<img src={item.icon} alt="" className={`${item.iconClassName ?? "h-8 w-8"} rounded-sm`} />
</picture>
) : (
<img src={item.icon} alt="" className={item.borderless ? "h-full w-full rounded-xl" : `${item.iconClassName ?? "h-8 w-8"} rounded-sm`} />
)}
</div>
<div className="min-w-0 flex-1">
<span className="text-sm font-medium text-neutral-900 dark:text-neutral-100">
{item.name}
</span>
<p className="mt-0.5 text-xs text-neutral-500 dark:text-neutral-400">
{item.description}
</p>
<div className="mt-2 flex items-center gap-2 rounded-xl border-neutral-200 dark:border-neutral-700 bg-neutral-50 dark:bg-neutral-800 px-3 py-2">
<code className="min-w-0 flex-1 truncate text-xs text-neutral-600 dark:text-neutral-300">
{item.command}
</code>
<CopyButton
content={item.command}
size="md"
title="Copy command to clipboard"
className="text-neutral-500 dark:text-neutral-400 hover:text-neutral-700 dark:hover:text-neutral-200 hover:bg-neutral-200/60 dark:hover:bg-neutral-700/70"
onCopy={() => {
setSettings({ LastHomeView: item.id }).catch(() => { });
}}
/>
</div>
</div>
</div>
</div>
);
return (
<main className="flex h-screen w-full flex-col relative">
<section
className={`flex-1 overflow-y-auto overscroll-contain relative min-h-0 ${isWindows ? "xl:pt-4" : "xl:pt-8"}`}
>
<div className="max-w-[730px] mx-auto w-full px-4 pt-4 pb-20 sm:px-6 sm:pt-6 sm:pb-24 lg:px-8 lg:pt-8 lg:pb-28">
<h1 className="text-xl font-semibold text-neutral-900 dark:text-neutral-100">
Launch
</h1>
<p className="mt-1 text-sm text-neutral-500 dark:text-neutral-400">
Copy a command and run it in your terminal.
</p>
<div className="mt-6 grid gap-7">
{LAUNCH_COMMANDS.map(renderCommandCard)}
</div>
</div>
</section>
</main>
);
}

View File

@@ -61,24 +61,7 @@ export const ModelPicker = forwardRef<
try { try {
const upstreamInfo = await getModelUpstreamInfo(model); const upstreamInfo = await getModelUpstreamInfo(model);
// Compare local digest with upstream digest if (upstreamInfo.stale) {
let isStale =
model.digest &&
upstreamInfo.digest &&
model.digest !== upstreamInfo.digest;
// If the model has a modified time and upstream has a push time,
// check if the model was modified after the push time - if so, it's not stale
if (isStale && model.modified_at && upstreamInfo.pushTime > 0) {
const modifiedAtTime =
new Date(model.modified_at as string | number | Date).getTime() /
1000;
if (modifiedAtTime > upstreamInfo.pushTime) {
isStale = false;
}
}
if (isStale) {
const currentStaleModels = const currentStaleModels =
queryClient.getQueryData<Map<string, boolean>>(["staleModels"]) || queryClient.getQueryData<Map<string, boolean>>(["staleModels"]) ||
new Map(); new Map();

View File

@@ -214,6 +214,7 @@ export default function Settings() {
Agent: false, Agent: false,
Tools: false, Tools: false,
ContextLength: 0, ContextLength: 0,
AutoUpdateEnabled: true,
}); });
updateSettingsMutation.mutate(defaultSettings); updateSettingsMutation.mutate(defaultSettings);
} }
@@ -272,6 +273,10 @@ export default function Settings() {
} }
const isWindows = navigator.platform.toLowerCase().includes("win"); const isWindows = navigator.platform.toLowerCase().includes("win");
const handleCloseSettings = () => {
const chatId = settings.LastHomeView === "chat" ? "new" : "launch";
navigate({ to: "/c/$chatId", params: { chatId } });
};
return ( return (
<main className="flex h-screen w-full flex-col select-none dark:bg-neutral-900"> <main className="flex h-screen w-full flex-col select-none dark:bg-neutral-900">
@@ -285,7 +290,7 @@ export default function Settings() {
> >
{isWindows && ( {isWindows && (
<button <button
onClick={() => navigate({ to: "/" })} onClick={handleCloseSettings}
className="hover:bg-neutral-100 mr-3 dark:hover:bg-neutral-800 rounded-full p-1.5" className="hover:bg-neutral-100 mr-3 dark:hover:bg-neutral-800 rounded-full p-1.5"
> >
<ArrowLeftIcon className="w-5 h-5 dark:text-white" /> <ArrowLeftIcon className="w-5 h-5 dark:text-white" />
@@ -295,7 +300,7 @@ export default function Settings() {
</h1> </h1>
{!isWindows && ( {!isWindows && (
<button <button
onClick={() => navigate({ to: "/" })} onClick={handleCloseSettings}
className="p-1 hover:bg-neutral-100 mr-3 dark:hover:bg-neutral-800 rounded-full" className="p-1 hover:bg-neutral-100 mr-3 dark:hover:bg-neutral-800 rounded-full"
> >
<XMarkIcon className="w-6 h-6 dark:text-white" /> <XMarkIcon className="w-6 h-6 dark:text-white" />

View File

@@ -9,6 +9,7 @@ interface SettingsState {
webSearchEnabled: boolean; webSearchEnabled: boolean;
selectedModel: string; selectedModel: string;
sidebarOpen: boolean; sidebarOpen: boolean;
lastHomeView: string;
thinkEnabled: boolean; thinkEnabled: boolean;
thinkLevel: string; thinkLevel: string;
} }
@@ -21,6 +22,7 @@ type SettingsUpdate = Partial<{
ThinkLevel: string; ThinkLevel: string;
SelectedModel: string; SelectedModel: string;
SidebarOpen: boolean; SidebarOpen: boolean;
LastHomeView: string;
}>; }>;
export function useSettings() { export function useSettings() {
@@ -50,6 +52,7 @@ export function useSettings() {
thinkLevel: settingsData?.settings?.ThinkLevel ?? "none", thinkLevel: settingsData?.settings?.ThinkLevel ?? "none",
selectedModel: settingsData?.settings?.SelectedModel ?? "", selectedModel: settingsData?.settings?.SelectedModel ?? "",
sidebarOpen: settingsData?.settings?.SidebarOpen ?? false, sidebarOpen: settingsData?.settings?.SidebarOpen ?? false,
lastHomeView: settingsData?.settings?.LastHomeView ?? "launch",
}), }),
[settingsData?.settings], [settingsData?.settings],
); );

View File

@@ -4,12 +4,15 @@ import Chat from "@/components/Chat";
import { getChat } from "@/api"; import { getChat } from "@/api";
import { SidebarLayout } from "@/components/layout/layout"; import { SidebarLayout } from "@/components/layout/layout";
import { ChatSidebar } from "@/components/ChatSidebar"; import { ChatSidebar } from "@/components/ChatSidebar";
import LaunchCommands from "@/components/LaunchCommands";
import { useEffect } from "react";
import { useSettings } from "@/hooks/useSettings";
export const Route = createFileRoute("/c/$chatId")({ export const Route = createFileRoute("/c/$chatId")({
component: RouteComponent, component: RouteComponent,
loader: async ({ context, params }) => { loader: async ({ context, params }) => {
// Skip loading for "new" chat // Skip loading for special non-chat views
if (params.chatId !== "new") { if (params.chatId !== "new" && params.chatId !== "launch") {
context.queryClient.ensureQueryData({ context.queryClient.ensureQueryData({
queryKey: ["chat", params.chatId], queryKey: ["chat", params.chatId],
queryFn: () => getChat(params.chatId), queryFn: () => getChat(params.chatId),
@@ -21,13 +24,42 @@ export const Route = createFileRoute("/c/$chatId")({
function RouteComponent() { function RouteComponent() {
const { chatId } = Route.useParams(); const { chatId } = Route.useParams();
const { settingsData, setSettings } = useSettings();
// Always call hooks at the top level - use a flag to skip data when chatId is "new" // Always call hooks at the top level - use a flag to skip data when chatId is a special view
const { const {
data: chatData, data: chatData,
isLoading: chatLoading, isLoading: chatLoading,
error: chatError, error: chatError,
} = useChat(chatId === "new" ? "" : chatId); } = useChat(chatId === "new" || chatId === "launch" ? "" : chatId);
useEffect(() => {
if (!settingsData) {
return;
}
if (chatId === "launch") {
if (
settingsData.LastHomeView !== "chat" &&
settingsData.LastHomeView !== "launch"
) {
return;
}
setSettings({ LastHomeView: "openclaw" }).catch(() => {
// Best effort persistence for home view preference.
});
return;
}
if (settingsData.LastHomeView === "chat") {
return;
}
setSettings({ LastHomeView: "chat" }).catch(() => {
// Best effort persistence for home view preference.
});
}, [chatId, settingsData, setSettings]);
// Handle "new" chat case - just use Chat component which handles everything // Handle "new" chat case - just use Chat component which handles everything
if (chatId === "new") { if (chatId === "new") {
@@ -38,6 +70,14 @@ function RouteComponent() {
); );
} }
if (chatId === "launch") {
return (
<SidebarLayout sidebar={<ChatSidebar currentChatId={chatId} />}>
<LaunchCommands />
</SidebarLayout>
);
}
// Handle existing chat case // Handle existing chat case
if (chatLoading) { if (chatLoading) {
return ( return (

View File

@@ -1,10 +1,17 @@
import { createFileRoute, redirect } from "@tanstack/react-router"; import { createFileRoute, redirect } from "@tanstack/react-router";
import { getSettings } from "@/api";
export const Route = createFileRoute("/")({ export const Route = createFileRoute("/")({
beforeLoad: () => { beforeLoad: async ({ context }) => {
const settingsData = await context.queryClient.ensureQueryData({
queryKey: ["settings"],
queryFn: getSettings,
});
const chatId = settingsData?.settings?.LastHomeView === "chat" ? "new" : "launch";
throw redirect({ throw redirect({
to: "/c/$chatId", to: "/c/$chatId",
params: { chatId: "new" }, params: { chatId },
mask: { mask: {
to: "/", to: "/",
}, },

View File

@@ -0,0 +1,57 @@
import { describe, expect, it, vi, beforeEach } from "vitest";
import { copyTextToClipboard } from "./clipboard";
describe("copyTextToClipboard", () => {
beforeEach(() => {
vi.restoreAllMocks();
});
it("copies via Clipboard API when available", async () => {
const writeText = vi.fn().mockResolvedValue(undefined);
vi.stubGlobal("navigator", {
clipboard: {
writeText,
},
});
const copied = await copyTextToClipboard("ollama launch claude");
expect(copied).toBe(true);
expect(writeText).toHaveBeenCalledWith("ollama launch claude");
});
it("falls back to execCommand when Clipboard API fails", async () => {
const writeText = vi.fn().mockRejectedValue(new Error("not allowed"));
vi.stubGlobal("navigator", {
clipboard: {
writeText,
},
});
const textarea = {
value: "",
setAttribute: vi.fn(),
style: {} as Record<string, string>,
focus: vi.fn(),
select: vi.fn(),
};
const appendChild = vi.fn();
const removeChild = vi.fn();
const execCommand = vi.fn().mockReturnValue(true);
vi.stubGlobal("document", {
createElement: vi.fn().mockReturnValue(textarea),
body: {
appendChild,
removeChild,
},
execCommand,
});
const copied = await copyTextToClipboard("ollama launch openclaw");
expect(copied).toBe(true);
expect(execCommand).toHaveBeenCalledWith("copy");
expect(appendChild).toHaveBeenCalled();
expect(removeChild).toHaveBeenCalled();
});
});

View File

@@ -0,0 +1,30 @@
export async function copyTextToClipboard(text: string): Promise<boolean> {
try {
await navigator.clipboard.writeText(text);
return true;
} catch (clipboardError) {
console.error(
"Clipboard API failed, falling back to execCommand",
clipboardError,
);
}
try {
const textarea = document.createElement("textarea");
textarea.value = text;
textarea.setAttribute("readonly", "true");
textarea.style.position = "fixed";
textarea.style.left = "-9999px";
textarea.style.opacity = "0";
document.body.appendChild(textarea);
textarea.focus();
textarea.select();
const copied = document.execCommand("copy");
document.body.removeChild(textarea);
return copied;
} catch (fallbackError) {
console.error("Fallback copy failed", fallbackError);
return false;
}
}

View File

@@ -133,9 +133,8 @@ type Error struct {
} }
type ModelUpstreamResponse struct { type ModelUpstreamResponse struct {
Digest string `json:"digest,omitempty"` Stale bool `json:"stale"`
PushTime int64 `json:"pushTime"` Error string `json:"error,omitempty"`
Error string `json:"error,omitempty"`
} }
// Serializable data for the browser state // Serializable data for the browser state

View File

@@ -32,6 +32,7 @@ import (
"github.com/ollama/ollama/app/version" "github.com/ollama/ollama/app/version"
ollamaAuth "github.com/ollama/ollama/auth" ollamaAuth "github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
_ "github.com/tkrajina/typescriptify-golang-structs/typescriptify" _ "github.com/tkrajina/typescriptify-golang-structs/typescriptify"
) )
@@ -155,7 +156,7 @@ func (s *Server) ollamaProxy() http.Handler {
return return
} }
target := envconfig.Host() target := envconfig.ConnectableHost()
s.log().Info("configuring ollama proxy", "target", target.String()) s.log().Info("configuring ollama proxy", "target", target.String())
newProxy := httputil.NewSingleHostReverseProxy(target) newProxy := httputil.NewSingleHostReverseProxy(target)
@@ -193,7 +194,7 @@ func (s *Server) Handler() http.Handler {
if CORS() { if CORS() {
w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, User-Agent, Accept, X-Requested-With")
w.Header().Set("Access-Control-Allow-Credentials", "true") w.Header().Set("Access-Control-Allow-Credentials", "true")
// Handle preflight requests // Handle preflight requests
@@ -318,7 +319,7 @@ func (s *Server) handleError(w http.ResponseWriter, e error) {
if CORS() { if CORS() {
w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, User-Agent, Accept, X-Requested-With")
w.Header().Set("Access-Control-Allow-Credentials", "true") w.Header().Set("Access-Control-Allow-Credentials", "true")
} }
@@ -341,8 +342,18 @@ func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error
// httpClient returns an HTTP client that automatically adds the User-Agent header // httpClient returns an HTTP client that automatically adds the User-Agent header
func (s *Server) httpClient() *http.Client { func (s *Server) httpClient() *http.Client {
return userAgentHTTPClient(10 * time.Second)
}
// inferenceClient uses almost the same HTTP client, but without a timeout so
// long requests aren't truncated
func (s *Server) inferenceClient() *api.Client {
return api.NewClient(envconfig.Host(), userAgentHTTPClient(0))
}
func userAgentHTTPClient(timeout time.Duration) *http.Client {
return &http.Client{ return &http.Client{
Timeout: 10 * time.Second, Timeout: timeout,
Transport: &userAgentTransport{ Transport: &userAgentTransport{
base: http.DefaultTransport, base: http.DefaultTransport,
}, },
@@ -720,11 +731,7 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
_, cancelLoading := context.WithCancel(ctx) _, cancelLoading := context.WithCancel(ctx)
loading := false loading := false
c, err := api.ClientFromEnvironment() c := s.inferenceClient()
if err != nil {
cancelLoading()
return err
}
// Check if the model exists locally by trying to show it // Check if the model exists locally by trying to show it
// TODO (jmorganca): skip this round trip and instead just act // TODO (jmorganca): skip this round trip and instead just act
@@ -1572,9 +1579,18 @@ func (s *Server) modelUpstream(w http.ResponseWriter, r *http.Request) error {
return json.NewEncoder(w).Encode(response) return json.NewEncoder(w).Encode(response)
} }
n := model.ParseName(req.Model)
stale := true
if m, err := manifest.ParseNamedManifest(n); err == nil {
if m.Digest() == digest {
stale = false
} else if pushTime > 0 && m.FileInfo().ModTime().Unix() >= pushTime {
stale = false
}
}
response := responses.ModelUpstreamResponse{ response := responses.ModelUpstreamResponse{
Digest: digest, Stale: stale,
PushTime: pushTime,
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
@@ -1672,7 +1688,6 @@ func supportsBrowserTools(model string) bool {
return strings.HasPrefix(strings.ToLower(model), "gpt-oss") return strings.HasPrefix(strings.ToLower(model), "gpt-oss")
} }
// buildChatRequest converts store.Chat to api.ChatRequest // buildChatRequest converts store.Chat to api.ChatRequest
func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, availableTools []map[string]any) (*api.ChatRequest, error) { func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, availableTools []map[string]any) (*api.ChatRequest, error) {
var msgs []api.Message var msgs []api.Message

View File

@@ -15,6 +15,7 @@ import (
"sync/atomic" "sync/atomic"
"testing" "testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/store" "github.com/ollama/ollama/app/store"
"github.com/ollama/ollama/app/updater" "github.com/ollama/ollama/app/updater"
) )
@@ -526,6 +527,33 @@ func TestUserAgentTransport(t *testing.T) {
t.Logf("User-Agent transport successfully set: %s", receivedUA) t.Logf("User-Agent transport successfully set: %s", receivedUA)
} }
func TestInferenceClientUsesUserAgent(t *testing.T) {
var gotUserAgent atomic.Value
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotUserAgent.Store(r.Header.Get("User-Agent"))
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{}`))
}))
defer ts.Close()
t.Setenv("OLLAMA_HOST", ts.URL)
server := &Server{}
client := server.inferenceClient()
_, err := client.Show(context.Background(), &api.ShowRequest{Model: "test"})
if err != nil {
t.Fatalf("show request failed: %v", err)
}
receivedUA, _ := gotUserAgent.Load().(string)
expectedUA := userAgent()
if receivedUA != expectedUA {
t.Errorf("User-Agent mismatch\nExpected: %s\nReceived: %s", expectedUA, receivedUA)
}
}
func TestSupportsBrowserTools(t *testing.T) { func TestSupportsBrowserTools(t *testing.T) {
tests := []struct { tests := []struct {
model string model string

View File

@@ -1,27 +1,31 @@
Ollama Benchmark Tool Ollama Benchmark Tool
--------------------- ---------------------
A Go-based command-line tool for benchmarking Ollama models with configurable parameters and multiple output formats. A Go-based command-line tool for benchmarking Ollama models with configurable parameters, warmup phases, TTFT tracking, VRAM monitoring, and benchstat/CSV output.
## Features ## Features
* Benchmark multiple models in a single run * Benchmark multiple models in a single run
* Support for both text and image prompts * Support for both text and image prompts
* Configurable generation parameters (temperature, max tokens, seed, etc.) * Configurable generation parameters (temperature, max tokens, seed, etc.)
* Supports benchstat and CSV output formats * Warmup phase before timed epochs to stabilize measurements
* Detailed performance metrics (prefill, generate, load, total durations) * Time-to-first-token (TTFT) tracking per epoch
* Model metadata display (parameter size, quantization level, family)
* VRAM and CPU memory usage tracking via running process info
* Controlled prompt token length for reproducible benchmarks
* Benchstat and CSV output formats
## Building from Source ## Building from Source
``` ```
go build -o ollama-bench bench.go go build -o ollama-bench ./cmd/bench
./ollama-bench -model gpt-oss:20b -epochs 6 -format csv ./ollama-bench -model gemma3 -epochs 6 -format csv
``` ```
Using Go Run (without building) Using Go Run (without building)
``` ```
go run bench.go -model gpt-oss:20b -epochs 3 go run ./cmd/bench -model gemma3 -epochs 3
``` ```
## Usage ## Usage
@@ -45,10 +49,16 @@ benchstat -col /name gemma.bench
./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image" ./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
``` ```
### Controlled Prompt Length
```
./ollama-bench -model gemma3 -epochs 6 -prompt-tokens 512
```
### Advanced Example ### Advanced Example
``` ```
./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv ./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -warmup 2 -format csv -output results.csv
``` ```
## Command Line Options ## Command Line Options
@@ -56,41 +66,48 @@ benchstat -col /name gemma.bench
| Option | Description | Default | | Option | Description | Default |
|----------|-------------|---------| |----------|-------------|---------|
| -model | Comma-separated list of models to benchmark | (required) | | -model | Comma-separated list of models to benchmark | (required) |
| -epochs | Number of iterations per model | 1 | | -epochs | Number of iterations per model | 6 |
| -max-tokens | Maximum tokens for model response | 0 (unlimited) | | -max-tokens | Maximum tokens for model response | 200 |
| -temperature | Temperature parameter | 0.0 | | -temperature | Temperature parameter | 0.0 |
| -seed | Random seed | 0 (random) | | -seed | Random seed | 0 (random) |
| -timeout | Timeout in seconds | 300 | | -timeout | Timeout in seconds | 300 |
| -p | Prompt text | "Write a long story." | | -p | Prompt text | (default story prompt) |
| -image | Image file to include in prompt | | | -image | Image file to include in prompt | |
| -k | Keep-alive duration in seconds | 0 | | -k | Keep-alive duration in seconds | 0 |
| -format | Output format (benchstat, csv) | benchstat | | -format | Output format (benchstat, csv) | benchstat |
| -output | Output file for results | "" (stdout) | | -output | Output file for results | "" (stdout) |
| -warmup | Number of warmup requests before timing | 1 |
| -prompt-tokens | Generate prompt targeting ~N tokens (0 = use -p) | 0 |
| -v | Verbose mode | false | | -v | Verbose mode | false |
| -debug | Show debug information | false | | -debug | Show debug information | false |
## Output Formats ## Output Formats
### Markdown Format ### Benchstat Format (default)
The default markdown format is suitable for copying and pasting into a GitHub issue and will look like: Compatible with Go's benchstat tool for statistical analysis. Uses one value/unit pair per line, standard `ns/op` for timing metrics, and `ns/token` for throughput. Each epoch produces one set of lines -- benchstat aggregates across repeated runs to compute statistics.
```
Model | Step | Count | Duration | nsPerToken | tokensPerSec |
|-------|------|-------|----------|------------|--------------|
| gpt-oss:20b | prefill | 124 | 30.006458ms | 241987.56 | 4132.44 |
| gpt-oss:20b | generate | 200 | 2.646843954s | 13234219.77 | 75.56 |
| gpt-oss:20b | load | 1 | 121.674208ms | - | - |
| gpt-oss:20b | total | 1 | 2.861047625s | - | - |
```
### Benchstat Format
Compatible with Go's benchstat tool for statistical analysis:
``` ```
BenchmarkModel/name=gpt-oss:20b/step=prefill 128 78125.00 ns/token 12800.00 token/sec # Model: gemma3 | Params: 4.3B | Quant: Q4_K_M | Family: gemma3 | Size: 4080218931 | VRAM: 4080218931
BenchmarkModel/name=gpt-oss:20b/step=generate 512 19531.25 ns/token 51200.00 token/sec BenchmarkModel/name=gemma3/step=prefill 1 78125.00 ns/token 12800.00 token/sec
BenchmarkModel/name=gpt-oss:20b/step=load 1 1500000000 ns/request BenchmarkModel/name=gemma3/step=generate 1 19531.25 ns/token 51200.00 token/sec
BenchmarkModel/name=gemma3/step=ttft 1 45123000 ns/op
BenchmarkModel/name=gemma3/step=load 1 1500000000 ns/op
BenchmarkModel/name=gemma3/step=total 1 2861047625 ns/op
```
Use with benchstat:
```
./ollama-bench -model gemma3 -epochs 6 > gemma3.bench
benchstat -col /step gemma3.bench
```
Compare two runs:
```
./ollama-bench -model gemma3 -epochs 6 > before.bench
# ... make changes ...
./ollama-bench -model gemma3 -epochs 6 > after.bench
benchstat before.bench after.bench
``` ```
### CSV Format ### CSV Format
@@ -99,17 +116,28 @@ Machine-readable comma-separated values:
``` ```
NAME,STEP,COUNT,NS_PER_COUNT,TOKEN_PER_SEC NAME,STEP,COUNT,NS_PER_COUNT,TOKEN_PER_SEC
gpt-oss:20b,prefill,128,78125.00,12800.00 # Model: gemma3 | Params: 4.3B | Quant: Q4_K_M | Family: gemma3 | Size: 4080218931 | VRAM: 4080218931
gpt-oss:20b,generate,512,19531.25,51200.00 gemma3,prefill,128,78125.00,12800.00
gpt-oss:20b,load,1,1500000000,0 gemma3,generate,512,19531.25,51200.00
gemma3,ttft,1,45123000,0
gemma3,load,1,1500000000,0
gemma3,total,1,2861047625,0
``` ```
## Metrics Explained ## Metrics Explained
The tool reports four types of metrics for each model: The tool reports the following metrics for each epoch:
* prefill: Time spent processing the prompt * **prefill**: Time spent processing the prompt (ns/token)
* generate: Time spent generating the response * **generate**: Time spent generating the response (ns/token)
* load: Model loading time (one-time cost) * **ttft**: Time to first token -- latency from request start to first response content
* total: Total request duration * **load**: Model loading time (one-time cost)
* **total**: Total request duration
Additionally, the model info comment line (displayed once per model before epochs) includes:
* **Params**: Model parameter count (e.g., 4.3B)
* **Quant**: Quantization level (e.g., Q4_K_M)
* **Family**: Model family (e.g., gemma3)
* **Size**: Total model memory in bytes
* **VRAM**: GPU memory used by the loaded model (when Size > VRAM, the difference is CPU spill)

View File

@@ -17,19 +17,22 @@ import (
) )
type flagOptions struct { type flagOptions struct {
models *string models *string
epochs *int epochs *int
maxTokens *int maxTokens *int
temperature *float64 temperature *float64
seed *int seed *int
timeout *int timeout *int
prompt *string prompt *string
imageFile *string imageFile *string
keepAlive *float64 keepAlive *float64
format *string format *string
outputFile *string outputFile *string
debug *bool debug *bool
verbose *bool verbose *bool
warmup *int
promptTokens *int
numCtx *int
} }
type Metrics struct { type Metrics struct {
@@ -39,48 +42,203 @@ type Metrics struct {
Duration time.Duration Duration time.Duration
} }
var once sync.Once type ModelInfo struct {
Name string
ParameterSize string
QuantizationLevel string
Family string
SizeBytes int64
VRAMBytes int64
NumCtx int64
}
const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.` const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.`
// Word list for generating prompts targeting a specific token count.
var promptWordList = []string{
"the", "quick", "brown", "fox", "jumps", "over", "lazy", "dog",
"a", "bright", "sunny", "day", "in", "the", "meadow", "where",
"flowers", "bloom", "and", "birds", "sing", "their", "morning",
"songs", "while", "gentle", "breeze", "carries", "sweet", "scent",
"of", "pine", "trees", "across", "rolling", "hills", "toward",
"distant", "mountains", "covered", "with", "fresh", "snow",
"beneath", "clear", "blue", "sky", "children", "play", "near",
"old", "stone", "bridge", "that", "crosses", "winding", "river",
}
// tokensPerWord is the calibrated ratio of tokens to words for the current model.
// Initialized with a heuristic, then updated during warmup based on actual tokenization.
var tokensPerWord = 1.3
func generatePromptForTokenCount(targetTokens int, epoch int) string {
targetWords := int(float64(targetTokens) / tokensPerWord)
if targetWords < 1 {
targetWords = 1
}
// Vary the starting offset by epoch to defeat KV cache prefix matching
offset := epoch * 7 // stride by a prime to get good distribution
n := len(promptWordList)
words := make([]string, targetWords)
for i := range words {
words[i] = promptWordList[((i+offset)%n+n)%n]
}
return strings.Join(words, " ")
}
// calibratePromptTokens adjusts tokensPerWord based on actual tokenization from a warmup run.
func calibratePromptTokens(targetTokens, actualTokens, wordCount int) {
if actualTokens <= 0 || wordCount <= 0 {
return
}
tokensPerWord = float64(actualTokens) / float64(wordCount)
newWords := int(float64(targetTokens) / tokensPerWord)
fmt.Fprintf(os.Stderr, "bench: calibrated %.2f tokens/word (target=%d, got=%d, words=%d → %d)\n",
tokensPerWord, targetTokens, actualTokens, wordCount, newWords)
}
func buildGenerateRequest(model string, fOpt flagOptions, imgData api.ImageData, epoch int) *api.GenerateRequest {
options := make(map[string]interface{})
if *fOpt.maxTokens > 0 {
options["num_predict"] = *fOpt.maxTokens
}
options["temperature"] = *fOpt.temperature
if fOpt.seed != nil && *fOpt.seed > 0 {
options["seed"] = *fOpt.seed
}
if fOpt.numCtx != nil && *fOpt.numCtx > 0 {
options["num_ctx"] = *fOpt.numCtx
}
var keepAliveDuration *api.Duration
if *fOpt.keepAlive > 0 {
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
keepAliveDuration = &duration
}
prompt := *fOpt.prompt
if *fOpt.promptTokens > 0 {
prompt = generatePromptForTokenCount(*fOpt.promptTokens, epoch)
} else {
// Vary the prompt per epoch to defeat KV cache prefix matching
prompt = fmt.Sprintf("[%d] %s", epoch, prompt)
}
req := &api.GenerateRequest{
Model: model,
Prompt: prompt,
Raw: true,
Options: options,
KeepAlive: keepAliveDuration,
}
if imgData != nil {
req.Images = []api.ImageData{imgData}
}
return req
}
func fetchModelInfo(ctx context.Context, client *api.Client, model string) ModelInfo {
info := ModelInfo{Name: model}
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
if err != nil {
fmt.Fprintf(os.Stderr, "WARNING: Could not fetch model info for '%s': %v\n", model, err)
return info
}
info.ParameterSize = resp.Details.ParameterSize
info.QuantizationLevel = resp.Details.QuantizationLevel
info.Family = resp.Details.Family
return info
}
func fetchMemoryUsage(ctx context.Context, client *api.Client, model string) (size, vram int64) {
resp, err := client.ListRunning(ctx)
if err != nil {
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
fmt.Fprintf(os.Stderr, "WARNING: Could not fetch memory usage: %v\n", err)
}
return 0, 0
}
for _, m := range resp.Models {
if m.Name == model || m.Model == model {
return m.Size, m.SizeVRAM
}
}
for _, m := range resp.Models {
if strings.HasPrefix(m.Name, model) || strings.HasPrefix(m.Model, model) {
return m.Size, m.SizeVRAM
}
}
return 0, 0
}
func fetchContextLength(ctx context.Context, client *api.Client, model string) int64 {
resp, err := client.ListRunning(ctx)
if err != nil {
return 0
}
for _, m := range resp.Models {
if m.Name == model || m.Model == model || strings.HasPrefix(m.Name, model) || strings.HasPrefix(m.Model, model) {
return int64(m.ContextLength)
}
}
return 0
}
func outputFormatHeader(w io.Writer, format string, verbose bool) {
switch format {
case "benchstat":
if verbose {
fmt.Fprintf(w, "goos: %s\n", runtime.GOOS)
fmt.Fprintf(w, "goarch: %s\n", runtime.GOARCH)
}
case "csv":
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
fmt.Fprintln(w, strings.Join(headings, ","))
}
}
func outputModelInfo(w io.Writer, format string, info ModelInfo) {
params := cmp.Or(info.ParameterSize, "unknown")
quant := cmp.Or(info.QuantizationLevel, "unknown")
family := cmp.Or(info.Family, "unknown")
memStr := ""
if info.SizeBytes > 0 {
memStr = fmt.Sprintf(" | Size: %d | VRAM: %d", info.SizeBytes, info.VRAMBytes)
}
ctxStr := ""
if info.NumCtx > 0 {
ctxStr = fmt.Sprintf(" | NumCtx: %d", info.NumCtx)
}
fmt.Fprintf(w, "# Model: %s | Params: %s | Quant: %s | Family: %s%s%s\n",
info.Name, params, quant, family, memStr, ctxStr)
}
func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) { func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) {
switch format { switch format {
case "benchstat": case "benchstat":
if verbose {
printHeader := func() {
fmt.Fprintf(w, "sysname: %s\n", runtime.GOOS)
fmt.Fprintf(w, "machine: %s\n", runtime.GOARCH)
}
once.Do(printHeader)
}
for _, m := range metrics { for _, m := range metrics {
if m.Step == "generate" || m.Step == "prefill" { if m.Step == "generate" || m.Step == "prefill" {
if m.Count > 0 { if m.Count > 0 {
nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count) nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count)
tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9 tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 %.2f ns/token %.2f token/sec\n",
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d %.2f ns/token %.2f token/sec\n", m.Model, m.Step, nsPerToken, tokensPerSec)
m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
} else { } else {
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d 0 ns/token 0 token/sec\n", fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 0 ns/token 0 token/sec\n",
m.Model, m.Step, m.Count) m.Model, m.Step)
} }
} else if m.Step == "ttft" {
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=ttft 1 %d ns/op\n",
m.Model, m.Duration.Nanoseconds())
} else { } else {
var suffix string fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 %d ns/op\n",
if m.Step == "load" { m.Model, m.Step, m.Duration.Nanoseconds())
suffix = "/step=load"
}
fmt.Fprintf(w, "BenchmarkModel/name=%s%s 1 %d ns/request\n",
m.Model, suffix, m.Duration.Nanoseconds())
} }
} }
case "csv": case "csv":
printHeader := func() {
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
fmt.Fprintln(w, strings.Join(headings, ","))
}
once.Do(printHeader)
for _, m := range metrics { for _, m := range metrics {
if m.Step == "generate" || m.Step == "prefill" { if m.Step == "generate" || m.Step == "prefill" {
var nsPerToken float64 var nsPerToken float64
@@ -94,39 +252,14 @@ func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool)
fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds()) fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds())
} }
} }
case "markdown":
printHeader := func() {
fmt.Fprintln(w, "| Model | Step | Count | Duration | nsPerToken | tokensPerSec |")
fmt.Fprintln(w, "|-------|------|-------|----------|------------|--------------|")
}
once.Do(printHeader)
for _, m := range metrics {
var nsPerToken, tokensPerSec float64
var nsPerTokenStr, tokensPerSecStr string
if m.Step == "generate" || m.Step == "prefill" {
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
nsPerTokenStr = fmt.Sprintf("%.2f", nsPerToken)
tokensPerSecStr = fmt.Sprintf("%.2f", tokensPerSec)
} else {
nsPerTokenStr = "-"
tokensPerSecStr = "-"
}
fmt.Fprintf(w, "| %s | %s | %d | %v | %s | %s |\n",
m.Model, m.Step, m.Count, m.Duration, nsPerTokenStr, tokensPerSecStr)
}
default: default:
fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format) fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format)
} }
} }
func BenchmarkChat(fOpt flagOptions) error { func BenchmarkModel(fOpt flagOptions) error {
models := strings.Split(*fOpt.models, ",") models := strings.Split(*fOpt.models, ",")
// todo - add multi-image support
var imgData api.ImageData var imgData api.ImageData
var err error var err error
if *fOpt.imageFile != "" { if *fOpt.imageFile != "" {
@@ -158,71 +291,141 @@ func BenchmarkChat(fOpt flagOptions) error {
out = f out = f
} }
outputFormatHeader(out, *fOpt.format, *fOpt.verbose)
// Log prompt-tokens info in debug mode
if *fOpt.debug && *fOpt.promptTokens > 0 {
prompt := generatePromptForTokenCount(*fOpt.promptTokens, 0)
wordCount := len(strings.Fields(prompt))
fmt.Fprintf(os.Stderr, "Generated prompt targeting ~%d tokens (%d words, varied per epoch)\n", *fOpt.promptTokens, wordCount)
}
for _, model := range models { for _, model := range models {
for range *fOpt.epochs { // Fetch model info
options := make(map[string]interface{}) infoCtx, infoCancel := context.WithTimeout(context.Background(), 10*time.Second)
if *fOpt.maxTokens > 0 { info := fetchModelInfo(infoCtx, client, model)
options["num_predict"] = *fOpt.maxTokens infoCancel()
}
options["temperature"] = *fOpt.temperature
if fOpt.seed != nil && *fOpt.seed > 0 {
options["seed"] = *fOpt.seed
}
var keepAliveDuration *api.Duration
if *fOpt.keepAlive > 0 {
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
keepAliveDuration = &duration
}
req := &api.ChatRequest{
Model: model,
Messages: []api.Message{
{
Role: "user",
Content: *fOpt.prompt,
},
},
Options: options,
KeepAlive: keepAliveDuration,
}
if imgData != nil {
req.Messages[0].Images = []api.ImageData{imgData}
}
var responseMetrics *api.Metrics
// Warmup phase (uses negative epoch numbers to avoid colliding with timed epochs)
for i := range *fOpt.warmup {
req := buildGenerateRequest(model, fOpt, imgData, -(i + 1))
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
defer cancel()
err = client.Chat(ctx, req, func(resp api.ChatResponse) error {
if *fOpt.debug {
fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Message.Thinking, resp.Message.Content))
}
var warmupMetrics *api.Metrics
err = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
if resp.Done { if resp.Done {
responseMetrics = &resp.Metrics warmupMetrics = &resp.Metrics
} }
return nil return nil
}) })
cancel()
if *fOpt.debug {
fmt.Fprintln(os.Stderr)
}
if err != nil { if err != nil {
if ctx.Err() == context.DeadlineExceeded { fmt.Fprintf(os.Stderr, "WARNING: Warmup %d/%d for %s failed: %v\n", i+1, *fOpt.warmup, model, err)
fmt.Fprintf(os.Stderr, "ERROR: Chat request timed out with model '%s' after %vs\n", model, 1) } else {
continue if *fOpt.debug {
fmt.Fprintf(os.Stderr, "Warmup %d/%d for %s complete\n", i+1, *fOpt.warmup, model)
} }
fmt.Fprintf(os.Stderr, "ERROR: Couldn't chat with model '%s': %v\n", model, err) // Calibrate prompt token count on last warmup run
if i == *fOpt.warmup-1 && *fOpt.promptTokens > 0 && warmupMetrics != nil {
prompt := generatePromptForTokenCount(*fOpt.promptTokens, -(i + 1))
wordCount := len(strings.Fields(prompt))
calibratePromptTokens(*fOpt.promptTokens, warmupMetrics.PromptEvalCount, wordCount)
}
}
}
// Fetch memory/context info once after warmup (model is loaded and stable)
memCtx, memCancel := context.WithTimeout(context.Background(), 5*time.Second)
info.SizeBytes, info.VRAMBytes = fetchMemoryUsage(memCtx, client, model)
if fOpt.numCtx != nil && *fOpt.numCtx > 0 {
info.NumCtx = int64(*fOpt.numCtx)
} else {
info.NumCtx = fetchContextLength(memCtx, client, model)
}
memCancel()
outputModelInfo(out, *fOpt.format, info)
// Timed epoch loop
shortCount := 0
for epoch := range *fOpt.epochs {
var responseMetrics *api.Metrics
var ttft time.Duration
short := false
// Retry loop: if the model hits a stop token before max-tokens,
// retry with a different prompt (up to maxRetries times).
const maxRetries = 3
for attempt := range maxRetries + 1 {
responseMetrics = nil
ttft = 0
var ttftOnce sync.Once
req := buildGenerateRequest(model, fOpt, imgData, epoch+attempt*1000)
requestStart := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
err = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
if *fOpt.debug {
fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Thinking, resp.Response))
}
// Capture TTFT on first content
ttftOnce.Do(func() {
if resp.Response != "" || resp.Thinking != "" {
ttft = time.Since(requestStart)
}
})
if resp.Done {
responseMetrics = &resp.Metrics
}
return nil
})
cancel()
if *fOpt.debug {
fmt.Fprintln(os.Stderr)
}
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
fmt.Fprintf(os.Stderr, "ERROR: Request timed out with model '%s' after %vs\n", model, *fOpt.timeout)
} else {
fmt.Fprintf(os.Stderr, "ERROR: Couldn't generate with model '%s': %v\n", model, err)
}
break
}
if responseMetrics == nil {
fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model)
break
}
// Check if the response was shorter than requested
short = *fOpt.maxTokens > 0 && responseMetrics.EvalCount < *fOpt.maxTokens
if !short || attempt == maxRetries {
break
}
if *fOpt.debug {
fmt.Fprintf(os.Stderr, "Short response (%d/%d tokens), retrying with different prompt (attempt %d/%d)\n",
responseMetrics.EvalCount, *fOpt.maxTokens, attempt+1, maxRetries)
}
}
if err != nil || responseMetrics == nil {
continue continue
} }
if responseMetrics == nil { if short {
fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model) shortCount++
continue if *fOpt.debug {
fmt.Fprintf(os.Stderr, "WARNING: Short response (%d/%d tokens) after %d retries for epoch %d\n",
responseMetrics.EvalCount, *fOpt.maxTokens, maxRetries, epoch+1)
}
} }
metrics := []Metrics{ metrics := []Metrics{
@@ -238,6 +441,12 @@ func BenchmarkChat(fOpt flagOptions) error {
Count: responseMetrics.EvalCount, Count: responseMetrics.EvalCount,
Duration: responseMetrics.EvalDuration, Duration: responseMetrics.EvalDuration,
}, },
{
Model: model,
Step: "ttft",
Count: 1,
Duration: ttft,
},
{ {
Model: model, Model: model,
Step: "load", Step: "load",
@@ -254,15 +463,42 @@ func BenchmarkChat(fOpt flagOptions) error {
OutputMetrics(out, *fOpt.format, metrics, *fOpt.verbose) OutputMetrics(out, *fOpt.format, metrics, *fOpt.verbose)
if *fOpt.debug && *fOpt.promptTokens > 0 {
fmt.Fprintf(os.Stderr, "Generated prompt targeting ~%d tokens (actual: %d)\n",
*fOpt.promptTokens, responseMetrics.PromptEvalCount)
}
if *fOpt.keepAlive > 0 { if *fOpt.keepAlive > 0 {
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond) time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
} }
} }
if shortCount > 0 {
fmt.Fprintf(os.Stderr, "WARNING: %d/%d epochs for '%s' had short responses (<%d tokens). Generation metrics may be unreliable.\n",
shortCount, *fOpt.epochs, model, *fOpt.maxTokens)
}
// Unload model before moving to the next one
unloadModel(client, model, *fOpt.timeout)
} }
return nil return nil
} }
func unloadModel(client *api.Client, model string, timeout int) {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
defer cancel()
zero := api.Duration{Duration: 0}
req := &api.GenerateRequest{
Model: model,
KeepAlive: &zero,
}
_ = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
return nil
})
}
func readImage(filePath string) (api.ImageData, error) { func readImage(filePath string) (api.ImageData, error) {
file, err := os.Open(filePath) file, err := os.Open(filePath)
if err != nil { if err != nil {
@@ -280,19 +516,22 @@ func readImage(filePath string) (api.ImageData, error) {
func main() { func main() {
fOpt := flagOptions{ fOpt := flagOptions{
models: flag.String("model", "", "Model to benchmark"), models: flag.String("model", "", "Model to benchmark"),
epochs: flag.Int("epochs", 6, "Number of epochs (iterations) per model"), epochs: flag.Int("epochs", 6, "Number of epochs (iterations) per model"),
maxTokens: flag.Int("max-tokens", 200, "Maximum tokens for model response"), maxTokens: flag.Int("max-tokens", 200, "Maximum tokens for model response"),
temperature: flag.Float64("temperature", 0, "Temperature parameter"), temperature: flag.Float64("temperature", 0, "Temperature parameter"),
seed: flag.Int("seed", 0, "Random seed"), seed: flag.Int("seed", 0, "Random seed"),
timeout: flag.Int("timeout", 60*5, "Timeout in seconds (default 300s)"), timeout: flag.Int("timeout", 60*5, "Timeout in seconds (default 300s)"),
prompt: flag.String("p", DefaultPrompt, "Prompt to use"), prompt: flag.String("p", DefaultPrompt, "Prompt to use"),
imageFile: flag.String("image", "", "Filename for an image to include"), imageFile: flag.String("image", "", "Filename for an image to include"),
keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"), keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"),
format: flag.String("format", "markdown", "Output format [benchstat|csv] (default benchstat)"), format: flag.String("format", "benchstat", "Output format [benchstat|csv]"),
outputFile: flag.String("output", "", "Output file for results (stdout if empty)"), outputFile: flag.String("output", "", "Output file for results (stdout if empty)"),
verbose: flag.Bool("v", false, "Show system information"), verbose: flag.Bool("v", false, "Show system information"),
debug: flag.Bool("debug", false, "Show debug information"), debug: flag.Bool("debug", false, "Show debug information"),
warmup: flag.Int("warmup", 1, "Number of warmup requests before timing"),
promptTokens: flag.Int("prompt-tokens", 0, "Generate prompt targeting ~N tokens (0 = use -p prompt)"),
numCtx: flag.Int("num-ctx", 0, "Context size (0 = server default)"),
} }
flag.Usage = func() { flag.Usage = func() {
@@ -302,11 +541,12 @@ func main() {
fmt.Fprintf(os.Stderr, "Options:\n") fmt.Fprintf(os.Stderr, "Options:\n")
flag.PrintDefaults() flag.PrintDefaults()
fmt.Fprintf(os.Stderr, "\nExamples:\n") fmt.Fprintf(os.Stderr, "\nExamples:\n")
fmt.Fprintf(os.Stderr, " bench -model gpt-oss:20b -epochs 3 -temperature 0.7\n") fmt.Fprintf(os.Stderr, " bench -model gemma3,llama3 -epochs 6\n")
fmt.Fprintf(os.Stderr, " bench -model gemma3 -epochs 6 -prompt-tokens 512 -format csv\n")
} }
flag.Parse() flag.Parse()
if !slices.Contains([]string{"markdown", "benchstat", "csv"}, *fOpt.format) { if !slices.Contains([]string{"benchstat", "csv"}, *fOpt.format) {
fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format) fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format)
os.Exit(1) os.Exit(1)
} }
@@ -317,5 +557,5 @@ func main() {
return return
} }
BenchmarkChat(fOpt) BenchmarkModel(fOpt)
} }

File diff suppressed because it is too large Load Diff

View File

@@ -11,6 +11,7 @@ import (
"fmt" "fmt"
"io" "io"
"log" "log"
"log/slog"
"math" "math"
"net" "net"
"net/http" "net/http"
@@ -38,9 +39,12 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config" "github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/cmd/launch"
"github.com/ollama/ollama/cmd/tui" "github.com/ollama/ollama/cmd/tui"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress" "github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline" "github.com/ollama/ollama/readline"
@@ -57,36 +61,42 @@ import (
func init() { func init() {
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O. // Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
config.DefaultSingleSelector = func(title string, items []config.ModelItem, current string) (string, error) { launch.DefaultSingleSelector = func(title string, items []launch.ModelItem, current string) (string, error) {
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
return "", fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode")
}
tuiItems := tui.ReorderItems(tui.ConvertItems(items)) tuiItems := tui.ReorderItems(tui.ConvertItems(items))
result, err := tui.SelectSingle(title, tuiItems, current) result, err := tui.SelectSingle(title, tuiItems, current)
if errors.Is(err, tui.ErrCancelled) { if errors.Is(err, tui.ErrCancelled) {
return "", config.ErrCancelled return "", launch.ErrCancelled
} }
return result, err return result, err
} }
config.DefaultMultiSelector = func(title string, items []config.ModelItem, preChecked []string) ([]string, error) { launch.DefaultMultiSelector = func(title string, items []launch.ModelItem, preChecked []string) ([]string, error) {
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
return nil, fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode")
}
tuiItems := tui.ReorderItems(tui.ConvertItems(items)) tuiItems := tui.ReorderItems(tui.ConvertItems(items))
result, err := tui.SelectMultiple(title, tuiItems, preChecked) result, err := tui.SelectMultiple(title, tuiItems, preChecked)
if errors.Is(err, tui.ErrCancelled) { if errors.Is(err, tui.ErrCancelled) {
return nil, config.ErrCancelled return nil, launch.ErrCancelled
} }
return result, err return result, err
} }
config.DefaultSignIn = func(modelName, signInURL string) (string, error) { launch.DefaultSignIn = func(modelName, signInURL string) (string, error) {
userName, err := tui.RunSignIn(modelName, signInURL) userName, err := tui.RunSignIn(modelName, signInURL)
if errors.Is(err, tui.ErrCancelled) { if errors.Is(err, tui.ErrCancelled) {
return "", config.ErrCancelled return "", launch.ErrCancelled
} }
return userName, err return userName, err
} }
config.DefaultConfirmPrompt = func(prompt string) (bool, error) { launch.DefaultConfirmPrompt = func(prompt string) (bool, error) {
ok, err := tui.RunConfirm(prompt) ok, err := tui.RunConfirm(prompt)
if errors.Is(err, tui.ErrCancelled) { if errors.Is(err, tui.ErrCancelled) {
return false, config.ErrCancelled return false, launch.ErrCancelled
} }
return ok, err return ok, err
} }
@@ -131,6 +141,17 @@ func getModelfileName(cmd *cobra.Command) (string, error) {
return absName, nil return absName, nil
} }
// isLocalhost returns true if the configured Ollama host is a loopback or unspecified address.
func isLocalhost() bool {
host := envconfig.Host()
h, _, _ := net.SplitHostPort(host.Host)
if h == "localhost" {
return true
}
ip := net.ParseIP(h)
return ip != nil && (ip.IsLoopback() || ip.IsUnspecified())
}
func CreateHandler(cmd *cobra.Command, args []string) error { func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr) p := progress.NewProgress(os.Stderr)
defer p.Stop() defer p.Stop()
@@ -145,6 +166,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
// Check for --experimental flag for safetensors model creation // Check for --experimental flag for safetensors model creation
experimental, _ := cmd.Flags().GetBool("experimental") experimental, _ := cmd.Flags().GetBool("experimental")
if experimental { if experimental {
if !isLocalhost() {
return errors.New("remote safetensor model creation not yet supported")
}
// Get Modelfile content - either from -f flag or default to "FROM ." // Get Modelfile content - either from -f flag or default to "FROM ."
var reader io.Reader var reader io.Reader
filename, err := getModelfileName(cmd) filename, err := getModelfileName(cmd)
@@ -168,29 +192,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to parse Modelfile: %w", err) return fmt.Errorf("failed to parse Modelfile: %w", err)
} }
// Extract FROM path and configuration modelDir, mfConfig, err := xcreateclient.ConfigFromModelfile(modelfile)
var modelDir string if err != nil {
mfConfig := &xcreateclient.ModelfileConfig{} return err
for _, cmd := range modelfile.Commands {
switch cmd.Name {
case "model":
modelDir = cmd.Args
case "template":
mfConfig.Template = cmd.Args
case "system":
mfConfig.System = cmd.Args
case "license":
mfConfig.License = cmd.Args
case "parser":
mfConfig.Parser = cmd.Args
case "renderer":
mfConfig.Renderer = cmd.Args
}
}
if modelDir == "" {
modelDir = "."
} }
// Resolve relative paths based on Modelfile location // Resolve relative paths based on Modelfile location
@@ -214,6 +218,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
if filename == "" { if filename == "" {
// No Modelfile found - check if current directory is an image gen model // No Modelfile found - check if current directory is an image gen model
if create.IsTensorModelDir(".") { if create.IsTensorModelDir(".") {
if !isLocalhost() {
return errors.New("remote safetensor model creation not yet supported")
}
quantize, _ := cmd.Flags().GetString("quantize") quantize, _ := cmd.Flags().GetString("quantize")
return xcreateclient.CreateModel(xcreateclient.CreateOptions{ return xcreateclient.CreateModel(xcreateclient.CreateOptions{
ModelName: modelName, ModelName: modelName,
@@ -406,12 +413,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
return err return err
} }
requestedCloud := modelref.HasExplicitCloudSource(opts.Model)
if info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model}); err != nil { if info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model}); err != nil {
return err return err
} else if info.RemoteHost != "" { } else if info.RemoteHost != "" || requestedCloud {
// Cloud model, no need to load/unload // Cloud model, no need to load/unload
isCloud := strings.HasPrefix(info.RemoteHost, "https://ollama.com") isCloud := requestedCloud || strings.HasPrefix(info.RemoteHost, "https://ollama.com")
// Check if user is signed in for ollama.com cloud models // Check if user is signed in for ollama.com cloud models
if isCloud { if isCloud {
@@ -422,10 +431,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
if opts.ShowConnect { if opts.ShowConnect {
p.StopAndClear() p.StopAndClear()
remoteModel := info.RemoteModel
if remoteModel == "" {
remoteModel = opts.Model
}
if isCloud { if isCloud {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel) fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", remoteModel)
} else { } else {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost) fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", remoteModel, info.RemoteHost)
} }
} }
@@ -497,6 +510,64 @@ func generateEmbedding(cmd *cobra.Command, modelName, input string, keepAlive *a
return nil return nil
} }
// TODO(parthsareen): consolidate with TUI signin flow
func handleCloudAuthorizationError(err error) bool {
var authErr api.AuthorizationError
if errors.As(err, &authErr) && authErr.StatusCode == http.StatusUnauthorized {
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
if authErr.SigninURL != "" {
fmt.Printf(ConnectInstructions, authErr.SigninURL)
}
return true
}
return false
}
// TEMP(drifkin): To match legacy `ollama run some-model:cloud` behavior, we
// best-effort pull cloud stub files for any explicit cloud source models.
// Remove this once `/api/tags` is cloud-aware.
func ensureCloudStub(ctx context.Context, client *api.Client, modelName string) {
if !modelref.HasExplicitCloudSource(modelName) {
return
}
normalizedName, _, err := modelref.NormalizePullName(modelName)
if err != nil {
slog.Warn("failed to normalize pull name", "model", modelName, "error", err, "normalizedName", normalizedName)
return
}
listResp, err := client.List(ctx)
if err != nil {
slog.Warn("failed to list models", "error", err)
return
}
if hasListedModelName(listResp.Models, modelName) || hasListedModelName(listResp.Models, normalizedName) {
return
}
logutil.Trace("pulling cloud stub", "model", modelName, "normalizedName", normalizedName)
err = client.Pull(ctx, &api.PullRequest{
Model: normalizedName,
}, func(api.ProgressResponse) error {
return nil
})
if err != nil {
slog.Warn("failed to pull cloud stub", "model", modelName, "error", err)
}
}
func hasListedModelName(models []api.ListModelResponse, name string) bool {
for _, m := range models {
if strings.EqualFold(m.Name, name) || strings.EqualFold(m.Model, name) {
return true
}
}
return false
}
func RunHandler(cmd *cobra.Command, args []string) error { func RunHandler(cmd *cobra.Command, args []string) error {
interactive := true interactive := true
@@ -585,17 +656,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
} }
opts.WordWrap = !nowrap opts.WordWrap = !nowrap
useImagegen := false
if cmd.Flags().Lookup("imagegen") != nil {
useImagegen, err = cmd.Flags().GetBool("imagegen")
if err != nil {
return err
}
}
if useImagegen {
opts.Options["use_imagegen_runner"] = true
}
// Fill out the rest of the options based on information about the // Fill out the rest of the options based on information about the
// model. // model.
client, err := api.ClientFromEnvironment() client, err := api.ClientFromEnvironment()
@@ -604,12 +664,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
} }
name := args[0] name := args[0]
requestedCloud := modelref.HasExplicitCloudSource(name)
info, err := func() (*api.ShowResponse, error) { info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: name} showReq := &api.ShowRequest{Name: name}
info, err := client.Show(cmd.Context(), showReq) info, err := client.Show(cmd.Context(), showReq)
var se api.StatusError var se api.StatusError
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound { if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
if requestedCloud {
return nil, err
}
if err := PullHandler(cmd, []string{name}); err != nil { if err := PullHandler(cmd, []string{name}); err != nil {
return nil, err return nil, err
} }
@@ -618,15 +682,21 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return info, err return info, err
}() }()
if err != nil { if err != nil {
if handleCloudAuthorizationError(err) {
return nil
}
return err return err
} }
ensureCloudStub(cmd.Context(), client, name)
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkFlag.Changed) opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkFlag.Changed)
if err != nil { if err != nil {
return err return err
} }
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision) audioCapable := slices.Contains(info.Capabilities, model.CapabilityAudio)
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision) || audioCapable
// TODO: remove the projector info and vision info checks below, // TODO: remove the projector info and vision info checks below,
// these are left in for backwards compatibility with older servers // these are left in for backwards compatibility with older servers
@@ -712,7 +782,13 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generateInteractive(cmd, opts) return generateInteractive(cmd, opts)
} }
return generate(cmd, opts) if err := generate(cmd, opts); err != nil {
if handleCloudAuthorizationError(err) {
return nil
}
return err
}
return nil
} }
func SigninHandler(cmd *cobra.Command, args []string) error { func SigninHandler(cmd *cobra.Command, args []string) error {
@@ -1419,6 +1495,9 @@ type displayResponseState struct {
func displayResponse(content string, wordWrap bool, state *displayResponseState) { func displayResponse(content string, wordWrap bool, state *displayResponseState) {
termWidth, _, _ := term.GetSize(int(os.Stdout.Fd())) termWidth, _, _ := term.GetSize(int(os.Stdout.Fd()))
if termWidth == 0 {
termWidth = 80
}
if wordWrap && termWidth >= 10 { if wordWrap && termWidth >= 10 {
for _, ch := range content { for _, ch := range content {
if state.lineLength+1 > termWidth-5 { if state.lineLength+1 > termWidth-5 {
@@ -1892,6 +1971,24 @@ func ensureServerRunning(ctx context.Context) error {
} }
} }
func launchInteractiveModel(cmd *cobra.Command, modelName string) error {
opts := runOptions{
Model: modelName,
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]any{},
ShowConnect: true,
}
// loadOrUnloadModel is cloud-safe here: remote/cloud models skip local preload
// and only validate auth/connectivity before interactive chat starts.
if err := loadOrUnloadModel(cmd, &opts); err != nil {
return fmt.Errorf("error loading model: %w", err)
}
if err := generateInteractive(cmd, opts); err != nil {
return fmt.Errorf("error running model: %w", err)
}
return nil
}
// runInteractiveTUI runs the main interactive TUI menu. // runInteractiveTUI runs the main interactive TUI menu.
func runInteractiveTUI(cmd *cobra.Command) { func runInteractiveTUI(cmd *cobra.Command) {
// Ensure the server is running before showing the TUI // Ensure the server is running before showing the TUI
@@ -1900,175 +1997,85 @@ func runInteractiveTUI(cmd *cobra.Command) {
return return
} }
// Selector adapters for tui deps := launcherDeps{
singleSelector := func(title string, items []config.ModelItem, current string) (string, error) { buildState: launch.BuildLauncherState,
tuiItems := tui.ReorderItems(tui.ConvertItems(items)) runMenu: tui.RunMenu,
result, err := tui.SelectSingle(title, tuiItems, current) resolveRunModel: launch.ResolveRunModel,
if errors.Is(err, tui.ErrCancelled) { launchIntegration: launch.LaunchIntegration,
return "", config.ErrCancelled runModel: launchInteractiveModel,
}
return result, err
}
multiSelector := func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
if errors.Is(err, tui.ErrCancelled) {
return nil, config.ErrCancelled
}
return result, err
} }
for { for {
result, err := tui.Run() continueLoop, err := runInteractiveTUIStep(cmd, deps)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err) fmt.Fprintf(os.Stderr, "Error: %v\n", err)
}
if !continueLoop {
return return
} }
}
}
runModel := func(modelName string) { type launcherDeps struct {
client, err := api.ClientFromEnvironment() buildState func(context.Context) (*launch.LauncherState, error)
if err != nil { runMenu func(*launch.LauncherState) (tui.TUIAction, error)
fmt.Fprintf(os.Stderr, "Error: %v\n", err) resolveRunModel func(context.Context, launch.RunModelRequest) (string, error)
return launchIntegration func(context.Context, launch.IntegrationLaunchRequest) error
} runModel func(*cobra.Command, string) error
if err := config.ShowOrPull(cmd.Context(), client, modelName); err != nil { }
if errors.Is(err, config.ErrCancelled) {
return
}
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
return
}
_ = config.SetLastModel(modelName)
opts := runOptions{
Model: modelName,
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]any{},
ShowConnect: true,
}
if err := loadOrUnloadModel(cmd, &opts); err != nil {
fmt.Fprintf(os.Stderr, "Error loading model: %v\n", err)
return
}
if err := generateInteractive(cmd, opts); err != nil {
fmt.Fprintf(os.Stderr, "Error running model: %v\n", err)
}
}
launchIntegration := func(name string) bool { func runInteractiveTUIStep(cmd *cobra.Command, deps launcherDeps) (bool, error) {
if err := config.EnsureInstalled(name); err != nil { state, err := deps.buildState(cmd.Context())
fmt.Fprintf(os.Stderr, "Error: %v\n", err) if err != nil {
return true return false, fmt.Errorf("build launcher state: %w", err)
} }
// If not configured or model no longer exists, prompt for model selection
configuredModel := config.IntegrationModel(name)
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) || config.IsCloudModelDisabled(cmd.Context(), configuredModel) {
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), name, singleSelector, multiSelector)
if errors.Is(err, config.ErrCancelled) {
return false // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", name, err)
return true
}
}
if err := config.LaunchIntegration(name); err != nil {
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", name, err)
}
return true
}
switch result.Selection { action, err := deps.runMenu(state)
case tui.SelectionNone: if err != nil {
// User quit return false, fmt.Errorf("run launcher menu: %w", err)
return }
case tui.SelectionRunModel:
_ = config.SetLastSelection("run") return runLauncherAction(cmd, action, deps)
if modelName := config.LastModel(); modelName != "" && !config.IsCloudModelDisabled(cmd.Context(), modelName) { }
runModel(modelName)
} else { func saveLauncherSelection(action tui.TUIAction) {
modelName, err := config.SelectModelWithSelector(cmd.Context(), singleSelector) // Best effort only: this affects menu recall, not launch correctness.
if errors.Is(err, config.ErrCancelled) { _ = config.SetLastSelection(action.LastSelection())
continue // Return to main menu }
}
if err != nil { func runLauncherAction(cmd *cobra.Command, action tui.TUIAction, deps launcherDeps) (bool, error) {
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err) switch action.Kind {
continue case tui.TUIActionNone:
} return false, nil
runModel(modelName) case tui.TUIActionRunModel:
} saveLauncherSelection(action)
case tui.SelectionChangeRunModel: modelName, err := deps.resolveRunModel(cmd.Context(), action.RunModelRequest())
_ = config.SetLastSelection("run") if errors.Is(err, launch.ErrCancelled) {
// Use model from modal if selected, otherwise show picker return true, nil
modelName := result.Model
if modelName == "" {
var err error
modelName, err = config.SelectModelWithSelector(cmd.Context(), singleSelector)
if errors.Is(err, config.ErrCancelled) {
continue // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
continue
}
}
if config.IsCloudModelDisabled(cmd.Context(), modelName) {
continue // Return to main menu
}
runModel(modelName)
case tui.SelectionIntegration:
_ = config.SetLastSelection(result.Integration)
if !launchIntegration(result.Integration) {
continue // Return to main menu
}
case tui.SelectionChangeIntegration:
_ = config.SetLastSelection(result.Integration)
if len(result.Models) > 0 {
// Filter out cloud-disabled models
var filtered []string
for _, m := range result.Models {
if !config.IsCloudModelDisabled(cmd.Context(), m) {
filtered = append(filtered, m)
}
}
if len(filtered) == 0 {
continue
}
result.Models = filtered
// Multi-select from modal (Editor integrations)
if err := config.SaveAndEditIntegration(result.Integration, result.Models); err != nil {
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
continue
}
if err := config.LaunchIntegrationWithModel(result.Integration, result.Models[0]); err != nil {
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
}
} else if result.Model != "" {
if config.IsCloudModelDisabled(cmd.Context(), result.Model) {
continue
}
// Single-select from modal - save and launch
if err := config.SaveIntegration(result.Integration, []string{result.Model}); err != nil {
fmt.Fprintf(os.Stderr, "Error saving config: %v\n", err)
continue
}
if err := config.LaunchIntegrationWithModel(result.Integration, result.Model); err != nil {
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
}
} else {
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), result.Integration, singleSelector, multiSelector)
if errors.Is(err, config.ErrCancelled) {
continue // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
continue
}
if err := config.LaunchIntegration(result.Integration); err != nil {
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
}
}
} }
if err != nil {
return true, fmt.Errorf("selecting model: %w", err)
}
if err := deps.runModel(cmd, modelName); err != nil {
return true, err
}
return true, nil
case tui.TUIActionLaunchIntegration:
saveLauncherSelection(action)
err := deps.launchIntegration(cmd.Context(), action.IntegrationLaunchRequest())
if errors.Is(err, launch.ErrCancelled) {
return true, nil
}
if err != nil {
return true, fmt.Errorf("launching %s: %w", action.Integration, err)
}
// VS Code is a GUI app — exit the TUI loop after launching
if action.Integration == "vscode" {
return false, nil
}
return true, nil
default:
return false, fmt.Errorf("unknown launcher action: %d", action.Kind)
} }
} }
@@ -2338,7 +2345,7 @@ func NewCLI() *cobra.Command {
copyCmd, copyCmd,
deleteCmd, deleteCmd,
runnerCmd, runnerCmd,
config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI), launch.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
) )
return rootCmd return rootCmd

270
cmd/cmd_launcher_test.go Normal file
View File

@@ -0,0 +1,270 @@
package cmd
import (
"context"
"testing"
"github.com/spf13/cobra"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/cmd/launch"
"github.com/ollama/ollama/cmd/tui"
)
func setCmdTestHome(t *testing.T, dir string) {
t.Helper()
t.Setenv("HOME", dir)
t.Setenv("USERPROFILE", dir)
}
func unexpectedRunModelResolution(t *testing.T) func(context.Context, launch.RunModelRequest) (string, error) {
t.Helper()
return func(ctx context.Context, req launch.RunModelRequest) (string, error) {
t.Fatalf("did not expect run-model resolution: %+v", req)
return "", nil
}
}
func unexpectedIntegrationLaunch(t *testing.T) func(context.Context, launch.IntegrationLaunchRequest) error {
t.Helper()
return func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
t.Fatalf("did not expect integration launch: %+v", req)
return nil
}
}
func unexpectedModelLaunch(t *testing.T) func(*cobra.Command, string) error {
t.Helper()
return func(cmd *cobra.Command, model string) error {
t.Fatalf("did not expect chat launch: %s", model)
return nil
}
}
func TestRunInteractiveTUI_RunModelActionsUseResolveRunModel(t *testing.T) {
tests := []struct {
name string
action tui.TUIAction
wantForce bool
wantModel string
}{
{
name: "enter uses saved model flow",
action: tui.TUIAction{Kind: tui.TUIActionRunModel},
wantModel: "qwen3:8b",
},
{
name: "right forces picker",
action: tui.TUIAction{Kind: tui.TUIActionRunModel, ForceConfigure: true},
wantForce: true,
wantModel: "glm-5:cloud",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setCmdTestHome(t, t.TempDir())
var menuCalls int
runMenu := func(state *launch.LauncherState) (tui.TUIAction, error) {
menuCalls++
if menuCalls == 1 {
return tt.action, nil
}
return tui.TUIAction{Kind: tui.TUIActionNone}, nil
}
var gotReq launch.RunModelRequest
var launched string
deps := launcherDeps{
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
return &launch.LauncherState{}, nil
},
runMenu: runMenu,
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
gotReq = req
return tt.wantModel, nil
},
launchIntegration: unexpectedIntegrationLaunch(t),
runModel: func(cmd *cobra.Command, model string) error {
launched = model
return nil
},
}
cmd := &cobra.Command{}
cmd.SetContext(context.Background())
for {
continueLoop, err := runInteractiveTUIStep(cmd, deps)
if err != nil {
t.Fatalf("unexpected step error: %v", err)
}
if !continueLoop {
break
}
}
if gotReq.ForcePicker != tt.wantForce {
t.Fatalf("expected ForcePicker=%v, got %v", tt.wantForce, gotReq.ForcePicker)
}
if launched != tt.wantModel {
t.Fatalf("expected interactive launcher to run %q, got %q", tt.wantModel, launched)
}
if got := config.LastSelection(); got != "run" {
t.Fatalf("expected last selection to be run, got %q", got)
}
})
}
}
func TestRunInteractiveTUI_IntegrationActionsUseLaunchIntegration(t *testing.T) {
tests := []struct {
name string
action tui.TUIAction
wantForce bool
}{
{
name: "enter launches integration",
action: tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"},
},
{
name: "right forces configure",
action: tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude", ForceConfigure: true},
wantForce: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setCmdTestHome(t, t.TempDir())
var menuCalls int
runMenu := func(state *launch.LauncherState) (tui.TUIAction, error) {
menuCalls++
if menuCalls == 1 {
return tt.action, nil
}
return tui.TUIAction{Kind: tui.TUIActionNone}, nil
}
var gotReq launch.IntegrationLaunchRequest
deps := launcherDeps{
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
return &launch.LauncherState{}, nil
},
runMenu: runMenu,
resolveRunModel: unexpectedRunModelResolution(t),
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
gotReq = req
return nil
},
runModel: unexpectedModelLaunch(t),
}
cmd := &cobra.Command{}
cmd.SetContext(context.Background())
for {
continueLoop, err := runInteractiveTUIStep(cmd, deps)
if err != nil {
t.Fatalf("unexpected step error: %v", err)
}
if !continueLoop {
break
}
}
if gotReq.Name != "claude" {
t.Fatalf("expected integration name to be passed through, got %q", gotReq.Name)
}
if gotReq.ForceConfigure != tt.wantForce {
t.Fatalf("expected ForceConfigure=%v, got %v", tt.wantForce, gotReq.ForceConfigure)
}
if got := config.LastSelection(); got != "claude" {
t.Fatalf("expected last selection to be claude, got %q", got)
}
})
}
}
func TestRunLauncherAction_RunModelContinuesAfterCancellation(t *testing.T) {
setCmdTestHome(t, t.TempDir())
cmd := &cobra.Command{}
cmd.SetContext(context.Background())
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionRunModel}, launcherDeps{
buildState: nil,
runMenu: nil,
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
return "", launch.ErrCancelled
},
launchIntegration: unexpectedIntegrationLaunch(t),
runModel: unexpectedModelLaunch(t),
})
if err != nil {
t.Fatalf("expected nil error on cancellation, got %v", err)
}
if !continueLoop {
t.Fatal("expected cancellation to continue the menu loop")
}
}
func TestRunLauncherAction_VSCodeExitsTUILoop(t *testing.T) {
setCmdTestHome(t, t.TempDir())
cmd := &cobra.Command{}
cmd.SetContext(context.Background())
// VS Code should exit the TUI loop (return false) after a successful launch.
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "vscode"}, launcherDeps{
resolveRunModel: unexpectedRunModelResolution(t),
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
return nil
},
runModel: unexpectedModelLaunch(t),
})
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if continueLoop {
t.Fatal("expected vscode launch to exit the TUI loop (return false)")
}
// Other integrations should continue the TUI loop (return true).
continueLoop, err = runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"}, launcherDeps{
resolveRunModel: unexpectedRunModelResolution(t),
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
return nil
},
runModel: unexpectedModelLaunch(t),
})
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if !continueLoop {
t.Fatal("expected non-vscode integration to continue the TUI loop (return true)")
}
}
func TestRunLauncherAction_IntegrationContinuesAfterCancellation(t *testing.T) {
setCmdTestHome(t, t.TempDir())
cmd := &cobra.Command{}
cmd.SetContext(context.Background())
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"}, launcherDeps{
buildState: nil,
runMenu: nil,
resolveRunModel: unexpectedRunModelResolution(t),
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
return launch.ErrCancelled
},
runModel: unexpectedModelLaunch(t),
})
if err != nil {
t.Fatalf("expected nil error on cancellation, got %v", err)
}
if !continueLoop {
t.Fatal("expected cancellation to continue the menu loop")
}
}

View File

@@ -301,7 +301,7 @@ Weigh anchor!
ParameterSize: "7B", ParameterSize: "7B",
QuantizationLevel: "FP16", QuantizationLevel: "FP16",
}, },
Requires: "0.14.0", Requires: "0.19.0",
}, false, &b); err != nil { }, false, &b); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -310,10 +310,17 @@ Weigh anchor!
architecture test architecture test
parameters 7B parameters 7B
quantization FP16 quantization FP16
requires 0.14.0 requires 0.19.0
` `
if diff := cmp.Diff(expect, b.String()); diff != "" { trimLinePadding := func(s string) string {
lines := strings.Split(s, "\n")
for i, line := range lines {
lines[i] = strings.TrimRight(line, " \t\r")
}
return strings.Join(lines, "\n")
}
if diff := cmp.Diff(trimLinePadding(expect), trimLinePadding(b.String())); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff) t.Errorf("unexpected output (-want +got):\n%s", diff)
} }
}) })
@@ -705,6 +712,347 @@ func TestRunEmbeddingModelNoInput(t *testing.T) {
} }
} }
func TestRunHandler_CloudAuthErrorOnShow_PrintsSigninMessage(t *testing.T) {
var generateCalled bool
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusUnauthorized)
if err := json.NewEncoder(w).Encode(map[string]string{
"error": "unauthorized",
"signin_url": "https://ollama.com/signin",
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
generateCalled = true
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
default:
http.NotFound(w, r)
}
}))
t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close)
cmd := &cobra.Command{}
cmd.SetContext(t.Context())
cmd.Flags().String("keepalive", "", "")
cmd.Flags().Bool("truncate", false, "")
cmd.Flags().Int("dimensions", 0, "")
cmd.Flags().Bool("verbose", false, "")
cmd.Flags().Bool("insecure", false, "")
cmd.Flags().Bool("nowordwrap", false, "")
cmd.Flags().String("format", "", "")
cmd.Flags().String("think", "", "")
cmd.Flags().Bool("hidethinking", false, "")
oldStdout := os.Stdout
readOut, writeOut, _ := os.Pipe()
os.Stdout = writeOut
t.Cleanup(func() { os.Stdout = oldStdout })
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
_ = writeOut.Close()
var out bytes.Buffer
_, _ = io.Copy(&out, readOut)
if err != nil {
t.Fatalf("RunHandler returned error: %v", err)
}
if generateCalled {
t.Fatal("expected run to stop before /api/generate after unauthorized /api/show")
}
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
t.Fatalf("expected sign-in guidance message, got %q", out.String())
}
if !strings.Contains(out.String(), "https://ollama.com/signin") {
t.Fatalf("expected signin_url in output, got %q", out.String())
}
}
func TestRunHandler_CloudAuthErrorOnGenerate_PrintsSigninMessage(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ShowResponse{
Capabilities: []model.Capability{model.CapabilityCompletion},
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusUnauthorized)
if err := json.NewEncoder(w).Encode(map[string]string{
"error": "unauthorized",
"signin_url": "https://ollama.com/signin",
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
default:
http.NotFound(w, r)
}
}))
t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close)
cmd := &cobra.Command{}
cmd.SetContext(t.Context())
cmd.Flags().String("keepalive", "", "")
cmd.Flags().Bool("truncate", false, "")
cmd.Flags().Int("dimensions", 0, "")
cmd.Flags().Bool("verbose", false, "")
cmd.Flags().Bool("insecure", false, "")
cmd.Flags().Bool("nowordwrap", false, "")
cmd.Flags().String("format", "", "")
cmd.Flags().String("think", "", "")
cmd.Flags().Bool("hidethinking", false, "")
oldStdout := os.Stdout
readOut, writeOut, _ := os.Pipe()
os.Stdout = writeOut
t.Cleanup(func() { os.Stdout = oldStdout })
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
_ = writeOut.Close()
var out bytes.Buffer
_, _ = io.Copy(&out, readOut)
if err != nil {
t.Fatalf("RunHandler returned error: %v", err)
}
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
t.Fatalf("expected sign-in guidance message, got %q", out.String())
}
if !strings.Contains(out.String(), "https://ollama.com/signin") {
t.Fatalf("expected signin_url in output, got %q", out.String())
}
}
func TestRunHandler_ExplicitCloudStubMissing_PullsNormalizedNameTEMP(t *testing.T) {
var pulledModel string
var generateCalled bool
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ShowResponse{
Capabilities: []model.Capability{model.CapabilityCompletion},
RemoteModel: "gpt-oss:20b",
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/tags" && r.Method == http.MethodGet:
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ListResponse{Models: nil}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/pull" && r.Method == http.MethodPost:
var req api.PullRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
pulledModel = req.Model
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ProgressResponse{Status: "success"}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
generateCalled = true
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
default:
http.NotFound(w, r)
}
}))
t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close)
cmd := &cobra.Command{}
cmd.SetContext(t.Context())
cmd.Flags().String("keepalive", "", "")
cmd.Flags().Bool("truncate", false, "")
cmd.Flags().Int("dimensions", 0, "")
cmd.Flags().Bool("verbose", false, "")
cmd.Flags().Bool("insecure", false, "")
cmd.Flags().Bool("nowordwrap", false, "")
cmd.Flags().String("format", "", "")
cmd.Flags().String("think", "", "")
cmd.Flags().Bool("hidethinking", false, "")
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
if err != nil {
t.Fatalf("RunHandler returned error: %v", err)
}
if pulledModel != "gpt-oss:20b-cloud" {
t.Fatalf("expected normalized pull model %q, got %q", "gpt-oss:20b-cloud", pulledModel)
}
if !generateCalled {
t.Fatal("expected /api/generate to be called")
}
}
func TestRunHandler_ExplicitCloudStubPresent_SkipsPullTEMP(t *testing.T) {
var pullCalled bool
var generateCalled bool
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ShowResponse{
Capabilities: []model.Capability{model.CapabilityCompletion},
RemoteModel: "gpt-oss:20b",
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/tags" && r.Method == http.MethodGet:
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ListResponse{
Models: []api.ListModelResponse{{Name: "gpt-oss:20b-cloud"}},
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/pull" && r.Method == http.MethodPost:
pullCalled = true
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ProgressResponse{Status: "success"}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
generateCalled = true
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
default:
http.NotFound(w, r)
}
}))
t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close)
cmd := &cobra.Command{}
cmd.SetContext(t.Context())
cmd.Flags().String("keepalive", "", "")
cmd.Flags().Bool("truncate", false, "")
cmd.Flags().Int("dimensions", 0, "")
cmd.Flags().Bool("verbose", false, "")
cmd.Flags().Bool("insecure", false, "")
cmd.Flags().Bool("nowordwrap", false, "")
cmd.Flags().String("format", "", "")
cmd.Flags().String("think", "", "")
cmd.Flags().Bool("hidethinking", false, "")
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
if err != nil {
t.Fatalf("RunHandler returned error: %v", err)
}
if pullCalled {
t.Fatal("expected /api/pull not to be called when cloud stub already exists")
}
if !generateCalled {
t.Fatal("expected /api/generate to be called")
}
}
func TestRunHandler_ExplicitCloudStubPullFailure_IsBestEffortTEMP(t *testing.T) {
var generateCalled bool
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ShowResponse{
Capabilities: []model.Capability{model.CapabilityCompletion},
RemoteModel: "gpt-oss:20b",
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/tags" && r.Method == http.MethodGet:
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ListResponse{Models: nil}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/pull" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusInternalServerError)
if err := json.NewEncoder(w).Encode(map[string]string{"error": "pull failed"}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
generateCalled = true
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
default:
http.NotFound(w, r)
}
}))
t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close)
cmd := &cobra.Command{}
cmd.SetContext(t.Context())
cmd.Flags().String("keepalive", "", "")
cmd.Flags().Bool("truncate", false, "")
cmd.Flags().Int("dimensions", 0, "")
cmd.Flags().Bool("verbose", false, "")
cmd.Flags().Bool("insecure", false, "")
cmd.Flags().Bool("nowordwrap", false, "")
cmd.Flags().String("format", "", "")
cmd.Flags().String("think", "", "")
cmd.Flags().Bool("hidethinking", false, "")
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
if err != nil {
t.Fatalf("RunHandler returned error: %v", err)
}
if !generateCalled {
t.Fatal("expected /api/generate to be called despite pull failure")
}
}
func TestGetModelfileName(t *testing.T) { func TestGetModelfileName(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -1212,6 +1560,20 @@ func TestNewCreateRequest(t *testing.T) {
Model: "newmodel", Model: "newmodel",
}, },
}, },
{
"explicit cloud model preserves source when parent lacks it",
"newmodel",
runOptions{
Model: "qwen3.5:cloud",
ParentModel: "qwen3.5",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "qwen3.5:cloud",
Model: "newmodel",
},
},
{ {
"parent model as filepath test", "parent model as filepath test",
"newmodel", "newmodel",
@@ -1557,7 +1919,7 @@ func TestShowInfoImageGen(t *testing.T) {
QuantizationLevel: "Q8", QuantizationLevel: "Q8",
}, },
Capabilities: []model.Capability{model.CapabilityImage}, Capabilities: []model.Capability{model.CapabilityImage},
Requires: "0.14.0", Requires: "0.19.0",
}, false, &b) }, false, &b)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -1567,7 +1929,7 @@ func TestShowInfoImageGen(t *testing.T) {
" architecture ZImagePipeline \n" + " architecture ZImagePipeline \n" +
" parameters 10.3B \n" + " parameters 10.3B \n" +
" quantization Q8 \n" + " quantization Q8 \n" +
" requires 0.14.0 \n" + " requires 0.19.0 \n" +
"\n" + "\n" +
" Capabilities\n" + " Capabilities\n" +
" image \n" + " image \n" +
@@ -1663,31 +2025,81 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) { func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
remoteHost string model string
whoamiStatus int showStatus int
whoamiResp any remoteHost string
expectedError string remoteModel string
whoamiStatus int
whoamiResp any
expectWhoami bool
expectedError string
expectAuthError bool
}{ }{
{ {
name: "ollama.com cloud model - user signed in", name: "ollama.com cloud model - user signed in",
model: "test-cloud-model",
remoteHost: "https://ollama.com", remoteHost: "https://ollama.com",
remoteModel: "test-model",
whoamiStatus: http.StatusOK, whoamiStatus: http.StatusOK,
whoamiResp: api.UserResponse{Name: "testuser"}, whoamiResp: api.UserResponse{Name: "testuser"},
expectWhoami: true,
}, },
{ {
name: "ollama.com cloud model - user not signed in", name: "ollama.com cloud model - user not signed in",
model: "test-cloud-model",
remoteHost: "https://ollama.com", remoteHost: "https://ollama.com",
remoteModel: "test-model",
whoamiStatus: http.StatusUnauthorized, whoamiStatus: http.StatusUnauthorized,
whoamiResp: map[string]string{ whoamiResp: map[string]string{
"error": "unauthorized", "error": "unauthorized",
"signin_url": "https://ollama.com/signin", "signin_url": "https://ollama.com/signin",
}, },
expectedError: "unauthorized", expectWhoami: true,
expectedError: "unauthorized",
expectAuthError: true,
}, },
{ {
name: "non-ollama.com remote - no auth check", name: "non-ollama.com remote - no auth check",
model: "test-cloud-model",
remoteHost: "https://other-remote.com", remoteHost: "https://other-remote.com",
remoteModel: "test-model",
whoamiStatus: http.StatusUnauthorized, // should not be called
whoamiResp: nil,
},
{
name: "explicit :cloud model - auth check without remote metadata",
model: "kimi-k2.5:cloud",
remoteHost: "",
remoteModel: "",
whoamiStatus: http.StatusOK,
whoamiResp: api.UserResponse{Name: "testuser"},
expectWhoami: true,
},
{
name: "explicit :cloud model without local stub returns not found by default",
model: "minimax-m2.7:cloud",
showStatus: http.StatusNotFound,
whoamiStatus: http.StatusOK,
whoamiResp: api.UserResponse{Name: "testuser"},
expectedError: "not found",
expectWhoami: false,
expectAuthError: false,
},
{
name: "explicit -cloud model - auth check without remote metadata",
model: "kimi-k2.5:latest-cloud",
remoteHost: "",
remoteModel: "",
whoamiStatus: http.StatusOK,
whoamiResp: api.UserResponse{Name: "testuser"},
expectWhoami: true,
},
{
name: "dash cloud-like name without explicit source does not require auth",
model: "test-cloud-model",
remoteHost: "",
remoteModel: "",
whoamiStatus: http.StatusUnauthorized, // should not be called whoamiStatus: http.StatusUnauthorized, // should not be called
whoamiResp: nil, whoamiResp: nil,
}, },
@@ -1699,10 +2111,15 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path { switch r.URL.Path {
case "/api/show": case "/api/show":
if tt.showStatus != 0 && tt.showStatus != http.StatusOK {
w.WriteHeader(tt.showStatus)
_ = json.NewEncoder(w).Encode(map[string]string{"error": "not found"})
return
}
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(api.ShowResponse{ if err := json.NewEncoder(w).Encode(api.ShowResponse{
RemoteHost: tt.remoteHost, RemoteHost: tt.remoteHost,
RemoteModel: "test-model", RemoteModel: tt.remoteModel,
}); err != nil { }); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
} }
@@ -1715,6 +2132,8 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
} }
} }
case "/api/generate":
w.WriteHeader(http.StatusOK)
default: default:
http.NotFound(w, r) http.NotFound(w, r)
} }
@@ -1727,29 +2146,28 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
cmd.SetContext(t.Context()) cmd.SetContext(t.Context())
opts := &runOptions{ opts := &runOptions{
Model: "test-cloud-model", Model: tt.model,
ShowConnect: false, ShowConnect: false,
} }
err := loadOrUnloadModel(cmd, opts) err := loadOrUnloadModel(cmd, opts)
if strings.HasPrefix(tt.remoteHost, "https://ollama.com") { if whoamiCalled != tt.expectWhoami {
if !whoamiCalled { t.Errorf("whoami called = %v, want %v", whoamiCalled, tt.expectWhoami)
t.Error("expected whoami to be called for ollama.com cloud model")
}
} else {
if whoamiCalled {
t.Error("whoami should not be called for non-ollama.com remote")
}
} }
if tt.expectedError != "" { if tt.expectedError != "" {
if err == nil { if err == nil {
t.Errorf("expected error containing %q, got nil", tt.expectedError) t.Errorf("expected error containing %q, got nil", tt.expectedError)
} else { } else {
var authErr api.AuthorizationError if !tt.expectAuthError && !strings.Contains(strings.ToLower(err.Error()), strings.ToLower(tt.expectedError)) {
if !errors.As(err, &authErr) { t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
t.Errorf("expected AuthorizationError, got %T: %v", err, err) }
if tt.expectAuthError {
var authErr api.AuthorizationError
if !errors.As(err, &authErr) {
t.Errorf("expected AuthorizationError, got %T: %v", err, err)
}
} }
} }
} else { } else {
@@ -1760,3 +2178,38 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
}) })
} }
} }
func TestIsLocalhost(t *testing.T) {
tests := []struct {
name string
host string
expected bool
}{
{"default empty", "", true},
{"localhost no port", "localhost", true},
{"localhost with port", "localhost:11435", true},
{"127.0.0.1 no port", "127.0.0.1", true},
{"127.0.0.1 with port", "127.0.0.1:11434", true},
{"0.0.0.0 no port", "0.0.0.0", true},
{"0.0.0.0 with port", "0.0.0.0:11434", true},
{"::1 no port", "::1", true},
{"[::1] with port", "[::1]:11434", true},
{"loopback with scheme", "http://localhost:11434", true},
{"remote hostname", "example.com", false},
{"remote hostname with port", "example.com:11434", false},
{"remote IP", "192.168.1.1", false},
{"remote IP with port", "192.168.1.1:11434", false},
{"remote with scheme", "http://example.com:11434", false},
{"https remote", "https://example.com:443", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", tt.host)
got := isLocalhost()
if got != tt.expected {
t.Errorf("isLocalhost() with OLLAMA_HOST=%q = %v, want %v", tt.host, got, tt.expected)
}
})
}
}

View File

@@ -1,192 +0,0 @@
package config
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
// Claude implements Runner and AliasConfigurer for Claude Code integration
type Claude struct{}
// Compile-time check that Claude implements AliasConfigurer
var _ AliasConfigurer = (*Claude)(nil)
func (c *Claude) String() string { return "Claude Code" }
func (c *Claude) args(model string, extra []string) []string {
var args []string
if model != "" {
args = append(args, "--model", model)
}
args = append(args, extra...)
return args
}
func (c *Claude) findPath() (string, error) {
if p, err := exec.LookPath("claude"); err == nil {
return p, nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
name := "claude"
if runtime.GOOS == "windows" {
name = "claude.exe"
}
fallback := filepath.Join(home, ".claude", "local", name)
if _, err := os.Stat(fallback); err != nil {
return "", err
}
return fallback, nil
}
func (c *Claude) Run(model string, args []string) error {
claudePath, err := c.findPath()
if err != nil {
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
}
cmd := exec.Command(claudePath, c.args(model, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
env := append(os.Environ(),
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
"ANTHROPIC_API_KEY=",
"ANTHROPIC_AUTH_TOKEN=ollama",
)
env = append(env, c.modelEnvVars(model)...)
cmd.Env = env
return cmd.Run()
}
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
func (c *Claude) modelEnvVars(model string) []string {
primary := model
fast := model
if cfg, err := loadIntegration("claude"); err == nil && cfg.Aliases != nil {
if p := cfg.Aliases["primary"]; p != "" {
primary = p
}
if f := cfg.Aliases["fast"]; f != "" {
fast = f
}
}
return []string{
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + primary,
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + primary,
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + fast,
"CLAUDE_CODE_SUBAGENT_MODEL=" + primary,
}
}
// ConfigureAliases sets up model aliases for Claude Code.
// model: the model to use (if empty, user will be prompted to select)
// aliases: existing alias configuration to preserve/update
// Cloud-only: subagent routing (fast model) is gated to cloud models only until
// there is a better strategy for prompt caching on local models.
func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAliases map[string]string, force bool) (map[string]string, bool, error) {
aliases := make(map[string]string)
for k, v := range existingAliases {
aliases[k] = v
}
if model != "" {
aliases["primary"] = model
}
if !force && aliases["primary"] != "" {
client, _ := api.ClientFromEnvironment()
if isCloudModel(ctx, client, aliases["primary"]) {
if isCloudModel(ctx, client, aliases["fast"]) {
return aliases, false, nil
}
} else {
delete(aliases, "fast")
return aliases, false, nil
}
}
items, existingModels, cloudModels, client, err := listModels(ctx)
if err != nil {
return nil, false, err
}
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
if aliases["primary"] == "" || force {
primary, err := DefaultSingleSelector("Select model:", items, aliases["primary"])
if err != nil {
return nil, false, err
}
if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil {
return nil, false, err
}
if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil {
return nil, false, err
}
aliases["primary"] = primary
}
if isCloudModel(ctx, client, aliases["primary"]) {
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
aliases["fast"] = aliases["primary"]
}
} else {
delete(aliases, "fast")
}
return aliases, true, nil
}
// SetAliases syncs the configured aliases to the Ollama server using prefix matching.
// Cloud-only: for local models (fast is empty), we delete any existing aliases to
// prevent stale routing to a previous cloud model.
func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
prefixes := []string{"claude-sonnet-", "claude-haiku-"}
if aliases["fast"] == "" {
for _, prefix := range prefixes {
_ = client.DeleteAliasExperimental(ctx, &api.AliasDeleteRequest{Alias: prefix})
}
return nil
}
prefixAliases := map[string]string{
"claude-sonnet-": aliases["primary"],
"claude-haiku-": aliases["fast"],
}
var errs []string
for prefix, target := range prefixAliases {
req := &api.AliasRequest{
Alias: prefix,
Target: target,
PrefixMatching: true,
}
if err := client.SetAliasExperimental(ctx, req); err != nil {
errs = append(errs, prefix)
}
}
if len(errs) > 0 {
return fmt.Errorf("failed to set aliases: %v", errs)
}
return nil
}

View File

@@ -1,67 +0,0 @@
package config
import (
"fmt"
"os"
"os/exec"
"strings"
"github.com/ollama/ollama/envconfig"
"golang.org/x/mod/semver"
)
// Codex implements Runner for Codex integration
type Codex struct{}
func (c *Codex) String() string { return "Codex" }
func (c *Codex) args(model string, extra []string) []string {
args := []string{"--oss"}
if model != "" {
args = append(args, "-m", model)
}
args = append(args, extra...)
return args
}
func (c *Codex) Run(model string, args []string) error {
if err := checkCodexVersion(); err != nil {
return err
}
cmd := exec.Command("codex", c.args(model, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(),
"OPENAI_BASE_URL="+envconfig.Host().String()+"/v1/",
"OPENAI_API_KEY=ollama",
)
return cmd.Run()
}
func checkCodexVersion() error {
if _, err := exec.LookPath("codex"); err != nil {
return fmt.Errorf("codex is not installed, install with: npm install -g @openai/codex")
}
out, err := exec.Command("codex", "--version").Output()
if err != nil {
return fmt.Errorf("failed to get codex version: %w", err)
}
// Parse output like "codex-cli 0.87.0"
fields := strings.Fields(strings.TrimSpace(string(out)))
if len(fields) < 2 {
return fmt.Errorf("unexpected codex version output: %s", string(out))
}
version := "v" + fields[len(fields)-1]
minVersion := "v0.81.0"
if semver.Compare(version, minVersion) < 0 {
return fmt.Errorf("codex version %s is too old, minimum required is %s, update with: npm update -g @openai/codex", fields[len(fields)-1], "0.81.0")
}
return nil
}

View File

@@ -1,31 +0,0 @@
package config
import (
"slices"
"testing"
)
func TestCodexArgs(t *testing.T) {
c := &Codex{}
tests := []struct {
name string
model string
args []string
want []string
}{
{"with model", "llama3.2", nil, []string{"--oss", "-m", "llama3.2"}},
{"empty model", "", nil, []string{"--oss"}},
{"with model and profile", "qwen3-coder", []string{"-p", "myprofile"}, []string{"--oss", "-m", "qwen3-coder", "-p", "myprofile"}},
{"with sandbox flag", "llama3.2", []string{"--sandbox", "workspace-write"}, []string{"--oss", "-m", "llama3.2", "--sandbox", "workspace-write"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model, tt.args)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
}
})
}
}

View File

@@ -3,7 +3,6 @@
package config package config
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -11,7 +10,7 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/cmd/internal/fileutil"
) )
type integration struct { type integration struct {
@@ -20,6 +19,9 @@ type integration struct {
Onboarded bool `json:"onboarded,omitempty"` Onboarded bool `json:"onboarded,omitempty"`
} }
// IntegrationConfig is the persisted config for one integration.
type IntegrationConfig = integration
type config struct { type config struct {
Integrations map[string]*integration `json:"integrations"` Integrations map[string]*integration `json:"integrations"`
LastModel string `json:"last_model,omitempty"` LastModel string `json:"last_model,omitempty"`
@@ -124,7 +126,7 @@ func save(cfg *config) error {
return err return err
} }
return writeWithBackup(path, data) return fileutil.WriteWithBackup(path, data)
} }
func SaveIntegration(appName string, models []string) error { func SaveIntegration(appName string, models []string) error {
@@ -155,8 +157,8 @@ func SaveIntegration(appName string, models []string) error {
return save(cfg) return save(cfg)
} }
// integrationOnboarded marks an integration as onboarded in ollama's config. // MarkIntegrationOnboarded marks an integration as onboarded in Ollama's config.
func integrationOnboarded(appName string) error { func MarkIntegrationOnboarded(appName string) error {
cfg, err := load() cfg, err := load()
if err != nil { if err != nil {
return err return err
@@ -174,7 +176,7 @@ func integrationOnboarded(appName string) error {
// IntegrationModel returns the first configured model for an integration, or empty string if not configured. // IntegrationModel returns the first configured model for an integration, or empty string if not configured.
func IntegrationModel(appName string) string { func IntegrationModel(appName string) string {
integrationConfig, err := loadIntegration(appName) integrationConfig, err := LoadIntegration(appName)
if err != nil || len(integrationConfig.Models) == 0 { if err != nil || len(integrationConfig.Models) == 0 {
return "" return ""
} }
@@ -183,7 +185,7 @@ func IntegrationModel(appName string) string {
// IntegrationModels returns all configured models for an integration, or nil. // IntegrationModels returns all configured models for an integration, or nil.
func IntegrationModels(appName string) []string { func IntegrationModels(appName string) []string {
integrationConfig, err := loadIntegration(appName) integrationConfig, err := LoadIntegration(appName)
if err != nil || len(integrationConfig.Models) == 0 { if err != nil || len(integrationConfig.Models) == 0 {
return nil return nil
} }
@@ -228,28 +230,8 @@ func SetLastSelection(selection string) error {
return save(cfg) return save(cfg)
} }
// ModelExists checks if a model exists on the Ollama server. // LoadIntegration returns the saved config for one integration.
func ModelExists(ctx context.Context, name string) bool { func LoadIntegration(appName string) (*integration, error) {
if name == "" {
return false
}
client, err := api.ClientFromEnvironment()
if err != nil {
return false
}
models, err := client.List(ctx)
if err != nil {
return false
}
for _, m := range models.Models {
if m.Name == name || strings.HasPrefix(m.Name, name+":") {
return true
}
}
return false
}
func loadIntegration(appName string) (*integration, error) {
cfg, err := load() cfg, err := load()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -263,7 +245,8 @@ func loadIntegration(appName string) (*integration, error) {
return integrationConfig, nil return integrationConfig, nil
} }
func saveAliases(appName string, aliases map[string]string) error { // SaveAliases replaces the saved aliases for one integration.
func SaveAliases(appName string, aliases map[string]string) error {
if appName == "" { if appName == "" {
return errors.New("app name cannot be empty") return errors.New("app name cannot be empty")
} }

View File

@@ -1,7 +1,6 @@
package config package config
import ( import (
"context"
"errors" "errors"
"os" "os"
"path/filepath" "path/filepath"
@@ -45,12 +44,12 @@ func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
"primary": "cloud-model", "primary": "cloud-model",
"fast": "cloud-model", "fast": "cloud-model",
} }
if err := saveAliases("claude", initial); err != nil { if err := SaveAliases("claude", initial); err != nil {
t.Fatalf("failed to save initial aliases: %v", err) t.Fatalf("failed to save initial aliases: %v", err)
} }
// Verify both are saved // Verify both are saved
loaded, err := loadIntegration("claude") loaded, err := LoadIntegration("claude")
if err != nil { if err != nil {
t.Fatalf("failed to load: %v", err) t.Fatalf("failed to load: %v", err)
} }
@@ -63,12 +62,12 @@ func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
"primary": "local-model", "primary": "local-model",
// fast intentionally missing // fast intentionally missing
} }
if err := saveAliases("claude", updated); err != nil { if err := SaveAliases("claude", updated); err != nil {
t.Fatalf("failed to save updated aliases: %v", err) t.Fatalf("failed to save updated aliases: %v", err)
} }
// Verify fast is GONE (not merged/preserved) // Verify fast is GONE (not merged/preserved)
loaded, err = loadIntegration("claude") loaded, err = LoadIntegration("claude")
if err != nil { if err != nil {
t.Fatalf("failed to load after update: %v", err) t.Fatalf("failed to load after update: %v", err)
} }
@@ -91,12 +90,12 @@ func TestSaveAliases_PreservesModels(t *testing.T) {
// Then update aliases // Then update aliases
aliases := map[string]string{"primary": "new-model"} aliases := map[string]string{"primary": "new-model"}
if err := saveAliases("claude", aliases); err != nil { if err := SaveAliases("claude", aliases); err != nil {
t.Fatalf("failed to save aliases: %v", err) t.Fatalf("failed to save aliases: %v", err)
} }
// Verify models are preserved // Verify models are preserved
loaded, err := loadIntegration("claude") loaded, err := LoadIntegration("claude")
if err != nil { if err != nil {
t.Fatalf("failed to load: %v", err) t.Fatalf("failed to load: %v", err)
} }
@@ -111,16 +110,16 @@ func TestSaveAliases_EmptyMap(t *testing.T) {
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
// Save with aliases // Save with aliases
if err := saveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil { if err := SaveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
t.Fatalf("failed to save: %v", err) t.Fatalf("failed to save: %v", err)
} }
// Save empty map // Save empty map
if err := saveAliases("claude", map[string]string{}); err != nil { if err := SaveAliases("claude", map[string]string{}); err != nil {
t.Fatalf("failed to save empty: %v", err) t.Fatalf("failed to save empty: %v", err)
} }
loaded, err := loadIntegration("claude") loaded, err := LoadIntegration("claude")
if err != nil { if err != nil {
t.Fatalf("failed to load: %v", err) t.Fatalf("failed to load: %v", err)
} }
@@ -135,16 +134,16 @@ func TestSaveAliases_NilMap(t *testing.T) {
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
// Save with aliases first // Save with aliases first
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil { if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("failed to save: %v", err) t.Fatalf("failed to save: %v", err)
} }
// Save nil map - should clear aliases // Save nil map - should clear aliases
if err := saveAliases("claude", nil); err != nil { if err := SaveAliases("claude", nil); err != nil {
t.Fatalf("failed to save nil: %v", err) t.Fatalf("failed to save nil: %v", err)
} }
loaded, err := loadIntegration("claude") loaded, err := LoadIntegration("claude")
if err != nil { if err != nil {
t.Fatalf("failed to load: %v", err) t.Fatalf("failed to load: %v", err)
} }
@@ -155,7 +154,7 @@ func TestSaveAliases_NilMap(t *testing.T) {
// TestSaveAliases_EmptyAppName returns error // TestSaveAliases_EmptyAppName returns error
func TestSaveAliases_EmptyAppName(t *testing.T) { func TestSaveAliases_EmptyAppName(t *testing.T) {
err := saveAliases("", map[string]string{"primary": "model"}) err := SaveAliases("", map[string]string{"primary": "model"})
if err == nil { if err == nil {
t.Error("expected error for empty app name") t.Error("expected error for empty app name")
} }
@@ -165,12 +164,12 @@ func TestSaveAliases_CaseInsensitive(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
if err := saveAliases("Claude", map[string]string{"primary": "model1"}); err != nil { if err := SaveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
t.Fatalf("failed to save: %v", err) t.Fatalf("failed to save: %v", err)
} }
// Load with different case // Load with different case
loaded, err := loadIntegration("claude") loaded, err := LoadIntegration("claude")
if err != nil { if err != nil {
t.Fatalf("failed to load: %v", err) t.Fatalf("failed to load: %v", err)
} }
@@ -179,11 +178,11 @@ func TestSaveAliases_CaseInsensitive(t *testing.T) {
} }
// Update with different case // Update with different case
if err := saveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil { if err := SaveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
t.Fatalf("failed to update: %v", err) t.Fatalf("failed to update: %v", err)
} }
loaded, err = loadIntegration("claude") loaded, err = LoadIntegration("claude")
if err != nil { if err != nil {
t.Fatalf("failed to load after update: %v", err) t.Fatalf("failed to load after update: %v", err)
} }
@@ -198,11 +197,11 @@ func TestSaveAliases_CreatesIntegration(t *testing.T) {
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
// Save aliases for non-existent integration // Save aliases for non-existent integration
if err := saveAliases("newintegration", map[string]string{"primary": "model"}); err != nil { if err := SaveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("failed to save: %v", err) t.Fatalf("failed to save: %v", err)
} }
loaded, err := loadIntegration("newintegration") loaded, err := LoadIntegration("newintegration")
if err != nil { if err != nil {
t.Fatalf("failed to load: %v", err) t.Fatalf("failed to load: %v", err)
} }
@@ -371,12 +370,12 @@ func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) {
t.Fatal("server should succeed") t.Fatal("server should succeed")
} }
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil { if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("saveAliases failed: %v", err) t.Fatalf("saveAliases failed: %v", err)
} }
// Verify it was actually saved // Verify it was actually saved
loaded, err := loadIntegration("claude") loaded, err := LoadIntegration("claude")
if err != nil { if err != nil {
t.Fatalf("failed to load: %v", err) t.Fatalf("failed to load: %v", err)
} }
@@ -408,7 +407,7 @@ func TestConfigFile_PreservesUnknownFields(t *testing.T) {
os.WriteFile(configPath, []byte(initialConfig), 0o644) os.WriteFile(configPath, []byte(initialConfig), 0o644)
// Update aliases // Update aliases
if err := saveAliases("claude", map[string]string{"primary": "model2"}); err != nil { if err := SaveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
t.Fatalf("failed to save: %v", err) t.Fatalf("failed to save: %v", err)
} }
@@ -440,11 +439,6 @@ func containsHelper(s, substr string) bool {
return false return false
} }
func TestClaudeImplementsAliasConfigurer(t *testing.T) {
c := &Claude{}
var _ AliasConfigurer = c // Compile-time check
}
func TestModelNameEdgeCases(t *testing.T) { func TestModelNameEdgeCases(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
@@ -464,11 +458,11 @@ func TestModelNameEdgeCases(t *testing.T) {
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
aliases := map[string]string{"primary": tc.model} aliases := map[string]string{"primary": tc.model}
if err := saveAliases("claude", aliases); err != nil { if err := SaveAliases("claude", aliases); err != nil {
t.Fatalf("failed to save model %q: %v", tc.model, err) t.Fatalf("failed to save model %q: %v", tc.model, err)
} }
loaded, err := loadIntegration("claude") loaded, err := LoadIntegration("claude")
if err != nil { if err != nil {
t.Fatalf("failed to load: %v", err) t.Fatalf("failed to load: %v", err)
} }
@@ -485,7 +479,7 @@ func TestSwitchingScenarios(t *testing.T) {
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
// Initial cloud config // Initial cloud config
if err := saveAliases("claude", map[string]string{ if err := SaveAliases("claude", map[string]string{
"primary": "cloud-model", "primary": "cloud-model",
"fast": "cloud-model", "fast": "cloud-model",
}); err != nil { }); err != nil {
@@ -493,13 +487,13 @@ func TestSwitchingScenarios(t *testing.T) {
} }
// Switch to local (no fast) // Switch to local (no fast)
if err := saveAliases("claude", map[string]string{ if err := SaveAliases("claude", map[string]string{
"primary": "local-model", "primary": "local-model",
}); err != nil { }); err != nil {
t.Fatal(err) t.Fatal(err)
} }
loaded, _ := loadIntegration("claude") loaded, _ := LoadIntegration("claude")
if loaded.Aliases["fast"] != "" { if loaded.Aliases["fast"] != "" {
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"]) t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
} }
@@ -513,21 +507,21 @@ func TestSwitchingScenarios(t *testing.T) {
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
// Initial local config // Initial local config
if err := saveAliases("claude", map[string]string{ if err := SaveAliases("claude", map[string]string{
"primary": "local-model", "primary": "local-model",
}); err != nil { }); err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Switch to cloud (with fast) // Switch to cloud (with fast)
if err := saveAliases("claude", map[string]string{ if err := SaveAliases("claude", map[string]string{
"primary": "cloud-model", "primary": "cloud-model",
"fast": "cloud-model", "fast": "cloud-model",
}); err != nil { }); err != nil {
t.Fatal(err) t.Fatal(err)
} }
loaded, _ := loadIntegration("claude") loaded, _ := LoadIntegration("claude")
if loaded.Aliases["fast"] != "cloud-model" { if loaded.Aliases["fast"] != "cloud-model" {
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"]) t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
} }
@@ -538,7 +532,7 @@ func TestSwitchingScenarios(t *testing.T) {
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
// Initial cloud config // Initial cloud config
if err := saveAliases("claude", map[string]string{ if err := SaveAliases("claude", map[string]string{
"primary": "cloud-model-1", "primary": "cloud-model-1",
"fast": "cloud-model-1", "fast": "cloud-model-1",
}); err != nil { }); err != nil {
@@ -546,14 +540,14 @@ func TestSwitchingScenarios(t *testing.T) {
} }
// Switch to different cloud // Switch to different cloud
if err := saveAliases("claude", map[string]string{ if err := SaveAliases("claude", map[string]string{
"primary": "cloud-model-2", "primary": "cloud-model-2",
"fast": "cloud-model-2", "fast": "cloud-model-2",
}); err != nil { }); err != nil {
t.Fatal(err) t.Fatal(err)
} }
loaded, _ := loadIntegration("claude") loaded, _ := LoadIntegration("claude")
if loaded.Aliases["primary"] != "cloud-model-2" { if loaded.Aliases["primary"] != "cloud-model-2" {
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"]) t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
} }
@@ -563,43 +557,13 @@ func TestSwitchingScenarios(t *testing.T) {
}) })
} }
func TestToolCapabilityFiltering(t *testing.T) {
t.Run("all models checked for tool capability", func(t *testing.T) {
// Both cloud and local models are checked for tool capability via Show API
// Only models with "tools" in capabilities are included
m := modelInfo{Name: "tool-model", Remote: false, ToolCapable: true}
if !m.ToolCapable {
t.Error("tool capable model should be marked as such")
}
})
t.Run("modelInfo includes ToolCapable field", func(t *testing.T) {
m := modelInfo{Name: "test", Remote: true, ToolCapable: true}
if !m.ToolCapable {
t.Error("ToolCapable field should be accessible")
}
})
}
func TestIsCloudModel_RequiresClient(t *testing.T) {
t.Run("nil client always returns false", func(t *testing.T) {
// isCloudModel now only uses Show API, no suffix detection
if isCloudModel(context.Background(), nil, "model:cloud") {
t.Error("nil client should return false regardless of suffix")
}
if isCloudModel(context.Background(), nil, "local-model") {
t.Error("nil client should return false")
}
})
}
func TestModelsAndAliasesMustStayInSync(t *testing.T) { func TestModelsAndAliasesMustStayInSync(t *testing.T) {
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) { t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
// Save aliases with one model // Save aliases with one model
if err := saveAliases("claude", map[string]string{"primary": "model-a"}); err != nil { if err := SaveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -608,7 +572,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
loaded, _ := loadIntegration("claude") loaded, _ := LoadIntegration("claude")
if loaded.Aliases["primary"] != loaded.Models[0] { if loaded.Aliases["primary"] != loaded.Models[0] {
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0]) t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
} }
@@ -622,11 +586,11 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
if err := SaveIntegration("claude", []string{"old-model"}); err != nil { if err := SaveIntegration("claude", []string{"old-model"}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil { if err := SaveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
loaded, _ := loadIntegration("claude") loaded, _ := LoadIntegration("claude")
// They should be different (this is the bug state) // They should be different (this is the bug state)
if loaded.Models[0] == loaded.Aliases["primary"] { if loaded.Models[0] == loaded.Aliases["primary"] {
@@ -638,7 +602,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
loaded, _ = loadIntegration("claude") loaded, _ = LoadIntegration("claude")
if loaded.Models[0] != loaded.Aliases["primary"] { if loaded.Models[0] != loaded.Aliases["primary"] {
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)", t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
loaded.Models[0], loaded.Aliases["primary"]) loaded.Models[0], loaded.Aliases["primary"])
@@ -653,20 +617,20 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
if err := SaveIntegration("claude", []string{"initial-model"}); err != nil { if err := SaveIntegration("claude", []string{"initial-model"}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil { if err := SaveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Update aliases AND models together // Update aliases AND models together
newAliases := map[string]string{"primary": "updated-model"} newAliases := map[string]string{"primary": "updated-model"}
if err := saveAliases("claude", newAliases); err != nil { if err := SaveAliases("claude", newAliases); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := SaveIntegration("claude", []string{newAliases["primary"]}); err != nil { if err := SaveIntegration("claude", []string{newAliases["primary"]}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
loaded, _ := loadIntegration("claude") loaded, _ := LoadIntegration("claude")
if loaded.Models[0] != "updated-model" { if loaded.Models[0] != "updated-model" {
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0]) t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
} }

View File

@@ -10,17 +10,10 @@ import (
// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests // setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests
func setTestHome(t *testing.T, dir string) { func setTestHome(t *testing.T, dir string) {
t.Setenv("HOME", dir) t.Setenv("HOME", dir)
t.Setenv("TMPDIR", dir)
t.Setenv("USERPROFILE", dir) t.Setenv("USERPROFILE", dir)
} }
// editorPaths is a test helper that safely calls Paths if the runner implements Editor
func editorPaths(r Runner) []string {
if editor, ok := r.(Editor); ok {
return editor.Paths()
}
return nil
}
func TestIntegrationConfig(t *testing.T) { func TestIntegrationConfig(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
@@ -31,7 +24,7 @@ func TestIntegrationConfig(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
config, err := loadIntegration("claude") config, err := LoadIntegration("claude")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -55,11 +48,11 @@ func TestIntegrationConfig(t *testing.T) {
"primary": "llama3.2:70b", "primary": "llama3.2:70b",
"fast": "llama3.2:8b", "fast": "llama3.2:8b",
} }
if err := saveAliases("claude", aliases); err != nil { if err := SaveAliases("claude", aliases); err != nil {
t.Fatal(err) t.Fatal(err)
} }
config, err := loadIntegration("claude") config, err := LoadIntegration("claude")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -77,14 +70,14 @@ func TestIntegrationConfig(t *testing.T) {
if err := SaveIntegration("claude", []string{"model-a"}); err != nil { if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil { if err := SaveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := SaveIntegration("claude", []string{"model-b"}); err != nil { if err := SaveIntegration("claude", []string{"model-b"}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
config, err := loadIntegration("claude") config, err := LoadIntegration("claude")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -96,7 +89,7 @@ func TestIntegrationConfig(t *testing.T) {
t.Run("defaultModel returns first model", func(t *testing.T) { t.Run("defaultModel returns first model", func(t *testing.T) {
SaveIntegration("codex", []string{"model-a", "model-b"}) SaveIntegration("codex", []string{"model-a", "model-b"})
config, _ := loadIntegration("codex") config, _ := LoadIntegration("codex")
defaultModel := "" defaultModel := ""
if len(config.Models) > 0 { if len(config.Models) > 0 {
defaultModel = config.Models[0] defaultModel = config.Models[0]
@@ -120,7 +113,7 @@ func TestIntegrationConfig(t *testing.T) {
t.Run("app name is case-insensitive", func(t *testing.T) { t.Run("app name is case-insensitive", func(t *testing.T) {
SaveIntegration("Claude", []string{"model-x"}) SaveIntegration("Claude", []string{"model-x"})
config, err := loadIntegration("claude") config, err := LoadIntegration("claude")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -137,8 +130,8 @@ func TestIntegrationConfig(t *testing.T) {
SaveIntegration("app1", []string{"model-1"}) SaveIntegration("app1", []string{"model-1"})
SaveIntegration("app2", []string{"model-2"}) SaveIntegration("app2", []string{"model-2"})
config1, _ := loadIntegration("app1") config1, _ := LoadIntegration("app1")
config2, _ := loadIntegration("app2") config2, _ := LoadIntegration("app2")
defaultModel1 := "" defaultModel1 := ""
if len(config1.Models) > 0 { if len(config1.Models) > 0 {
@@ -185,64 +178,6 @@ func TestListIntegrations(t *testing.T) {
}) })
} }
func TestEditorPaths(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("returns empty for claude (no Editor)", func(t *testing.T) {
r := integrations["claude"]
paths := editorPaths(r)
if len(paths) != 0 {
t.Errorf("expected no paths for claude, got %v", paths)
}
})
t.Run("returns empty for codex (no Editor)", func(t *testing.T) {
r := integrations["codex"]
paths := editorPaths(r)
if len(paths) != 0 {
t.Errorf("expected no paths for codex, got %v", paths)
}
})
t.Run("returns empty for droid when no config exists", func(t *testing.T) {
r := integrations["droid"]
paths := editorPaths(r)
if len(paths) != 0 {
t.Errorf("expected no paths, got %v", paths)
}
})
t.Run("returns path for droid when config exists", func(t *testing.T) {
settingsDir, _ := os.UserHomeDir()
settingsDir = filepath.Join(settingsDir, ".factory")
os.MkdirAll(settingsDir, 0o755)
os.WriteFile(filepath.Join(settingsDir, "settings.json"), []byte(`{}`), 0o644)
r := integrations["droid"]
paths := editorPaths(r)
if len(paths) != 1 {
t.Errorf("expected 1 path, got %d", len(paths))
}
})
t.Run("returns paths for opencode when configs exist", func(t *testing.T) {
home, _ := os.UserHomeDir()
configDir := filepath.Join(home, ".config", "opencode")
stateDir := filepath.Join(home, ".local", "state", "opencode")
os.MkdirAll(configDir, 0o755)
os.MkdirAll(stateDir, 0o755)
os.WriteFile(filepath.Join(configDir, "opencode.json"), []byte(`{}`), 0o644)
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644)
r := integrations["opencode"]
paths := editorPaths(r)
if len(paths) != 2 {
t.Errorf("expected 2 paths, got %d: %v", len(paths), paths)
}
})
}
func TestLoadIntegration_CorruptedJSON(t *testing.T) { func TestLoadIntegration_CorruptedJSON(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
@@ -251,7 +186,7 @@ func TestLoadIntegration_CorruptedJSON(t *testing.T) {
os.MkdirAll(dir, 0o755) os.MkdirAll(dir, 0o755)
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644) os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
_, err := loadIntegration("test") _, err := LoadIntegration("test")
if err == nil { if err == nil {
t.Error("expected error for nonexistent integration in corrupted file") t.Error("expected error for nonexistent integration in corrupted file")
} }
@@ -265,7 +200,7 @@ func TestSaveIntegration_NilModels(t *testing.T) {
t.Fatalf("saveIntegration with nil models failed: %v", err) t.Fatalf("saveIntegration with nil models failed: %v", err)
} }
config, err := loadIntegration("test") config, err := LoadIntegration("test")
if err != nil { if err != nil {
t.Fatalf("loadIntegration failed: %v", err) t.Fatalf("loadIntegration failed: %v", err)
} }
@@ -294,7 +229,7 @@ func TestLoadIntegration_NonexistentIntegration(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
_, err := loadIntegration("nonexistent") _, err := LoadIntegration("nonexistent")
if err == nil { if err == nil {
t.Error("expected error for nonexistent integration, got nil") t.Error("expected error for nonexistent integration, got nil")
} }

File diff suppressed because it is too large Load Diff

View File

@@ -1,59 +0,0 @@
package config
import (
"errors"
"fmt"
"os"
"golang.org/x/term"
)
// ANSI escape sequences for terminal formatting.
const (
ansiBold = "\033[1m"
ansiReset = "\033[0m"
ansiGray = "\033[37m"
ansiGreen = "\033[32m"
ansiYellow = "\033[33m"
)
// ErrCancelled is returned when the user cancels a selection.
var ErrCancelled = errors.New("cancelled")
// errCancelled is kept as an alias for backward compatibility within the package.
var errCancelled = ErrCancelled
// DefaultConfirmPrompt provides a TUI-based confirmation prompt.
// When set, confirmPrompt delegates to it instead of using raw terminal I/O.
var DefaultConfirmPrompt func(prompt string) (bool, error)
func confirmPrompt(prompt string) (bool, error) {
if DefaultConfirmPrompt != nil {
return DefaultConfirmPrompt(prompt)
}
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
return false, err
}
defer term.Restore(fd, oldState)
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
buf := make([]byte, 1)
for {
if _, err := os.Stdin.Read(buf); err != nil {
return false, err
}
switch buf[0] {
case 'Y', 'y', 13:
fmt.Fprintf(os.Stderr, "yes\r\n")
return true, nil
case 'N', 'n', 27, 3:
fmt.Fprintf(os.Stderr, "no\r\n")
return false, nil
}
}
}

View File

@@ -1,19 +0,0 @@
package config
import (
"testing"
)
func TestErrCancelled(t *testing.T) {
t.Run("NotNil", func(t *testing.T) {
if errCancelled == nil {
t.Error("errCancelled should not be nil")
}
})
t.Run("Message", func(t *testing.T) {
if errCancelled.Error() != "cancelled" {
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
}
})
}

View File

@@ -17,6 +17,7 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/readline" "github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
@@ -46,7 +47,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.") fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
if opts.MultiModal { if opts.MultiModal {
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, or .webp images.\n", filepath.FromSlash("/path/to/file")) fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, .webp images, or .wav audio files.\n", filepath.FromSlash("/path/to/file"))
} }
fmt.Fprintln(os.Stderr, "") fmt.Fprintln(os.Stderr, "")
@@ -540,6 +541,13 @@ func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
parentModel = "" parentModel = ""
} }
// Preserve explicit cloud intent for sessions started with `:cloud`.
// Cloud model metadata can return a source-less parent_model (for example
// "qwen3.5"), which would otherwise make `/save` create a local derivative.
if modelref.HasExplicitCloudSource(opts.Model) && !modelref.HasExplicitCloudSource(parentModel) {
parentModel = ""
}
req := &api.CreateRequest{ req := &api.CreateRequest{
Model: name, Model: name,
From: cmp.Or(parentModel, opts.Model), From: cmp.Or(parentModel, opts.Model),
@@ -584,7 +592,7 @@ func extractFileNames(input string) []string {
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20) // Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
// and followed by more characters and a file extension // and followed by more characters and a file extension
// This will capture non filename strings, but we'll check for file existence to remove mismatches // This will capture non filename strings, but we'll check for file existence to remove mismatches
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp)\b` regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp|wav)\b`
re := regexp.MustCompile(regexPattern) re := regexp.MustCompile(regexPattern)
return re.FindAllString(input, -1) return re.FindAllString(input, -1)
@@ -600,10 +608,16 @@ func extractFileData(input string) (string, []api.ImageData, error) {
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
continue continue
} else if err != nil { } else if err != nil {
fmt.Fprintf(os.Stderr, "Couldn't process image: %q\n", err) fmt.Fprintf(os.Stderr, "Couldn't process file: %q\n", err)
return "", imgs, err return "", imgs, err
} }
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp) ext := strings.ToLower(filepath.Ext(nfp))
switch ext {
case ".wav":
fmt.Fprintf(os.Stderr, "Added audio '%s'\n", nfp)
default:
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
}
input = strings.ReplaceAll(input, "'"+nfp+"'", "") input = strings.ReplaceAll(input, "'"+nfp+"'", "")
input = strings.ReplaceAll(input, "'"+fp+"'", "") input = strings.ReplaceAll(input, "'"+fp+"'", "")
input = strings.ReplaceAll(input, fp, "") input = strings.ReplaceAll(input, fp, "")
@@ -677,9 +691,9 @@ func getImageData(filePath string) ([]byte, error) {
} }
contentType := http.DetectContentType(buf) contentType := http.DetectContentType(buf)
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp"} allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp", "audio/wave"}
if !slices.Contains(allowedTypes, contentType) { if !slices.Contains(allowedTypes, contentType) {
return nil, fmt.Errorf("invalid image type: %s", contentType) return nil, fmt.Errorf("invalid file type: %s", contentType)
} }
info, err := file.Stat() info, err := file.Stat()
@@ -687,8 +701,7 @@ func getImageData(filePath string) ([]byte, error) {
return nil, err return nil, err
} }
// Check if the file size exceeds 100MB var maxSize int64 = 100 * 1024 * 1024 // 100MB
var maxSize int64 = 100 * 1024 * 1024 // 100MB in bytes
if info.Size() > maxSize { if info.Size() > maxSize {
return nil, errors.New("file size exceeds maximum limit (100MB)") return nil, errors.New("file size exceeds maximum limit (100MB)")
} }

View File

@@ -84,3 +84,33 @@ func TestExtractFileDataRemovesQuotedFilepath(t *testing.T) {
assert.Len(t, imgs, 1) assert.Len(t, imgs, 1)
assert.Equal(t, cleaned, "before after") assert.Equal(t, cleaned, "before after")
} }
func TestExtractFileDataWAV(t *testing.T) {
dir := t.TempDir()
fp := filepath.Join(dir, "sample.wav")
data := make([]byte, 600)
copy(data[:44], []byte{
'R', 'I', 'F', 'F',
0x58, 0x02, 0x00, 0x00, // file size - 8
'W', 'A', 'V', 'E',
'f', 'm', 't', ' ',
0x10, 0x00, 0x00, 0x00, // fmt chunk size
0x01, 0x00, // PCM
0x01, 0x00, // mono
0x80, 0x3e, 0x00, 0x00, // 16000 Hz
0x00, 0x7d, 0x00, 0x00, // byte rate
0x02, 0x00, // block align
0x10, 0x00, // 16-bit
'd', 'a', 't', 'a',
0x34, 0x02, 0x00, 0x00, // data size
})
if err := os.WriteFile(fp, data, 0o600); err != nil {
t.Fatalf("failed to write test audio: %v", err)
}
input := "before " + fp + " after"
cleaned, imgs, err := extractFileData(input)
assert.NoError(t, err)
assert.Len(t, imgs, 1)
assert.Equal(t, "before after", cleaned)
}

View File

@@ -1,4 +1,6 @@
package config // Package fileutil provides small shared helpers for reading JSON files
// and writing config files with backup-on-overwrite semantics.
package fileutil
import ( import (
"bytes" "bytes"
@@ -9,7 +11,8 @@ import (
"time" "time"
) )
func readJSONFile(path string) (map[string]any, error) { // ReadJSON reads a JSON object file into a generic map.
func ReadJSON(path string) (map[string]any, error) {
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -33,12 +36,13 @@ func copyFile(src, dst string) error {
return os.WriteFile(dst, data, info.Mode().Perm()) return os.WriteFile(dst, data, info.Mode().Perm())
} }
func backupDir() string { // BackupDir returns the shared backup directory used before overwriting files.
func BackupDir() string {
return filepath.Join(os.TempDir(), "ollama-backups") return filepath.Join(os.TempDir(), "ollama-backups")
} }
func backupToTmp(srcPath string) (string, error) { func backupToTmp(srcPath string) (string, error) {
dir := backupDir() dir := BackupDir()
if err := os.MkdirAll(dir, 0o755); err != nil { if err := os.MkdirAll(dir, 0o755); err != nil {
return "", err return "", err
} }
@@ -50,8 +54,8 @@ func backupToTmp(srcPath string) (string, error) {
return backupPath, nil return backupPath, nil
} }
// writeWithBackup writes data to path via temp file + rename, backing up any existing file first // WriteWithBackup writes data to path via temp file + rename, backing up any existing file first.
func writeWithBackup(path string, data []byte) error { func WriteWithBackup(path string, data []byte) error {
var backupPath string var backupPath string
// backup must be created before any writes to the target file // backup must be created before any writes to the target file
if existingContent, err := os.ReadFile(path); err == nil { if existingContent, err := os.ReadFile(path); err == nil {

View File

@@ -1,4 +1,4 @@
package config package fileutil
import ( import (
"encoding/json" "encoding/json"
@@ -9,6 +9,21 @@ import (
"testing" "testing"
) )
func TestMain(m *testing.M) {
tmpRoot, err := os.MkdirTemp("", "fileutil-test-*")
if err != nil {
panic(err)
}
if err := os.Setenv("TMPDIR", tmpRoot); err != nil {
panic(err)
}
code := m.Run()
_ = os.RemoveAll(tmpRoot)
os.Exit(code)
}
func mustMarshal(t *testing.T, v any) []byte { func mustMarshal(t *testing.T, v any) []byte {
t.Helper() t.Helper()
data, err := json.MarshalIndent(v, "", " ") data, err := json.MarshalIndent(v, "", " ")
@@ -18,14 +33,19 @@ func mustMarshal(t *testing.T, v any) []byte {
return data return data
} }
func isolatedTempDir(t *testing.T) string {
t.Helper()
return t.TempDir()
}
func TestWriteWithBackup(t *testing.T) { func TestWriteWithBackup(t *testing.T) {
tmpDir := t.TempDir() tmpDir := isolatedTempDir(t)
t.Run("creates file", func(t *testing.T) { t.Run("creates file", func(t *testing.T) {
path := filepath.Join(tmpDir, "new.json") path := filepath.Join(tmpDir, "new.json")
data := mustMarshal(t, map[string]string{"key": "value"}) data := mustMarshal(t, map[string]string{"key": "value"})
if err := writeWithBackup(path, data); err != nil { if err := WriteWithBackup(path, data); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -43,17 +63,17 @@ func TestWriteWithBackup(t *testing.T) {
} }
}) })
t.Run("creates backup in /tmp/ollama-backups", func(t *testing.T) { t.Run("creates backup in the temp backup directory", func(t *testing.T) {
path := filepath.Join(tmpDir, "backup.json") path := filepath.Join(tmpDir, "backup.json")
os.WriteFile(path, []byte(`{"original": true}`), 0o644) os.WriteFile(path, []byte(`{"original": true}`), 0o644)
data := mustMarshal(t, map[string]bool{"updated": true}) data := mustMarshal(t, map[string]bool{"updated": true})
if err := writeWithBackup(path, data); err != nil { if err := WriteWithBackup(path, data); err != nil {
t.Fatal(err) t.Fatal(err)
} }
entries, err := os.ReadDir(backupDir()) entries, err := os.ReadDir(BackupDir())
if err != nil { if err != nil {
t.Fatal("backup directory not created") t.Fatal("backup directory not created")
} }
@@ -63,7 +83,7 @@ func TestWriteWithBackup(t *testing.T) {
if filepath.Ext(entry.Name()) != ".json" { if filepath.Ext(entry.Name()) != ".json" {
name := entry.Name() name := entry.Name()
if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." { if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." {
backupPath := filepath.Join(backupDir(), name) backupPath := filepath.Join(BackupDir(), name)
backup, err := os.ReadFile(backupPath) backup, err := os.ReadFile(backupPath)
if err == nil { if err == nil {
var backupData map[string]bool var backupData map[string]bool
@@ -79,7 +99,7 @@ func TestWriteWithBackup(t *testing.T) {
} }
if !foundBackup { if !foundBackup {
t.Error("backup file not created in /tmp/ollama-backups") t.Error("backup file not created in backup directory")
} }
current, _ := os.ReadFile(path) current, _ := os.ReadFile(path)
@@ -94,11 +114,11 @@ func TestWriteWithBackup(t *testing.T) {
path := filepath.Join(tmpDir, "nobak.json") path := filepath.Join(tmpDir, "nobak.json")
data := mustMarshal(t, map[string]string{"new": "file"}) data := mustMarshal(t, map[string]string{"new": "file"})
if err := writeWithBackup(path, data); err != nil { if err := WriteWithBackup(path, data); err != nil {
t.Fatal(err) t.Fatal(err)
} }
entries, _ := os.ReadDir(backupDir()) entries, _ := os.ReadDir(BackupDir())
for _, entry := range entries { for _, entry := range entries {
if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." { if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." {
t.Error("backup should not exist for new file") t.Error("backup should not exist for new file")
@@ -111,11 +131,11 @@ func TestWriteWithBackup(t *testing.T) {
data := mustMarshal(t, map[string]string{"key": "value"}) data := mustMarshal(t, map[string]string{"key": "value"})
if err := writeWithBackup(path, data); err != nil { if err := WriteWithBackup(path, data); err != nil {
t.Fatal(err) t.Fatal(err)
} }
entries1, _ := os.ReadDir(backupDir()) entries1, _ := os.ReadDir(BackupDir())
countBefore := 0 countBefore := 0
for _, e := range entries1 { for _, e := range entries1 {
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." { if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
@@ -123,11 +143,11 @@ func TestWriteWithBackup(t *testing.T) {
} }
} }
if err := writeWithBackup(path, data); err != nil { if err := WriteWithBackup(path, data); err != nil {
t.Fatal(err) t.Fatal(err)
} }
entries2, _ := os.ReadDir(backupDir()) entries2, _ := os.ReadDir(BackupDir())
countAfter := 0 countAfter := 0
for _, e := range entries2 { for _, e := range entries2 {
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." { if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
@@ -145,11 +165,11 @@ func TestWriteWithBackup(t *testing.T) {
os.WriteFile(path, []byte(`{"v": 1}`), 0o644) os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
data := mustMarshal(t, map[string]int{"v": 2}) data := mustMarshal(t, map[string]int{"v": 2})
if err := writeWithBackup(path, data); err != nil { if err := WriteWithBackup(path, data); err != nil {
t.Fatal(err) t.Fatal(err)
} }
entries, _ := os.ReadDir(backupDir()) entries, _ := os.ReadDir(BackupDir())
var found bool var found bool
for _, entry := range entries { for _, entry := range entries {
name := entry.Name() name := entry.Name()
@@ -161,7 +181,7 @@ func TestWriteWithBackup(t *testing.T) {
} }
} }
found = true found = true
os.Remove(filepath.Join(backupDir(), name)) os.Remove(filepath.Join(BackupDir(), name))
break break
} }
} }
@@ -180,7 +200,7 @@ func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
t.Skip("permission tests unreliable on Windows") t.Skip("permission tests unreliable on Windows")
} }
tmpDir := t.TempDir() tmpDir := isolatedTempDir(t)
path := filepath.Join(tmpDir, "config.json") path := filepath.Join(tmpDir, "config.json")
// Create original file // Create original file
@@ -188,13 +208,13 @@ func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
os.WriteFile(path, originalContent, 0o644) os.WriteFile(path, originalContent, 0o644)
// Make backup directory read-only to force backup failure // Make backup directory read-only to force backup failure
backupDir := backupDir() backupDir := BackupDir()
os.MkdirAll(backupDir, 0o755) os.MkdirAll(backupDir, 0o755)
os.Chmod(backupDir, 0o444) // Read-only os.Chmod(backupDir, 0o444) // Read-only
defer os.Chmod(backupDir, 0o755) defer os.Chmod(backupDir, 0o755)
newContent := []byte(`{"updated": true}`) newContent := []byte(`{"updated": true}`)
err := writeWithBackup(path, newContent) err := WriteWithBackup(path, newContent)
// Should fail because backup couldn't be created // Should fail because backup couldn't be created
if err == nil { if err == nil {
@@ -215,7 +235,7 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
t.Skip("permission tests unreliable on Windows") t.Skip("permission tests unreliable on Windows")
} }
tmpDir := t.TempDir() tmpDir := isolatedTempDir(t)
// Create a read-only directory // Create a read-only directory
readOnlyDir := filepath.Join(tmpDir, "readonly") readOnlyDir := filepath.Join(tmpDir, "readonly")
@@ -224,7 +244,7 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
defer os.Chmod(readOnlyDir, 0o755) defer os.Chmod(readOnlyDir, 0o755)
path := filepath.Join(readOnlyDir, "config.json") path := filepath.Join(readOnlyDir, "config.json")
err := writeWithBackup(path, []byte(`{"test": true}`)) err := WriteWithBackup(path, []byte(`{"test": true}`))
if err == nil { if err == nil {
t.Error("expected permission error, got nil") t.Error("expected permission error, got nil")
@@ -234,10 +254,10 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist. // TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist.
// writeWithBackup doesn't create directories - caller is responsible. // writeWithBackup doesn't create directories - caller is responsible.
func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) { func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) {
tmpDir := t.TempDir() tmpDir := isolatedTempDir(t)
path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json") path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json")
err := writeWithBackup(path, []byte(`{"test": true}`)) err := WriteWithBackup(path, []byte(`{"test": true}`))
// Should fail because directory doesn't exist // Should fail because directory doesn't exist
if err == nil { if err == nil {
@@ -252,7 +272,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
t.Skip("symlink tests may require admin on Windows") t.Skip("symlink tests may require admin on Windows")
} }
tmpDir := t.TempDir() tmpDir := isolatedTempDir(t)
realFile := filepath.Join(tmpDir, "real.json") realFile := filepath.Join(tmpDir, "real.json")
symlink := filepath.Join(tmpDir, "link.json") symlink := filepath.Join(tmpDir, "link.json")
@@ -261,7 +281,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
os.Symlink(realFile, symlink) os.Symlink(realFile, symlink)
// Write through symlink // Write through symlink
err := writeWithBackup(symlink, []byte(`{"v": 2}`)) err := WriteWithBackup(symlink, []byte(`{"v": 2}`))
if err != nil { if err != nil {
t.Fatalf("writeWithBackup through symlink failed: %v", err) t.Fatalf("writeWithBackup through symlink failed: %v", err)
} }
@@ -276,7 +296,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters. // TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters.
// User may have config files with unusual names. // User may have config files with unusual names.
func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) { func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) {
tmpDir := t.TempDir() tmpDir := isolatedTempDir(t)
// File with spaces and special chars // File with spaces and special chars
path := filepath.Join(tmpDir, "my config (backup).json") path := filepath.Join(tmpDir, "my config (backup).json")
@@ -305,7 +325,7 @@ func TestCopyFile_PreservesPermissions(t *testing.T) {
t.Skip("permission preservation tests unreliable on Windows") t.Skip("permission preservation tests unreliable on Windows")
} }
tmpDir := t.TempDir() tmpDir := isolatedTempDir(t)
src := filepath.Join(tmpDir, "src.json") src := filepath.Join(tmpDir, "src.json")
dst := filepath.Join(tmpDir, "dst.json") dst := filepath.Join(tmpDir, "dst.json")
@@ -327,7 +347,7 @@ func TestCopyFile_PreservesPermissions(t *testing.T) {
// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist. // TestCopyFile_SourceNotFound verifies clear error when source doesn't exist.
func TestCopyFile_SourceNotFound(t *testing.T) { func TestCopyFile_SourceNotFound(t *testing.T) {
tmpDir := t.TempDir() tmpDir := isolatedTempDir(t)
src := filepath.Join(tmpDir, "nonexistent.json") src := filepath.Join(tmpDir, "nonexistent.json")
dst := filepath.Join(tmpDir, "dst.json") dst := filepath.Join(tmpDir, "dst.json")
@@ -339,11 +359,11 @@ func TestCopyFile_SourceNotFound(t *testing.T) {
// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory. // TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory.
func TestWriteWithBackup_TargetIsDirectory(t *testing.T) { func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
tmpDir := t.TempDir() tmpDir := isolatedTempDir(t)
dirPath := filepath.Join(tmpDir, "actualdir") dirPath := filepath.Join(tmpDir, "actualdir")
os.MkdirAll(dirPath, 0o755) os.MkdirAll(dirPath, 0o755)
err := writeWithBackup(dirPath, []byte(`{"test": true}`)) err := WriteWithBackup(dirPath, []byte(`{"test": true}`))
if err == nil { if err == nil {
t.Error("expected error when target is a directory, got nil") t.Error("expected error when target is a directory, got nil")
} }
@@ -351,10 +371,10 @@ func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly. // TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly.
func TestWriteWithBackup_EmptyData(t *testing.T) { func TestWriteWithBackup_EmptyData(t *testing.T) {
tmpDir := t.TempDir() tmpDir := isolatedTempDir(t)
path := filepath.Join(tmpDir, "empty.json") path := filepath.Join(tmpDir, "empty.json")
err := writeWithBackup(path, []byte{}) err := WriteWithBackup(path, []byte{})
if err != nil { if err != nil {
t.Fatalf("writeWithBackup with empty data failed: %v", err) t.Fatalf("writeWithBackup with empty data failed: %v", err)
} }
@@ -375,7 +395,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
t.Skip("permission tests unreliable on Windows") t.Skip("permission tests unreliable on Windows")
} }
tmpDir := t.TempDir() tmpDir := isolatedTempDir(t)
path := filepath.Join(tmpDir, "unreadable.json") path := filepath.Join(tmpDir, "unreadable.json")
// Create file and make it unreadable // Create file and make it unreadable
@@ -384,7 +404,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
defer os.Chmod(path, 0o644) defer os.Chmod(path, 0o644)
// Should fail because we can't read the file to compare/backup // Should fail because we can't read the file to compare/backup
err := writeWithBackup(path, []byte(`{"updated": true}`)) err := WriteWithBackup(path, []byte(`{"updated": true}`))
if err == nil { if err == nil {
t.Error("expected error when file is unreadable, got nil") t.Error("expected error when file is unreadable, got nil")
} }
@@ -393,7 +413,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes // TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes
// within the same second (timestamp collision scenario). // within the same second (timestamp collision scenario).
func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) { func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
tmpDir := t.TempDir() tmpDir := isolatedTempDir(t)
path := filepath.Join(tmpDir, "rapid.json") path := filepath.Join(tmpDir, "rapid.json")
// Create initial file // Create initial file
@@ -402,7 +422,7 @@ func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
// Rapid successive writes // Rapid successive writes
for i := 1; i <= 3; i++ { for i := 1; i <= 3; i++ {
data := []byte(fmt.Sprintf(`{"v": %d}`, i)) data := []byte(fmt.Sprintf(`{"v": %d}`, i))
if err := writeWithBackup(path, data); err != nil { if err := WriteWithBackup(path, data); err != nil {
t.Fatalf("write %d failed: %v", i, err) t.Fatalf("write %d failed: %v", i, err)
} }
} }
@@ -414,7 +434,7 @@ func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
} }
// Verify at least one backup exists // Verify at least one backup exists
entries, _ := os.ReadDir(backupDir()) entries, _ := os.ReadDir(BackupDir())
var backupCount int var backupCount int
for _, e := range entries { for _, e := range entries {
if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." { if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." {
@@ -432,8 +452,9 @@ func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
t.Skip("test modifies system temp directory") t.Skip("test modifies system temp directory")
} }
tmpDir := isolatedTempDir(t)
// Create a file at the backup directory path // Create a file at the backup directory path
backupPath := backupDir() backupPath := BackupDir()
// Clean up any existing directory first // Clean up any existing directory first
os.RemoveAll(backupPath) os.RemoveAll(backupPath)
// Create a file instead of directory // Create a file instead of directory
@@ -443,11 +464,10 @@ func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
os.MkdirAll(backupPath, 0o755) os.MkdirAll(backupPath, 0o755)
}() }()
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "test.json") path := filepath.Join(tmpDir, "test.json")
os.WriteFile(path, []byte(`{"original": true}`), 0o644) os.WriteFile(path, []byte(`{"original": true}`), 0o644)
err := writeWithBackup(path, []byte(`{"updated": true}`)) err := WriteWithBackup(path, []byte(`{"updated": true}`))
if err == nil { if err == nil {
t.Error("expected error when backup dir is a file, got nil") t.Error("expected error when backup dir is a file, got nil")
} }
@@ -459,7 +479,7 @@ func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
t.Skip("permission tests unreliable on Windows") t.Skip("permission tests unreliable on Windows")
} }
tmpDir := t.TempDir() tmpDir := isolatedTempDir(t)
// Count existing temp files // Count existing temp files
countTempFiles := func() int { countTempFiles := func() int {
@@ -493,7 +513,7 @@ func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
badPath := filepath.Join(tmpDir, "isdir") badPath := filepath.Join(tmpDir, "isdir")
os.MkdirAll(badPath, 0o755) os.MkdirAll(badPath, 0o755)
_ = writeWithBackup(badPath, []byte(`{"test": true}`)) _ = WriteWithBackup(badPath, []byte(`{"test": true}`))
after := countTempFiles() after := countTempFiles()
if after > before { if after > before {

87
cmd/launch/claude.go Normal file
View File

@@ -0,0 +1,87 @@
package launch
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"github.com/ollama/ollama/envconfig"
)
// Claude implements Runner for Claude Code integration.
type Claude struct{}
func (c *Claude) String() string { return "Claude Code" }
func (c *Claude) args(model string, extra []string) []string {
var args []string
if model != "" {
args = append(args, "--model", model)
}
args = append(args, extra...)
return args
}
func (c *Claude) findPath() (string, error) {
if p, err := exec.LookPath("claude"); err == nil {
return p, nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
name := "claude"
if runtime.GOOS == "windows" {
name = "claude.exe"
}
fallback := filepath.Join(home, ".claude", "local", name)
if _, err := os.Stat(fallback); err != nil {
return "", err
}
return fallback, nil
}
func (c *Claude) Run(model string, args []string) error {
claudePath, err := c.findPath()
if err != nil {
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
}
cmd := exec.Command(claudePath, c.args(model, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
env := append(os.Environ(),
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
"ANTHROPIC_API_KEY=",
"ANTHROPIC_AUTH_TOKEN=ollama",
"CLAUDE_CODE_ATTRIBUTION_HEADER=0",
)
env = append(env, c.modelEnvVars(model)...)
cmd.Env = env
return cmd.Run()
}
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
func (c *Claude) modelEnvVars(model string) []string {
env := []string{
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + model,
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + model,
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + model,
"CLAUDE_CODE_SUBAGENT_MODEL=" + model,
}
if isCloudModelName(model) {
if l, ok := lookupCloudModelLimit(model); ok {
env = append(env, "CLAUDE_CODE_AUTO_COMPACT_WINDOW="+strconv.Itoa(l.Context))
}
}
return env
}

View File

@@ -1,4 +1,4 @@
package config package launch
import ( import (
"os" "os"
@@ -117,10 +117,7 @@ func TestClaudeModelEnvVars(t *testing.T) {
return m return m
} }
t.Run("falls back to model param when no aliases saved", func(t *testing.T) { t.Run("maps all Claude model env vars to the provided model", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
got := envMap(c.modelEnvVars("llama3.2")) got := envMap(c.modelEnvVars("llama3.2"))
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2" { if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2" {
t.Errorf("OPUS = %q, want llama3.2", got["ANTHROPIC_DEFAULT_OPUS_MODEL"]) t.Errorf("OPUS = %q, want llama3.2", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
@@ -134,65 +131,41 @@ func TestClaudeModelEnvVars(t *testing.T) {
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2" { if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2" {
t.Errorf("SUBAGENT = %q, want llama3.2", got["CLAUDE_CODE_SUBAGENT_MODEL"]) t.Errorf("SUBAGENT = %q, want llama3.2", got["CLAUDE_CODE_SUBAGENT_MODEL"])
} }
}) if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty for local models", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
t.Run("uses primary alias for opus sonnet and subagent", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
SaveIntegration("claude", []string{"qwen3:8b"})
saveAliases("claude", map[string]string{"primary": "qwen3:8b"})
got := envMap(c.modelEnvVars("qwen3:8b"))
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "qwen3:8b" {
t.Errorf("OPUS = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
}
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "qwen3:8b" {
t.Errorf("SONNET = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
}
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "qwen3:8b" {
t.Errorf("HAIKU = %q, want qwen3:8b (no fast alias)", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
}
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "qwen3:8b" {
t.Errorf("SUBAGENT = %q, want qwen3:8b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
} }
}) })
t.Run("uses fast alias for haiku", func(t *testing.T) { t.Run("supports empty model", func(t *testing.T) {
tmpDir := t.TempDir() got := envMap(c.modelEnvVars(""))
setTestHome(t, tmpDir) if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "" {
t.Errorf("OPUS = %q, want empty", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
SaveIntegration("claude", []string{"llama3.2:70b"})
saveAliases("claude", map[string]string{
"primary": "llama3.2:70b",
"fast": "llama3.2:8b",
})
got := envMap(c.modelEnvVars("llama3.2:70b"))
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2:70b" {
t.Errorf("OPUS = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
} }
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2:70b" { if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "" {
t.Errorf("SONNET = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"]) t.Errorf("SONNET = %q, want empty", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
} }
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2:8b" { if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "" {
t.Errorf("HAIKU = %q, want llama3.2:8b", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"]) t.Errorf("HAIKU = %q, want empty", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
} }
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2:70b" { if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "" {
t.Errorf("SUBAGENT = %q, want llama3.2:70b", got["CLAUDE_CODE_SUBAGENT_MODEL"]) t.Errorf("SUBAGENT = %q, want empty", got["CLAUDE_CODE_SUBAGENT_MODEL"])
}
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
} }
}) })
t.Run("alias primary overrides model param", func(t *testing.T) { t.Run("sets auto compact window for known cloud models", func(t *testing.T) {
tmpDir := t.TempDir() got := envMap(c.modelEnvVars("glm-5:cloud"))
setTestHome(t, tmpDir) if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "202752" {
t.Errorf("AUTO_COMPACT_WINDOW = %q, want 202752", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
}
})
SaveIntegration("claude", []string{"saved-model"}) t.Run("does not set auto compact window for unknown cloud models", func(t *testing.T) {
saveAliases("claude", map[string]string{"primary": "saved-model"}) got := envMap(c.modelEnvVars("unknown-model:cloud"))
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
got := envMap(c.modelEnvVars("different-model")) t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "saved-model" {
t.Errorf("OPUS = %q, want saved-model", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
} }
}) })
} }

View File

@@ -1,14 +1,13 @@
package config package launch
import ( import (
"context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
) )
@@ -22,24 +21,6 @@ func (c *Cline) Run(model string, args []string) error {
return fmt.Errorf("cline is not installed, install with: npm install -g cline") return fmt.Errorf("cline is not installed, install with: npm install -g cline")
} }
models := []string{model}
if config, err := loadIntegration("cline"); err == nil && len(config.Models) > 0 {
models = config.Models
}
var err error
models, err = resolveEditorModels("cline", models, func() ([]string, error) {
return selectModels(context.Background(), "cline", "")
})
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
if err := c.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("cline", args...) cmd := exec.Command("cline", args...)
cmd.Stdin = os.Stdin cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
@@ -97,7 +78,7 @@ func (c *Cline) Edit(models []string) error {
if err != nil { if err != nil {
return err return err
} }
return writeWithBackup(configPath, data) return fileutil.WriteWithBackup(configPath, data)
} }
func (c *Cline) Models() []string { func (c *Cline) Models() []string {
@@ -106,7 +87,7 @@ func (c *Cline) Models() []string {
return nil return nil
} }
config, err := readJSONFile(filepath.Join(home, ".cline", "data", "globalState.json")) config, err := fileutil.ReadJSON(filepath.Join(home, ".cline", "data", "globalState.json"))
if err != nil { if err != nil {
return nil return nil
} }

View File

@@ -1,4 +1,4 @@
package config package launch
import ( import (
"encoding/json" "encoding/json"

148
cmd/launch/codex.go Normal file
View File

@@ -0,0 +1,148 @@
package launch
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"github.com/ollama/ollama/envconfig"
"golang.org/x/mod/semver"
)
// Codex implements Runner for Codex integration
type Codex struct{}
func (c *Codex) String() string { return "Codex" }
const codexProfileName = "ollama-launch"
func (c *Codex) args(model string, extra []string) []string {
args := []string{"--profile", codexProfileName}
if model != "" {
args = append(args, "-m", model)
}
args = append(args, extra...)
return args
}
func (c *Codex) Run(model string, args []string) error {
if err := checkCodexVersion(); err != nil {
return err
}
if err := ensureCodexConfig(); err != nil {
return fmt.Errorf("failed to configure codex: %w", err)
}
cmd := exec.Command("codex", c.args(model, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(),
"OPENAI_API_KEY=ollama",
)
return cmd.Run()
}
// ensureCodexConfig writes a [profiles.ollama-launch] section to ~/.codex/config.toml
// with openai_base_url pointing to the local Ollama server.
func ensureCodexConfig() error {
home, err := os.UserHomeDir()
if err != nil {
return err
}
codexDir := filepath.Join(home, ".codex")
if err := os.MkdirAll(codexDir, 0o755); err != nil {
return err
}
configPath := filepath.Join(codexDir, "config.toml")
return writeCodexProfile(configPath)
}
// writeCodexProfile ensures ~/.codex/config.toml has the ollama-launch profile
// and model provider sections with the correct base URL.
func writeCodexProfile(configPath string) error {
baseURL := envconfig.Host().String() + "/v1/"
sections := []struct {
header string
lines []string
}{
{
header: fmt.Sprintf("[profiles.%s]", codexProfileName),
lines: []string{
fmt.Sprintf("openai_base_url = %q", baseURL),
`forced_login_method = "api"`,
fmt.Sprintf("model_provider = %q", codexProfileName),
},
},
{
header: fmt.Sprintf("[model_providers.%s]", codexProfileName),
lines: []string{
`name = "Ollama"`,
fmt.Sprintf("base_url = %q", baseURL),
},
},
}
content, readErr := os.ReadFile(configPath)
text := ""
if readErr == nil {
text = string(content)
}
for _, s := range sections {
block := strings.Join(append([]string{s.header}, s.lines...), "\n") + "\n"
if idx := strings.Index(text, s.header); idx >= 0 {
// Replace the existing section up to the next section header.
rest := text[idx+len(s.header):]
if endIdx := strings.Index(rest, "\n["); endIdx >= 0 {
text = text[:idx] + block + rest[endIdx+1:]
} else {
text = text[:idx] + block
}
} else {
// Append the section.
if text != "" && !strings.HasSuffix(text, "\n") {
text += "\n"
}
if text != "" {
text += "\n"
}
text += block
}
}
return os.WriteFile(configPath, []byte(text), 0o644)
}
func checkCodexVersion() error {
if _, err := exec.LookPath("codex"); err != nil {
return fmt.Errorf("codex is not installed, install with: npm install -g @openai/codex")
}
out, err := exec.Command("codex", "--version").Output()
if err != nil {
return fmt.Errorf("failed to get codex version: %w", err)
}
// Parse output like "codex-cli 0.87.0"
fields := strings.Fields(strings.TrimSpace(string(out)))
if len(fields) < 2 {
return fmt.Errorf("unexpected codex version output: %s", string(out))
}
version := "v" + fields[len(fields)-1]
minVersion := "v0.81.0"
if semver.Compare(version, minVersion) < 0 {
return fmt.Errorf("codex version %s is too old, minimum required is %s, update with: npm update -g @openai/codex", fields[len(fields)-1], "0.81.0")
}
return nil
}

229
cmd/launch/codex_test.go Normal file
View File

@@ -0,0 +1,229 @@
package launch
import (
"os"
"path/filepath"
"slices"
"strings"
"testing"
)
func TestCodexArgs(t *testing.T) {
c := &Codex{}
tests := []struct {
name string
model string
args []string
want []string
}{
{"with model", "llama3.2", nil, []string{"--profile", "ollama-launch", "-m", "llama3.2"}},
{"empty model", "", nil, []string{"--profile", "ollama-launch"}},
{"with model and extra args", "qwen3.5", []string{"-p", "myprofile"}, []string{"--profile", "ollama-launch", "-m", "qwen3.5", "-p", "myprofile"}},
{"with sandbox flag", "llama3.2", []string{"--sandbox", "workspace-write"}, []string{"--profile", "ollama-launch", "-m", "llama3.2", "--sandbox", "workspace-write"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model, tt.args)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
}
})
}
}
func TestWriteCodexProfile(t *testing.T) {
t.Run("creates new file when none exists", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml")
if err := writeCodexProfile(configPath); err != nil {
t.Fatal(err)
}
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatal(err)
}
content := string(data)
if !strings.Contains(content, "[profiles.ollama-launch]") {
t.Error("missing [profiles.ollama-launch] header")
}
if !strings.Contains(content, "openai_base_url") {
t.Error("missing openai_base_url key")
}
if !strings.Contains(content, "/v1/") {
t.Error("missing /v1/ suffix in base URL")
}
if !strings.Contains(content, `forced_login_method = "api"`) {
t.Error("missing forced_login_method key")
}
if !strings.Contains(content, `model_provider = "ollama-launch"`) {
t.Error("missing model_provider key")
}
if !strings.Contains(content, "[model_providers.ollama-launch]") {
t.Error("missing [model_providers.ollama-launch] section")
}
if !strings.Contains(content, `name = "Ollama"`) {
t.Error("missing model provider name")
}
})
t.Run("appends profile to existing file without profile", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml")
existing := "[some_other_section]\nkey = \"value\"\n"
os.WriteFile(configPath, []byte(existing), 0o644)
if err := writeCodexProfile(configPath); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
content := string(data)
if !strings.Contains(content, "[some_other_section]") {
t.Error("existing section was removed")
}
if !strings.Contains(content, "[profiles.ollama-launch]") {
t.Error("missing [profiles.ollama-launch] header")
}
})
t.Run("replaces existing profile section", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml")
existing := "[profiles.ollama-launch]\nopenai_base_url = \"http://old:1234/v1/\"\n\n[model_providers.ollama-launch]\nname = \"Ollama\"\nbase_url = \"http://old:1234/v1/\"\n"
os.WriteFile(configPath, []byte(existing), 0o644)
if err := writeCodexProfile(configPath); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
content := string(data)
if strings.Contains(content, "old:1234") {
t.Error("old URL was not replaced")
}
if strings.Count(content, "[profiles.ollama-launch]") != 1 {
t.Errorf("expected exactly one [profiles.ollama-launch] section, got %d", strings.Count(content, "[profiles.ollama-launch]"))
}
if strings.Count(content, "[model_providers.ollama-launch]") != 1 {
t.Errorf("expected exactly one [model_providers.ollama-launch] section, got %d", strings.Count(content, "[model_providers.ollama-launch]"))
}
})
t.Run("replaces profile while preserving following sections", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml")
existing := "[profiles.ollama-launch]\nopenai_base_url = \"http://old:1234/v1/\"\n[another_section]\nfoo = \"bar\"\n"
os.WriteFile(configPath, []byte(existing), 0o644)
if err := writeCodexProfile(configPath); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
content := string(data)
if strings.Contains(content, "old:1234") {
t.Error("old URL was not replaced")
}
if !strings.Contains(content, "[another_section]") {
t.Error("following section was removed")
}
if !strings.Contains(content, "foo = \"bar\"") {
t.Error("following section content was removed")
}
})
t.Run("appends newline to file not ending with newline", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml")
existing := "[other]\nkey = \"val\""
os.WriteFile(configPath, []byte(existing), 0o644)
if err := writeCodexProfile(configPath); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
content := string(data)
if !strings.Contains(content, "[profiles.ollama-launch]") {
t.Error("missing [profiles.ollama-launch] header")
}
// Should not have double blank lines from missing trailing newline
if strings.Contains(content, "\n\n\n") {
t.Error("unexpected triple newline in output")
}
})
t.Run("uses custom OLLAMA_HOST", func(t *testing.T) {
t.Setenv("OLLAMA_HOST", "http://myhost:9999")
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml")
if err := writeCodexProfile(configPath); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
content := string(data)
if !strings.Contains(content, "myhost:9999/v1/") {
t.Errorf("expected custom host in URL, got:\n%s", content)
}
})
}
func TestEnsureCodexConfig(t *testing.T) {
t.Run("creates .codex dir and config.toml", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if err := ensureCodexConfig(); err != nil {
t.Fatal(err)
}
configPath := filepath.Join(tmpDir, ".codex", "config.toml")
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatalf("config.toml not created: %v", err)
}
content := string(data)
if !strings.Contains(content, "[profiles.ollama-launch]") {
t.Error("missing [profiles.ollama-launch] header")
}
if !strings.Contains(content, "openai_base_url") {
t.Error("missing openai_base_url key")
}
})
t.Run("is idempotent", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if err := ensureCodexConfig(); err != nil {
t.Fatal(err)
}
if err := ensureCodexConfig(); err != nil {
t.Fatal(err)
}
configPath := filepath.Join(tmpDir, ".codex", "config.toml")
data, _ := os.ReadFile(configPath)
content := string(data)
if strings.Count(content, "[profiles.ollama-launch]") != 1 {
t.Errorf("expected exactly one [profiles.ollama-launch] section after two calls, got %d", strings.Count(content, "[profiles.ollama-launch]"))
}
if strings.Count(content, "[model_providers.ollama-launch]") != 1 {
t.Errorf("expected exactly one [model_providers.ollama-launch] section after two calls, got %d", strings.Count(content, "[model_providers.ollama-launch]"))
}
})
}

598
cmd/launch/command_test.go Normal file
View File

@@ -0,0 +1,598 @@
package launch
import (
"bytes"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/cmd/config"
"github.com/spf13/cobra"
)
func captureStderr(t *testing.T, fn func()) string {
t.Helper()
oldStderr := os.Stderr
r, w, err := os.Pipe()
if err != nil {
t.Fatalf("failed to create stderr pipe: %v", err)
}
os.Stderr = w
defer func() {
os.Stderr = oldStderr
}()
done := make(chan string, 1)
go func() {
var buf bytes.Buffer
_, _ = io.Copy(&buf, r)
done <- buf.String()
}()
fn()
_ = w.Close()
return <-done
}
func TestLaunchCmd(t *testing.T) {
mockCheck := func(cmd *cobra.Command, args []string) error {
return nil
}
mockTUI := func(cmd *cobra.Command) {}
cmd := LaunchCmd(mockCheck, mockTUI)
t.Run("command structure", func(t *testing.T) {
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]")
}
if cmd.Short == "" {
t.Error("Short description should not be empty")
}
if cmd.Long == "" {
t.Error("Long description should not be empty")
}
})
t.Run("flags exist", func(t *testing.T) {
if cmd.Flags().Lookup("model") == nil {
t.Error("--model flag should exist")
}
if cmd.Flags().Lookup("config") == nil {
t.Error("--config flag should exist")
}
if cmd.Flags().Lookup("yes") == nil {
t.Error("--yes flag should exist")
}
})
t.Run("PreRunE is set", func(t *testing.T) {
if cmd.PreRunE == nil {
t.Error("PreRunE should be set to checkServerHeartbeat")
}
})
}
func TestLaunchCmdTUICallback(t *testing.T) {
mockCheck := func(cmd *cobra.Command, args []string) error {
return nil
}
t.Run("no args calls TUI", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{})
_ = cmd.Execute()
if !tuiCalled {
t.Error("TUI callback should be called when no args provided")
}
})
t.Run("integration arg bypasses TUI", func(t *testing.T) {
srv := httptest.NewServer(http.NotFoundHandler())
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"claude"})
_ = cmd.Execute()
if tuiCalled {
t.Error("TUI callback should NOT be called when integration arg provided")
}
})
t.Run("--model flag without integration returns error", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"--model", "test-model"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected --model without an integration to fail")
}
if !strings.Contains(err.Error(), "require an integration name") {
t.Fatalf("expected integration-name guidance, got %v", err)
}
if tuiCalled {
t.Error("TUI callback should NOT be called when --model is provided without an integration")
}
})
t.Run("--config flag without integration returns error", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"--config"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected --config without an integration to fail")
}
if !strings.Contains(err.Error(), "require an integration name") {
t.Fatalf("expected integration-name guidance, got %v", err)
}
if tuiCalled {
t.Error("TUI callback should NOT be called when --config is provided without an integration")
}
})
t.Run("--yes flag without integration returns error", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"--yes"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected --yes without an integration to fail")
}
if !strings.Contains(err.Error(), "require an integration name") {
t.Fatalf("expected integration-name guidance, got %v", err)
}
if tuiCalled {
t.Error("TUI callback should NOT be called when --yes is provided without an integration")
}
})
t.Run("extra args without integration return error", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"--model", "test-model", "--", "--sandbox", "workspace-write"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected flags and extra args without an integration to fail")
}
if !strings.Contains(err.Error(), "require an integration name") {
t.Fatalf("expected integration-name guidance, got %v", err)
}
if tuiCalled {
t.Error("TUI callback should NOT be called when flags or extra args are provided without an integration")
}
})
}
func TestLaunchCmdNilHeartbeat(t *testing.T) {
cmd := LaunchCmd(nil, nil)
if cmd == nil {
t.Fatal("LaunchCmd returned nil")
}
if cmd.PreRunE != nil {
t.Log("Note: PreRunE is set even when nil is passed (acceptable)")
}
}
func TestLaunchCmdModelFlagFiltersDisabledCloudFromSavedConfig(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
if err := config.SaveIntegration("stubeditor", []string{"glm-5:cloud"}); err != nil {
t.Fatalf("failed to seed saved config: %v", err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/status":
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
case "/api/show":
fmt.Fprintf(w, `{"model":"llama3.2"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherEditorRunner{}
restore := OverrideIntegration("stubeditor", stub)
defer restore()
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
if err := cmd.Execute(); err != nil {
t.Fatalf("launch command failed: %v", err)
}
saved, err := config.LoadIntegration("stubeditor")
if err != nil {
t.Fatalf("failed to reload integration config: %v", err)
}
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
}
if stub.ranModel != "llama3.2" {
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
}
}
func TestLaunchCmdModelFlagClearsDisabledCloudOverride(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/status":
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
case "/api/show":
fmt.Fprint(w, `{"model":"llama3.2"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherSingleRunner{}
restore := OverrideIntegration("stubapp", stub)
defer restore()
oldSelector := DefaultSingleSelector
defer func() { DefaultSingleSelector = oldSelector }()
var selectorCalls int
var gotCurrent string
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
selectorCalls++
gotCurrent = current
return "llama3.2", nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubapp", "--model", "glm-5:cloud"})
stderr := captureStderr(t, func() {
if err := cmd.Execute(); err != nil {
t.Fatalf("launch command failed: %v", err)
}
})
if selectorCalls != 1 {
t.Fatalf("expected disabled cloud override to fall back to selector, got %d calls", selectorCalls)
}
if gotCurrent != "" {
t.Fatalf("expected disabled override to be cleared before selection, got current %q", gotCurrent)
}
if stub.ranModel != "llama3.2" {
t.Fatalf("expected launch to run with replacement local model, got %q", stub.ranModel)
}
if !strings.Contains(stderr, "Warning: ignoring --model glm-5:cloud because cloud is disabled") {
t.Fatalf("expected disabled-cloud warning, got stderr: %q", stderr)
}
saved, err := config.LoadIntegration("stubapp")
if err != nil {
t.Fatalf("failed to reload integration config: %v", err)
}
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
}
}
func TestLaunchCmdYes_AutoConfirmsLaunchPromptPath(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
withInteractiveSession(t, false)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
fmt.Fprint(w, `{"model":"llama3.2"}`)
case "/api/status":
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, `{"error":"not found"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherEditorRunner{paths: []string{"/tmp/stubeditor.json"}}
restore := OverrideIntegration("stubeditor", stub)
defer restore()
DefaultConfirmPrompt = func(prompt string) (bool, error) {
t.Fatalf("unexpected prompt with --yes: %q", prompt)
return false, nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2", "--yes"})
if err := cmd.Execute(); err != nil {
t.Fatalf("launch command with --yes failed: %v", err)
}
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
}
if stub.ranModel != "llama3.2" {
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
}
}
func TestLaunchCmdHeadlessWithYes_AutoPullsMissingLocalModel(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
withInteractiveSession(t, false)
var pullCalled bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, `{"error":"model not found"}`)
case "/api/pull":
pullCalled = true
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, `{"status":"success"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherSingleRunner{}
restore := OverrideIntegration("stubapp", stub)
defer restore()
DefaultConfirmPrompt = func(prompt string) (bool, error) {
t.Fatalf("unexpected prompt with --yes in headless autopull path: %q", prompt)
return false, nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubapp", "--model", "missing-model", "--yes"})
if err := cmd.Execute(); err != nil {
t.Fatalf("launch command with --yes failed: %v", err)
}
if !pullCalled {
t.Fatal("expected missing local model to be auto-pulled with --yes in headless mode")
}
if stub.ranModel != "missing-model" {
t.Fatalf("expected launch to run with pulled model, got %q", stub.ranModel)
}
}
func TestLaunchCmdHeadlessWithoutYes_ReturnsActionableConfirmError(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
withInteractiveSession(t, false)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
fmt.Fprint(w, `{"model":"llama3.2"}`)
case "/api/status":
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, `{"error":"not found"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherEditorRunner{paths: []string{"/tmp/stubeditor.json"}}
restore := OverrideIntegration("stubeditor", stub)
defer restore()
DefaultConfirmPrompt = func(prompt string) (bool, error) {
t.Fatalf("unexpected prompt in headless non-yes mode: %q", prompt)
return false, nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected launch command to fail without --yes in headless mode")
}
if !strings.Contains(err.Error(), "re-run with --yes") {
t.Fatalf("expected actionable --yes guidance, got %v", err)
}
if len(stub.edited) != 0 {
t.Fatalf("expected no editor writes when confirmation is blocked, got %v", stub.edited)
}
if stub.ranModel != "" {
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
}
}
func TestLaunchCmdIntegrationArgPromptsForModelWithSavedSelection(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
if err := config.SaveIntegration("stubapp", []string{"llama3.2"}); err != nil {
t.Fatalf("failed to seed saved config: %v", err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"},{"name":"qwen3:8b"}]}`)
case "/api/show":
fmt.Fprint(w, `{"model":"qwen3:8b"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherSingleRunner{}
restore := OverrideIntegration("stubapp", stub)
defer restore()
oldSelector := DefaultSingleSelector
defer func() { DefaultSingleSelector = oldSelector }()
var gotCurrent string
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
gotCurrent = current
return "qwen3:8b", nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubapp"})
if err := cmd.Execute(); err != nil {
t.Fatalf("launch command failed: %v", err)
}
if gotCurrent != "llama3.2" {
t.Fatalf("expected selector current model to be saved model llama3.2, got %q", gotCurrent)
}
if stub.ranModel != "qwen3:8b" {
t.Fatalf("expected launch to run selected model qwen3:8b, got %q", stub.ranModel)
}
saved, err := config.LoadIntegration("stubapp")
if err != nil {
t.Fatalf("failed to reload integration config: %v", err)
}
if diff := cmp.Diff([]string{"qwen3:8b"}, saved.Models); diff != "" {
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
}
}
func TestLaunchCmdHeadlessYes_IntegrationRequiresModelEvenWhenSaved(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
withInteractiveSession(t, false)
if err := config.SaveIntegration("stubapp", []string{"llama3.2"}); err != nil {
t.Fatalf("failed to seed saved config: %v", err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
fmt.Fprint(w, `{"model":"llama3.2"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherSingleRunner{}
restore := OverrideIntegration("stubapp", stub)
defer restore()
oldSelector := DefaultSingleSelector
defer func() { DefaultSingleSelector = oldSelector }()
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
t.Fatal("selector should not be called for headless --yes saved-model launch")
return "", nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubapp", "--yes"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected launch command to fail when --yes is used headlessly without --model")
}
if !strings.Contains(err.Error(), "requires --model <model>") {
t.Fatalf("expected actionable --model guidance, got %v", err)
}
if stub.ranModel != "" {
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
}
}
func TestLaunchCmdHeadlessYes_IntegrationWithoutSavedModelReturnsError(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
withInteractiveSession(t, false)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherSingleRunner{}
restore := OverrideIntegration("stubapp", stub)
defer restore()
oldSelector := DefaultSingleSelector
defer func() { DefaultSingleSelector = oldSelector }()
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
t.Fatal("selector should not be called for headless --yes without saved model")
return "", nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubapp", "--yes"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected launch command to fail when --yes is used headlessly without --model")
}
if !strings.Contains(err.Error(), "requires --model <model>") {
t.Fatalf("expected actionable --model guidance, got %v", err)
}
if stub.ranModel != "" {
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
}
}

View File

@@ -1,16 +1,14 @@
package config package launch
import ( import (
"context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"slices" "slices"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
) )
@@ -47,25 +45,6 @@ func (d *Droid) Run(model string, args []string) error {
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart") return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
} }
// Call Edit() to ensure config is up-to-date before launch
models := []string{model}
if config, err := loadIntegration("droid"); err == nil && len(config.Models) > 0 {
models = config.Models
}
var err error
models, err = resolveEditorModels("droid", models, func() ([]string, error) {
return selectModels(context.Background(), "droid", "")
})
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
if err := d.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("droid", args...) cmd := exec.Command("droid", args...)
cmd.Stdin = os.Stdin cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
@@ -111,6 +90,16 @@ func (d *Droid) Edit(models []string) error {
json.Unmarshal(data, &settings) // ignore error, zero values are fine json.Unmarshal(data, &settings) // ignore error, zero values are fine
} }
settingsMap = updateDroidSettings(settingsMap, settings, models)
data, err := json.MarshalIndent(settingsMap, "", " ")
if err != nil {
return err
}
return fileutil.WriteWithBackup(settingsPath, data)
}
func updateDroidSettings(settingsMap map[string]any, settings droidSettings, models []string) map[string]any {
// Keep only non-Ollama models from the raw map (preserves extra fields) // Keep only non-Ollama models from the raw map (preserves extra fields)
// Rebuild Ollama models // Rebuild Ollama models
var nonOllamaModels []any var nonOllamaModels []any
@@ -125,13 +114,12 @@ func (d *Droid) Edit(models []string) error {
} }
// Build new Ollama model entries with sequential indices (0, 1, 2, ...) // Build new Ollama model entries with sequential indices (0, 1, 2, ...)
client, _ := api.ClientFromEnvironment()
var newModels []any var newModels []any
var defaultModelID string var defaultModelID string
for i, model := range models { for i, model := range models {
maxOutput := 64000 maxOutput := 64000
if isCloudModel(context.Background(), client, model) { if isCloudModelName(model) {
if l, ok := lookupCloudModelLimit(model); ok { if l, ok := lookupCloudModelLimit(model); ok {
maxOutput = l.Output maxOutput = l.Output
} }
@@ -167,12 +155,7 @@ func (d *Droid) Edit(models []string) error {
} }
settingsMap["sessionDefaultSettings"] = sessionSettings settingsMap["sessionDefaultSettings"] = sessionSettings
return settingsMap
data, err := json.MarshalIndent(settingsMap, "", " ")
if err != nil {
return err
}
return writeWithBackup(settingsPath, data)
} }
func (d *Droid) Models() []string { func (d *Droid) Models() []string {

View File

@@ -1,4 +1,4 @@
package config package launch
import ( import (
"encoding/json" "encoding/json"
@@ -6,6 +6,8 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/ollama/ollama/cmd/internal/fileutil"
) )
func TestDroidIntegration(t *testing.T) { func TestDroidIntegration(t *testing.T) {
@@ -362,7 +364,7 @@ func TestDroidEdit_DuplicateModels(t *testing.T) {
t.Fatalf("Edit with duplicates failed: %v", err) t.Fatalf("Edit with duplicates failed: %v", err)
} }
settings, err := readJSONFile(settingsPath) settings, err := fileutil.ReadJSON(settingsPath)
if err != nil { if err != nil {
t.Fatalf("readJSONFile failed: %v", err) t.Fatalf("readJSONFile failed: %v", err)
} }
@@ -392,7 +394,7 @@ func TestDroidEdit_MalformedModelEntry(t *testing.T) {
} }
// Malformed entries (non-object) are dropped - only valid model objects are preserved // Malformed entries (non-object) are dropped - only valid model objects are preserved
settings, _ := readJSONFile(settingsPath) settings, _ := fileutil.ReadJSON(settingsPath)
customModels, _ := settings["customModels"].([]any) customModels, _ := settings["customModels"].([]any)
// Should have: 1 new Ollama model only (malformed entries dropped) // Should have: 1 new Ollama model only (malformed entries dropped)
@@ -419,7 +421,7 @@ func TestDroidEdit_WrongTypeSessionSettings(t *testing.T) {
} }
// Should create proper sessionDefaultSettings // Should create proper sessionDefaultSettings
settings, _ := readJSONFile(settingsPath) settings, _ := fileutil.ReadJSON(settingsPath)
session, ok := settings["sessionDefaultSettings"].(map[string]any) session, ok := settings["sessionDefaultSettings"].(map[string]any)
if !ok { if !ok {
t.Fatalf("sessionDefaultSettings should be map after setup, got %T", settings["sessionDefaultSettings"]) t.Fatalf("sessionDefaultSettings should be map after setup, got %T", settings["sessionDefaultSettings"])
@@ -1008,34 +1010,34 @@ func TestDroidEdit_ModelNamesWithSpecialCharacters(t *testing.T) {
} }
func TestDroidEdit_MissingCustomModelsKey(t *testing.T) { func TestDroidEdit_MissingCustomModelsKey(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
os.MkdirAll(settingsDir, 0o755)
// No customModels key at all // No customModels key at all
original := `{ original := `{
"diffMode": "github", "diffMode": "github",
"sessionDefaultSettings": {"autonomyMode": "auto-high"} "sessionDefaultSettings": {"autonomyMode": "auto-high"}
}` }`
os.WriteFile(settingsPath, []byte(original), 0o644)
if err := d.Edit([]string{"model-a"}); err != nil { var settingsStruct droidSettings
var settings map[string]any
if err := json.Unmarshal([]byte(original), &settings); err != nil {
t.Fatal(err)
}
if err := json.Unmarshal([]byte(original), &settingsStruct); err != nil {
t.Fatal(err) t.Fatal(err)
} }
data, _ := os.ReadFile(settingsPath) settings = updateDroidSettings(settings, settingsStruct, []string{"model-a"})
var settings map[string]any
json.Unmarshal(data, &settings)
// Original fields preserved // Original fields preserved
if settings["diffMode"] != "github" { if settings["diffMode"] != "github" {
t.Error("diffMode not preserved") t.Error("diffMode not preserved")
} }
session, ok := settings["sessionDefaultSettings"].(map[string]any)
if !ok {
t.Fatal("sessionDefaultSettings not preserved")
}
if session["autonomyMode"] != "auto-high" {
t.Error("sessionDefaultSettings.autonomyMode not preserved")
}
// customModels created // customModels created
models, ok := settings["customModels"].([]any) models, ok := settings["customModels"].([]any)
@@ -1276,25 +1278,17 @@ func TestDroidEdit_LocalModelDefaultMaxOutput(t *testing.T) {
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) { func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
// Verify that every cloud model in cloudModelLimits has a valid output // Verify that every cloud model in cloudModelLimits has a valid output
// value that would be used for maxOutputTokens when isCloudModel returns true. // value that would be used for maxOutputTokens when the selected model uses
// :cloud suffix stripping must also work since that's how users specify them. // the explicit :cloud source tag.
for name, expected := range cloudModelLimits { for name, expected := range cloudModelLimits {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
l, ok := lookupCloudModelLimit(name)
if !ok {
t.Fatalf("lookupCloudModelLimit(%q) returned false", name)
}
if l.Output != expected.Output {
t.Errorf("output = %d, want %d", l.Output, expected.Output)
}
// Also verify :cloud suffix lookup
cloudName := name + ":cloud" cloudName := name + ":cloud"
l2, ok := lookupCloudModelLimit(cloudName) l, ok := lookupCloudModelLimit(cloudName)
if !ok { if !ok {
t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName) t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName)
} }
if l2.Output != expected.Output { if l.Output != expected.Output {
t.Errorf(":cloud output = %d, want %d", l2.Output, expected.Output) t.Errorf("output = %d, want %d", l.Output, expected.Output)
} }
}) })
} }

881
cmd/launch/launch.go Normal file
View File

@@ -0,0 +1,881 @@
package launch
import (
"context"
"errors"
"fmt"
"net/http"
"os"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
"github.com/spf13/cobra"
"golang.org/x/term"
)
// LauncherState is the launch-owned snapshot used to render the root launcher menu.
type LauncherState struct {
LastSelection string
RunModel string
RunModelUsable bool
Integrations map[string]LauncherIntegrationState
}
// LauncherIntegrationState is the launch-owned status for one launcher integration.
type LauncherIntegrationState struct {
Name string
DisplayName string
Description string
Installed bool
AutoInstallable bool
Selectable bool
Changeable bool
CurrentModel string
ModelUsable bool
InstallHint string
Editor bool
}
// RunModelRequest controls how the root launcher resolves the chat model.
type RunModelRequest struct {
ForcePicker bool
Policy *LaunchPolicy
}
// LaunchConfirmMode controls confirmation behavior across launch flows.
type LaunchConfirmMode int
const (
// LaunchConfirmPrompt prompts the user for confirmation.
LaunchConfirmPrompt LaunchConfirmMode = iota
// LaunchConfirmAutoApprove skips prompts and treats confirmation as accepted.
LaunchConfirmAutoApprove
// LaunchConfirmRequireYes rejects confirmation requests with a --yes hint.
LaunchConfirmRequireYes
)
// LaunchMissingModelMode controls local missing-model handling in launch flows.
type LaunchMissingModelMode int
const (
// LaunchMissingModelPromptToPull prompts to pull a missing local model.
LaunchMissingModelPromptToPull LaunchMissingModelMode = iota
// LaunchMissingModelAutoPull pulls a missing local model without prompting.
LaunchMissingModelAutoPull
// LaunchMissingModelFail fails immediately when a local model is missing.
LaunchMissingModelFail
)
// LaunchPolicy controls launch behavior that may vary by caller context.
type LaunchPolicy struct {
Confirm LaunchConfirmMode
MissingModel LaunchMissingModelMode
}
func defaultLaunchPolicy(interactive bool, yes bool) LaunchPolicy {
policy := LaunchPolicy{
Confirm: LaunchConfirmPrompt,
MissingModel: LaunchMissingModelPromptToPull,
}
switch {
case yes:
// if yes flag is set, auto approve and auto pull
policy.Confirm = LaunchConfirmAutoApprove
policy.MissingModel = LaunchMissingModelAutoPull
case !interactive:
// otherwise make sure to stop when needed
policy.Confirm = LaunchConfirmRequireYes
policy.MissingModel = LaunchMissingModelFail
}
return policy
}
func (p LaunchPolicy) confirmPolicy() launchConfirmPolicy {
switch p.Confirm {
case LaunchConfirmAutoApprove:
return launchConfirmPolicy{yes: true}
case LaunchConfirmRequireYes:
return launchConfirmPolicy{requireYesMessage: true}
default:
return launchConfirmPolicy{}
}
}
func (p LaunchPolicy) missingModelPolicy() missingModelPolicy {
switch p.MissingModel {
case LaunchMissingModelAutoPull:
return missingModelAutoPull
case LaunchMissingModelFail:
return missingModelFail
default:
return missingModelPromptPull
}
}
// IntegrationLaunchRequest controls the canonical integration launcher flow.
type IntegrationLaunchRequest struct {
Name string
ModelOverride string
ForceConfigure bool
ConfigureOnly bool
ExtraArgs []string
Policy *LaunchPolicy
}
var isInteractiveSession = func() bool {
return term.IsTerminal(int(os.Stdin.Fd())) && term.IsTerminal(int(os.Stdout.Fd()))
}
// Runner executes a model with an integration.
type Runner interface {
Run(model string, args []string) error
String() string
}
// Editor can edit config files for integrations that support model configuration.
type Editor interface {
Paths() []string
Edit(models []string) error
Models() []string
}
type modelInfo struct {
Name string
Remote bool
ToolCapable bool
}
// ModelInfo re-exports launcher model inventory details for callers.
type ModelInfo = modelInfo
// ModelItem represents a model for selection UIs.
type ModelItem struct {
Name string
Description string
Recommended bool
}
// LaunchCmd returns the cobra command for launching integrations.
// The runTUI callback is called when the root launcher UI should be shown.
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error, runTUI func(cmd *cobra.Command)) *cobra.Command {
var modelFlag string
var configFlag bool
var yesFlag bool
cmd := &cobra.Command{
Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
Short: "Launch the Ollama menu or an integration",
Long: `Launch the Ollama interactive menu, or directly launch a specific integration.
Without arguments, this is equivalent to running 'ollama' directly.
Flags and extra arguments require an integration name.
Supported integrations:
claude Claude Code
cline Cline
codex Codex
droid Droid
opencode OpenCode
openclaw OpenClaw (aliases: clawdbot, moltbot)
pi Pi
vscode    VS Code (aliases: code)
Examples:
ollama launch
ollama launch claude
ollama launch claude --model <model>
ollama launch droid --config (does not auto-launch)
ollama launch codex -- -p myprofile (pass extra args to integration)
ollama launch codex -- --sandbox workspace-write`,
Args: cobra.ArbitraryArgs,
PreRunE: checkServerHeartbeat,
RunE: func(cmd *cobra.Command, args []string) error {
policy := defaultLaunchPolicy(isInteractiveSession(), yesFlag)
// reset when done to make sure state doens't leak between launches
restoreConfirmPolicy := withLaunchConfirmPolicy(policy.confirmPolicy())
defer restoreConfirmPolicy()
var name string
var passArgs []string
dashIdx := cmd.ArgsLenAtDash()
if dashIdx == -1 {
if len(args) > 1 {
return fmt.Errorf("unexpected arguments: %v\nUse '--' to pass extra arguments to the integration", args[1:])
}
if len(args) == 1 {
name = args[0]
}
} else {
if dashIdx > 1 {
return fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx)
}
if dashIdx == 1 {
name = args[0]
}
passArgs = args[dashIdx:]
}
if name == "" {
if cmd.Flags().Changed("model") || cmd.Flags().Changed("config") || cmd.Flags().Changed("yes") || len(passArgs) > 0 {
return fmt.Errorf("flags and extra args require an integration name, for example: 'ollama launch claude --model qwen3.5'")
}
runTUI(cmd)
return nil
}
if modelFlag != "" && isCloudModelName(modelFlag) {
if client, err := api.ClientFromEnvironment(); err == nil {
if disabled, _ := cloudStatusDisabled(cmd.Context(), client); disabled {
fmt.Fprintf(os.Stderr, "Warning: ignoring --model %s because cloud is disabled\n", modelFlag)
modelFlag = ""
}
}
}
headlessYes := yesFlag && !isInteractiveSession()
err := LaunchIntegration(cmd.Context(), IntegrationLaunchRequest{
Name: name,
ModelOverride: modelFlag,
ForceConfigure: configFlag || (modelFlag == "" && !headlessYes),
ConfigureOnly: configFlag,
ExtraArgs: passArgs,
Policy: &policy,
})
if errors.Is(err, ErrCancelled) {
return nil
}
return err
},
}
cmd.Flags().StringVar(&modelFlag, "model", "", "Model to use")
cmd.Flags().BoolVar(&configFlag, "config", false, "Configure without launching")
cmd.Flags().BoolVarP(&yesFlag, "yes", "y", false, "Automatically answer yes to confirmation prompts")
return cmd
}
type launcherClient struct {
apiClient *api.Client
modelInventory []ModelInfo
inventoryLoaded bool
policy LaunchPolicy
}
func newLauncherClient(policy LaunchPolicy) (*launcherClient, error) {
apiClient, err := api.ClientFromEnvironment()
if err != nil {
return nil, err
}
return &launcherClient{
apiClient: apiClient,
policy: policy,
}, nil
}
// BuildLauncherState returns the launch-owned root launcher menu snapshot.
func BuildLauncherState(ctx context.Context) (*LauncherState, error) {
launchClient, err := newLauncherClient(defaultLaunchPolicy(isInteractiveSession(), false))
if err != nil {
return nil, err
}
return launchClient.buildLauncherState(ctx)
}
// ResolveRunModel returns the model that should be used for interactive chat.
func ResolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
// Called by the launcher TUI "Run a model" action (cmd/runLauncherAction),
// which resolves models separately from LaunchIntegration. Callers can pass
// Policy directly; otherwise we fall back to ambient --yes/session defaults.
policy := defaultLaunchPolicy(isInteractiveSession(), currentLaunchConfirmPolicy.yes)
if req.Policy != nil {
policy = *req.Policy
}
launchClient, err := newLauncherClient(policy)
if err != nil {
return "", err
}
return launchClient.resolveRunModel(ctx, req)
}
// LaunchIntegration runs the canonical launcher flow for one integration.
func LaunchIntegration(ctx context.Context, req IntegrationLaunchRequest) error {
name, runner, err := LookupIntegration(req.Name)
if err != nil {
return err
}
if !req.ConfigureOnly {
if err := EnsureIntegrationInstalled(name, runner); err != nil {
return err
}
}
var policy LaunchPolicy
// TUI does not set a policy, whereas ollama launch <app> does as it can have flags which change the behavior
if req.Policy == nil {
policy = defaultLaunchPolicy(isInteractiveSession(), false)
} else {
policy = *req.Policy
}
launchClient, err := newLauncherClient(policy)
if err != nil {
return err
}
saved, _ := loadStoredIntegrationConfig(name)
// In headless --yes mode we cannot prompt, so require an explicit --model.
if policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() && req.ModelOverride == "" {
return fmt.Errorf("headless --yes launch for %s requires --model <model>", name)
}
if editor, ok := runner.(Editor); ok {
return launchClient.launchEditorIntegration(ctx, name, runner, editor, saved, req)
}
return launchClient.launchSingleIntegration(ctx, name, runner, saved, req)
}
func (c *launcherClient) buildLauncherState(ctx context.Context) (*LauncherState, error) {
_ = c.loadModelInventoryOnce(ctx)
state := &LauncherState{
LastSelection: config.LastSelection(),
RunModel: config.LastModel(),
Integrations: make(map[string]LauncherIntegrationState),
}
runModelUsable, err := c.savedModelUsable(ctx, state.RunModel)
if err != nil {
runModelUsable = false
}
state.RunModelUsable = runModelUsable
for _, info := range ListIntegrationInfos() {
integrationState, err := c.buildLauncherIntegrationState(ctx, info)
if err != nil {
return nil, err
}
state.Integrations[info.Name] = integrationState
}
return state, nil
}
func (c *launcherClient) buildLauncherIntegrationState(ctx context.Context, info IntegrationInfo) (LauncherIntegrationState, error) {
integration, err := integrationFor(info.Name)
if err != nil {
return LauncherIntegrationState{}, err
}
currentModel, usable, err := c.launcherModelState(ctx, info.Name, integration.editor)
if err != nil {
return LauncherIntegrationState{}, err
}
return LauncherIntegrationState{
Name: info.Name,
DisplayName: info.DisplayName,
Description: info.Description,
Installed: integration.installed,
AutoInstallable: integration.autoInstallable,
Selectable: integration.installed || integration.autoInstallable,
Changeable: integration.installed || integration.autoInstallable,
CurrentModel: currentModel,
ModelUsable: usable,
InstallHint: integration.installHint,
Editor: integration.editor,
}, nil
}
func (c *launcherClient) launcherModelState(ctx context.Context, name string, isEditor bool) (string, bool, error) {
cfg, loadErr := loadStoredIntegrationConfig(name)
hasModels := loadErr == nil && len(cfg.Models) > 0
if !hasModels {
return "", false, nil
}
if isEditor {
filtered := c.filterDisabledCloudModels(ctx, cfg.Models)
if len(filtered) > 0 {
return filtered[0], true, nil
}
return cfg.Models[0], false, nil
}
model := cfg.Models[0]
usable, usableErr := c.savedModelUsable(ctx, model)
return model, usableErr == nil && usable, nil
}
func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
current := config.LastModel()
if !req.ForcePicker && current != "" && c.policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() {
if err := c.ensureModelsReady(ctx, []string{current}); err != nil {
return "", err
}
fmt.Fprintf(os.Stderr, "Headless mode: auto-selected last used model %q\n", current)
return current, nil
}
if !req.ForcePicker {
usable, err := c.savedModelUsable(ctx, current)
if err != nil {
return "", err
}
if usable {
if err := c.ensureModelsReady(ctx, []string{current}); err != nil {
return "", err
}
return current, nil
}
}
model, err := c.selectSingleModelWithSelector(ctx, "Select model to run:", current, DefaultSingleSelector)
if err != nil {
return "", err
}
if model != current {
if err := config.SetLastModel(model); err != nil {
return "", err
}
}
return model, nil
}
func (c *launcherClient) launchSingleIntegration(ctx context.Context, name string, runner Runner, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
current := primaryModelFromConfig(saved)
target := req.ModelOverride
needsConfigure := req.ForceConfigure
if target == "" {
target = current
usable, err := c.savedModelUsable(ctx, target)
if err != nil {
return err
}
if !usable {
needsConfigure = true
}
}
if needsConfigure {
selected, err := c.selectSingleModelWithSelector(ctx, fmt.Sprintf("Select model for %s:", runner), target, DefaultSingleSelector)
if err != nil {
return err
}
target = selected
} else if err := c.ensureModelsReady(ctx, []string{target}); err != nil {
return err
}
if target == "" {
return nil
}
if target != current {
if err := config.SaveIntegration(name, []string{target}); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
}
return launchAfterConfiguration(name, runner, target, req)
}
func (c *launcherClient) launchEditorIntegration(ctx context.Context, name string, runner Runner, editor Editor, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
models, needsConfigure := c.resolveEditorLaunchModels(ctx, saved, req)
if needsConfigure {
selected, err := c.selectMultiModelsForIntegration(ctx, runner, models)
if err != nil {
return err
}
models = selected
} else if len(models) > 0 {
if err := c.ensureModelsReady(ctx, models[:1]); err != nil {
return err
}
}
if len(models) == 0 {
return nil
}
if needsConfigure || req.ModelOverride != "" {
if err := prepareEditorIntegration(name, runner, editor, models); err != nil {
return err
}
}
return launchAfterConfiguration(name, runner, models[0], req)
}
func (c *launcherClient) selectSingleModelWithSelector(ctx context.Context, title, current string, selector SingleSelector) (string, error) {
if selector == nil {
return "", fmt.Errorf("no selector configured")
}
items, _, err := c.loadSelectableModels(ctx, nil, current, "no models available, run 'ollama pull <model>' first")
if err != nil {
return "", err
}
selected, err := selector(title, items, current)
if err != nil {
return "", err
}
if err := c.ensureModelsReady(ctx, []string{selected}); err != nil {
return "", err
}
return selected, nil
}
func (c *launcherClient) selectMultiModelsForIntegration(ctx context.Context, runner Runner, preChecked []string) ([]string, error) {
if DefaultMultiSelector == nil {
return nil, fmt.Errorf("no selector configured")
}
current := firstModel(preChecked)
items, orderedChecked, err := c.loadSelectableModels(ctx, preChecked, current, "no models available")
if err != nil {
return nil, err
}
if len(preChecked) > 0 {
// Keep list order stable in multi-select even when there are existing checks.
// checked/default state still comes from orderedChecked.
stableItems, _, stableErr := c.loadSelectableModels(ctx, nil, current, "no models available")
if stableErr != nil {
return nil, stableErr
}
items = stableItems
}
selected, err := DefaultMultiSelector(fmt.Sprintf("Select models for %s:", runner), items, orderedChecked)
if err != nil {
return nil, err
}
accepted, skipped, err := c.selectReadyModelsForSave(ctx, selected)
if err != nil {
return nil, err
}
for _, skip := range skipped {
fmt.Fprintf(os.Stderr, "Skipped %s: %s\n", skip.model, skip.reason)
}
return accepted, nil
}
func (c *launcherClient) loadSelectableModels(ctx context.Context, preChecked []string, current, emptyMessage string) ([]ModelItem, []string, error) {
if err := c.loadModelInventoryOnce(ctx); err != nil {
return nil, nil, err
}
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
items, orderedChecked, _, _ := buildModelList(c.modelInventory, preChecked, current)
if cloudDisabled {
items = filterCloudItems(items)
orderedChecked = c.filterDisabledCloudModels(ctx, orderedChecked)
}
if len(items) == 0 {
return nil, nil, errors.New(emptyMessage)
}
return items, orderedChecked, nil
}
func (c *launcherClient) ensureModelsReady(ctx context.Context, models []string) error {
models = dedupeModelList(models)
if len(models) == 0 {
return nil
}
cloudModels := make(map[string]bool, len(models))
for _, model := range models {
isCloudModel := isCloudModelName(model)
if isCloudModel {
cloudModels[model] = true
}
if err := showOrPullWithPolicy(ctx, c.apiClient, model, c.policy.missingModelPolicy(), isCloudModel); err != nil {
return err
}
}
return ensureAuth(ctx, c.apiClient, cloudModels, models)
}
func dedupeModelList(models []string) []string {
deduped := make([]string, 0, len(models))
seen := make(map[string]bool, len(models))
for _, model := range models {
if model == "" || seen[model] {
continue
}
seen[model] = true
deduped = append(deduped, model)
}
return deduped
}
type skippedModel struct {
model string
reason string
}
func (c *launcherClient) selectReadyModelsForSave(ctx context.Context, selected []string) ([]string, []skippedModel, error) {
selected = dedupeModelList(selected)
accepted := make([]string, 0, len(selected))
skipped := make([]skippedModel, 0, len(selected))
for _, model := range selected {
if err := c.ensureModelsReady(ctx, []string{model}); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, nil, err
}
skipped = append(skipped, skippedModel{
model: model,
reason: skippedModelReason(model, err),
})
continue
}
accepted = append(accepted, model)
}
return accepted, skipped, nil
}
func skippedModelReason(model string, err error) string {
if errors.Is(err, ErrCancelled) {
if isCloudModelName(model) {
return "sign in was cancelled"
}
return "download was cancelled"
}
return err.Error()
}
func (c *launcherClient) resolveEditorLaunchModels(ctx context.Context, saved *config.IntegrationConfig, req IntegrationLaunchRequest) ([]string, bool) {
if req.ForceConfigure {
return editorPreCheckedModels(saved, req.ModelOverride), true
}
if req.ModelOverride != "" {
models := append([]string{req.ModelOverride}, additionalSavedModels(saved, req.ModelOverride)...)
models = c.filterDisabledCloudModels(ctx, models)
return models, len(models) == 0
}
if saved == nil || len(saved.Models) == 0 {
return nil, true
}
models := c.filterDisabledCloudModels(ctx, saved.Models)
return models, len(models) == 0
}
func (c *launcherClient) filterDisabledCloudModels(ctx context.Context, models []string) []string {
// if connection cannot be established or there is a 404, cloud models will continue to be displayed
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
if !cloudDisabled {
return append([]string(nil), models...)
}
filtered := make([]string, 0, len(models))
for _, model := range models {
if !isCloudModelName(model) {
filtered = append(filtered, model)
}
}
return filtered
}
func (c *launcherClient) savedModelUsable(ctx context.Context, name string) (bool, error) {
if err := c.loadModelInventoryOnce(ctx); err != nil {
return c.showBasedModelUsable(ctx, name)
}
return c.singleModelUsable(ctx, name), nil
}
func (c *launcherClient) showBasedModelUsable(ctx context.Context, name string) (bool, error) {
if name == "" {
return false, nil
}
info, err := c.apiClient.Show(ctx, &api.ShowRequest{Model: name})
if err != nil {
var statusErr api.StatusError
if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
return false, nil
}
return false, err
}
if isCloudModelName(name) || info.RemoteModel != "" {
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
return !cloudDisabled, nil
}
return true, nil
}
func (c *launcherClient) singleModelUsable(ctx context.Context, name string) bool {
if name == "" {
return false
}
if isCloudModelName(name) {
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
return !cloudDisabled
}
return c.hasLocalModel(name)
}
func (c *launcherClient) hasLocalModel(name string) bool {
for _, model := range c.modelInventory {
if model.Remote {
continue
}
if model.Name == name || strings.HasPrefix(model.Name, name+":") {
return true
}
}
return false
}
func (c *launcherClient) loadModelInventoryOnce(ctx context.Context) error {
if c.inventoryLoaded {
return nil
}
resp, err := c.apiClient.List(ctx)
if err != nil {
return err
}
c.modelInventory = c.modelInventory[:0]
for _, model := range resp.Models {
c.modelInventory = append(c.modelInventory, ModelInfo{
Name: model.Name,
Remote: model.RemoteModel != "",
})
}
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
if cloudDisabled {
c.modelInventory = filterCloudModels(c.modelInventory)
}
c.inventoryLoaded = true
return nil
}
func runIntegration(runner Runner, modelName string, args []string) error {
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", runner, modelName)
return runner.Run(modelName, args)
}
func launchAfterConfiguration(name string, runner Runner, model string, req IntegrationLaunchRequest) error {
if req.ConfigureOnly {
launch, err := ConfirmPrompt(fmt.Sprintf("Launch %s now?", runner))
if err != nil {
return err
}
if !launch {
return nil
}
}
if err := EnsureIntegrationInstalled(name, runner); err != nil {
return err
}
return runIntegration(runner, model, req.ExtraArgs)
}
func loadStoredIntegrationConfig(name string) (*config.IntegrationConfig, error) {
cfg, err := config.LoadIntegration(name)
if err == nil {
return cfg, nil
}
if !errors.Is(err, os.ErrNotExist) {
return nil, err
}
spec, specErr := LookupIntegrationSpec(name)
if specErr != nil {
return nil, err
}
for _, alias := range spec.Aliases {
legacy, legacyErr := config.LoadIntegration(alias)
if legacyErr == nil {
migrateLegacyIntegrationConfig(spec.Name, legacy)
if migrated, migratedErr := config.LoadIntegration(spec.Name); migratedErr == nil {
return migrated, nil
}
return legacy, nil
}
if legacyErr != nil && !errors.Is(legacyErr, os.ErrNotExist) {
return nil, legacyErr
}
}
return nil, err
}
func migrateLegacyIntegrationConfig(canonical string, legacy *config.IntegrationConfig) {
if legacy == nil {
return
}
_ = config.SaveIntegration(canonical, append([]string(nil), legacy.Models...))
if len(legacy.Aliases) > 0 {
_ = config.SaveAliases(canonical, cloneAliases(legacy.Aliases))
}
if legacy.Onboarded {
_ = config.MarkIntegrationOnboarded(canonical)
}
}
func primaryModelFromConfig(cfg *config.IntegrationConfig) string {
if cfg == nil || len(cfg.Models) == 0 {
return ""
}
return cfg.Models[0]
}
func cloneAliases(aliases map[string]string) map[string]string {
if len(aliases) == 0 {
return make(map[string]string)
}
cloned := make(map[string]string, len(aliases))
for key, value := range aliases {
cloned[key] = value
}
return cloned
}
func firstModel(models []string) string {
if len(models) == 0 {
return ""
}
return models[0]
}
func editorPreCheckedModels(saved *config.IntegrationConfig, override string) []string {
if override == "" {
if saved == nil {
return nil
}
return append([]string(nil), saved.Models...)
}
return append([]string{override}, additionalSavedModels(saved, override)...)
}
func additionalSavedModels(saved *config.IntegrationConfig, exclude string) []string {
if saved == nil {
return nil
}
var models []string
for _, model := range saved.Models {
if model != exclude {
models = append(models, model)
}
}
return models
}

1990
cmd/launch/launch_test.go Normal file

File diff suppressed because it is too large Load Diff

494
cmd/launch/models.go Normal file
View File

@@ -0,0 +1,494 @@
package launch
import (
"context"
"errors"
"fmt"
"net/http"
"os"
"os/exec"
"runtime"
"slices"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/cmd/internal/fileutil"
internalcloud "github.com/ollama/ollama/internal/cloud"
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/progress"
)
var recommendedModels = []ModelItem{
{Name: "kimi-k2.5:cloud", Description: "Multimodal reasoning with subagents", Recommended: true},
{Name: "qwen3.5:cloud", Description: "Reasoning, coding, and agentic tool use with vision", Recommended: true},
{Name: "glm-5:cloud", Description: "Reasoning and code generation", Recommended: true},
{Name: "minimax-m2.7:cloud", Description: "Fast, efficient coding and real-world productivity", Recommended: true},
{Name: "glm-4.7-flash", Description: "Reasoning and code generation locally", Recommended: true},
{Name: "qwen3.5", Description: "Reasoning, coding, and visual understanding locally", Recommended: true},
}
var recommendedVRAM = map[string]string{
"glm-4.7-flash": "~25GB",
"qwen3.5": "~11GB",
}
// cloudModelLimit holds context and output token limits for a cloud model.
type cloudModelLimit struct {
Context int
Output int
}
// cloudModelLimits maps cloud model base names to their token limits.
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
var cloudModelLimits = map[string]cloudModelLimit{
"minimax-m2.7": {Context: 204_800, Output: 128_000},
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
"glm-4.6": {Context: 202_752, Output: 131_072},
"glm-4.7": {Context: 202_752, Output: 131_072},
"glm-5": {Context: 202_752, Output: 131_072},
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
"kimi-k2.5": {Context: 262_144, Output: 262_144},
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
"qwen3.5": {Context: 262_144, Output: 32_768},
}
// lookupCloudModelLimit returns the token limits for a cloud model.
// It normalizes explicit cloud source suffixes before checking the shared limit map.
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
base, stripped := modelref.StripCloudSourceTag(name)
if stripped {
if l, ok := cloudModelLimits[base]; ok {
return l, true
}
}
return cloudModelLimit{}, false
}
// missingModelPolicy controls how model-not-found errors should be handled.
type missingModelPolicy int
const (
// missingModelPromptPull prompts the user to download missing local models.
missingModelPromptPull missingModelPolicy = iota
// missingModelAutoPull downloads missing local models without prompting.
missingModelAutoPull
// missingModelFail returns an error for missing local models without prompting.
missingModelFail
)
// OpenBrowser opens the URL in the user's browser.
func OpenBrowser(url string) {
switch runtime.GOOS {
case "darwin":
_ = exec.Command("open", url).Start()
case "linux":
// Skip on headless systems where no display server is available
if os.Getenv("DISPLAY") == "" && os.Getenv("WAYLAND_DISPLAY") == "" {
return
}
_ = exec.Command("xdg-open", url).Start()
case "windows":
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
}
}
// ensureAuth ensures the user is signed in before cloud-backed models run.
func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error {
var selectedCloudModels []string
for _, m := range selected {
if cloudModels[m] {
selectedCloudModels = append(selectedCloudModels, m)
}
}
if len(selectedCloudModels) == 0 {
return nil
}
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
}
user, err := client.Whoami(ctx)
if err == nil && user != nil && user.Name != "" {
return nil
}
var aErr api.AuthorizationError
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
return err
}
modelList := strings.Join(selectedCloudModels, ", ")
if DefaultSignIn != nil {
_, err := DefaultSignIn(modelList, aErr.SigninURL)
if errors.Is(err, ErrCancelled) {
return ErrCancelled
}
if err != nil {
return fmt.Errorf("%s requires sign in", modelList)
}
return nil
}
yes, err := ConfirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
if errors.Is(err, ErrCancelled) {
return ErrCancelled
}
if err != nil {
return err
}
if !yes {
return ErrCancelled
}
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
OpenBrowser(aErr.SigninURL)
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\r\033[K")
return ctx.Err()
case <-ticker.C:
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
if frame%10 == 0 {
u, err := client.Whoami(ctx)
if err == nil && u != nil && u.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
return nil
}
}
}
}
}
// showOrPullWithPolicy checks if a model exists and applies the provided missing-model policy.
func showOrPullWithPolicy(ctx context.Context, client *api.Client, model string, policy missingModelPolicy, isCloudModel bool) error {
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
return nil
} else {
var statusErr api.StatusError
if !errors.As(err, &statusErr) || statusErr.StatusCode != http.StatusNotFound {
return err
}
}
if isCloudModel {
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
}
return fmt.Errorf("model %q not found", model)
}
switch policy {
case missingModelAutoPull:
return pullMissingModel(ctx, client, model)
case missingModelFail:
return fmt.Errorf("model %q not found; run 'ollama pull %s' first, or use --yes to auto-pull", model, model)
default:
return confirmAndPull(ctx, client, model)
}
}
func confirmAndPull(ctx context.Context, client *api.Client, model string) error {
if ok, err := ConfirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
return err
} else if !ok {
return errCancelled
}
fmt.Fprintf(os.Stderr, "\n")
return pullMissingModel(ctx, client, model)
}
func pullMissingModel(ctx context.Context, client *api.Client, model string) error {
if err := pullModel(ctx, client, model, false); err != nil {
return fmt.Errorf("failed to pull %s: %w", model, err)
}
return nil
}
// prepareEditorIntegration persists models and applies editor-managed config files.
func prepareEditorIntegration(name string, runner Runner, editor Editor, models []string) error {
if ok, err := confirmEditorEdit(runner, editor); err != nil {
return err
} else if !ok {
return errCancelled
}
if err := editor.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
if err := config.SaveIntegration(name, models); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
return nil
}
func confirmEditorEdit(runner Runner, editor Editor) (bool, error) {
paths := editor.Paths()
if len(paths) == 0 {
return true, nil
}
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", runner)
for _, path := range paths {
fmt.Fprintf(os.Stderr, " %s\n", path)
}
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", fileutil.BackupDir())
return ConfirmPrompt("Proceed?")
}
// buildModelList merges existing models with recommendations for selection UIs.
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
existingModels = make(map[string]bool)
cloudModels = make(map[string]bool)
recommended := make(map[string]bool)
var hasLocalModel, hasCloudModel bool
recDesc := make(map[string]string)
for _, rec := range recommendedModels {
recommended[rec.Name] = true
recDesc[rec.Name] = rec.Description
}
for _, m := range existing {
existingModels[m.Name] = true
if m.Remote {
cloudModels[m.Name] = true
hasCloudModel = true
} else {
hasLocalModel = true
}
displayName := strings.TrimSuffix(m.Name, ":latest")
existingModels[displayName] = true
item := ModelItem{Name: displayName, Recommended: recommended[displayName], Description: recDesc[displayName]}
items = append(items, item)
}
for _, rec := range recommendedModels {
if existingModels[rec.Name] || existingModels[rec.Name+":latest"] {
continue
}
items = append(items, rec)
if isCloudModelName(rec.Name) {
cloudModels[rec.Name] = true
}
}
checked := make(map[string]bool, len(preChecked))
for _, n := range preChecked {
checked[n] = true
}
if current != "" {
matchedCurrent := false
for _, item := range items {
if item.Name == current {
current = item.Name
matchedCurrent = true
break
}
}
if !matchedCurrent {
for _, item := range items {
if strings.HasPrefix(item.Name, current+":") {
current = item.Name
break
}
}
}
}
if checked[current] {
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
}
notInstalled := make(map[string]bool)
for i := range items {
if !existingModels[items[i].Name] && !cloudModels[items[i].Name] {
notInstalled[items[i].Name] = true
var parts []string
if items[i].Description != "" {
parts = append(parts, items[i].Description)
}
if vram := recommendedVRAM[items[i].Name]; vram != "" {
parts = append(parts, vram)
}
parts = append(parts, "(not downloaded)")
items[i].Description = strings.Join(parts, ", ")
}
}
recRank := make(map[string]int)
for i, rec := range recommendedModels {
recRank[rec.Name] = i + 1
}
onlyLocal := hasLocalModel && !hasCloudModel
if hasLocalModel || hasCloudModel {
slices.SortStableFunc(items, func(a, b ModelItem) int {
ac, bc := checked[a.Name], checked[b.Name]
aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
aRec, bRec := recRank[a.Name] > 0, recRank[b.Name] > 0
aCloud, bCloud := cloudModels[a.Name], cloudModels[b.Name]
if ac != bc {
if ac {
return -1
}
return 1
}
if aRec != bRec {
if aRec {
return -1
}
return 1
}
if aRec && bRec {
if aCloud != bCloud {
if onlyLocal {
if aCloud {
return 1
}
return -1
}
if aCloud {
return -1
}
return 1
}
return recRank[a.Name] - recRank[b.Name]
}
if aNew != bNew {
if aNew {
return 1
}
return -1
}
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
})
}
return items, preChecked, existingModels, cloudModels
}
// isCloudModelName reports whether the model name has an explicit cloud source.
func isCloudModelName(name string) bool {
return modelref.HasExplicitCloudSource(name)
}
// filterCloudModels drops remote-only models from the given inventory.
func filterCloudModels(existing []modelInfo) []modelInfo {
filtered := existing[:0]
for _, m := range existing {
if !m.Remote {
filtered = append(filtered, m)
}
}
return filtered
}
// filterCloudItems removes cloud models from selection items.
func filterCloudItems(items []ModelItem) []ModelItem {
filtered := items[:0]
for _, item := range items {
if !isCloudModelName(item.Name) {
filtered = append(filtered, item)
}
}
return filtered
}
func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
if client == nil {
return false
}
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
if err != nil {
return false
}
return resp.RemoteModel != ""
}
// cloudStatusDisabled returns whether cloud usage is currently disabled.
func cloudStatusDisabled(ctx context.Context, client *api.Client) (disabled bool, known bool) {
status, err := client.CloudStatusExperimental(ctx)
if err != nil {
var statusErr api.StatusError
if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
return false, false
}
return false, false
}
return status.Cloud.Disabled, true
}
// TODO(parthsareen): this duplicates the pull progress UI in cmd.PullHandler.
// Move the shared pull rendering to a small utility once the package boundary settles.
func pullModel(ctx context.Context, client *api.Client, model string, insecure bool) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()
bars := make(map[string]*progress.Bar)
var status string
var spinner *progress.Spinner
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
if resp.Completed == 0 {
return nil
}
if spinner != nil {
spinner.Stop()
}
bar, ok := bars[resp.Digest]
if !ok {
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
name = strings.TrimSpace(name)
if isDigest {
name = name[:min(12, len(name))]
}
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
bar.Set(resp.Completed)
} else if status != resp.Status {
if spinner != nil {
spinner.Stop()
}
status = resp.Status
spinner = progress.NewSpinner(status)
p.Add(status, spinner)
}
return nil
}
request := api.PullRequest{Name: model, Insecure: insecure}
return client.Pull(ctx, &request, fn)
}

View File

@@ -1,4 +1,4 @@
package config package launch
import ( import (
"context" "context"
@@ -14,7 +14,10 @@ import (
"strings" "strings"
"time" "time"
"golang.org/x/mod/semver"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
@@ -24,6 +27,9 @@ const defaultGatewayPort = 18789
// Bound model capability probing so launch/config cannot hang on slow/unreachable API calls. // Bound model capability probing so launch/config cannot hang on slow/unreachable API calls.
var openclawModelShowTimeout = 5 * time.Second var openclawModelShowTimeout = 5 * time.Second
// openclawFreshInstall is set to true when ensureOpenclawInstalled performs an install
var openclawFreshInstall bool
type Openclaw struct{} type Openclaw struct{}
func (c *Openclaw) String() string { return "OpenClaw" } func (c *Openclaw) String() string { return "OpenClaw" }
@@ -34,10 +40,7 @@ func (c *Openclaw) Run(model string, args []string) error {
return err return err
} }
firstLaunch := true firstLaunch := !c.onboarded()
if integrationConfig, err := loadIntegration("openclaw"); err == nil {
firstLaunch = !integrationConfig.Onboarded
}
if firstLaunch { if firstLaunch {
fmt.Fprintf(os.Stderr, "\n%sSecurity%s\n\n", ansiBold, ansiReset) fmt.Fprintf(os.Stderr, "\n%sSecurity%s\n\n", ansiBold, ansiReset)
@@ -45,28 +48,46 @@ func (c *Openclaw) Run(model string, args []string) error {
fmt.Fprintf(os.Stderr, " A bad prompt can trick it into doing unsafe things.\n\n") fmt.Fprintf(os.Stderr, " A bad prompt can trick it into doing unsafe things.\n\n")
fmt.Fprintf(os.Stderr, "%s Learn more: https://docs.openclaw.ai/gateway/security%s\n\n", ansiGray, ansiReset) fmt.Fprintf(os.Stderr, "%s Learn more: https://docs.openclaw.ai/gateway/security%s\n\n", ansiGray, ansiReset)
ok, err := confirmPrompt("I understand the risks. Continue?") ok, err := ConfirmPrompt("I understand the risks. Continue?")
if err != nil { if err != nil {
return err return err
} }
if !ok { if !ok {
return nil return nil
} }
}
if !c.onboarded() { // Ensure the latest version is installed before onboarding so we get
// the newest wizard flags (e.g. --auth-choice ollama).
if !openclawFreshInstall {
update := exec.Command(bin, "update")
update.Stdout = os.Stdout
update.Stderr = os.Stderr
_ = update.Run() // best-effort; continue even if update fails
}
fmt.Fprintf(os.Stderr, "\n%sSetting up OpenClaw with Ollama...%s\n", ansiGreen, ansiReset) fmt.Fprintf(os.Stderr, "\n%sSetting up OpenClaw with Ollama...%s\n", ansiGreen, ansiReset)
fmt.Fprintf(os.Stderr, "%s Model: %s%s\n\n", ansiGray, model, ansiReset) fmt.Fprintf(os.Stderr, "%s Model: %s%s\n\n", ansiGray, model, ansiReset)
cmd := exec.Command(bin, "onboard", onboardArgs := []string{
"onboard",
"--non-interactive", "--non-interactive",
"--accept-risk", "--accept-risk",
"--auth-choice", "skip", "--auth-choice", "ollama",
"--gateway-token", "ollama", "--custom-base-url", envconfig.Host().String(),
"--install-daemon", "--custom-model-id", model,
"--skip-channels", "--skip-channels",
"--skip-skills", "--skip-skills",
) }
if canInstallDaemon() {
onboardArgs = append(onboardArgs, "--install-daemon")
} else {
// When we can't install a daemon (e.g. no systemd, sudo dropped
// XDG_RUNTIME_DIR, or container environment), skip the gateway
// health check so non-interactive onboarding completes. The
// gateway is started as a foreground child process after onboarding.
onboardArgs = append(onboardArgs, "--skip-health")
}
cmd := exec.Command(bin, onboardArgs...)
cmd.Stdin = os.Stdin cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
@@ -75,25 +96,13 @@ func (c *Openclaw) Run(model string, args []string) error {
} }
patchDeviceScopes() patchDeviceScopes()
// Onboarding overwrites openclaw.json, so re-apply the model config
// that Edit() wrote before Run() was called.
if err := c.Edit([]string{model}); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not re-apply model config: %v%s\n", ansiYellow, err, ansiReset)
}
} }
if strings.HasSuffix(model, ":cloud") || strings.HasSuffix(model, "-cloud") { if ensureWebSearchPlugin() {
if ensureWebSearchPlugin() { registerWebSearchPlugin()
registerWebSearchPlugin()
}
} }
if firstLaunch { fmt.Fprintf(os.Stderr, "\n%sStarting your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
fmt.Fprintf(os.Stderr, "\n%sPreparing your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
} else {
fmt.Fprintf(os.Stderr, "\n%sStarting your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
}
// When extra args are passed through, run exactly what the user asked for // When extra args are passed through, run exactly what the user asked for
// after setup and skip the built-in gateway+TUI convenience flow. // after setup and skip the built-in gateway+TUI convenience flow.
@@ -106,11 +115,6 @@ func (c *Openclaw) Run(model string, args []string) error {
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
return windowsHint(err) return windowsHint(err)
} }
if firstLaunch {
if err := integrationOnboarded("openclaw"); err != nil {
return fmt.Errorf("failed to save onboarding state: %w", err)
}
}
return nil return nil
} }
@@ -118,7 +122,7 @@ func (c *Openclaw) Run(model string, args []string) error {
addr := fmt.Sprintf("localhost:%d", port) addr := fmt.Sprintf("localhost:%d", port)
// If the gateway is already running (e.g. via the daemon), restart it // If the gateway is already running (e.g. via the daemon), restart it
// so it picks up any config changes from Edit() above (model, provider, etc.). // so it picks up any config changes (model, provider, etc.).
if portOpen(addr) { if portOpen(addr) {
restart := exec.Command(bin, "daemon", "restart") restart := exec.Command(bin, "daemon", "restart")
restart.Env = openclawEnv() restart.Env = openclawEnv()
@@ -165,11 +169,6 @@ func (c *Openclaw) Run(model string, args []string) error {
return windowsHint(err) return windowsHint(err)
} }
if firstLaunch {
if err := integrationOnboarded("openclaw"); err != nil {
return fmt.Errorf("failed to save onboarding state: %w", err)
}
}
return nil return nil
} }
@@ -409,6 +408,25 @@ func patchScopes(obj map[string]any, key string, required []string) bool {
return added return added
} }
// canInstallDaemon reports whether the openclaw daemon can be installed as a
// background service. Returns false on Linux when systemd is absent (e.g.
// containers) so that --install-daemon is omitted and the gateway is started
// as a foreground child process instead. Returns true in all other cases.
func canInstallDaemon() bool {
if runtime.GOOS != "linux" {
return true
}
// /run/systemd/system exists as a directory when systemd is the init system.
// This is absent in most containers.
fi, err := os.Stat("/run/systemd/system")
if err != nil || !fi.IsDir() {
return false
}
// Even when systemd is the init system, user services require a user
// manager instance. XDG_RUNTIME_DIR being set is a prerequisite.
return os.Getenv("XDG_RUNTIME_DIR") != ""
}
func ensureOpenclawInstalled() (string, error) { func ensureOpenclawInstalled() (string, error) {
if _, err := exec.LookPath("openclaw"); err == nil { if _, err := exec.LookPath("openclaw"); err == nil {
return "openclaw", nil return "openclaw", nil
@@ -417,16 +435,20 @@ func ensureOpenclawInstalled() (string, error) {
return "clawdbot", nil return "clawdbot", nil
} }
if _, err := exec.LookPath("npm"); err != nil { _, npmErr := exec.LookPath("npm")
return "", fmt.Errorf("openclaw is not installed and npm was not found\n\n" + _, gitErr := exec.LookPath("git")
"Install Node.js first:\n" + if npmErr != nil || gitErr != nil {
" https://nodejs.org/\n\n" + var missing []string
"Then rerun:\n" + if npmErr != nil {
" ollama launch\n" + missing = append(missing, "npm (Node.js): https://nodejs.org/")
"and select OpenClaw") }
if gitErr != nil {
missing = append(missing, "git: https://git-scm.com/")
}
return "", fmt.Errorf("openclaw is not installed and required dependencies are missing\n\nInstall the following first:\n %s", strings.Join(missing, "\n "))
} }
ok, err := confirmPrompt("OpenClaw is not installed. Install with npm?") ok, err := ConfirmPrompt("OpenClaw is not installed. Install with npm?")
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -448,6 +470,7 @@ func ensureOpenclawInstalled() (string, error) {
} }
fmt.Fprintf(os.Stderr, "%sOpenClaw installed successfully%s\n\n", ansiGreen, ansiReset) fmt.Fprintf(os.Stderr, "%sOpenClaw installed successfully%s\n\n", ansiGreen, ansiReset)
openclawFreshInstall = true
return "openclaw", nil return "openclaw", nil
} }
@@ -502,7 +525,7 @@ func (c *Openclaw) Edit(models []string) error {
ollama = make(map[string]any) ollama = make(map[string]any)
} }
ollama["baseUrl"] = envconfig.Host().String() + "/v1" ollama["baseUrl"] = envconfig.Host().String()
// needed to register provider // needed to register provider
ollama["apiKey"] = "ollama-local" ollama["apiKey"] = "ollama-local"
ollama["api"] = "ollama" ollama["api"] = "ollama"
@@ -561,7 +584,7 @@ func (c *Openclaw) Edit(models []string) error {
if err != nil { if err != nil {
return err return err
} }
if err := writeWithBackup(configPath, data); err != nil { if err := fileutil.WriteWithBackup(configPath, data); err != nil {
return err return err
} }
@@ -592,6 +615,8 @@ func clearSessionModelOverride(primary string) {
if override, _ := sess["modelOverride"].(string); override != "" && override != primary { if override, _ := sess["modelOverride"].(string); override != "" && override != primary {
delete(sess, "modelOverride") delete(sess, "modelOverride")
delete(sess, "providerOverride") delete(sess, "providerOverride")
}
if model, _ := sess["model"].(string); model != "" && model != primary {
sess["model"] = primary sess["model"] = primary
changed = true changed = true
} }
@@ -606,11 +631,15 @@ func clearSessionModelOverride(primary string) {
_ = os.WriteFile(path, out, 0o600) _ = os.WriteFile(path, out, 0o600)
} }
const webSearchNpmPackage = "@ollama/openclaw-web-search" const (
webSearchNpmPackage = "@ollama/openclaw-web-search"
webSearchMinVersion = "0.2.1"
)
// ensureWebSearchPlugin installs the openclaw-web-search extension into the // ensureWebSearchPlugin installs the openclaw-web-search extension into the
// user-level extensions directory (~/.openclaw/extensions/) if it isn't already // user-level extensions directory (~/.openclaw/extensions/) if it isn't already
// present. Returns true if the extension is available. // present, or re-installs if the installed version is older than webSearchMinVersion.
// Returns true if the extension is available.
func ensureWebSearchPlugin() bool { func ensureWebSearchPlugin() bool {
home, err := os.UserHomeDir() home, err := os.UserHomeDir()
if err != nil { if err != nil {
@@ -618,8 +647,8 @@ func ensureWebSearchPlugin() bool {
} }
pluginDir := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search") pluginDir := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
if _, err := os.Stat(filepath.Join(pluginDir, "index.ts")); err == nil { if webSearchPluginUpToDate(pluginDir) {
return true // already installed return true
} }
npmBin, err := exec.LookPath("npm") npmBin, err := exec.LookPath("npm")
@@ -653,6 +682,34 @@ func ensureWebSearchPlugin() bool {
return true return true
} }
// webSearchPluginUpToDate returns true if the plugin is installed and its
// package.json version is >= webSearchMinVersion.
func webSearchPluginUpToDate(pluginDir string) bool {
data, err := os.ReadFile(filepath.Join(pluginDir, "package.json"))
if err != nil {
return false
}
var pkg struct {
Version string `json:"version"`
}
if json.Unmarshal(data, &pkg) != nil || pkg.Version == "" {
return false
}
return !versionLessThan(pkg.Version, webSearchMinVersion)
}
// versionLessThan compares two semver version strings (major.minor.patch).
// Inputs may omit the "v" prefix; it is added automatically for semver.Compare.
func versionLessThan(a, b string) bool {
if !strings.HasPrefix(a, "v") {
a = "v" + a
}
if !strings.HasPrefix(b, "v") {
b = "v" + b
}
return semver.Compare(a, b) < 0
}
// registerWebSearchPlugin adds plugins.entries.openclaw-web-search to the OpenClaw // registerWebSearchPlugin adds plugins.entries.openclaw-web-search to the OpenClaw
// config so the gateway activates it on next start. Best-effort; silently returns // config so the gateway activates it on next start. Best-effort; silently returns
// on any error. // on any error.
@@ -679,23 +736,67 @@ func registerWebSearchPlugin() {
if entries == nil { if entries == nil {
entries = make(map[string]any) entries = make(map[string]any)
} }
if _, ok := entries["openclaw-web-search"]; ok {
return // already registered
}
entries["openclaw-web-search"] = map[string]any{"enabled": true} entries["openclaw-web-search"] = map[string]any{"enabled": true}
plugins["entries"] = entries plugins["entries"] = entries
// Pin trust so the gateway doesn't warn about untracked plugins.
allow, _ := plugins["allow"].([]any)
hasAllow := false
for _, v := range allow {
if s, ok := v.(string); ok && s == "openclaw-web-search" {
hasAllow = true
break
}
}
if !hasAllow {
allow = append(allow, "openclaw-web-search")
}
plugins["allow"] = allow
// Record install provenance so the loader can verify the plugin origin.
installs, _ := plugins["installs"].(map[string]any)
if installs == nil {
installs = make(map[string]any)
}
pluginDir := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
installs["openclaw-web-search"] = map[string]any{
"source": "npm",
"spec": webSearchNpmPackage,
"installPath": pluginDir,
}
plugins["installs"] = installs
config["plugins"] = plugins config["plugins"] = plugins
// Disable the built-in web search since our plugin replaces it. // Add plugin tools to tools.alsoAllow so they survive the coding profile's
// policy pipeline (which has an explicit allow list of core tools only).
tools, _ := config["tools"].(map[string]any) tools, _ := config["tools"].(map[string]any)
if tools == nil { if tools == nil {
tools = make(map[string]any) tools = make(map[string]any)
} }
alsoAllow, _ := tools["alsoAllow"].([]any)
needed := []string{"ollama_web_search", "ollama_web_fetch"}
have := make(map[string]bool, len(alsoAllow))
for _, v := range alsoAllow {
if s, ok := v.(string); ok {
have[s] = true
}
}
for _, name := range needed {
if !have[name] {
alsoAllow = append(alsoAllow, name)
}
}
tools["alsoAllow"] = alsoAllow
// Disable built-in web search/fetch since our plugin replaces them.
web, _ := tools["web"].(map[string]any) web, _ := tools["web"].(map[string]any)
if web == nil { if web == nil {
web = make(map[string]any) web = make(map[string]any)
} }
web["search"] = map[string]any{"enabled": false} web["search"] = map[string]any{"enabled": false}
web["fetch"] = map[string]any{"enabled": false}
tools["web"] = web tools["web"] = web
config["tools"] = tools config["tools"] = tools
@@ -776,9 +877,9 @@ func (c *Openclaw) Models() []string {
return nil return nil
} }
config, err := readJSONFile(filepath.Join(home, ".openclaw", "openclaw.json")) config, err := fileutil.ReadJSON(filepath.Join(home, ".openclaw", "openclaw.json"))
if err != nil { if err != nil {
config, err = readJSONFile(filepath.Join(home, ".clawdbot", "clawdbot.json")) config, err = fileutil.ReadJSON(filepath.Join(home, ".clawdbot", "clawdbot.json"))
if err != nil { if err != nil {
return nil return nil
} }

View File

@@ -1,4 +1,4 @@
package config package launch
import ( import (
"bytes" "bytes"
@@ -82,78 +82,6 @@ func TestOpenclawRunPassthroughArgs(t *testing.T) {
} }
} }
func TestOpenclawRunFirstLaunchPersistence(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses a POSIX shell test binary")
}
oldHook := DefaultConfirmPrompt
DefaultConfirmPrompt = func(prompt string) (bool, error) {
return true, nil
}
defer func() { DefaultConfirmPrompt = oldHook }()
t.Run("success persists onboarding flag", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
// Mark OpenClaw onboarding complete so Run takes passthrough path directly.
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
"wizard": {"lastRunAt": "2026-01-01T00:00:00Z"}
}`), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "openclaw"), []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
t.Fatal(err)
}
c := &Openclaw{}
if err := c.Run("llama3.2", []string{"gateway", "--status"}); err != nil {
t.Fatalf("Run() error = %v", err)
}
integrationConfig, err := loadIntegration("openclaw")
if err != nil {
t.Fatalf("loadIntegration() error = %v", err)
}
if !integrationConfig.Onboarded {
t.Fatal("expected onboarding flag to be persisted after successful run")
}
})
t.Run("failure does not persist onboarding flag", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
"wizard": {"lastRunAt": "2026-01-01T00:00:00Z"}
}`), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "openclaw"), []byte("#!/bin/sh\nexit 1\n"), 0o755); err != nil {
t.Fatal(err)
}
c := &Openclaw{}
if err := c.Run("llama3.2", []string{"gateway", "--status"}); err == nil {
t.Fatal("expected run failure")
}
integrationConfig, err := loadIntegration("openclaw")
if err == nil && integrationConfig.Onboarded {
t.Fatal("expected onboarding flag to remain unset after failed run")
}
})
}
func TestOpenclawEdit(t *testing.T) { func TestOpenclawEdit(t *testing.T) {
c := &Openclaw{} c := &Openclaw{}
tmpDir := t.TempDir() tmpDir := t.TempDir()
@@ -589,7 +517,7 @@ const testOpenclawFixture = `{
"providers": { "providers": {
"anthropic": {"apiKey": "xxx"}, "anthropic": {"apiKey": "xxx"},
"ollama": { "ollama": {
"baseUrl": "http://127.0.0.1:11434/v1", "baseUrl": "http://127.0.0.1:11434",
"models": [{"id": "old-model", "customField": "preserved"}] "models": [{"id": "old-model", "customField": "preserved"}]
} }
} }
@@ -1448,7 +1376,7 @@ func TestOpenclawModelConfig(t *testing.T) {
// report it as a remote/cloud model // report it as a remote/cloud model
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" { if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"minimax-m2.5"}`) fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"minimax-m2.7"}`)
return return
} }
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
@@ -1458,7 +1386,7 @@ func TestOpenclawModelConfig(t *testing.T) {
u, _ := url.Parse(srv.URL) u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client()) client := api.NewClient(u, srv.Client())
cfg, isCloud := openclawModelConfig(context.Background(), client, "minimax-m2.5:cloud") cfg, isCloud := openclawModelConfig(context.Background(), client, "minimax-m2.7:cloud")
if !isCloud { if !isCloud {
t.Error("expected isCloud = true for cloud model") t.Error("expected isCloud = true for cloud model")
@@ -1528,7 +1456,7 @@ func TestIntegrationOnboarded(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
integrationConfig, err := loadIntegration("openclaw") integrationConfig, err := LoadIntegration("openclaw")
if err == nil && integrationConfig.Onboarded { if err == nil && integrationConfig.Onboarded {
t.Error("expected false for fresh config") t.Error("expected false for fresh config")
} }
@@ -1542,7 +1470,7 @@ func TestIntegrationOnboarded(t *testing.T) {
if err := integrationOnboarded("openclaw"); err != nil { if err := integrationOnboarded("openclaw"); err != nil {
t.Fatal(err) t.Fatal(err)
} }
integrationConfig, err := loadIntegration("openclaw") integrationConfig, err := LoadIntegration("openclaw")
if err != nil || !integrationConfig.Onboarded { if err != nil || !integrationConfig.Onboarded {
t.Error("expected true after integrationOnboarded") t.Error("expected true after integrationOnboarded")
} }
@@ -1556,7 +1484,7 @@ func TestIntegrationOnboarded(t *testing.T) {
if err := integrationOnboarded("OpenClaw"); err != nil { if err := integrationOnboarded("OpenClaw"); err != nil {
t.Fatal(err) t.Fatal(err)
} }
integrationConfig, err := loadIntegration("openclaw") integrationConfig, err := LoadIntegration("openclaw")
if err != nil || !integrationConfig.Onboarded { if err != nil || !integrationConfig.Onboarded {
t.Error("expected true when set with different case") t.Error("expected true when set with different case")
} }
@@ -1575,7 +1503,7 @@ func TestIntegrationOnboarded(t *testing.T) {
} }
// Verify onboarded is set // Verify onboarded is set
integrationConfig, err := loadIntegration("openclaw") integrationConfig, err := LoadIntegration("openclaw")
if err != nil || !integrationConfig.Onboarded { if err != nil || !integrationConfig.Onboarded {
t.Error("expected true after integrationOnboarded") t.Error("expected true after integrationOnboarded")
} }
@@ -1587,3 +1515,377 @@ func TestIntegrationOnboarded(t *testing.T) {
} }
}) })
} }
func TestVersionLessThan(t *testing.T) {
tests := []struct {
a, b string
want bool
}{
{"0.1.7", "0.2.1", true},
{"0.2.0", "0.2.1", true},
{"0.2.1", "0.2.1", false},
{"0.2.2", "0.2.1", false},
{"1.0.0", "0.2.1", false},
{"0.2.1", "1.0.0", true},
{"v0.1.7", "0.2.1", true},
{"0.2.1", "v0.2.1", false},
}
for _, tt := range tests {
t.Run(tt.a+"_vs_"+tt.b, func(t *testing.T) {
if got := versionLessThan(tt.a, tt.b); got != tt.want {
t.Errorf("versionLessThan(%q, %q) = %v, want %v", tt.a, tt.b, got, tt.want)
}
})
}
}
func TestWebSearchPluginUpToDate(t *testing.T) {
t.Run("missing directory", func(t *testing.T) {
if webSearchPluginUpToDate(filepath.Join(t.TempDir(), "nonexistent")) {
t.Error("expected false for missing directory")
}
})
t.Run("missing package.json", func(t *testing.T) {
dir := t.TempDir()
if webSearchPluginUpToDate(dir) {
t.Error("expected false for missing package.json")
}
})
t.Run("old version", func(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "package.json"), []byte(`{"version":"0.1.7"}`), 0o644); err != nil {
t.Fatal(err)
}
if webSearchPluginUpToDate(dir) {
t.Error("expected false for old version 0.1.7")
}
})
t.Run("exact minimum version", func(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "package.json"), []byte(`{"version":"0.2.1"}`), 0o644); err != nil {
t.Fatal(err)
}
if !webSearchPluginUpToDate(dir) {
t.Error("expected true for exact minimum version 0.2.1")
}
})
t.Run("newer version", func(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "package.json"), []byte(`{"version":"1.0.0"}`), 0o644); err != nil {
t.Fatal(err)
}
if !webSearchPluginUpToDate(dir) {
t.Error("expected true for newer version 1.0.0")
}
})
t.Run("invalid json", func(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "package.json"), []byte(`not json`), 0o644); err != nil {
t.Fatal(err)
}
if webSearchPluginUpToDate(dir) {
t.Error("expected false for invalid json")
}
})
t.Run("empty version", func(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "package.json"), []byte(`{"version":""}`), 0o644); err != nil {
t.Fatal(err)
}
if webSearchPluginUpToDate(dir) {
t.Error("expected false for empty version")
}
})
}
func TestRegisterWebSearchPlugin(t *testing.T) {
home := t.TempDir()
setTestHome(t, home)
configDir := filepath.Join(home, ".openclaw")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
configPath := filepath.Join(configDir, "openclaw.json")
t.Run("fresh config", func(t *testing.T) {
if err := os.WriteFile(configPath, []byte(`{}`), 0o644); err != nil {
t.Fatal(err)
}
registerWebSearchPlugin()
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatal(err)
}
var config map[string]any
if err := json.Unmarshal(data, &config); err != nil {
t.Fatal(err)
}
plugins, _ := config["plugins"].(map[string]any)
if plugins == nil {
t.Fatal("plugins section missing")
}
// Check entries
entries, _ := plugins["entries"].(map[string]any)
entry, _ := entries["openclaw-web-search"].(map[string]any)
if enabled, _ := entry["enabled"].(bool); !enabled {
t.Error("expected entries.openclaw-web-search.enabled = true")
}
// Check allow list
allow, _ := plugins["allow"].([]any)
found := false
for _, v := range allow {
if s, ok := v.(string); ok && s == "openclaw-web-search" {
found = true
}
}
if !found {
t.Error("expected plugins.allow to contain openclaw-web-search")
}
// Check install provenance
installs, _ := plugins["installs"].(map[string]any)
record, _ := installs["openclaw-web-search"].(map[string]any)
if record == nil {
t.Fatal("expected plugins.installs.openclaw-web-search")
}
if source, _ := record["source"].(string); source != "npm" {
t.Errorf("install source = %q, want %q", source, "npm")
}
if spec, _ := record["spec"].(string); spec != webSearchNpmPackage {
t.Errorf("install spec = %q, want %q", spec, webSearchNpmPackage)
}
expectedPath := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
if installPath, _ := record["installPath"].(string); installPath != expectedPath {
t.Errorf("installPath = %q, want %q", installPath, expectedPath)
}
})
t.Run("idempotent", func(t *testing.T) {
if err := os.WriteFile(configPath, []byte(`{}`), 0o644); err != nil {
t.Fatal(err)
}
registerWebSearchPlugin()
registerWebSearchPlugin()
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatal(err)
}
var config map[string]any
if err := json.Unmarshal(data, &config); err != nil {
t.Fatal(err)
}
plugins, _ := config["plugins"].(map[string]any)
allow, _ := plugins["allow"].([]any)
count := 0
for _, v := range allow {
if s, ok := v.(string); ok && s == "openclaw-web-search" {
count++
}
}
if count != 1 {
t.Errorf("expected exactly 1 openclaw-web-search in allow, got %d", count)
}
})
t.Run("preserves existing config", func(t *testing.T) {
initial := map[string]any{
"plugins": map[string]any{
"allow": []any{"some-other-plugin"},
"entries": map[string]any{
"some-other-plugin": map[string]any{"enabled": true},
},
"installs": map[string]any{
"some-other-plugin": map[string]any{
"source": "npm",
"installPath": "/some/path",
},
},
},
"customField": "preserved",
}
data, _ := json.Marshal(initial)
if err := os.WriteFile(configPath, data, 0o644); err != nil {
t.Fatal(err)
}
registerWebSearchPlugin()
out, err := os.ReadFile(configPath)
if err != nil {
t.Fatal(err)
}
var config map[string]any
if err := json.Unmarshal(out, &config); err != nil {
t.Fatal(err)
}
if config["customField"] != "preserved" {
t.Error("customField was not preserved")
}
plugins, _ := config["plugins"].(map[string]any)
entries, _ := plugins["entries"].(map[string]any)
if entries["some-other-plugin"] == nil {
t.Error("existing plugin entry was lost")
}
installs, _ := plugins["installs"].(map[string]any)
if installs["some-other-plugin"] == nil {
t.Error("existing install record was lost")
}
allow, _ := plugins["allow"].([]any)
hasOther, hasWebSearch := false, false
for _, v := range allow {
s, _ := v.(string)
if s == "some-other-plugin" {
hasOther = true
}
if s == "openclaw-web-search" {
hasWebSearch = true
}
}
if !hasOther {
t.Error("existing allow entry was lost")
}
if !hasWebSearch {
t.Error("openclaw-web-search not added to allow")
}
})
}
func TestClearSessionModelOverride(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
sessionsDir := filepath.Join(tmpDir, ".openclaw", "agents", "main", "sessions")
sessionsPath := filepath.Join(sessionsDir, "sessions.json")
writeSessionsFile := func(t *testing.T, sessions map[string]map[string]any) {
t.Helper()
if err := os.MkdirAll(sessionsDir, 0o755); err != nil {
t.Fatal(err)
}
data, err := json.Marshal(sessions)
if err != nil {
t.Fatal(err)
}
if err := os.WriteFile(sessionsPath, data, 0o600); err != nil {
t.Fatal(err)
}
}
readSessionsFile := func(t *testing.T) map[string]map[string]any {
t.Helper()
data, err := os.ReadFile(sessionsPath)
if err != nil {
t.Fatalf("reading sessions file: %v", err)
}
var sessions map[string]map[string]any
if err := json.Unmarshal(data, &sessions); err != nil {
t.Fatalf("parsing sessions file: %v", err)
}
return sessions
}
t.Run("clears modelOverride and updates model", func(t *testing.T) {
writeSessionsFile(t, map[string]map[string]any{
"sess1": {"model": "ollama/old-model", "modelOverride": "old-model", "providerOverride": "ollama"},
})
clearSessionModelOverride("new-model")
sessions := readSessionsFile(t)
sess := sessions["sess1"]
if _, ok := sess["modelOverride"]; ok {
t.Error("modelOverride should have been deleted")
}
if _, ok := sess["providerOverride"]; ok {
t.Error("providerOverride should have been deleted")
}
if sess["model"] != "new-model" {
t.Errorf("model = %q, want %q", sess["model"], "new-model")
}
})
t.Run("updates model field in sessions without modelOverride", func(t *testing.T) {
// This is the bug case: session has model pointing to old primary,
// but no explicit modelOverride. After changing primary, the session
// model field must also be updated.
writeSessionsFile(t, map[string]map[string]any{
"sess1": {"model": "ollama/old-model"},
})
clearSessionModelOverride("new-model")
sessions := readSessionsFile(t)
if sessions["sess1"]["model"] != "new-model" {
t.Errorf("model = %q, want %q", sessions["sess1"]["model"], "new-model")
}
})
t.Run("does not update session already using primary", func(t *testing.T) {
writeSessionsFile(t, map[string]map[string]any{
"sess1": {"model": "current-model"},
})
clearSessionModelOverride("current-model")
sessions := readSessionsFile(t)
if sessions["sess1"]["model"] != "current-model" {
t.Errorf("model = %q, want %q", sessions["sess1"]["model"], "current-model")
}
})
t.Run("does not update session with empty model field", func(t *testing.T) {
writeSessionsFile(t, map[string]map[string]any{
"sess1": {"other": "data"},
})
clearSessionModelOverride("new-model")
sessions := readSessionsFile(t)
if _, ok := sessions["sess1"]["model"]; ok {
t.Error("model field should not have been added to session with no model")
}
})
t.Run("handles multiple sessions mixed", func(t *testing.T) {
writeSessionsFile(t, map[string]map[string]any{
"with-override": {"model": "old", "modelOverride": "old", "providerOverride": "ollama"},
"without-override": {"model": "old"},
"already-current": {"model": "new-model"},
"no-model": {"other": "data"},
})
clearSessionModelOverride("new-model")
sessions := readSessionsFile(t)
if sessions["with-override"]["model"] != "new-model" {
t.Errorf("with-override model = %q, want %q", sessions["with-override"]["model"], "new-model")
}
if _, ok := sessions["with-override"]["modelOverride"]; ok {
t.Error("with-override: modelOverride should be deleted")
}
if sessions["without-override"]["model"] != "new-model" {
t.Errorf("without-override model = %q, want %q", sessions["without-override"]["model"], "new-model")
}
if sessions["already-current"]["model"] != "new-model" {
t.Errorf("already-current model = %q, want %q", sessions["already-current"]["model"], "new-model")
}
if _, ok := sessions["no-model"]["model"]; ok {
t.Error("no-model: model should not have been added")
}
})
t.Run("no-op when sessions file missing", func(t *testing.T) {
os.RemoveAll(sessionsDir)
clearSessionModelOverride("new-model") // should not panic or error
})
}

View File

@@ -1,9 +1,7 @@
package config package launch
import ( import (
"context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"maps" "maps"
"os" "os"
@@ -12,34 +10,13 @@ import (
"slices" "slices"
"strings" "strings"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
) )
// OpenCode implements Runner and Editor for OpenCode integration // OpenCode implements Runner and Editor for OpenCode integration
type OpenCode struct{} type OpenCode struct{}
// cloudModelLimit holds context and output token limits for a cloud model.
type cloudModelLimit struct {
Context int
Output int
}
// lookupCloudModelLimit returns the token limits for a cloud model.
// It tries the exact name first, then strips the ":cloud" suffix.
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
if l, ok := cloudModelLimits[name]; ok {
return l, true
}
base := strings.TrimSuffix(name, ":cloud")
if base != name {
if l, ok := cloudModelLimits[base]; ok {
return l, true
}
}
return cloudModelLimit{}, false
}
func (o *OpenCode) String() string { return "OpenCode" } func (o *OpenCode) String() string { return "OpenCode" }
func (o *OpenCode) Run(model string, args []string) error { func (o *OpenCode) Run(model string, args []string) error {
@@ -47,25 +24,6 @@ func (o *OpenCode) Run(model string, args []string) error {
return fmt.Errorf("opencode is not installed, install from https://opencode.ai") return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
} }
// Call Edit() to ensure config is up-to-date before launch
models := []string{model}
if config, err := loadIntegration("opencode"); err == nil && len(config.Models) > 0 {
models = config.Models
}
var err error
models, err = resolveEditorModels("opencode", models, func() ([]string, error) {
return selectModels(context.Background(), "opencode", "")
})
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
if err := o.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("opencode", args...) cmd := exec.Command("opencode", args...)
cmd.Stdin = os.Stdin cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
@@ -122,13 +80,18 @@ func (o *OpenCode) Edit(modelList []string) error {
if !ok { if !ok {
ollama = map[string]any{ ollama = map[string]any{
"npm": "@ai-sdk/openai-compatible", "npm": "@ai-sdk/openai-compatible",
"name": "Ollama (local)", "name": "Ollama",
"options": map[string]any{ "options": map[string]any{
"baseURL": envconfig.Host().String() + "/v1", "baseURL": envconfig.Host().String() + "/v1",
}, },
} }
} }
// Migrate legacy provider name
if name, _ := ollama["name"].(string); name == "Ollama (local)" {
ollama["name"] = "Ollama"
}
models, ok := ollama["models"].(map[string]any) models, ok := ollama["models"].(map[string]any)
if !ok { if !ok {
models = make(map[string]any) models = make(map[string]any)
@@ -147,8 +110,6 @@ func (o *OpenCode) Edit(modelList []string) error {
} }
} }
client, _ := api.ClientFromEnvironment()
for _, model := range modelList { for _, model := range modelList {
if existing, ok := models[model].(map[string]any); ok { if existing, ok := models[model].(map[string]any); ok {
// migrate existing models without _launch marker // migrate existing models without _launch marker
@@ -158,7 +119,7 @@ func (o *OpenCode) Edit(modelList []string) error {
existing["name"] = strings.TrimSuffix(name, " [Ollama]") existing["name"] = strings.TrimSuffix(name, " [Ollama]")
} }
} }
if isCloudModel(context.Background(), client, model) { if isCloudModelName(model) {
if l, ok := lookupCloudModelLimit(model); ok { if l, ok := lookupCloudModelLimit(model); ok {
existing["limit"] = map[string]any{ existing["limit"] = map[string]any{
"context": l.Context, "context": l.Context,
@@ -172,7 +133,7 @@ func (o *OpenCode) Edit(modelList []string) error {
"name": model, "name": model,
"_launch": true, "_launch": true,
} }
if isCloudModel(context.Background(), client, model) { if isCloudModelName(model) {
if l, ok := lookupCloudModelLimit(model); ok { if l, ok := lookupCloudModelLimit(model); ok {
entry["limit"] = map[string]any{ entry["limit"] = map[string]any{
"context": l.Context, "context": l.Context,
@@ -186,12 +147,13 @@ func (o *OpenCode) Edit(modelList []string) error {
ollama["models"] = models ollama["models"] = models
provider["ollama"] = ollama provider["ollama"] = ollama
config["provider"] = provider config["provider"] = provider
config["model"] = "ollama/" + modelList[0]
configData, err := json.MarshalIndent(config, "", " ") configData, err := json.MarshalIndent(config, "", " ")
if err != nil { if err != nil {
return err return err
} }
if err := writeWithBackup(configPath, configData); err != nil { if err := fileutil.WriteWithBackup(configPath, configData); err != nil {
return err return err
} }
@@ -243,7 +205,7 @@ func (o *OpenCode) Edit(modelList []string) error {
if err != nil { if err != nil {
return err return err
} }
return writeWithBackup(statePath, stateData) return fileutil.WriteWithBackup(statePath, stateData)
} }
func (o *OpenCode) Models() []string { func (o *OpenCode) Models() []string {
@@ -251,7 +213,7 @@ func (o *OpenCode) Models() []string {
if err != nil { if err != nil {
return nil return nil
} }
config, err := readJSONFile(filepath.Join(home, ".config", "opencode", "opencode.json")) config, err := fileutil.ReadJSON(filepath.Join(home, ".config", "opencode", "opencode.json"))
if err != nil { if err != nil {
return nil return nil
} }

View File

@@ -1,8 +1,10 @@
package config package launch
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http"
"net/http/httptest"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
@@ -47,6 +49,7 @@ func TestOpenCodeEdit(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
assertOpenCodeModelExists(t, configPath, "llama3.2") assertOpenCodeModelExists(t, configPath, "llama3.2")
assertOpenCodeDefaultModel(t, configPath, "ollama/llama3.2")
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2") assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
}) })
@@ -155,11 +158,13 @@ func TestOpenCodeEdit(t *testing.T) {
o.Edit([]string{"llama3.2", "mistral"}) o.Edit([]string{"llama3.2", "mistral"})
assertOpenCodeModelExists(t, configPath, "llama3.2") assertOpenCodeModelExists(t, configPath, "llama3.2")
assertOpenCodeModelExists(t, configPath, "mistral") assertOpenCodeModelExists(t, configPath, "mistral")
assertOpenCodeDefaultModel(t, configPath, "ollama/llama3.2")
// Then remove one by only selecting the other // Then remove one by only selecting the other
o.Edit([]string{"llama3.2"}) o.Edit([]string{"llama3.2"})
assertOpenCodeModelExists(t, configPath, "llama3.2") assertOpenCodeModelExists(t, configPath, "llama3.2")
assertOpenCodeModelNotExists(t, configPath, "mistral") assertOpenCodeModelNotExists(t, configPath, "mistral")
assertOpenCodeDefaultModel(t, configPath, "ollama/llama3.2")
}) })
t.Run("preserve user customizations on managed models", func(t *testing.T) { t.Run("preserve user customizations on managed models", func(t *testing.T) {
@@ -232,6 +237,44 @@ func TestOpenCodeEdit(t *testing.T) {
} }
}) })
t.Run("migrate Ollama (local) provider name", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"name":"Ollama (local)","npm":"@ai-sdk/openai-compatible","options":{"baseURL":"http://localhost:11434/v1"}}}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider := cfg["provider"].(map[string]any)
ollama := provider["ollama"].(map[string]any)
if ollama["name"] != "Ollama" {
t.Errorf("provider name not migrated: got %q, want %q", ollama["name"], "Ollama")
}
})
t.Run("preserve custom provider name", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"name":"My Custom Ollama","npm":"@ai-sdk/openai-compatible","options":{"baseURL":"http://localhost:11434/v1"}}}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider := cfg["provider"].(map[string]any)
ollama := provider["ollama"].(map[string]any)
if ollama["name"] != "My Custom Ollama" {
t.Errorf("custom provider name was changed: got %q, want %q", ollama["name"], "My Custom Ollama")
}
})
t.Run("remove model preserves non-ollama models", func(t *testing.T) { t.Run("remove model preserves non-ollama models", func(t *testing.T) {
cleanup() cleanup()
os.MkdirAll(configDir, 0o755) os.MkdirAll(configDir, 0o755)
@@ -298,6 +341,22 @@ func assertOpenCodeModelNotExists(t *testing.T, path, model string) {
} }
} }
func assertOpenCodeDefaultModel(t *testing.T, path, want string) {
t.Helper()
data, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatal(err)
}
got, _ := cfg["model"].(string)
if got != want {
t.Fatalf("default model = %q, want %q", got, want)
}
}
func assertOpenCodeRecentModel(t *testing.T, path string, index int, providerID, modelID string) { func assertOpenCodeRecentModel(t *testing.T, path string, index int, providerID, modelID string) {
t.Helper() t.Helper()
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
@@ -619,6 +678,54 @@ func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
} }
} }
func TestOpenCodeEdit_BackfillsCloudModelLimitOnExistingEntry(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"glm-5"}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{
"provider": {
"ollama": {
"models": {
"glm-5:cloud": {
"name": "glm-5:cloud",
"_launch": true
}
}
}
}
}`), 0o644)
if err := o.Edit([]string{"glm-5:cloud"}); err != nil {
t.Fatal(err)
}
entry := readOpenCodeModel(t, configPath, "glm-5:cloud")
limit, ok := entry["limit"].(map[string]any)
if !ok {
t.Fatal("cloud model limit was not added on re-edit")
}
if limit["context"] != float64(202_752) {
t.Errorf("context = %v, want 202752", limit["context"])
}
if limit["output"] != float64(131_072) {
t.Errorf("output = %v, want 131072", limit["output"])
}
}
func TestLookupCloudModelLimit(t *testing.T) { func TestLookupCloudModelLimit(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -626,13 +733,19 @@ func TestLookupCloudModelLimit(t *testing.T) {
wantContext int wantContext int
wantOutput int wantOutput int
}{ }{
{"glm-4.7", true, 202_752, 131_072}, {"glm-4.7", false, 0, 0},
{"glm-4.7:cloud", true, 202_752, 131_072}, {"glm-4.7:cloud", true, 202_752, 131_072},
{"kimi-k2.5", true, 262_144, 262_144}, {"glm-5:cloud", true, 202_752, 131_072},
{"gpt-oss:120b-cloud", true, 131_072, 131_072},
{"gpt-oss:20b-cloud", true, 131_072, 131_072},
{"kimi-k2.5", false, 0, 0},
{"kimi-k2.5:cloud", true, 262_144, 262_144}, {"kimi-k2.5:cloud", true, 262_144, 262_144},
{"deepseek-v3.2", true, 163_840, 65_536}, {"deepseek-v3.2", false, 0, 0},
{"deepseek-v3.2:cloud", true, 163_840, 65_536}, {"deepseek-v3.2:cloud", true, 163_840, 65_536},
{"qwen3-coder:480b", true, 262_144, 65_536}, {"qwen3.5", false, 0, 0},
{"qwen3.5:cloud", true, 262_144, 32_768},
{"qwen3-coder:480b", false, 0, 0},
{"qwen3-coder:480b:cloud", true, 262_144, 65_536},
{"qwen3-coder-next:cloud", true, 262_144, 32_768}, {"qwen3-coder-next:cloud", true, 262_144, 32_768},
{"llama3.2", false, 0, 0}, {"llama3.2", false, 0, 0},
{"unknown-model:cloud", false, 0, 0}, {"unknown-model:cloud", false, 0, 0},

View File

@@ -1,4 +1,4 @@
package config package launch
import ( import (
"context" "context"
@@ -12,6 +12,7 @@ import (
"strings" "strings"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
@@ -19,29 +20,151 @@ import (
// Pi implements Runner and Editor for Pi (Pi Coding Agent) integration // Pi implements Runner and Editor for Pi (Pi Coding Agent) integration
type Pi struct{} type Pi struct{}
const (
piNpmPackage = "@mariozechner/pi-coding-agent"
piWebSearchSource = "npm:@ollama/pi-web-search"
piWebSearchPkg = "@ollama/pi-web-search"
)
func (p *Pi) String() string { return "Pi" } func (p *Pi) String() string { return "Pi" }
func (p *Pi) Run(model string, args []string) error { func (p *Pi) Run(model string, args []string) error {
if _, err := exec.LookPath("pi"); err != nil { fmt.Fprintf(os.Stderr, "\n%sPreparing Pi...%s\n", ansiGray, ansiReset)
return fmt.Errorf("pi is not installed, install with: npm install -g @mariozechner/pi-coding-agent") if err := ensureNpmInstalled(); err != nil {
return err
} }
// Call Edit() to ensure config is up-to-date before launch fmt.Fprintf(os.Stderr, "%sChecking Pi installation...%s\n", ansiGray, ansiReset)
models := []string{model} bin, err := ensurePiInstalled()
if config, err := loadIntegration("pi"); err == nil && len(config.Models) > 0 { if err != nil {
models = config.Models return err
}
if err := p.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
} }
cmd := exec.Command("pi", args...) ensurePiWebSearchPackage(bin)
fmt.Fprintf(os.Stderr, "\n%sLaunching Pi...%s\n\n", ansiGray, ansiReset)
cmd := exec.Command(bin, args...)
cmd.Stdin = os.Stdin cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
return cmd.Run() return cmd.Run()
} }
func ensureNpmInstalled() error {
if _, err := exec.LookPath("npm"); err != nil {
return fmt.Errorf("npm (Node.js) is required to launch pi\n\nInstall it first:\n https://nodejs.org/")
}
return nil
}
func ensurePiInstalled() (string, error) {
if _, err := exec.LookPath("pi"); err == nil {
return "pi", nil
}
if _, err := exec.LookPath("npm"); err != nil {
return "", fmt.Errorf("pi is not installed and required dependencies are missing\n\nInstall the following first:\n npm (Node.js): https://nodejs.org/")
}
ok, err := ConfirmPrompt("Pi is not installed. Install with npm?")
if err != nil {
return "", err
}
if !ok {
return "", fmt.Errorf("pi installation cancelled")
}
fmt.Fprintf(os.Stderr, "\nInstalling Pi...\n")
cmd := exec.Command("npm", "install", "-g", piNpmPackage+"@latest")
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("failed to install pi: %w", err)
}
if _, err := exec.LookPath("pi"); err != nil {
return "", fmt.Errorf("pi was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
}
fmt.Fprintf(os.Stderr, "%sPi installed successfully%s\n\n", ansiGreen, ansiReset)
return "pi", nil
}
func ensurePiWebSearchPackage(bin string) {
if !shouldManagePiWebSearch() {
fmt.Fprintf(os.Stderr, "%sCloud is disabled; skipping %s setup.%s\n", ansiGray, piWebSearchPkg, ansiReset)
return
}
fmt.Fprintf(os.Stderr, "%sChecking Pi web search package...%s\n", ansiGray, ansiReset)
installed, err := piPackageInstalled(bin, piWebSearchSource)
if err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not check %s installation: %v%s\n", ansiYellow, piWebSearchPkg, err, ansiReset)
return
}
if !installed {
fmt.Fprintf(os.Stderr, "%sInstalling %s...%s\n", ansiGray, piWebSearchPkg, ansiReset)
cmd := exec.Command(bin, "install", piWebSearchSource)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not install %s: %v%s\n", ansiYellow, piWebSearchPkg, err, ansiReset)
return
}
fmt.Fprintf(os.Stderr, "%s ✓ Installed %s%s\n", ansiGreen, piWebSearchPkg, ansiReset)
return
}
fmt.Fprintf(os.Stderr, "%sUpdating %s...%s\n", ansiGray, piWebSearchPkg, ansiReset)
cmd := exec.Command(bin, "update", piWebSearchSource)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not update %s: %v%s\n", ansiYellow, piWebSearchPkg, err, ansiReset)
return
}
fmt.Fprintf(os.Stderr, "%s ✓ Updated %s%s\n", ansiGreen, piWebSearchPkg, ansiReset)
}
func shouldManagePiWebSearch() bool {
client, err := api.ClientFromEnvironment()
if err != nil {
return true
}
disabled, known := cloudStatusDisabled(context.Background(), client)
if known && disabled {
return false
}
return true
}
func piPackageInstalled(bin, source string) (bool, error) {
cmd := exec.Command(bin, "list")
out, err := cmd.CombinedOutput()
if err != nil {
msg := strings.TrimSpace(string(out))
if msg == "" {
return false, err
}
return false, fmt.Errorf("%w: %s", err, msg)
}
for _, line := range strings.Split(string(out), "\n") {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, source) {
return true, nil
}
}
return false, nil
}
func (p *Pi) Paths() []string { func (p *Pi) Paths() []string {
home, err := os.UserHomeDir() home, err := os.UserHomeDir()
if err != nil { if err != nil {
@@ -107,7 +230,8 @@ func (p *Pi) Edit(models []string) error {
// Build new models list: // Build new models list:
// 1. Keep user-managed models (no _launch marker) - untouched // 1. Keep user-managed models (no _launch marker) - untouched
// 2. Keep ollama-managed models (_launch marker) that are still selected // 2. Keep ollama-managed models (_launch marker) that are still selected,
// except stale cloud entries that should be rebuilt below
// 3. Add new ollama-managed models // 3. Add new ollama-managed models
var newModels []any var newModels []any
for _, m := range existingModels { for _, m := range existingModels {
@@ -117,7 +241,13 @@ func (p *Pi) Edit(models []string) error {
if !isPiOllamaModel(modelObj) { if !isPiOllamaModel(modelObj) {
newModels = append(newModels, m) newModels = append(newModels, m)
} else if selectedSet[id] { } else if selectedSet[id] {
// Ollama-managed and still selected - keep it // Rebuild stale managed cloud entries so createConfig refreshes
// the whole entry instead of patching it in place.
if !hasContextWindow(modelObj) {
if _, ok := lookupCloudModelLimit(id); ok {
continue
}
}
newModels = append(newModels, m) newModels = append(newModels, m)
selectedSet[id] = false selectedSet[id] = false
} }
@@ -142,7 +272,7 @@ func (p *Pi) Edit(models []string) error {
if err != nil { if err != nil {
return err return err
} }
if err := writeWithBackup(configPath, configData); err != nil { if err := fileutil.WriteWithBackup(configPath, configData); err != nil {
return err return err
} }
@@ -160,7 +290,7 @@ func (p *Pi) Edit(models []string) error {
if err != nil { if err != nil {
return err return err
} }
return writeWithBackup(settingsPath, settingsData) return fileutil.WriteWithBackup(settingsPath, settingsData)
} }
func (p *Pi) Models() []string { func (p *Pi) Models() []string {
@@ -170,7 +300,7 @@ func (p *Pi) Models() []string {
} }
configPath := filepath.Join(home, ".pi", "agent", "models.json") configPath := filepath.Join(home, ".pi", "agent", "models.json")
config, err := readJSONFile(configPath) config, err := fileutil.ReadJSON(configPath)
if err != nil { if err != nil {
return nil return nil
} }
@@ -199,15 +329,38 @@ func isPiOllamaModel(cfg map[string]any) bool {
return false return false
} }
func hasContextWindow(cfg map[string]any) bool {
switch v := cfg["contextWindow"].(type) {
case float64:
return v > 0
case int:
return v > 0
case int64:
return v > 0
default:
return false
}
}
// createConfig builds Pi model config with capability detection // createConfig builds Pi model config with capability detection
func createConfig(ctx context.Context, client *api.Client, modelID string) map[string]any { func createConfig(ctx context.Context, client *api.Client, modelID string) map[string]any {
cfg := map[string]any{ cfg := map[string]any{
"id": modelID, "id": modelID,
"_launch": true, "_launch": true,
} }
if l, ok := lookupCloudModelLimit(modelID); ok {
cfg["contextWindow"] = l.Context
}
applyCloudContextFallback := func() {
if l, ok := lookupCloudModelLimit(modelID); ok {
cfg["contextWindow"] = l.Context
}
}
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID}) resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID})
if err != nil { if err != nil {
applyCloudContextFallback()
return cfg return cfg
} }
@@ -223,15 +376,21 @@ func createConfig(ctx context.Context, client *api.Client, modelID string) map[s
cfg["reasoning"] = true cfg["reasoning"] = true
} }
// Extract context window from ModelInfo // Extract context window from ModelInfo. For known cloud models, the
// pre-filled shared limit remains unless the server provides a positive value.
hasContextWindow := false
for key, val := range resp.ModelInfo { for key, val := range resp.ModelInfo {
if strings.HasSuffix(key, ".context_length") { if strings.HasSuffix(key, ".context_length") {
if ctxLen, ok := val.(float64); ok && ctxLen > 0 { if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
cfg["contextWindow"] = int(ctxLen) cfg["contextWindow"] = int(ctxLen)
hasContextWindow = true
} }
break break
} }
} }
if !hasContextWindow {
applyCloudContextFallback()
}
return cfg return cfg
} }

View File

@@ -1,4 +1,4 @@
package config package launch
import ( import (
"context" "context"
@@ -9,6 +9,8 @@ import (
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"strings"
"testing" "testing"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
@@ -33,6 +35,339 @@ func TestPiIntegration(t *testing.T) {
}) })
} }
func TestPiRun_InstallAndWebSearchLifecycle(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell test binaries")
}
writeScript := func(t *testing.T, path, content string) {
t.Helper()
if err := os.WriteFile(path, []byte(content), 0o755); err != nil {
t.Fatal(err)
}
}
seedPiScript := func(t *testing.T, dir string) {
t.Helper()
piPath := filepath.Join(dir, "pi")
listPath := filepath.Join(dir, "pi-list.txt")
piScript := fmt.Sprintf(`#!/bin/sh
echo "$@" >> %q
if [ "$1" = "list" ]; then
if [ -f %q ]; then
/bin/cat %q
fi
exit 0
fi
if [ "$1" = "update" ] && [ "$PI_FAIL_UPDATE" = "1" ]; then
echo "update failed" >&2
exit 1
fi
if [ "$1" = "install" ] && [ "$PI_FAIL_INSTALL" = "1" ]; then
echo "install failed" >&2
exit 1
fi
exit 0
`, filepath.Join(dir, "pi.log"), listPath, listPath)
writeScript(t, piPath, piScript)
}
seedNpmNoop := func(t *testing.T, dir string) {
t.Helper()
writeScript(t, filepath.Join(dir, "npm"), "#!/bin/sh\nexit 0\n")
}
withConfirm := func(t *testing.T, fn func(prompt string) (bool, error)) {
t.Helper()
oldConfirm := DefaultConfirmPrompt
DefaultConfirmPrompt = fn
t.Cleanup(func() { DefaultConfirmPrompt = oldConfirm })
}
setCloudStatus := func(t *testing.T, disabled bool) {
t.Helper()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/status" {
fmt.Fprintf(w, `{"cloud":{"disabled":%t,"source":"config"}}`, disabled)
return
}
http.NotFound(w, r)
}))
t.Cleanup(srv.Close)
t.Setenv("OLLAMA_HOST", srv.URL)
}
t.Run("pi missing + user accepts install", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
setCloudStatus(t, false)
if err := os.WriteFile(filepath.Join(tmpDir, "pi-list.txt"), []byte("User packages:\n npm:@ollama/pi-web-search\n"), 0o644); err != nil {
t.Fatal(err)
}
npmScript := fmt.Sprintf(`#!/bin/sh
echo "$@" >> %q
if [ "$1" = "install" ] && [ "$2" = "-g" ] && [ "$3" = %q ]; then
/bin/cat > %q <<'EOS'
#!/bin/sh
echo "$@" >> %q
if [ "$1" = "list" ]; then
if [ -f %q ]; then
/bin/cat %q
fi
exit 0
fi
exit 0
EOS
/bin/chmod +x %q
fi
exit 0
`, filepath.Join(tmpDir, "npm.log"), piNpmPackage+"@latest", filepath.Join(tmpDir, "pi"), filepath.Join(tmpDir, "pi.log"), filepath.Join(tmpDir, "pi-list.txt"), filepath.Join(tmpDir, "pi-list.txt"), filepath.Join(tmpDir, "pi"))
writeScript(t, filepath.Join(tmpDir, "npm"), npmScript)
withConfirm(t, func(prompt string) (bool, error) {
if strings.Contains(prompt, "Pi is not installed.") {
return true, nil
}
return true, nil
})
p := &Pi{}
if err := p.Run("ignored", []string{"--version"}); err != nil {
t.Fatalf("Run() error = %v", err)
}
npmCalls, err := os.ReadFile(filepath.Join(tmpDir, "npm.log"))
if err != nil {
t.Fatal(err)
}
if !strings.Contains(string(npmCalls), "install -g "+piNpmPackage+"@latest") {
t.Fatalf("expected npm install call, got:\n%s", npmCalls)
}
piCalls, err := os.ReadFile(filepath.Join(tmpDir, "pi.log"))
if err != nil {
t.Fatal(err)
}
got := string(piCalls)
if !strings.Contains(got, "list\n") {
t.Fatalf("expected pi list call, got:\n%s", got)
}
if !strings.Contains(got, "update "+piWebSearchSource+"\n") {
t.Fatalf("expected pi update call, got:\n%s", got)
}
if !strings.Contains(got, "--version\n") {
t.Fatalf("expected final pi launch call, got:\n%s", got)
}
})
t.Run("pi missing + user declines install", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
setCloudStatus(t, false)
writeScript(t, filepath.Join(tmpDir, "npm"), "#!/bin/sh\nexit 0\n")
withConfirm(t, func(prompt string) (bool, error) {
if strings.Contains(prompt, "Pi is not installed.") {
return false, nil
}
return true, nil
})
p := &Pi{}
err := p.Run("ignored", nil)
if err == nil || !strings.Contains(err.Error(), "pi installation cancelled") {
t.Fatalf("expected install cancellation error, got %v", err)
}
})
t.Run("pi installed + web search missing auto-installs", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
setCloudStatus(t, false)
if err := os.WriteFile(filepath.Join(tmpDir, "pi-list.txt"), []byte("User packages:\n"), 0o644); err != nil {
t.Fatal(err)
}
seedPiScript(t, tmpDir)
seedNpmNoop(t, tmpDir)
withConfirm(t, func(prompt string) (bool, error) {
t.Fatalf("did not expect confirmation prompt, got %q", prompt)
return false, nil
})
p := &Pi{}
if err := p.Run("ignored", []string{"session"}); err != nil {
t.Fatalf("Run() error = %v", err)
}
piCalls, err := os.ReadFile(filepath.Join(tmpDir, "pi.log"))
if err != nil {
t.Fatal(err)
}
got := string(piCalls)
if !strings.Contains(got, "list\n") {
t.Fatalf("expected pi list call, got:\n%s", got)
}
if !strings.Contains(got, "install "+piWebSearchSource+"\n") {
t.Fatalf("expected pi install call, got:\n%s", got)
}
if strings.Contains(got, "update "+piWebSearchSource+"\n") {
t.Fatalf("did not expect pi update call when package missing, got:\n%s", got)
}
if !strings.Contains(got, "session\n") {
t.Fatalf("expected final pi launch call, got:\n%s", got)
}
})
t.Run("pi installed + web search present updates every launch", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
setCloudStatus(t, false)
if err := os.WriteFile(filepath.Join(tmpDir, "pi-list.txt"), []byte("User packages:\n "+piWebSearchSource+"\n"), 0o644); err != nil {
t.Fatal(err)
}
seedPiScript(t, tmpDir)
seedNpmNoop(t, tmpDir)
p := &Pi{}
if err := p.Run("ignored", []string{"doctor"}); err != nil {
t.Fatalf("Run() error = %v", err)
}
piCalls, err := os.ReadFile(filepath.Join(tmpDir, "pi.log"))
if err != nil {
t.Fatal(err)
}
got := string(piCalls)
if !strings.Contains(got, "update "+piWebSearchSource+"\n") {
t.Fatalf("expected pi update call, got:\n%s", got)
}
})
t.Run("web search update failure warns and continues", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
setCloudStatus(t, false)
t.Setenv("PI_FAIL_UPDATE", "1")
if err := os.WriteFile(filepath.Join(tmpDir, "pi-list.txt"), []byte("User packages:\n "+piWebSearchSource+"\n"), 0o644); err != nil {
t.Fatal(err)
}
seedPiScript(t, tmpDir)
seedNpmNoop(t, tmpDir)
p := &Pi{}
stderr := captureStderr(t, func() {
if err := p.Run("ignored", []string{"session"}); err != nil {
t.Fatalf("Run() should continue after web search update failure, got %v", err)
}
})
if !strings.Contains(stderr, "Warning: could not update "+piWebSearchPkg) {
t.Fatalf("expected update warning, got:\n%s", stderr)
}
piCalls, err := os.ReadFile(filepath.Join(tmpDir, "pi.log"))
if err != nil {
t.Fatal(err)
}
if !strings.Contains(string(piCalls), "session\n") {
t.Fatalf("expected final pi launch call, got:\n%s", piCalls)
}
})
t.Run("web search install failure warns and continues", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
setCloudStatus(t, false)
t.Setenv("PI_FAIL_INSTALL", "1")
if err := os.WriteFile(filepath.Join(tmpDir, "pi-list.txt"), []byte("User packages:\n"), 0o644); err != nil {
t.Fatal(err)
}
seedPiScript(t, tmpDir)
seedNpmNoop(t, tmpDir)
withConfirm(t, func(prompt string) (bool, error) {
t.Fatalf("did not expect confirmation prompt, got %q", prompt)
return false, nil
})
p := &Pi{}
stderr := captureStderr(t, func() {
if err := p.Run("ignored", []string{"session"}); err != nil {
t.Fatalf("Run() should continue after web search install failure, got %v", err)
}
})
if !strings.Contains(stderr, "Warning: could not install "+piWebSearchPkg) {
t.Fatalf("expected install warning, got:\n%s", stderr)
}
piCalls, err := os.ReadFile(filepath.Join(tmpDir, "pi.log"))
if err != nil {
t.Fatal(err)
}
if !strings.Contains(string(piCalls), "session\n") {
t.Fatalf("expected final pi launch call, got:\n%s", piCalls)
}
})
t.Run("cloud disabled skips web search package management", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
setCloudStatus(t, true)
if err := os.WriteFile(filepath.Join(tmpDir, "pi-list.txt"), []byte("User packages:\n"), 0o644); err != nil {
t.Fatal(err)
}
seedPiScript(t, tmpDir)
seedNpmNoop(t, tmpDir)
p := &Pi{}
stderr := captureStderr(t, func() {
if err := p.Run("ignored", []string{"session"}); err != nil {
t.Fatalf("Run() error = %v", err)
}
})
if !strings.Contains(stderr, "Cloud is disabled; skipping "+piWebSearchPkg+" setup.") {
t.Fatalf("expected cloud-disabled skip message, got:\n%s", stderr)
}
piCalls, err := os.ReadFile(filepath.Join(tmpDir, "pi.log"))
if err != nil {
t.Fatal(err)
}
got := string(piCalls)
if strings.Contains(got, "list\n") || strings.Contains(got, "install "+piWebSearchSource+"\n") || strings.Contains(got, "update "+piWebSearchSource+"\n") {
t.Fatalf("did not expect web search package management calls, got:\n%s", got)
}
if !strings.Contains(got, "session\n") {
t.Fatalf("expected final pi launch call, got:\n%s", got)
}
})
t.Run("missing npm returns error before pi flow", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
setCloudStatus(t, false)
seedPiScript(t, tmpDir)
p := &Pi{}
err := p.Run("ignored", []string{"session"})
if err == nil || !strings.Contains(err.Error(), "npm (Node.js) is required to launch pi") {
t.Fatalf("expected missing npm error, got %v", err)
}
if _, statErr := os.Stat(filepath.Join(tmpDir, "pi.log")); !os.IsNotExist(statErr) {
t.Fatalf("expected pi not to run when npm is missing, stat err = %v", statErr)
}
})
}
func TestPiPaths(t *testing.T) { func TestPiPaths(t *testing.T) {
pi := &Pi{} pi := &Pi{}
@@ -192,6 +527,48 @@ func TestPiEdit(t *testing.T) {
} }
}) })
t.Run("rebuilds stale existing managed cloud model", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://localhost:11434/v1",
"api": "openai-completions",
"apiKey": "ollama",
"models": [
{"id": "glm-5:cloud", "_launch": true, "legacyField": "stale"}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
if err := pi.Edit([]string{"glm-5:cloud"}); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
modelEntry := modelsArray[0].(map[string]any)
if modelEntry["contextWindow"] != float64(202_752) {
t.Errorf("contextWindow = %v, want 202752", modelEntry["contextWindow"])
}
input, ok := modelEntry["input"].([]any)
if !ok || len(input) != 1 || input[0] != "text" {
t.Errorf("input = %v, want [text]", modelEntry["input"])
}
if _, ok := modelEntry["legacyField"]; ok {
t.Error("legacyField should be removed when stale managed cloud entry is rebuilt")
}
})
t.Run("replaces old models with new ones", func(t *testing.T) { t.Run("replaces old models with new ones", func(t *testing.T) {
cleanup() cleanup()
os.MkdirAll(configDir, 0o755) os.MkdirAll(configDir, 0o755)
@@ -798,6 +1175,59 @@ func TestCreateConfig(t *testing.T) {
} }
}) })
t.Run("cloud model falls back to hardcoded context when show fails", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"model not found"}`)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg := createConfig(context.Background(), client, "kimi-k2.5:cloud")
if cfg["contextWindow"] != 262_144 {
t.Errorf("contextWindow = %v, want 262144", cfg["contextWindow"])
}
})
t.Run("cloud model falls back to hardcoded context when show omits model info", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":[],"model_info":{}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg := createConfig(context.Background(), client, "glm-5:cloud")
if cfg["contextWindow"] != 202_752 {
t.Errorf("contextWindow = %v, want 202752", cfg["contextWindow"])
}
})
t.Run("cloud model with dash suffix falls back to hardcoded context", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"model not found"}`)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg := createConfig(context.Background(), client, "gpt-oss:120b-cloud")
if cfg["contextWindow"] != 131_072 {
t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
}
})
t.Run("skips zero context length", func(t *testing.T) { t.Run("skips zero context length", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" { if r.URL.Path == "/api/show" {

373
cmd/launch/registry.go Normal file
View File

@@ -0,0 +1,373 @@
package launch
import (
"fmt"
"os"
"os/exec"
"slices"
"strings"
)
// IntegrationInstallSpec describes how launcher should detect and guide installation.
type IntegrationInstallSpec struct {
CheckInstalled func() bool
EnsureInstalled func() error
URL string
Command []string
}
// IntegrationSpec is the canonical registry entry for one integration.
type IntegrationSpec struct {
Name string
Runner Runner
Aliases []string
Hidden bool
Description string
Install IntegrationInstallSpec
}
// IntegrationInfo contains display information about a registered integration.
type IntegrationInfo struct {
Name string
DisplayName string
Description string
}
var launcherIntegrationOrder = []string{"opencode", "droid", "pi"}
var integrationSpecs = []*IntegrationSpec{
{
Name: "claude",
Runner: &Claude{},
Description: "Anthropic's coding tool with subagents",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := (&Claude{}).findPath()
return err == nil
},
URL: "https://code.claude.com/docs/en/quickstart",
},
},
{
Name: "cline",
Runner: &Cline{},
Description: "Autonomous coding agent with parallel execution",
Hidden: true,
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := exec.LookPath("cline")
return err == nil
},
Command: []string{"npm", "install", "-g", "cline"},
},
},
{
Name: "codex",
Runner: &Codex{},
Description: "OpenAI's open-source coding agent",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := exec.LookPath("codex")
return err == nil
},
URL: "https://developers.openai.com/codex/cli/",
Command: []string{"npm", "install", "-g", "@openai/codex"},
},
},
{
Name: "droid",
Runner: &Droid{},
Description: "Factory's coding agent across terminal and IDEs",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := exec.LookPath("droid")
return err == nil
},
URL: "https://docs.factory.ai/cli/getting-started/quickstart",
},
},
{
Name: "opencode",
Runner: &OpenCode{},
Description: "Anomaly's open-source coding agent",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := exec.LookPath("opencode")
return err == nil
},
URL: "https://opencode.ai",
},
},
{
Name: "openclaw",
Runner: &Openclaw{},
Aliases: []string{"clawdbot", "moltbot"},
Description: "Personal AI with 100+ skills",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
if _, err := exec.LookPath("openclaw"); err == nil {
return true
}
if _, err := exec.LookPath("clawdbot"); err == nil {
return true
}
return false
},
EnsureInstalled: func() error {
_, err := ensureOpenclawInstalled()
return err
},
URL: "https://docs.openclaw.ai",
},
},
{
Name: "pi",
Runner: &Pi{},
Description: "Minimal AI agent toolkit with plugin support",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := exec.LookPath("pi")
return err == nil
},
EnsureInstalled: func() error {
_, err := ensurePiInstalled()
return err
},
Command: []string{"npm", "install", "-g", "@mariozechner/pi-coding-agent@latest"},
},
},
{
Name: "vscode",
Runner: &VSCode{},
Aliases: []string{"code"},
Description: "Microsoft's open-source AI code editor",
Hidden: true,
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
return (&VSCode{}).findBinary() != ""
},
URL: "https://code.visualstudio.com",
},
},
}
var integrationSpecsByName map[string]*IntegrationSpec
func init() {
rebuildIntegrationSpecIndexes()
}
func hyperlink(url, text string) string {
return fmt.Sprintf("\033]8;;%s\033\\%s\033]8;;\033\\", url, text)
}
func rebuildIntegrationSpecIndexes() {
integrationSpecsByName = make(map[string]*IntegrationSpec, len(integrationSpecs))
canonical := make(map[string]bool, len(integrationSpecs))
for _, spec := range integrationSpecs {
key := strings.ToLower(spec.Name)
if key == "" {
panic("launch: integration spec missing name")
}
if canonical[key] {
panic(fmt.Sprintf("launch: duplicate integration name %q", key))
}
canonical[key] = true
integrationSpecsByName[key] = spec
}
seenAliases := make(map[string]string)
for _, spec := range integrationSpecs {
for _, alias := range spec.Aliases {
key := strings.ToLower(alias)
if key == "" {
panic(fmt.Sprintf("launch: integration %q has empty alias", spec.Name))
}
if canonical[key] {
panic(fmt.Sprintf("launch: alias %q collides with canonical integration name", key))
}
if owner, exists := seenAliases[key]; exists {
panic(fmt.Sprintf("launch: alias %q collides between %q and %q", key, owner, spec.Name))
}
seenAliases[key] = spec.Name
integrationSpecsByName[key] = spec
}
}
orderSeen := make(map[string]bool, len(launcherIntegrationOrder))
for _, name := range launcherIntegrationOrder {
key := strings.ToLower(name)
if orderSeen[key] {
panic(fmt.Sprintf("launch: duplicate launcher order entry %q", key))
}
orderSeen[key] = true
spec, ok := integrationSpecsByName[key]
if !ok {
panic(fmt.Sprintf("launch: unknown launcher order entry %q", key))
}
if spec.Name != key {
panic(fmt.Sprintf("launch: launcher order entry %q must use canonical name, not alias", key))
}
if spec.Hidden {
panic(fmt.Sprintf("launch: hidden integration %q cannot appear in launcher order", key))
}
}
}
// LookupIntegrationSpec resolves either a canonical integration name or alias to its spec.
func LookupIntegrationSpec(name string) (*IntegrationSpec, error) {
spec, ok := integrationSpecsByName[strings.ToLower(name)]
if !ok {
return nil, fmt.Errorf("unknown integration: %s", name)
}
return spec, nil
}
// LookupIntegration resolves a registry name to the canonical key and runner.
func LookupIntegration(name string) (string, Runner, error) {
spec, err := LookupIntegrationSpec(name)
if err != nil {
return "", nil, err
}
return spec.Name, spec.Runner, nil
}
// ListVisibleIntegrationSpecs returns the canonical integrations that should appear in interactive UIs.
func ListVisibleIntegrationSpecs() []IntegrationSpec {
visible := make([]IntegrationSpec, 0, len(integrationSpecs))
for _, spec := range integrationSpecs {
if spec.Hidden {
continue
}
visible = append(visible, *spec)
}
orderRank := make(map[string]int, len(launcherIntegrationOrder))
for i, name := range launcherIntegrationOrder {
orderRank[name] = i + 1
}
slices.SortFunc(visible, func(a, b IntegrationSpec) int {
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
if aRank > 0 && bRank > 0 {
return aRank - bRank
}
if aRank > 0 {
return 1
}
if bRank > 0 {
return -1
}
return strings.Compare(a.Name, b.Name)
})
return visible
}
// ListIntegrationInfos returns the registered integrations in launcher display order.
func ListIntegrationInfos() []IntegrationInfo {
visible := ListVisibleIntegrationSpecs()
infos := make([]IntegrationInfo, 0, len(visible))
for _, spec := range visible {
infos = append(infos, IntegrationInfo{
Name: spec.Name,
DisplayName: spec.Runner.String(),
Description: spec.Description,
})
}
return infos
}
// IntegrationSelectionItems returns the sorted integration items shown by launcher selection UIs.
func IntegrationSelectionItems() ([]ModelItem, error) {
visible := ListVisibleIntegrationSpecs()
if len(visible) == 0 {
return nil, fmt.Errorf("no integrations available")
}
items := make([]ModelItem, 0, len(visible))
for _, spec := range visible {
description := spec.Runner.String()
if conn, err := loadStoredIntegrationConfig(spec.Name); err == nil && len(conn.Models) > 0 {
description = fmt.Sprintf("%s (%s)", spec.Runner.String(), conn.Models[0])
}
items = append(items, ModelItem{Name: spec.Name, Description: description})
}
return items, nil
}
// IsIntegrationInstalled checks if an integration binary is installed.
func IsIntegrationInstalled(name string) bool {
integration, err := integrationFor(name)
if err != nil {
fmt.Fprintf(os.Stderr, "Ollama couldn't find integration %q, so it'll show up as not installed.\n", name)
return false
}
return integration.installed
}
// integration is resolved registry metadata used by launcher state and install checks.
// It combines immutable registry spec data with computed runtime traits.
type integration struct {
spec *IntegrationSpec
installed bool
autoInstallable bool
editor bool
installHint string
}
// integrationFor resolves an integration name into the canonical spec plus
// derived launcher/install traits used across registry and launch flows.
func integrationFor(name string) (integration, error) {
spec, err := LookupIntegrationSpec(name)
if err != nil {
return integration{}, err
}
installed := true
if spec.Install.CheckInstalled != nil {
installed = spec.Install.CheckInstalled()
}
_, editor := spec.Runner.(Editor)
hint := ""
if spec.Install.URL != "" {
hint = "Install from " + hyperlink(spec.Install.URL, spec.Install.URL)
} else if len(spec.Install.Command) > 0 {
hint = "Install with: " + strings.Join(spec.Install.Command, " ")
}
return integration{
spec: spec,
installed: installed,
autoInstallable: spec.Install.EnsureInstalled != nil,
editor: editor,
installHint: hint,
}, nil
}
// EnsureIntegrationInstalled installs auto-installable integrations when missing.
func EnsureIntegrationInstalled(name string, runner Runner) error {
integration, err := integrationFor(name)
if err != nil {
return fmt.Errorf("%s is not installed", runner)
}
if integration.installed {
return nil
}
if integration.autoInstallable {
return integration.spec.Install.EnsureInstalled()
}
switch {
case integration.spec.Install.URL != "":
return fmt.Errorf("%s is not installed, install from %s", integration.spec.Name, integration.spec.Install.URL)
case len(integration.spec.Install.Command) > 0:
return fmt.Errorf("%s is not installed, install with: %s", integration.spec.Name, strings.Join(integration.spec.Install.Command, " "))
default:
return fmt.Errorf("%s is not installed", runner)
}
}

View File

@@ -0,0 +1,21 @@
package launch
import "strings"
// OverrideIntegration replaces one registry entry's runner for tests and returns a restore function.
func OverrideIntegration(name string, runner Runner) func() {
spec, err := LookupIntegrationSpec(name)
if err != nil {
key := strings.ToLower(name)
integrationSpecsByName[key] = &IntegrationSpec{Name: key, Runner: runner}
return func() {
delete(integrationSpecsByName, key)
}
}
original := spec.Runner
spec.Runner = runner
return func() {
spec.Runner = original
}
}

View File

@@ -0,0 +1,71 @@
package launch
import (
"os"
"path/filepath"
"testing"
)
func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
tests := []struct {
name string
binary string
runner Runner
checkPath func(home string) string
}{
{
name: "droid",
binary: "droid",
runner: &Droid{},
checkPath: func(home string) string {
return filepath.Join(home, ".factory", "settings.json")
},
},
{
name: "opencode",
binary: "opencode",
runner: &OpenCode{},
checkPath: func(home string) string {
return filepath.Join(home, ".config", "opencode", "opencode.json")
},
},
{
name: "cline",
binary: "cline",
runner: &Cline{},
checkPath: func(home string) string {
return filepath.Join(home, ".cline", "data", "globalState.json")
},
},
{
name: "pi",
binary: "pi",
runner: &Pi{},
checkPath: func(home string) string {
return filepath.Join(home, ".pi", "agent", "models.json")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
home := t.TempDir()
setTestHome(t, home)
binDir := t.TempDir()
writeFakeBinary(t, binDir, tt.binary)
if tt.name == "pi" {
writeFakeBinary(t, binDir, "npm")
}
t.Setenv("PATH", binDir)
configPath := tt.checkPath(home)
if err := tt.runner.Run("llama3.2", nil); err != nil {
t.Fatalf("Run returned error: %v", err)
}
if _, err := os.Stat(configPath); !os.IsNotExist(err) {
t.Fatalf("expected Run to leave %s untouched, got err=%v", configPath, err)
}
})
}
}

View File

@@ -0,0 +1,103 @@
package launch
import (
"errors"
"fmt"
"os"
"golang.org/x/term"
)
// ANSI escape sequences for terminal formatting.
const (
ansiBold = "\033[1m"
ansiReset = "\033[0m"
ansiGray = "\033[37m"
ansiGreen = "\033[32m"
ansiYellow = "\033[33m"
)
// ErrCancelled is returned when the user cancels a selection.
var ErrCancelled = errors.New("cancelled")
// errCancelled is kept as an internal alias for existing call sites.
var errCancelled = ErrCancelled
// DefaultConfirmPrompt provides a TUI-based confirmation prompt.
// When set, ConfirmPrompt delegates to it instead of using raw terminal I/O.
var DefaultConfirmPrompt func(prompt string) (bool, error)
// SingleSelector is a function type for single item selection.
// current is the name of the previously selected item to highlight; empty means no pre-selection.
type SingleSelector func(title string, items []ModelItem, current string) (string, error)
// MultiSelector is a function type for multi item selection.
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
// DefaultSingleSelector is the default single-select implementation.
var DefaultSingleSelector SingleSelector
// DefaultMultiSelector is the default multi-select implementation.
var DefaultMultiSelector MultiSelector
// DefaultSignIn provides a TUI-based sign-in flow.
// When set, ensureAuth uses it instead of plain text prompts.
// Returns the signed-in username or an error.
var DefaultSignIn func(modelName, signInURL string) (string, error)
type launchConfirmPolicy struct {
yes bool
requireYesMessage bool
}
var currentLaunchConfirmPolicy launchConfirmPolicy
func withLaunchConfirmPolicy(policy launchConfirmPolicy) func() {
old := currentLaunchConfirmPolicy
currentLaunchConfirmPolicy = policy
return func() {
currentLaunchConfirmPolicy = old
}
}
// ConfirmPrompt is the shared confirmation gate for launch flows (integration
// edits, missing-model pulls, sign-in prompts, OpenClaw install/security, etc).
// Behavior is controlled by currentLaunchConfirmPolicy, typically scoped by
// withLaunchConfirmPolicy in LaunchCmd (e.g. auto-approve with --yes).
func ConfirmPrompt(prompt string) (bool, error) {
if currentLaunchConfirmPolicy.yes {
return true, nil
}
if currentLaunchConfirmPolicy.requireYesMessage {
return false, fmt.Errorf("%s requires confirmation; re-run with --yes to continue", prompt)
}
if DefaultConfirmPrompt != nil {
return DefaultConfirmPrompt(prompt)
}
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
return false, err
}
defer term.Restore(fd, oldState)
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
buf := make([]byte, 1)
for {
if _, err := os.Stdin.Read(buf); err != nil {
return false, err
}
switch buf[0] {
case 'Y', 'y', 13:
fmt.Fprintf(os.Stderr, "yes\r\n")
return true, nil
case 'N', 'n', 27, 3:
fmt.Fprintf(os.Stderr, "no\r\n")
return false, nil
}
}
}

View File

@@ -0,0 +1,76 @@
package launch
import (
"strings"
"testing"
)
func TestErrCancelled(t *testing.T) {
t.Run("NotNil", func(t *testing.T) {
if errCancelled == nil {
t.Error("errCancelled should not be nil")
}
})
t.Run("Message", func(t *testing.T) {
if errCancelled.Error() != "cancelled" {
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
}
})
}
func TestWithLaunchConfirmPolicy_ScopesAndRestores(t *testing.T) {
oldPolicy := currentLaunchConfirmPolicy
oldHook := DefaultConfirmPrompt
t.Cleanup(func() {
currentLaunchConfirmPolicy = oldPolicy
DefaultConfirmPrompt = oldHook
})
currentLaunchConfirmPolicy = launchConfirmPolicy{}
var hookCalls int
DefaultConfirmPrompt = func(prompt string) (bool, error) {
hookCalls++
return true, nil
}
restoreOuter := withLaunchConfirmPolicy(launchConfirmPolicy{requireYesMessage: true})
restoreInner := withLaunchConfirmPolicy(launchConfirmPolicy{yes: true})
ok, err := ConfirmPrompt("test prompt")
if err != nil {
t.Fatalf("expected --yes policy to allow prompt, got error: %v", err)
}
if !ok {
t.Fatal("expected --yes policy to auto-accept prompt")
}
if hookCalls != 0 {
t.Fatalf("expected --yes to skip hook, got %d hook calls", hookCalls)
}
restoreInner()
_, err = ConfirmPrompt("test prompt")
if err == nil {
t.Fatal("expected requireYesMessage policy to block prompt")
}
if !strings.Contains(err.Error(), "re-run with --yes") {
t.Fatalf("expected actionable --yes error, got: %v", err)
}
if hookCalls != 0 {
t.Fatalf("expected blocking policy to skip hook, got %d hook calls", hookCalls)
}
restoreOuter()
ok, err = ConfirmPrompt("test prompt")
if err != nil {
t.Fatalf("expected restored default behavior to use hook, got error: %v", err)
}
if !ok {
t.Fatal("expected hook to return true")
}
if hookCalls != 1 {
t.Fatalf("expected one hook call after restore, got %d", hookCalls)
}
}

View File

@@ -0,0 +1,82 @@
package launch
import (
"strings"
"testing"
"github.com/ollama/ollama/cmd/config"
)
var (
integrations map[string]Runner
integrationAliases map[string]bool
integrationOrder = launcherIntegrationOrder
)
func init() {
integrations = buildTestIntegrations()
integrationAliases = buildTestIntegrationAliases()
}
func buildTestIntegrations() map[string]Runner {
result := make(map[string]Runner, len(integrationSpecsByName))
for name, spec := range integrationSpecsByName {
result[strings.ToLower(name)] = spec.Runner
}
return result
}
func buildTestIntegrationAliases() map[string]bool {
result := make(map[string]bool)
for _, spec := range integrationSpecs {
for _, alias := range spec.Aliases {
result[strings.ToLower(alias)] = true
}
}
return result
}
func setTestHome(t *testing.T, dir string) {
t.Helper()
setLaunchTestHome(t, dir)
}
func SaveIntegration(appName string, models []string) error {
return config.SaveIntegration(appName, models)
}
func LoadIntegration(appName string) (*config.IntegrationConfig, error) {
return config.LoadIntegration(appName)
}
func SaveAliases(appName string, aliases map[string]string) error {
return config.SaveAliases(appName, aliases)
}
func LastModel() string {
return config.LastModel()
}
func SetLastModel(model string) error {
return config.SetLastModel(model)
}
func LastSelection() string {
return config.LastSelection()
}
func SetLastSelection(selection string) error {
return config.SetLastSelection(selection)
}
func IntegrationModel(appName string) string {
return config.IntegrationModel(appName)
}
func IntegrationModels(appName string) []string {
return config.IntegrationModels(appName)
}
func integrationOnboarded(appName string) error {
return config.MarkIntegrationOnboarded(appName)
}

591
cmd/launch/vscode.go Normal file
View File

@@ -0,0 +1,591 @@
package launch
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
_ "github.com/mattn/go-sqlite3"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig"
)
// VSCode implements Runner and Editor for Visual Studio Code integration.
type VSCode struct{}
func (v *VSCode) String() string { return "Visual Studio Code" }
// findBinary returns the path/command to launch VS Code, or "" if not found.
// It checks platform-specific locations only.
func (v *VSCode) findBinary() string {
var candidates []string
switch runtime.GOOS {
case "darwin":
candidates = []string{
"/Applications/Visual Studio Code.app",
}
case "windows":
if localAppData := os.Getenv("LOCALAPPDATA"); localAppData != "" {
candidates = append(candidates, filepath.Join(localAppData, "Programs", "Microsoft VS Code", "bin", "code.cmd"))
}
default: // linux
candidates = []string{
"/usr/bin/code",
"/snap/bin/code",
}
}
for _, c := range candidates {
if _, err := os.Stat(c); err == nil {
return c
}
}
return ""
}
// IsRunning reports whether VS Code is currently running.
// Each platform uses a pattern specific enough to avoid matching Cursor or
// other VS Code forks.
func (v *VSCode) IsRunning() bool {
switch runtime.GOOS {
case "darwin":
out, err := exec.Command("pgrep", "-f", "Visual Studio Code.app/Contents/MacOS/Code").Output()
return err == nil && len(out) > 0
case "windows":
// Match VS Code by executable path to avoid matching Cursor or other forks.
out, err := exec.Command("powershell", "-NoProfile", "-Command",
`Get-Process Code -ErrorAction SilentlyContinue | Where-Object { $_.Path -like '*Microsoft VS Code*' } | Select-Object -First 1`).Output()
return err == nil && len(strings.TrimSpace(string(out))) > 0
default:
// Match VS Code specifically by its install path to avoid matching
// Cursor (/cursor/) or other forks.
for _, pattern := range []string{"/usr/share/code/", "/snap/code/"} {
out, err := exec.Command("pgrep", "-f", pattern).Output()
if err == nil && len(out) > 0 {
return true
}
}
return false
}
}
// Quit gracefully quits VS Code and waits for it to exit so that it flushes
// its in-memory state back to the database.
func (v *VSCode) Quit() {
if !v.IsRunning() {
return
}
switch runtime.GOOS {
case "darwin":
_ = exec.Command("osascript", "-e", `quit app "Visual Studio Code"`).Run()
case "windows":
// Kill VS Code by executable path to avoid killing Cursor or other forks.
_ = exec.Command("powershell", "-NoProfile", "-Command",
`Get-Process Code -ErrorAction SilentlyContinue | Where-Object { $_.Path -like '*Microsoft VS Code*' } | Stop-Process -Force`).Run()
default:
for _, pattern := range []string{"/usr/share/code/", "/snap/code/"} {
_ = exec.Command("pkill", "-f", pattern).Run()
}
}
// Wait for the process to fully exit and flush its state to disk
// TODO(hoyyeva): update spinner to use bubble tea
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
fmt.Fprintf(os.Stderr, "\033[90mRestarting VS Code... %s\033[0m", spinnerFrames[0])
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for range 150 { // 150 ticks × 200ms = 30s timeout
<-ticker.C
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mRestarting VS Code... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
if frame%5 == 0 { // check every ~1s
if !v.IsRunning() {
fmt.Fprintf(os.Stderr, "\r\033[K")
// Give VS Code a moment to finish writing its state DB
time.Sleep(1 * time.Second)
return
}
}
}
fmt.Fprintf(os.Stderr, "\r\033[K")
}
const (
minCopilotChatVersion = "0.41.0"
minVSCodeVersion = "1.113"
)
func (v *VSCode) Run(model string, args []string) error {
v.checkVSCodeVersion()
v.checkCopilotChatVersion()
// Get all configured models (saved by the launcher framework before Run is called)
models := []string{model}
if cfg, err := loadStoredIntegrationConfig("vscode"); err == nil && len(cfg.Models) > 0 {
models = cfg.Models
}
// VS Code discovers models from ollama ls. Cloud models that pass Show
// (the server knows about them) but aren't in ls need to be pulled to
// register them so VS Code can find them.
if client, err := api.ClientFromEnvironment(); err == nil {
v.ensureModelsRegistered(context.Background(), client, models)
}
// Warn if the default model doesn't support tool calling
if client, err := api.ClientFromEnvironment(); err == nil {
if resp, err := client.Show(context.Background(), &api.ShowRequest{Model: models[0]}); err == nil {
hasTools := false
for _, c := range resp.Capabilities {
if c == "tools" {
hasTools = true
break
}
}
if !hasTools {
fmt.Fprintf(os.Stderr, "Note: %s does not support tool calling and may not appear in the Copilot Chat model picker.\n", models[0])
}
}
}
v.printModelAccessTip()
if v.IsRunning() {
restart, err := ConfirmPrompt("Restart VS Code?")
if err != nil {
restart = false
}
if restart {
v.Quit()
if err := v.ShowInModelPicker(models); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not update VS Code model picker: %v%s\n", ansiYellow, err, ansiReset)
}
v.FocusVSCode()
} else {
fmt.Fprintf(os.Stderr, "\nTo get the latest model configuration, restart VS Code when you're ready.\n")
}
} else {
if err := v.ShowInModelPicker(models); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not update VS Code model picker: %v%s\n", ansiYellow, err, ansiReset)
}
v.FocusVSCode()
}
return nil
}
// ensureModelsRegistered pulls models that the server knows about (Show succeeds)
// but aren't in ollama ls yet. This is needed for cloud models so that VS Code
// can discover them from the Ollama API.
func (v *VSCode) ensureModelsRegistered(ctx context.Context, client *api.Client, models []string) {
listed, err := client.List(ctx)
if err != nil {
return
}
registered := make(map[string]bool, len(listed.Models))
for _, m := range listed.Models {
registered[m.Name] = true
}
for _, model := range models {
if registered[model] {
continue
}
// Also check without :latest suffix
if !strings.Contains(model, ":") && registered[model+":latest"] {
continue
}
if err := pullModel(ctx, client, model, false); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not register model %s: %v%s\n", ansiYellow, model, err, ansiReset)
}
}
}
// FocusVSCode brings VS Code to the foreground.
func (v *VSCode) FocusVSCode() {
binary := v.findBinary()
if binary == "" {
return
}
if runtime.GOOS == "darwin" && strings.HasSuffix(binary, ".app") {
_ = exec.Command("open", "-a", binary).Run()
} else {
_ = exec.Command(binary).Start()
}
}
// printModelAccessTip shows instructions for finding Ollama models in VS Code.
func (v *VSCode) printModelAccessTip() {
fmt.Fprintf(os.Stderr, "\nTip: To use Ollama models, open Copilot Chat and click the model picker.\n")
fmt.Fprintf(os.Stderr, " If you don't see your models, click \"Other models\" to find them.\n\n")
}
func (v *VSCode) Paths() []string {
if p := v.chatLanguageModelsPath(); fileExists(p) {
return []string{p}
}
return nil
}
func (v *VSCode) Edit(models []string) error {
if len(models) == 0 {
return nil
}
// Write chatLanguageModels.json with Ollama vendor entry
clmPath := v.chatLanguageModelsPath()
if err := os.MkdirAll(filepath.Dir(clmPath), 0o755); err != nil {
return err
}
var entries []map[string]any
if data, err := os.ReadFile(clmPath); err == nil {
_ = json.Unmarshal(data, &entries)
}
// Remove any existing Ollama entries, preserve others
filtered := make([]map[string]any, 0, len(entries))
for _, entry := range entries {
if vendor, _ := entry["vendor"].(string); vendor != "ollama" {
filtered = append(filtered, entry)
}
}
// Add new Ollama entry
filtered = append(filtered, map[string]any{
"vendor": "ollama",
"name": "Ollama",
"url": envconfig.Host().String(),
})
data, err := json.MarshalIndent(filtered, "", " ")
if err != nil {
return err
}
if err := fileutil.WriteWithBackup(clmPath, data); err != nil {
return err
}
// Clean up legacy settings from older Ollama integrations
v.updateSettings()
return nil
}
func (v *VSCode) Models() []string {
if !v.hasOllamaVendor() {
return nil
}
if cfg, err := loadStoredIntegrationConfig("vscode"); err == nil {
return cfg.Models
}
return nil
}
// hasOllamaVendor checks if chatLanguageModels.json contains an Ollama vendor entry.
func (v *VSCode) hasOllamaVendor() bool {
data, err := os.ReadFile(v.chatLanguageModelsPath())
if err != nil {
return false
}
var entries []map[string]any
if err := json.Unmarshal(data, &entries); err != nil {
return false
}
for _, entry := range entries {
if vendor, _ := entry["vendor"].(string); vendor == "ollama" {
return true
}
}
return false
}
func (v *VSCode) chatLanguageModelsPath() string {
return v.vscodePath("chatLanguageModels.json")
}
func (v *VSCode) settingsPath() string {
return v.vscodePath("settings.json")
}
// updateSettings cleans up legacy settings from older Ollama integrations.
func (v *VSCode) updateSettings() {
settingsPath := v.settingsPath()
data, err := os.ReadFile(settingsPath)
if err != nil {
return
}
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
return
}
changed := false
for _, key := range []string{"github.copilot.chat.byok.ollamaEndpoint", "ollama.launch.configured"} {
if _, ok := settings[key]; ok {
delete(settings, key)
changed = true
}
}
if !changed {
return
}
updated, err := json.MarshalIndent(settings, "", " ")
if err != nil {
return
}
_ = fileutil.WriteWithBackup(settingsPath, updated)
}
func (v *VSCode) statePath() string {
return v.vscodePath("globalStorage", "state.vscdb")
}
// ShowInModelPicker ensures the given models are visible in VS Code's Copilot
// Chat model picker. It sets the configured models to true in the picker
// preferences so they appear in the dropdown. Models use the VS Code identifier
// format "ollama/Ollama/<name>".
func (v *VSCode) ShowInModelPicker(models []string) error {
if len(models) == 0 {
return nil
}
dbPath := v.statePath()
needsCreate := !fileExists(dbPath)
if needsCreate {
if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil {
return fmt.Errorf("creating state directory: %w", err)
}
}
db, err := sql.Open("sqlite3", dbPath+"?_busy_timeout=5000")
if err != nil {
return fmt.Errorf("opening state database: %w", err)
}
defer db.Close()
// Create the table if this is a fresh DB. Schema must match what VS Code creates.
if needsCreate {
if _, err := db.Exec("CREATE TABLE ItemTable (key TEXT UNIQUE ON CONFLICT REPLACE, value BLOB)"); err != nil {
return fmt.Errorf("initializing state database: %w", err)
}
}
// Read existing preferences
prefs := make(map[string]bool)
var prefsJSON string
if err := db.QueryRow("SELECT value FROM ItemTable WHERE key = 'chatModelPickerPreferences'").Scan(&prefsJSON); err == nil {
_ = json.Unmarshal([]byte(prefsJSON), &prefs)
}
// Build name→ID map from VS Code's cached model list.
// VS Code uses numeric IDs like "ollama/Ollama/4", not "ollama/Ollama/kimi-k2.5:cloud".
nameToID := make(map[string]string)
var cacheJSON string
if err := db.QueryRow("SELECT value FROM ItemTable WHERE key = 'chat.cachedLanguageModels.v2'").Scan(&cacheJSON); err == nil {
var cached []map[string]any
if json.Unmarshal([]byte(cacheJSON), &cached) == nil {
for _, entry := range cached {
meta, _ := entry["metadata"].(map[string]any)
if meta == nil {
continue
}
if vendor, _ := meta["vendor"].(string); vendor == "ollama" {
name, _ := meta["name"].(string)
id, _ := entry["identifier"].(string)
if name != "" && id != "" {
nameToID[name] = id
}
}
}
}
}
// Ollama config is authoritative: always show configured models,
// hide Ollama models that are no longer in the config.
configuredIDs := make(map[string]bool)
for _, m := range models {
for _, id := range v.modelVSCodeIDs(m, nameToID) {
prefs[id] = true
configuredIDs[id] = true
}
}
for id := range prefs {
if strings.HasPrefix(id, "ollama/") && !configuredIDs[id] {
prefs[id] = false
}
}
data, _ := json.Marshal(prefs)
if _, err = db.Exec("INSERT OR REPLACE INTO ItemTable (key, value) VALUES ('chatModelPickerPreferences', ?)", string(data)); err != nil {
return err
}
return nil
}
// modelVSCodeIDs returns all possible VS Code picker IDs for a model name.
func (v *VSCode) modelVSCodeIDs(model string, nameToID map[string]string) []string {
var ids []string
if id, ok := nameToID[model]; ok {
ids = append(ids, id)
} else if !strings.Contains(model, ":") {
if id, ok := nameToID[model+":latest"]; ok {
ids = append(ids, id)
}
}
ids = append(ids, "ollama/Ollama/"+model)
if !strings.Contains(model, ":") {
ids = append(ids, "ollama/Ollama/"+model+":latest")
}
return ids
}
func (v *VSCode) vscodePath(parts ...string) string {
home, _ := os.UserHomeDir()
var base string
switch runtime.GOOS {
case "darwin":
base = filepath.Join(home, "Library", "Application Support", "Code", "User")
case "windows":
base = filepath.Join(os.Getenv("APPDATA"), "Code", "User")
default:
base = filepath.Join(home, ".config", "Code", "User")
}
return filepath.Join(append([]string{base}, parts...)...)
}
// checkVSCodeVersion warns if VS Code is older than minVSCodeVersion.
func (v *VSCode) checkVSCodeVersion() {
codeCLI := v.findCodeCLI()
if codeCLI == "" {
return
}
out, err := exec.Command(codeCLI, "--version").Output()
if err != nil {
return
}
// "code --version" outputs: version\ncommit\narch
lines := strings.Split(strings.TrimSpace(string(out)), "\n")
if len(lines) == 0 || lines[0] == "" {
return
}
version := strings.TrimSpace(lines[0])
if compareVersions(version, minVSCodeVersion) < 0 {
fmt.Fprintf(os.Stderr, "\n%sWarning: VS Code version (%s) is older than the recommended version (%s)%s\n", ansiYellow, version, minVSCodeVersion, ansiReset)
fmt.Fprintf(os.Stderr, "Please update VS Code to the latest version.\n\n")
}
}
// checkCopilotChatVersion warns if the GitHub Copilot Chat extension is
// missing or older than minCopilotChatVersion.
func (v *VSCode) checkCopilotChatVersion() {
codeCLI := v.findCodeCLI()
if codeCLI == "" {
return
}
out, err := exec.Command(codeCLI, "--list-extensions", "--show-versions").Output()
if err != nil {
return
}
installed, version := parseCopilotChatVersion(string(out))
if !installed {
fmt.Fprintf(os.Stderr, "\n%sWarning: GitHub Copilot Chat extension is not installed%s\n", ansiYellow, ansiReset)
fmt.Fprintf(os.Stderr, "Install it in VS Code: Extensions → search \"GitHub Copilot Chat\" → Install\n\n")
return
}
if compareVersions(version, minCopilotChatVersion) < 0 {
fmt.Fprintf(os.Stderr, "\n%sWarning: GitHub Copilot Chat extension version (%s) is older than the recommended version (%s)%s\n", ansiYellow, version, minCopilotChatVersion, ansiReset)
fmt.Fprintf(os.Stderr, "Please update it in VS Code: Extensions → search \"GitHub Copilot Chat\" → Update\n\n")
}
}
// findCodeCLI returns the path to the VS Code CLI for querying extensions.
// On macOS, findBinary may return an .app bundle which can't run --list-extensions,
// so this resolves to the actual CLI binary inside the bundle.
func (v *VSCode) findCodeCLI() string {
binary := v.findBinary()
if binary == "" {
return ""
}
if runtime.GOOS == "darwin" && strings.HasSuffix(binary, ".app") {
bundleCLI := binary + "/Contents/Resources/app/bin/code"
if _, err := os.Stat(bundleCLI); err == nil {
return bundleCLI
}
return ""
}
return binary
}
// parseCopilotChatVersion extracts the version of the GitHub Copilot Chat
// extension from "code --list-extensions --show-versions" output.
func parseCopilotChatVersion(output string) (installed bool, version string) {
for _, line := range strings.Split(output, "\n") {
// Format: github.copilot-chat@0.40.1
if !strings.HasPrefix(strings.ToLower(line), "github.copilot-chat@") {
continue
}
parts := strings.SplitN(line, "@", 2)
if len(parts) != 2 {
continue
}
return true, strings.TrimSpace(parts[1])
}
return false, ""
}
// compareVersions compares two dot-separated version strings.
// Returns -1 if a < b, 0 if a == b, 1 if a > b.
func compareVersions(a, b string) int {
aParts := strings.Split(a, ".")
bParts := strings.Split(b, ".")
maxLen := len(aParts)
if len(bParts) > maxLen {
maxLen = len(bParts)
}
for i := range maxLen {
var aNum, bNum int
if i < len(aParts) {
aNum, _ = strconv.Atoi(aParts[i])
}
if i < len(bParts) {
bNum, _ = strconv.Atoi(bParts[i])
}
if aNum < bNum {
return -1
}
if aNum > bNum {
return 1
}
}
return 0
}
func fileExists(path string) bool {
_, err := os.Stat(path)
return err == nil
}

486
cmd/launch/vscode_test.go Normal file
View File

@@ -0,0 +1,486 @@
package launch
import (
"database/sql"
"encoding/json"
"os"
"path/filepath"
"runtime"
"testing"
_ "github.com/mattn/go-sqlite3"
)
func TestVSCodeIntegration(t *testing.T) {
v := &VSCode{}
t.Run("String", func(t *testing.T) {
if got := v.String(); got != "Visual Studio Code" {
t.Errorf("String() = %q, want %q", got, "Visual Studio Code")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = v
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = v
})
}
func TestVSCodeEdit(t *testing.T) {
v := &VSCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("XDG_CONFIG_HOME", "")
clmPath := testVSCodePath(t, tmpDir, "chatLanguageModels.json")
tests := []struct {
name string
setup string // initial chatLanguageModels.json content, empty means no file
models []string
validate func(t *testing.T, data []byte)
}{
{
name: "fresh install",
models: []string{"llama3.2"},
validate: func(t *testing.T, data []byte) {
assertOllamaVendorConfigured(t, data)
},
},
{
name: "preserve other vendor entries",
setup: `[{"vendor": "azure", "name": "Azure", "url": "https://example.com"}]`,
models: []string{"llama3.2"},
validate: func(t *testing.T, data []byte) {
var entries []map[string]any
json.Unmarshal(data, &entries)
if len(entries) != 2 {
t.Errorf("expected 2 entries, got %d", len(entries))
}
// Check Azure entry preserved
found := false
for _, e := range entries {
if v, _ := e["vendor"].(string); v == "azure" {
found = true
}
}
if !found {
t.Error("azure vendor entry was not preserved")
}
assertOllamaVendorConfigured(t, data)
},
},
{
name: "update existing ollama entry",
setup: `[{"vendor": "ollama", "name": "Ollama", "url": "http://old:11434"}]`,
models: []string{"llama3.2"},
validate: func(t *testing.T, data []byte) {
assertOllamaVendorConfigured(t, data)
},
},
{
name: "empty models is no-op",
setup: `[{"vendor": "azure", "name": "Azure"}]`,
models: []string{},
validate: func(t *testing.T, data []byte) {
if string(data) != `[{"vendor": "azure", "name": "Azure"}]` {
t.Error("empty models should not modify file")
}
},
},
{
name: "corrupted JSON treated as empty",
setup: `{corrupted json`,
models: []string{"llama3.2"},
validate: func(t *testing.T, data []byte) {
var entries []map[string]any
if err := json.Unmarshal(data, &entries); err != nil {
t.Errorf("result is not valid JSON: %v", err)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
os.RemoveAll(filepath.Dir(clmPath))
if tt.setup != "" {
os.MkdirAll(filepath.Dir(clmPath), 0o755)
os.WriteFile(clmPath, []byte(tt.setup), 0o644)
}
if err := v.Edit(tt.models); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(clmPath)
tt.validate(t, data)
})
}
}
func TestVSCodeEditCleansUpOldSettings(t *testing.T) {
v := &VSCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("XDG_CONFIG_HOME", "")
settingsPath := testVSCodePath(t, tmpDir, "settings.json")
// Create settings.json with old byok setting
os.MkdirAll(filepath.Dir(settingsPath), 0o755)
os.WriteFile(settingsPath, []byte(`{"github.copilot.chat.byok.ollamaEndpoint": "http://old:11434", "ollama.launch.configured": true, "editor.fontSize": 14}`), 0o644)
if err := v.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
// Verify old settings were removed
data, err := os.ReadFile(settingsPath)
if err != nil {
t.Fatal(err)
}
var settings map[string]any
json.Unmarshal(data, &settings)
if _, ok := settings["github.copilot.chat.byok.ollamaEndpoint"]; ok {
t.Error("github.copilot.chat.byok.ollamaEndpoint should have been removed")
}
if _, ok := settings["ollama.launch.configured"]; ok {
t.Error("ollama.launch.configured should have been removed")
}
if settings["editor.fontSize"] != float64(14) {
t.Error("editor.fontSize should have been preserved")
}
}
func TestVSCodePaths(t *testing.T) {
v := &VSCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("XDG_CONFIG_HOME", "")
clmPath := testVSCodePath(t, tmpDir, "chatLanguageModels.json")
t.Run("no file returns nil", func(t *testing.T) {
os.Remove(clmPath)
if paths := v.Paths(); paths != nil {
t.Errorf("expected nil, got %v", paths)
}
})
t.Run("existing file returns path", func(t *testing.T) {
os.MkdirAll(filepath.Dir(clmPath), 0o755)
os.WriteFile(clmPath, []byte(`[]`), 0o644)
if paths := v.Paths(); len(paths) != 1 {
t.Errorf("expected 1 path, got %d", len(paths))
}
})
}
// testVSCodePath returns the expected VS Code config path for the given file in tests.
func testVSCodePath(t *testing.T, tmpDir, filename string) string {
t.Helper()
switch runtime.GOOS {
case "darwin":
return filepath.Join(tmpDir, "Library", "Application Support", "Code", "User", filename)
case "windows":
t.Setenv("APPDATA", tmpDir)
return filepath.Join(tmpDir, "Code", "User", filename)
default:
return filepath.Join(tmpDir, ".config", "Code", "User", filename)
}
}
func assertOllamaVendorConfigured(t *testing.T, data []byte) {
t.Helper()
var entries []map[string]any
if err := json.Unmarshal(data, &entries); err != nil {
t.Fatalf("invalid JSON: %v", err)
}
for _, entry := range entries {
if vendor, _ := entry["vendor"].(string); vendor == "ollama" {
if name, _ := entry["name"].(string); name != "Ollama" {
t.Errorf("expected name \"Ollama\", got %q", name)
}
if url, _ := entry["url"].(string); url == "" {
t.Error("url not set")
}
return
}
}
t.Error("no ollama vendor entry found")
}
func TestShowInModelPicker(t *testing.T) {
v := &VSCode{}
// helper to create a state DB with optional seed data
setupDB := func(t *testing.T, tmpDir string, seedPrefs map[string]bool, seedCache []map[string]any) string {
t.Helper()
dbDir := filepath.Join(tmpDir, "globalStorage")
os.MkdirAll(dbDir, 0o755)
dbPath := filepath.Join(dbDir, "state.vscdb")
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
t.Fatal(err)
}
defer db.Close()
if _, err := db.Exec("CREATE TABLE ItemTable (key TEXT UNIQUE ON CONFLICT REPLACE, value BLOB)"); err != nil {
t.Fatal(err)
}
if seedPrefs != nil {
data, _ := json.Marshal(seedPrefs)
db.Exec("INSERT INTO ItemTable (key, value) VALUES ('chatModelPickerPreferences', ?)", string(data))
}
if seedCache != nil {
data, _ := json.Marshal(seedCache)
db.Exec("INSERT INTO ItemTable (key, value) VALUES ('chat.cachedLanguageModels.v2', ?)", string(data))
}
return dbPath
}
// helper to read prefs back from DB
readPrefs := func(t *testing.T, dbPath string) map[string]bool {
t.Helper()
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
t.Fatal(err)
}
defer db.Close()
var raw string
if err := db.QueryRow("SELECT value FROM ItemTable WHERE key = 'chatModelPickerPreferences'").Scan(&raw); err != nil {
t.Fatal(err)
}
prefs := make(map[string]bool)
json.Unmarshal([]byte(raw), &prefs)
return prefs
}
t.Run("fresh DB creates table and shows models", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("XDG_CONFIG_HOME", "")
if runtime.GOOS == "windows" {
t.Setenv("APPDATA", tmpDir)
}
err := v.ShowInModelPicker([]string{"llama3.2"})
if err != nil {
t.Fatal(err)
}
dbPath := testVSCodePath(t, tmpDir, filepath.Join("globalStorage", "state.vscdb"))
prefs := readPrefs(t, dbPath)
if !prefs["ollama/Ollama/llama3.2"] {
t.Error("expected llama3.2 to be shown")
}
if !prefs["ollama/Ollama/llama3.2:latest"] {
t.Error("expected llama3.2:latest to be shown")
}
})
t.Run("configured models are shown", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("XDG_CONFIG_HOME", "")
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), nil, nil)
err := v.ShowInModelPicker([]string{"llama3.2", "qwen3:8b"})
if err != nil {
t.Fatal(err)
}
prefs := readPrefs(t, dbPath)
if !prefs["ollama/Ollama/llama3.2"] {
t.Error("expected llama3.2 to be shown")
}
if !prefs["ollama/Ollama/qwen3:8b"] {
t.Error("expected qwen3:8b to be shown")
}
})
t.Run("removed models are hidden", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("XDG_CONFIG_HOME", "")
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), map[string]bool{
"ollama/Ollama/llama3.2": true,
"ollama/Ollama/llama3.2:latest": true,
"ollama/Ollama/mistral": true,
"ollama/Ollama/mistral:latest": true,
}, nil)
// Only configure llama3.2 — mistral should get hidden
err := v.ShowInModelPicker([]string{"llama3.2"})
if err != nil {
t.Fatal(err)
}
prefs := readPrefs(t, dbPath)
if !prefs["ollama/Ollama/llama3.2"] {
t.Error("expected llama3.2 to stay shown")
}
if prefs["ollama/Ollama/mistral"] {
t.Error("expected mistral to be hidden")
}
if prefs["ollama/Ollama/mistral:latest"] {
t.Error("expected mistral:latest to be hidden")
}
})
t.Run("non-ollama prefs are preserved", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("XDG_CONFIG_HOME", "")
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), map[string]bool{
"copilot/gpt-4o": true,
}, nil)
err := v.ShowInModelPicker([]string{"llama3.2"})
if err != nil {
t.Fatal(err)
}
prefs := readPrefs(t, dbPath)
if !prefs["copilot/gpt-4o"] {
t.Error("expected copilot/gpt-4o to stay shown")
}
})
t.Run("uses cached numeric IDs when available", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("XDG_CONFIG_HOME", "")
cache := []map[string]any{
{
"identifier": "ollama/Ollama/4",
"metadata": map[string]any{"vendor": "ollama", "name": "llama3.2"},
},
}
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), nil, cache)
err := v.ShowInModelPicker([]string{"llama3.2"})
if err != nil {
t.Fatal(err)
}
prefs := readPrefs(t, dbPath)
if !prefs["ollama/Ollama/4"] {
t.Error("expected numeric ID ollama/Ollama/4 to be shown")
}
// Name-based fallback should also be set
if !prefs["ollama/Ollama/llama3.2"] {
t.Error("expected name-based ID to also be shown")
}
})
t.Run("empty models is no-op", func(t *testing.T) {
err := v.ShowInModelPicker([]string{})
if err != nil {
t.Fatal(err)
}
})
t.Run("previously hidden model is re-shown when configured", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("XDG_CONFIG_HOME", "")
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), map[string]bool{
"ollama/Ollama/llama3.2": false,
"ollama/Ollama/llama3.2:latest": false,
}, nil)
// Ollama config is authoritative — should override the hidden state
err := v.ShowInModelPicker([]string{"llama3.2"})
if err != nil {
t.Fatal(err)
}
prefs := readPrefs(t, dbPath)
if !prefs["ollama/Ollama/llama3.2"] {
t.Error("expected llama3.2 to be re-shown")
}
})
}
func TestParseCopilotChatVersion(t *testing.T) {
tests := []struct {
name string
output string
wantInstalled bool
wantVersion string
}{
{
name: "found among other extensions",
output: "ms-python.python@2024.1.1\ngithub.copilot-chat@0.40.1\ngithub.copilot@1.200.0\n",
wantInstalled: true,
wantVersion: "0.40.1",
},
{
name: "only extension",
output: "GitHub.copilot-chat@0.41.0\n",
wantInstalled: true,
wantVersion: "0.41.0",
},
{
name: "not installed",
output: "ms-python.python@2024.1.1\ngithub.copilot@1.200.0\n",
wantInstalled: false,
},
{
name: "empty output",
output: "",
wantInstalled: false,
},
{
name: "case insensitive match",
output: "GitHub.Copilot-Chat@0.39.0\n",
wantInstalled: true,
wantVersion: "0.39.0",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
installed, version := parseCopilotChatVersion(tt.output)
if installed != tt.wantInstalled {
t.Errorf("installed = %v, want %v", installed, tt.wantInstalled)
}
if installed && version != tt.wantVersion {
t.Errorf("version = %q, want %q", version, tt.wantVersion)
}
})
}
}
func TestCompareVersions(t *testing.T) {
tests := []struct {
a, b string
want int
}{
{"0.40.1", "0.40.1", 0},
{"0.40.2", "0.40.1", 1},
{"0.40.0", "0.40.1", -1},
{"0.41.0", "0.40.1", 1},
{"0.39.9", "0.40.1", -1},
{"1.0.0", "0.40.1", 1},
{"0.40", "0.40.1", -1},
{"0.40.1.1", "0.40.1", 1},
}
for _, tt := range tests {
t.Run(tt.a+"_vs_"+tt.b, func(t *testing.T) {
got := compareVersions(tt.a, tt.b)
if got != tt.want {
t.Errorf("compareVersions(%q, %q) = %d, want %d", tt.a, tt.b, got, tt.want)
}
})
}
}

View File

@@ -7,7 +7,7 @@ import (
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
"github.com/ollama/ollama/cmd/config" "github.com/ollama/ollama/cmd/launch"
) )
var ( var (
@@ -64,8 +64,8 @@ type SelectItem struct {
Recommended bool Recommended bool
} }
// ConvertItems converts config.ModelItem slice to SelectItem slice. // ConvertItems converts launch.ModelItem slice to SelectItem slice.
func ConvertItems(items []config.ModelItem) []SelectItem { func ConvertItems(items []launch.ModelItem) []SelectItem {
out := make([]SelectItem, len(items)) out := make([]SelectItem, len(items))
for i, item := range items { for i, item := range items {
out[i] = SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended} out[i] = SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
@@ -101,6 +101,16 @@ type selectorModel struct {
width int width int
} }
func selectorModelWithCurrent(title string, items []SelectItem, current string) selectorModel {
m := selectorModel{
title: title,
items: items,
cursor: cursorForCurrent(items, current),
}
m.updateScroll(m.otherStart())
return m
}
func (m selectorModel) filteredItems() []SelectItem { func (m selectorModel) filteredItems() []SelectItem {
if m.filter == "" { if m.filter == "" {
return m.items return m.items
@@ -232,6 +242,10 @@ func (m selectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.cancelled = true m.cancelled = true
return m, tea.Quit return m, tea.Quit
case tea.KeyLeft:
m.cancelled = true
return m, tea.Quit
case tea.KeyEnter: case tea.KeyEnter:
filtered := m.filteredItems() filtered := m.filteredItems()
if len(filtered) > 0 && m.cursor < len(filtered) { if len(filtered) > 0 && m.cursor < len(filtered) {
@@ -344,7 +358,7 @@ func (m selectorModel) renderContent() string {
} }
s.WriteString("\n") s.WriteString("\n")
help := "↑/↓ navigate • enter select • esc cancel" help := "↑/↓ navigate • enter select • ← back"
if m.helpText != "" { if m.helpText != "" {
help = m.helpText help = m.helpText
} }
@@ -367,13 +381,24 @@ func (m selectorModel) View() string {
// cursorForCurrent returns the item index matching current, or 0 if not found. // cursorForCurrent returns the item index matching current, or 0 if not found.
func cursorForCurrent(items []SelectItem, current string) int { func cursorForCurrent(items []SelectItem, current string) int {
if current != "" { if current == "" {
for i, item := range items { return 0
if item.Name == current || strings.HasPrefix(item.Name, current+":") || strings.HasPrefix(current, item.Name+":") { }
return i
} // Prefer exact name matches before tag-prefix fallback so "qwen3.5" does not
// incorrectly select "qwen3.5:cloud" (and vice versa) based on list order.
for i, item := range items {
if item.Name == current {
return i
} }
} }
for i, item := range items {
if strings.HasPrefix(item.Name, current+":") || strings.HasPrefix(current, item.Name+":") {
return i
}
}
return 0 return 0
} }
@@ -382,11 +407,7 @@ func SelectSingle(title string, items []SelectItem, current string) (string, err
return "", fmt.Errorf("no items to select from") return "", fmt.Errorf("no items to select from")
} }
m := selectorModel{ m := selectorModelWithCurrent(title, items, current)
title: title,
items: items,
cursor: cursorForCurrent(items, current),
}
p := tea.NewProgram(m) p := tea.NewProgram(m)
finalModel, err := p.Run() finalModel, err := p.Run()
@@ -523,6 +544,7 @@ func (m *multiSelectorModel) toggleItem() {
origIdx := m.itemIndex[item.Name] origIdx := m.itemIndex[item.Name]
if m.checked[origIdx] { if m.checked[origIdx] {
wasDefault := len(m.checkOrder) > 0 && m.checkOrder[len(m.checkOrder)-1] == origIdx
delete(m.checked, origIdx) delete(m.checked, origIdx)
for i, idx := range m.checkOrder { for i, idx := range m.checkOrder {
if idx == origIdx { if idx == origIdx {
@@ -530,6 +552,34 @@ func (m *multiSelectorModel) toggleItem() {
break break
} }
} }
if wasDefault {
// When removing the default, pick the nearest checked model above it
// (or below if none above) so default fallback follows list order.
newDefault := -1
for i := origIdx - 1; i >= 0; i-- {
if m.checked[i] {
newDefault = i
break
}
}
if newDefault == -1 {
for i := origIdx + 1; i < len(m.items); i++ {
if m.checked[i] {
newDefault = i
break
}
}
}
if newDefault != -1 {
for i, idx := range m.checkOrder {
if idx == newDefault {
m.checkOrder = append(m.checkOrder[:i], m.checkOrder[i+1:]...)
break
}
}
m.checkOrder = append(m.checkOrder, newDefault)
}
}
} else { } else {
m.checked[origIdx] = true m.checked[origIdx] = true
m.checkOrder = append(m.checkOrder, origIdx) m.checkOrder = append(m.checkOrder, origIdx)
@@ -562,6 +612,10 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.cancelled = true m.cancelled = true
return m, tea.Quit return m, tea.Quit
case tea.KeyLeft:
m.cancelled = true
return m, tea.Quit
case tea.KeyTab: case tea.KeyTab:
m.multi = !m.multi m.multi = !m.multi
@@ -764,7 +818,7 @@ func (m multiSelectorModel) View() string {
s.WriteString("\n") s.WriteString("\n")
if !m.multi { if !m.multi {
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • tab add multiple • esc cancel")) s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • tab add multiple • ← back"))
} else { } else {
count := m.selectedCount() count := m.selectedCount()
if count == 0 { if count == 0 {
@@ -773,7 +827,7 @@ func (m multiSelectorModel) View() string {
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count))) s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
} }
s.WriteString("\n\n") s.WriteString("\n\n")
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • tab select single • enter confirm • esc cancel")) s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • tab select single • enter confirm • ← back"))
} }
result := s.String() result := s.String()

View File

@@ -216,6 +216,41 @@ func TestUpdateScroll(t *testing.T) {
} }
} }
func TestSelectorModelWithCurrent_ScrollsToCurrentInMoreSection(t *testing.T) {
m := selectorModelWithCurrent("Pick:", mixedItems(), "other-10")
if m.cursor != 11 {
t.Fatalf("cursor = %d, want 11", m.cursor)
}
if m.scrollOffset == 0 {
t.Fatal("scrollOffset should move to reveal current item in More section")
}
content := m.renderContent()
if !strings.Contains(content, "▸ other-10") {
t.Fatalf("expected current item to be visible and highlighted\n%s", content)
}
}
func TestSelectorModelWithCurrent_HighlightsExactLocalWhenCloudVariantExists(t *testing.T) {
m := selectorModelWithCurrent("Pick:", []SelectItem{
{Name: "qwen3.5:cloud", Recommended: true},
{Name: "qwen3.5", Recommended: true},
}, "qwen3.5")
if m.cursor != 1 {
t.Fatalf("cursor = %d, want 1", m.cursor)
}
content := m.renderContent()
if !strings.Contains(content, "▸ qwen3.5") {
t.Fatalf("expected local qwen3.5 to be highlighted\n%s", content)
}
if strings.Contains(content, "▸ qwen3.5:cloud") {
t.Fatalf("did not expect cloud qwen3.5:cloud to be highlighted\n%s", content)
}
}
func TestRenderContent_SectionHeaders(t *testing.T) { func TestRenderContent_SectionHeaders(t *testing.T) {
m := selectorModel{ m := selectorModel{
title: "Pick:", title: "Pick:",
@@ -418,6 +453,28 @@ func TestCursorForCurrent(t *testing.T) {
} }
} }
func TestCursorForCurrent_PrefersExactLocalOverCloudPrefix(t *testing.T) {
testItems := []SelectItem{
{Name: "qwen3.5:cloud", Recommended: true},
{Name: "qwen3.5", Recommended: true},
}
if got := cursorForCurrent(testItems, "qwen3.5"); got != 1 {
t.Errorf("cursorForCurrent(%q) = %d, want %d", "qwen3.5", got, 1)
}
}
func TestCursorForCurrent_PrefersExactCloudOverLocalPrefix(t *testing.T) {
testItems := []SelectItem{
{Name: "qwen3.5", Recommended: true},
{Name: "qwen3.5:cloud", Recommended: true},
}
if got := cursorForCurrent(testItems, "qwen3.5:cloud"); got != 1 {
t.Errorf("cursorForCurrent(%q) = %d, want %d", "qwen3.5:cloud", got, 1)
}
}
// --- ReorderItems --- // --- ReorderItems ---
func TestReorderItems(t *testing.T) { func TestReorderItems(t *testing.T) {
@@ -725,6 +782,9 @@ func TestMulti_MultiModeHelpText(t *testing.T) {
if !strings.Contains(content, "tab select single") { if !strings.Contains(content, "tab select single") {
t.Error("multi mode should show 'tab select single' in help") t.Error("multi mode should show 'tab select single' in help")
} }
if !strings.Contains(content, "← back") {
t.Error("multi mode should show '← back' in help")
}
} }
// --- preChecked initialization order --- // --- preChecked initialization order ---
@@ -783,6 +843,74 @@ func TestMulti_LastCheckedIsDefault(t *testing.T) {
} }
} }
func TestMulti_UncheckingDefaultFallsBackToNearestCheckedAbove(t *testing.T) {
// Default is "b", and checked models are "a", "b", "c".
// Unticking default should make "a" (the nearest checked item above) default.
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"b", "c", "a"})
m.multi = true
m.cursor = 1 // "b"
m.toggleItem()
lastIdx := m.checkOrder[len(m.checkOrder)-1]
if m.items[lastIdx].Name != "a" {
t.Fatalf("expected default to fall back to 'a', got %q", m.items[lastIdx].Name)
}
}
func TestMulti_UncheckingTopDefaultFallsBackToNearestCheckedBelow(t *testing.T) {
// Default is top item "a". With no checked item above, fallback should pick
// the nearest checked item below ("b").
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "c", "b"})
m.multi = true
m.cursor = 0 // "a"
m.toggleItem()
lastIdx := m.checkOrder[len(m.checkOrder)-1]
if m.items[lastIdx].Name != "b" {
t.Fatalf("expected default to fall back to 'b', got %q", m.items[lastIdx].Name)
}
}
// --- Left arrow back navigation ---
func TestSelectorLeftArrowCancelsWhenNoFilter(t *testing.T) {
m := selectorModelWithCurrent("Pick:", items("a", "b", "c"), "")
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyLeft})
got := updated.(selectorModel)
if !got.cancelled {
t.Error("left arrow with empty filter should cancel (go back)")
}
}
func TestSelectorLeftArrowCancelsWhenFiltering(t *testing.T) {
m := selectorModelWithCurrent("Pick:", items("a", "b", "c"), "")
m.filter = "a"
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyLeft})
got := updated.(selectorModel)
if !got.cancelled {
t.Error("left arrow with active filter should still cancel (go back)")
}
}
func TestMultiSelectorLeftArrowCancelsWhenNoFilter(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyLeft})
got := updated.(multiSelectorModel)
if !got.cancelled {
t.Error("left arrow with empty filter should cancel (go back)")
}
}
func TestMultiSelectorLeftArrowCancelsWhenFiltering(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
m.filter = "a"
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyLeft})
got := updated.(multiSelectorModel)
if !got.cancelled {
t.Error("left arrow with active filter should still cancel (go back)")
}
}
// Key message helpers for testing // Key message helpers for testing
type keyType = int type keyType = int

View File

@@ -1,15 +1,24 @@
package tui package tui
import ( import (
"context"
"fmt" "fmt"
"strings" "strings"
"time" "time"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
"github.com/ollama/ollama/cmd/config" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/launch"
) )
type signInTickMsg struct{}
type signInCheckMsg struct {
signedIn bool
userName string
}
type signInModel struct { type signInModel struct {
modelName string modelName string
signInURL string signInURL string
@@ -88,11 +97,8 @@ func renderSignIn(modelName, signInURL string, spinner, width int) string {
fmt.Fprintf(&s, "To use %s, please sign in.\n\n", selectorSelectedItemStyle.Render(modelName)) fmt.Fprintf(&s, "To use %s, please sign in.\n\n", selectorSelectedItemStyle.Render(modelName))
// Wrap in OSC 8 hyperlink so the entire URL is clickable even when wrapped.
// Padding is outside the hyperlink so spaces don't get underlined.
link := fmt.Sprintf("\033]8;;%s\033\\%s\033]8;;\033\\", signInURL, urlColor.Render(signInURL))
s.WriteString("Navigate to:\n") s.WriteString("Navigate to:\n")
s.WriteString(urlWrap.Render(link)) s.WriteString(urlWrap.Render(urlColor.Render(signInURL)))
s.WriteString("\n\n") s.WriteString("\n\n")
s.WriteString(lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).Render( s.WriteString(lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).Render(
@@ -104,9 +110,21 @@ func renderSignIn(modelName, signInURL string, spinner, width int) string {
return lipgloss.NewStyle().PaddingLeft(2).Render(s.String()) return lipgloss.NewStyle().PaddingLeft(2).Render(s.String())
} }
func checkSignIn() tea.Msg {
client, err := api.ClientFromEnvironment()
if err != nil {
return signInCheckMsg{signedIn: false}
}
user, err := client.Whoami(context.Background())
if err == nil && user != nil && user.Name != "" {
return signInCheckMsg{signedIn: true, userName: user.Name}
}
return signInCheckMsg{signedIn: false}
}
// RunSignIn shows a bubbletea sign-in dialog and polls until the user signs in or cancels. // RunSignIn shows a bubbletea sign-in dialog and polls until the user signs in or cancels.
func RunSignIn(modelName, signInURL string) (string, error) { func RunSignIn(modelName, signInURL string) (string, error) {
config.OpenBrowser(signInURL) launch.OpenBrowser(signInURL)
m := signInModel{ m := signInModel{
modelName: modelName, modelName: modelName,

View File

@@ -25,22 +25,6 @@ func TestRenderSignIn_ContainsURL(t *testing.T) {
} }
} }
func TestRenderSignIn_OSC8Hyperlink(t *testing.T) {
url := "https://ollama.com/connect?key=abc123"
got := renderSignIn("test:cloud", url, 0, 120)
// Should contain OSC 8 open sequence with the URL
osc8Open := "\033]8;;" + url + "\033\\"
if !strings.Contains(got, osc8Open) {
t.Error("should contain OSC 8 open sequence with URL")
}
// Should contain OSC 8 close sequence
osc8Close := "\033]8;;\033\\"
if !strings.Contains(got, osc8Close) {
t.Error("should contain OSC 8 close sequence")
}
}
func TestRenderSignIn_ContainsSpinner(t *testing.T) { func TestRenderSignIn_ContainsSpinner(t *testing.T) {
got := renderSignIn("test:cloud", "https://example.com", 0, 80) got := renderSignIn("test:cloud", "https://example.com", 0, 80)

View File

@@ -1,16 +1,11 @@
package tui package tui
import ( import (
"context"
"errors"
"fmt" "fmt"
"strings"
"time"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/cmd/launch"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
@@ -45,30 +40,24 @@ var (
type menuItem struct { type menuItem struct {
title string title string
description string description string
integration string // integration name for loading model config, empty if not an integration integration string
isRunModel bool isRunModel bool
isOthers bool isOthers bool
} }
var mainMenuItems = []menuItem{ var mainMenuItems = []menuItem{
{ {
title: "Run a model", title: "Chat with a model",
description: "Start an interactive chat with a model", description: "Start an interactive chat with a model",
isRunModel: true, isRunModel: true,
}, },
{ {
title: "Launch Claude Code",
description: "Agentic coding across large codebases",
integration: "claude", integration: "claude",
}, },
{ {
title: "Launch Codex",
description: "OpenAI's open-source coding agent",
integration: "codex", integration: "codex",
}, },
{ {
title: "Launch OpenClaw",
description: "Personal AI with 100+ skills",
integration: "openclaw", integration: "openclaw",
}, },
} }
@@ -79,277 +68,106 @@ var othersMenuItem = menuItem{
isOthers: true, isOthers: true,
} }
// getOtherIntegrations dynamically builds the "Others" list from the integration type model struct {
// registry, excluding any integrations already present in the pinned mainMenuItems. state *launch.LauncherState
func getOtherIntegrations() []menuItem { items []menuItem
pinned := map[string]bool{ cursor int
"run": true, // not an integration but in the pinned list showOthers bool
width int
quitting bool
selected bool
action TUIAction
}
func newModel(state *launch.LauncherState) model {
m := model{
state: state,
} }
m.showOthers = shouldExpandOthers(state)
m.items = buildMenuItems(state, m.showOthers)
m.cursor = initialCursor(state, m.items)
return m
}
func shouldExpandOthers(state *launch.LauncherState) bool {
if state == nil {
return false
}
for _, item := range otherIntegrationItems(state) {
if item.integration == state.LastSelection {
return true
}
}
return false
}
func buildMenuItems(state *launch.LauncherState, showOthers bool) []menuItem {
items := make([]menuItem, 0, len(mainMenuItems)+1)
for _, item := range mainMenuItems { for _, item := range mainMenuItems {
if item.integration != "" { if item.integration == "" {
pinned[item.integration] = true items = append(items, item)
continue
}
if integrationState, ok := state.Integrations[item.integration]; ok {
items = append(items, integrationMenuItem(integrationState))
} }
} }
var others []menuItem if showOthers {
for _, info := range config.ListIntegrationInfos() { items = append(items, otherIntegrationItems(state)...)
} else {
items = append(items, othersMenuItem)
}
return items
}
func integrationMenuItem(state launch.LauncherIntegrationState) menuItem {
description := state.Description
if description == "" {
description = "Open " + state.DisplayName + " integration"
}
return menuItem{
title: "Launch " + state.DisplayName,
description: description,
integration: state.Name,
}
}
func otherIntegrationItems(state *launch.LauncherState) []menuItem {
pinned := map[string]bool{
"claude": true,
"codex": true,
"openclaw": true,
}
var items []menuItem
for _, info := range launch.ListIntegrationInfos() {
if pinned[info.Name] { if pinned[info.Name] {
continue continue
} }
desc := info.Description integrationState, ok := state.Integrations[info.Name]
if desc == "" { if !ok {
desc = "Open " + info.DisplayName + " integration"
}
others = append(others, menuItem{
title: "Launch " + info.DisplayName,
description: desc,
integration: info.Name,
})
}
return others
}
type model struct {
items []menuItem
cursor int
quitting bool
selected bool
changeModel bool
changeModels []string // multi-select result for Editor integrations
showOthers bool
availableModels map[string]bool
err error
showingModal bool
modalSelector selectorModel
modalItems []SelectItem
showingMultiModal bool
multiModalSelector multiSelectorModel
showingSignIn bool
signInURL string
signInModel string
signInSpinner int
signInFromModal bool // true if sign-in was triggered from modal (not main menu)
width int // terminal width from WindowSizeMsg
statusMsg string // temporary status message shown near help text
}
type signInTickMsg struct{}
type signInCheckMsg struct {
signedIn bool
userName string
}
type clearStatusMsg struct{}
func (m *model) modelExists(name string) bool {
if m.availableModels == nil || name == "" {
return false
}
if m.availableModels[name] {
return true
}
// Check for prefix match (e.g., "llama2" matches "llama2:latest")
for modelName := range m.availableModels {
if strings.HasPrefix(modelName, name+":") {
return true
}
}
return false
}
func (m *model) buildModalItems() []SelectItem {
modelItems, _ := config.GetModelItems(context.Background())
return ReorderItems(ConvertItems(modelItems))
}
func (m *model) openModelModal(currentModel string) {
m.modalItems = m.buildModalItems()
cursor := 0
if currentModel != "" {
for i, item := range m.modalItems {
if item.Name == currentModel || strings.HasPrefix(item.Name, currentModel+":") || strings.HasPrefix(currentModel, item.Name+":") {
cursor = i
break
}
}
}
m.modalSelector = selectorModel{
title: "Select model:",
items: m.modalItems,
cursor: cursor,
helpText: "↑/↓ navigate • enter select • ← back",
}
m.modalSelector.updateScroll(m.modalSelector.otherStart())
m.showingModal = true
}
func (m *model) openMultiModelModal(integration string) {
items := m.buildModalItems()
var preChecked []string
if models := config.IntegrationModels(integration); len(models) > 0 {
preChecked = models
}
m.multiModalSelector = newMultiSelectorModel("Select models:", items, preChecked)
// Set cursor to the first pre-checked (last used) model
if len(preChecked) > 0 {
for i, item := range items {
if item.Name == preChecked[0] {
m.multiModalSelector.cursor = i
m.multiModalSelector.updateScroll(m.multiModalSelector.otherStart())
break
}
}
}
m.showingMultiModal = true
}
func isCloudModel(name string) bool {
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
}
func cloudStatusDisabled(client *api.Client) bool {
status, err := client.CloudStatusExperimental(context.Background())
if err != nil {
return false
}
return status.Cloud.Disabled
}
func cloudModelDisabled(name string) bool {
if !isCloudModel(name) {
return false
}
client, err := api.ClientFromEnvironment()
if err != nil {
return false
}
return cloudStatusDisabled(client)
}
// checkCloudSignIn checks if a cloud model needs sign-in.
// Returns a command to start sign-in if needed, or nil if already signed in.
func (m *model) checkCloudSignIn(modelName string, fromModal bool) tea.Cmd {
if modelName == "" || !isCloudModel(modelName) {
return nil
}
client, err := api.ClientFromEnvironment()
if err != nil {
return nil
}
if cloudStatusDisabled(client) {
return nil
}
user, err := client.Whoami(context.Background())
if err == nil && user != nil && user.Name != "" {
return nil
}
var aErr api.AuthorizationError
if errors.As(err, &aErr) && aErr.SigninURL != "" {
return m.startSignIn(modelName, aErr.SigninURL, fromModal)
}
return nil
}
// startSignIn initiates the sign-in flow for a cloud model.
// fromModal indicates if this was triggered from the model picker modal.
func (m *model) startSignIn(modelName, signInURL string, fromModal bool) tea.Cmd {
m.showingModal = false
m.showingSignIn = true
m.signInURL = signInURL
m.signInModel = modelName
m.signInSpinner = 0
m.signInFromModal = fromModal
config.OpenBrowser(signInURL)
return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
return signInTickMsg{}
})
}
func checkSignIn() tea.Msg {
client, err := api.ClientFromEnvironment()
if err != nil {
return signInCheckMsg{signedIn: false}
}
user, err := client.Whoami(context.Background())
if err == nil && user != nil && user.Name != "" {
return signInCheckMsg{signedIn: true, userName: user.Name}
}
return signInCheckMsg{signedIn: false}
}
func (m *model) loadAvailableModels() {
m.availableModels = make(map[string]bool)
client, err := api.ClientFromEnvironment()
if err != nil {
return
}
models, err := client.List(context.Background())
if err != nil {
return
}
cloudDisabled := cloudStatusDisabled(client)
for _, mdl := range models.Models {
if cloudDisabled && mdl.RemoteModel != "" {
continue continue
} }
m.availableModels[mdl.Name] = true items = append(items, integrationMenuItem(integrationState))
} }
return items
} }
func (m *model) buildItems() { func initialCursor(state *launch.LauncherState, items []menuItem) int {
others := getOtherIntegrations() if state == nil || state.LastSelection == "" {
m.items = make([]menuItem, 0, len(mainMenuItems)+1+len(others)) return 0
m.items = append(m.items, mainMenuItems...)
if m.showOthers {
m.items = append(m.items, others...)
} else {
m.items = append(m.items, othersMenuItem)
} }
} for i, item := range items {
if state.LastSelection == "run" && item.isRunModel {
func isOthersIntegration(name string) bool { return i
for _, item := range getOtherIntegrations() { }
if item.integration == name { if item.integration == state.LastSelection {
return true return i
} }
} }
return false return 0
}
func initialModel() model {
m := model{
cursor: 0,
}
m.loadAvailableModels()
lastSelection := config.LastSelection()
if isOthersIntegration(lastSelection) {
m.showOthers = true
}
m.buildItems()
if lastSelection != "" {
for i, item := range m.items {
if lastSelection == "run" && item.isRunModel {
m.cursor = i
break
} else if item.integration == lastSelection {
m.cursor = i
break
}
}
}
return m
} }
func (m model) Init() tea.Cmd { func (m model) Init() tea.Cmd {
@@ -357,143 +175,11 @@ func (m model) Init() tea.Cmd {
} }
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if wmsg, ok := msg.(tea.WindowSizeMsg); ok {
wasSet := m.width > 0
m.width = wmsg.Width
if wasSet {
return m, tea.EnterAltScreen
}
return m, nil
}
if _, ok := msg.(clearStatusMsg); ok {
m.statusMsg = ""
return m, nil
}
if m.showingSignIn {
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc:
m.showingSignIn = false
if m.signInFromModal {
m.showingModal = true
}
return m, nil
}
case signInTickMsg:
m.signInSpinner++
// Check sign-in status every 5th tick (~1 second)
if m.signInSpinner%5 == 0 {
return m, tea.Batch(
tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
return signInTickMsg{}
}),
checkSignIn,
)
}
return m, tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
return signInTickMsg{}
})
case signInCheckMsg:
if msg.signedIn {
if m.signInFromModal {
m.modalSelector.selected = m.signInModel
m.changeModel = true
} else {
m.selected = true
}
m.quitting = true
return m, tea.Quit
}
}
return m, nil
}
if m.showingMultiModal {
switch msg := msg.(type) {
case tea.KeyMsg:
if msg.Type == tea.KeyLeft {
m.showingMultiModal = false
return m, nil
}
updated, cmd := m.multiModalSelector.Update(msg)
m.multiModalSelector = updated.(multiSelectorModel)
if m.multiModalSelector.cancelled {
m.showingMultiModal = false
return m, nil
}
if m.multiModalSelector.confirmed {
var selected []string
if m.multiModalSelector.singleAdd != "" {
// Single-add mode: prepend picked model, keep existing deduped
selected = []string{m.multiModalSelector.singleAdd}
for _, name := range config.IntegrationModels(m.items[m.cursor].integration) {
if name != m.multiModalSelector.singleAdd {
selected = append(selected, name)
}
}
} else {
// Last checked is default (first in result)
co := m.multiModalSelector.checkOrder
last := co[len(co)-1]
selected = []string{m.multiModalSelector.items[last].Name}
for _, idx := range co {
if idx != last {
selected = append(selected, m.multiModalSelector.items[idx].Name)
}
}
}
if len(selected) > 0 {
m.changeModels = selected
m.changeModel = true
m.quitting = true
return m, tea.Quit
}
m.multiModalSelector.confirmed = false
return m, nil
}
return m, cmd
}
return m, nil
}
if m.showingModal {
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc, tea.KeyLeft:
m.showingModal = false
return m, nil
case tea.KeyEnter:
filtered := m.modalSelector.filteredItems()
if len(filtered) > 0 && m.modalSelector.cursor < len(filtered) {
m.modalSelector.selected = filtered[m.modalSelector.cursor].Name
}
if m.modalSelector.selected != "" {
if cmd := m.checkCloudSignIn(m.modalSelector.selected, true); cmd != nil {
return m, cmd
}
m.changeModel = true
m.quitting = true
return m, tea.Quit
}
return m, nil
default:
// Delegate navigation (up/down/pgup/pgdown/filter/backspace) to selectorModel
m.modalSelector.updateNavigation(msg)
}
}
return m, nil
}
switch msg := msg.(type) { switch msg := msg.(type) {
case tea.WindowSizeMsg:
m.width = msg.Width
return m, nil
case tea.KeyMsg: case tea.KeyMsg:
switch msg.String() { switch msg.String() {
case "ctrl+c", "q", "esc": case "ctrl+c", "q", "esc":
@@ -504,162 +190,78 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if m.cursor > 0 { if m.cursor > 0 {
m.cursor-- m.cursor--
} }
// Auto-collapse "Others" when cursor moves back into pinned items
if m.showOthers && m.cursor < len(mainMenuItems) { if m.showOthers && m.cursor < len(mainMenuItems) {
m.showOthers = false m.showOthers = false
m.buildItems() m.items = buildMenuItems(m.state, false)
m.cursor = min(m.cursor, len(m.items)-1)
} }
return m, nil
case "down", "j": case "down", "j":
if m.cursor < len(m.items)-1 { if m.cursor < len(m.items)-1 {
m.cursor++ m.cursor++
} }
// Auto-expand "Others..." when cursor lands on it
if m.cursor < len(m.items) && m.items[m.cursor].isOthers && !m.showOthers { if m.cursor < len(m.items) && m.items[m.cursor].isOthers && !m.showOthers {
m.showOthers = true m.showOthers = true
m.buildItems() m.items = buildMenuItems(m.state, true)
// cursor now points at the first "other" integration
} }
return m, nil
case "enter", " ": case "enter", " ":
item := m.items[m.cursor] if m.selectableItem(m.items[m.cursor]) {
m.selected = true
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) && !config.AutoInstallable(item.integration) { m.action = actionForMenuItem(m.items[m.cursor], false)
return m, nil m.quitting = true
return m, tea.Quit
} }
return m, nil
var configuredModel string
if item.isRunModel {
configuredModel = config.LastModel()
} else if item.integration != "" {
configuredModel = config.IntegrationModel(item.integration)
}
if cmd := m.checkCloudSignIn(configuredModel, false); cmd != nil {
return m, cmd
}
if configuredModel != "" && isCloudModel(configuredModel) && cloudModelDisabled(configuredModel) {
if item.integration != "" && config.IsEditorIntegration(item.integration) {
m.openMultiModelModal(item.integration)
} else {
m.openModelModal(configuredModel)
}
return m, nil
}
m.selected = true
m.quitting = true
return m, tea.Quit
case "right", "l": case "right", "l":
item := m.items[m.cursor] item := m.items[m.cursor]
if item.integration != "" || item.isRunModel { if item.isRunModel || m.changeableItem(item) {
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) { m.selected = true
if config.AutoInstallable(item.integration) { m.action = actionForMenuItem(item, true)
// Auto-installable: select to trigger install flow m.quitting = true
m.selected = true return m, tea.Quit
m.quitting = true
return m, tea.Quit
}
return m, nil
}
if item.integration != "" && config.IsEditorIntegration(item.integration) {
m.openMultiModelModal(item.integration)
} else {
var currentModel string
if item.isRunModel {
currentModel = config.LastModel()
} else if item.integration != "" {
currentModel = config.IntegrationModel(item.integration)
}
m.openModelModal(currentModel)
}
} }
return m, nil
} }
} }
return m, nil return m, nil
} }
func (m model) selectableItem(item menuItem) bool {
if item.isRunModel {
return true
}
if item.integration == "" || item.isOthers {
return false
}
state, ok := m.state.Integrations[item.integration]
return ok && state.Selectable
}
func (m model) changeableItem(item menuItem) bool {
if item.integration == "" || item.isOthers {
return false
}
state, ok := m.state.Integrations[item.integration]
return ok && state.Changeable
}
func (m model) View() string { func (m model) View() string {
if m.quitting { if m.quitting {
return "" return ""
} }
if m.showingSignIn {
return m.renderSignInDialog()
}
if m.showingMultiModal {
return m.multiModalSelector.View()
}
if m.showingModal {
return m.renderModal()
}
s := selectorTitleStyle.Render("Ollama "+versionStyle.Render(version.Version)) + "\n\n" s := selectorTitleStyle.Render("Ollama "+versionStyle.Render(version.Version)) + "\n\n"
for i, item := range m.items { for i, item := range m.items {
cursor := "" s += m.renderMenuItem(i, item)
style := menuItemStyle
isInstalled := true
if item.integration != "" {
isInstalled = config.IsIntegrationInstalled(item.integration)
}
if m.cursor == i {
cursor = "▸ "
if isInstalled {
style = menuSelectedItemStyle
} else {
style = greyedSelectedStyle
}
} else if !isInstalled && item.integration != "" {
style = greyedStyle
}
title := item.title
var modelSuffix string
if item.integration != "" {
if !isInstalled {
if config.AutoInstallable(item.integration) {
title += " " + notInstalledStyle.Render("(install)")
} else {
title += " " + notInstalledStyle.Render("(not installed)")
}
} else if m.cursor == i {
if mdl := config.IntegrationModel(item.integration); mdl != "" && m.modelExists(mdl) {
modelSuffix = " " + modelStyle.Render("("+mdl+")")
}
}
} else if item.isRunModel && m.cursor == i {
if mdl := config.LastModel(); mdl != "" && m.modelExists(mdl) {
modelSuffix = " " + modelStyle.Render("("+mdl+")")
}
}
s += style.Render(cursor+title) + modelSuffix + "\n"
desc := item.description
if !isInstalled && item.integration != "" && m.cursor == i {
if config.AutoInstallable(item.integration) {
desc = "Press enter to install"
} else if hint := config.IntegrationInstallHint(item.integration); hint != "" {
desc = hint
} else {
desc = "not installed"
}
}
s += menuDescStyle.Render(desc) + "\n\n"
} }
if m.statusMsg != "" { s += "\n" + selectorHelpStyle.Render("↑/↓ navigate • enter launch • → configure • esc quit")
s += "\n" + lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "124", Dark: "210"}).Render(m.statusMsg) + "\n"
}
s += "\n" + selectorHelpStyle.Render("↑/↓ navigate • enter launch • → change model • esc quit")
if m.width > 0 { if m.width > 0 {
return lipgloss.NewStyle().MaxWidth(m.width).Render(s) return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
@@ -667,80 +269,125 @@ func (m model) View() string {
return s return s
} }
func (m model) renderModal() string { func (m model) renderMenuItem(index int, item menuItem) string {
modalStyle := lipgloss.NewStyle(). cursor := ""
PaddingBottom(1). style := menuItemStyle
PaddingRight(2) title := item.title
description := item.description
modelSuffix := ""
s := modalStyle.Render(m.modalSelector.renderContent()) if m.cursor == index {
if m.width > 0 { cursor = "▸ "
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
}
return s
}
func (m model) renderSignInDialog() string {
return renderSignIn(m.signInModel, m.signInURL, m.signInSpinner, m.width)
}
type Selection int
const (
SelectionNone Selection = iota
SelectionRunModel
SelectionChangeRunModel
SelectionIntegration // Generic integration selection
SelectionChangeIntegration // Generic change model for integration
)
type Result struct {
Selection Selection
Integration string // integration name if applicable
Model string // model name if selected from single-select modal
Models []string // models selected from multi-select modal (Editor integrations)
}
func Run() (Result, error) {
m := initialModel()
p := tea.NewProgram(m)
finalModel, err := p.Run()
if err != nil {
return Result{Selection: SelectionNone}, fmt.Errorf("error running TUI: %w", err)
}
fm := finalModel.(model)
if fm.err != nil {
return Result{Selection: SelectionNone}, fm.err
}
if !fm.selected && !fm.changeModel {
return Result{Selection: SelectionNone}, nil
}
item := fm.items[fm.cursor]
if fm.changeModel {
if item.isRunModel {
return Result{
Selection: SelectionChangeRunModel,
Model: fm.modalSelector.selected,
}, nil
}
return Result{
Selection: SelectionChangeIntegration,
Integration: item.integration,
Model: fm.modalSelector.selected,
Models: fm.changeModels,
}, nil
} }
if item.isRunModel { if item.isRunModel {
return Result{Selection: SelectionRunModel}, nil if m.cursor == index && m.state.RunModel != "" {
modelSuffix = " " + modelStyle.Render("("+m.state.RunModel+")")
}
if m.cursor == index {
style = menuSelectedItemStyle
}
} else if item.isOthers {
if m.cursor == index {
style = menuSelectedItemStyle
}
} else {
integrationState := m.state.Integrations[item.integration]
if !integrationState.Selectable {
if m.cursor == index {
style = greyedSelectedStyle
} else {
style = greyedStyle
}
} else if m.cursor == index {
style = menuSelectedItemStyle
}
if m.cursor == index && integrationState.CurrentModel != "" {
modelSuffix = " " + modelStyle.Render("("+integrationState.CurrentModel+")")
}
if !integrationState.Installed {
if integrationState.AutoInstallable {
title += " " + notInstalledStyle.Render("(install)")
} else {
title += " " + notInstalledStyle.Render("(not installed)")
}
if m.cursor == index {
if integrationState.AutoInstallable {
description = "Press enter to install"
} else if integrationState.InstallHint != "" {
description = integrationState.InstallHint
} else {
description = "not installed"
}
}
}
} }
return Result{ return style.Render(cursor+title) + modelSuffix + "\n" + menuDescStyle.Render(description) + "\n\n"
Selection: SelectionIntegration, }
Integration: item.integration,
}, nil type TUIActionKind int
const (
TUIActionNone TUIActionKind = iota
TUIActionRunModel
TUIActionLaunchIntegration
)
type TUIAction struct {
Kind TUIActionKind
Integration string
ForceConfigure bool
}
func (a TUIAction) LastSelection() string {
switch a.Kind {
case TUIActionRunModel:
return "run"
case TUIActionLaunchIntegration:
return a.Integration
default:
return ""
}
}
func (a TUIAction) RunModelRequest() launch.RunModelRequest {
return launch.RunModelRequest{ForcePicker: a.ForceConfigure}
}
func (a TUIAction) IntegrationLaunchRequest() launch.IntegrationLaunchRequest {
return launch.IntegrationLaunchRequest{
Name: a.Integration,
ForceConfigure: a.ForceConfigure,
}
}
func actionForMenuItem(item menuItem, forceConfigure bool) TUIAction {
switch {
case item.isRunModel:
return TUIAction{Kind: TUIActionRunModel, ForceConfigure: forceConfigure}
case item.integration != "":
return TUIAction{Kind: TUIActionLaunchIntegration, Integration: item.integration, ForceConfigure: forceConfigure}
default:
return TUIAction{Kind: TUIActionNone}
}
}
func RunMenu(state *launch.LauncherState) (TUIAction, error) {
menu := newModel(state)
program := tea.NewProgram(menu)
finalModel, err := program.Run()
if err != nil {
return TUIAction{Kind: TUIActionNone}, fmt.Errorf("error running TUI: %w", err)
}
finalMenu := finalModel.(model)
if !finalMenu.selected {
return TUIAction{Kind: TUIActionNone}, nil
}
return finalMenu.action, nil
} }

178
cmd/tui/tui_test.go Normal file
View File

@@ -0,0 +1,178 @@
package tui
import (
"strings"
"testing"
tea "github.com/charmbracelet/bubbletea"
"github.com/ollama/ollama/cmd/launch"
)
func launcherTestState() *launch.LauncherState {
return &launch.LauncherState{
LastSelection: "run",
RunModel: "qwen3:8b",
Integrations: map[string]launch.LauncherIntegrationState{
"claude": {
Name: "claude",
DisplayName: "Claude Code",
Description: "Anthropic's coding tool with subagents",
Selectable: true,
Changeable: true,
CurrentModel: "glm-5:cloud",
},
"codex": {
Name: "codex",
DisplayName: "Codex",
Description: "OpenAI's open-source coding agent",
Selectable: true,
Changeable: true,
},
"openclaw": {
Name: "openclaw",
DisplayName: "OpenClaw",
Description: "Personal AI with 100+ skills",
Selectable: true,
Changeable: true,
AutoInstallable: true,
},
"droid": {
Name: "droid",
DisplayName: "Droid",
Description: "Factory's coding agent across terminal and IDEs",
Selectable: true,
Changeable: true,
},
"pi": {
Name: "pi",
DisplayName: "Pi",
Description: "Minimal AI agent toolkit with plugin support",
Selectable: true,
Changeable: true,
},
},
}
}
func TestMenuRendersPinnedItemsAndMore(t *testing.T) {
view := newModel(launcherTestState()).View()
for _, want := range []string{"Chat with a model", "Launch Claude Code", "Launch Codex", "Launch OpenClaw", "More..."} {
if !strings.Contains(view, want) {
t.Fatalf("expected menu view to contain %q\n%s", want, view)
}
}
}
func TestMenuExpandsOthersFromLastSelection(t *testing.T) {
state := launcherTestState()
state.LastSelection = "pi"
menu := newModel(state)
if !menu.showOthers {
t.Fatal("expected others section to expand when last selection is in the overflow list")
}
view := menu.View()
if !strings.Contains(view, "Launch Pi") {
t.Fatalf("expected expanded view to contain overflow integration\n%s", view)
}
if strings.Contains(view, "More...") {
t.Fatalf("expected expanded view to replace More... item\n%s", view)
}
}
func TestMenuEnterOnRunSelectsRun(t *testing.T) {
menu := newModel(launcherTestState())
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
got := updated.(model)
want := TUIAction{Kind: TUIActionRunModel}
if !got.selected || got.action != want {
t.Fatalf("expected enter on run to select run action, got selected=%v action=%v", got.selected, got.action)
}
}
func TestMenuRightOnRunSelectsChangeRun(t *testing.T) {
menu := newModel(launcherTestState())
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
got := updated.(model)
want := TUIAction{Kind: TUIActionRunModel, ForceConfigure: true}
if !got.selected || got.action != want {
t.Fatalf("expected right on run to select change-run action, got selected=%v action=%v", got.selected, got.action)
}
}
func TestMenuEnterOnIntegrationSelectsLaunch(t *testing.T) {
menu := newModel(launcherTestState())
menu.cursor = 1
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
got := updated.(model)
want := TUIAction{Kind: TUIActionLaunchIntegration, Integration: "claude"}
if !got.selected || got.action != want {
t.Fatalf("expected enter on integration to launch, got selected=%v action=%v", got.selected, got.action)
}
}
func TestMenuRightOnIntegrationSelectsConfigure(t *testing.T) {
menu := newModel(launcherTestState())
menu.cursor = 1
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
got := updated.(model)
want := TUIAction{Kind: TUIActionLaunchIntegration, Integration: "claude", ForceConfigure: true}
if !got.selected || got.action != want {
t.Fatalf("expected right on integration to configure, got selected=%v action=%v", got.selected, got.action)
}
}
func TestMenuIgnoresDisabledActions(t *testing.T) {
state := launcherTestState()
claude := state.Integrations["claude"]
claude.Selectable = false
claude.Changeable = false
state.Integrations["claude"] = claude
menu := newModel(state)
menu.cursor = 1
updatedEnter, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
if updatedEnter.(model).selected {
t.Fatal("expected non-selectable integration to ignore enter")
}
updatedRight, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
if updatedRight.(model).selected {
t.Fatal("expected non-changeable integration to ignore right")
}
}
func TestMenuShowsCurrentModelSuffixes(t *testing.T) {
menu := newModel(launcherTestState())
runView := menu.View()
if !strings.Contains(runView, "(qwen3:8b)") {
t.Fatalf("expected run row to show current model suffix\n%s", runView)
}
menu.cursor = 1
integrationView := menu.View()
if !strings.Contains(integrationView, "(glm-5:cloud)") {
t.Fatalf("expected integration row to show current model suffix\n%s", integrationView)
}
}
func TestMenuShowsInstallStatusAndHint(t *testing.T) {
state := launcherTestState()
codex := state.Integrations["codex"]
codex.Installed = false
codex.Selectable = false
codex.Changeable = false
codex.InstallHint = "Install from https://example.com/codex"
state.Integrations["codex"] = codex
menu := newModel(state)
menu.cursor = 2
view := menu.View()
if !strings.Contains(view, "(not installed)") {
t.Fatalf("expected not-installed marker\n%s", view)
}
if !strings.Contains(view, codex.InstallHint) {
t.Fatalf("expected install hint in description\n%s", view)
}
}

View File

@@ -290,6 +290,8 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
conv = &gemma3Model{Architecture: p.Architectures[0]} conv = &gemma3Model{Architecture: p.Architectures[0]}
case "Gemma3nForConditionalGeneration": case "Gemma3nForConditionalGeneration":
conv = &gemma3nModel{} conv = &gemma3nModel{}
case "Gemma4ForCausalLM", "Gemma4ForConditionalGeneration":
conv = &gemma4Model{Architecture: p.Architectures[0]}
case "Phi3ForCausalLM": case "Phi3ForCausalLM":
conv = &phi3Model{} conv = &phi3Model{}
case "Qwen2ForCausalLM": case "Qwen2ForCausalLM":

574
convert/convert_gemma4.go Normal file
View File

@@ -0,0 +1,574 @@
package convert
import (
"bytes"
"encoding/binary"
"fmt"
"math"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type gemma4Model struct {
gemmaModel
Architecture string
TextModel struct {
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
HeadDim uint32 `json:"head_dim"`
GlobalHeadDim uint32 `json:"global_head_dim"`
VocabSize uint32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
SlidingWindow uint32 `json:"sliding_window"`
SlidingWindowPattern *int32 `json:"_sliding_window_pattern"`
LayerTypes []string `json:"layer_types"`
FinalLogitSoftcapping float32 `json:"final_logit_softcapping"`
EnableMoeBlock bool `json:"enable_moe_block"`
NumExperts *uint32 `json:"num_experts"`
TopKExperts *uint32 `json:"top_k_experts"`
ExpertIntermediateSize *uint32 `json:"moe_intermediate_size"`
HiddenSizePerLayerInput *uint32 `json:"hidden_size_per_layer_input"`
NumKVSharedLayers uint32 `json:"num_kv_shared_layers"`
AttentionKEqV bool `json:"attention_k_eq_v"`
NumGlobalKeyValueHeads *uint32 `json:"num_global_key_value_heads"`
QueryPreAttnScalar *uint32 `json:"query_pre_attn_scalar"`
UseDoubleWideMLP bool `json:"use_double_wide_mlp"`
RopeParameters map[string]*struct {
RopeTheta float32 `json:"rope_theta"`
PartialRotaryFactor *float32 `json:"partial_rotary_factor"`
} `json:"rope_parameters"`
} `json:"text_config"`
VisionModel struct {
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
IntermediateSize uint32 `json:"intermediate_size"`
PatchSize uint32 `json:"patch_size"`
NumChannels uint32 `json:"num_channels"`
PoolingKernelSize uint32 `json:"pooling_kernel_size"`
LayerNormEps float32 `json:"layer_norm_eps"`
} `json:"vision_config"`
AudioModel *struct {
HiddenSize uint32 `json:"hidden_size"`
OutputProjDims uint32 `json:"output_proj_dims"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
ConvKernelSize uint32 `json:"conv_kernel_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
} `json:"audio_config"`
}
func (p *gemma4Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma4"
kv["tokenizer.ggml.model"] = "llama"
kv["tokenizer.ggml.pre"] = "gemma4"
tc := p.TextModel
kv["gemma4.block_count"] = tc.NumHiddenLayers
kv["gemma4.embedding_length"] = tc.HiddenSize
// Per-layer FFN width: when use_double_wide_mlp is set, KV-shared layers get 2x FFN width.
if tc.UseDoubleWideMLP && tc.NumKVSharedLayers > 0 {
firstShared := int(tc.NumHiddenLayers) - int(tc.NumKVSharedLayers)
ffnWidths := make([]int32, tc.NumHiddenLayers)
for i := range ffnWidths {
if i >= firstShared {
ffnWidths[i] = int32(tc.IntermediateSize * 2)
} else {
ffnWidths[i] = int32(tc.IntermediateSize)
}
}
kv["gemma4.feed_forward_length"] = ffnWidths
} else {
kv["gemma4.feed_forward_length"] = tc.IntermediateSize
}
kv["gemma4.context_length"] = tc.MaxPositionEmbeddings
kv["gemma4.attention.head_count"] = tc.NumAttentionHeads
// Per-layer KV head count array: SWA layers use NumKeyValueHeads, global layers use NumGlobalKeyValueHeads
if tc.NumGlobalKeyValueHeads != nil && *tc.NumGlobalKeyValueHeads != tc.NumKeyValueHeads && len(tc.LayerTypes) > 0 {
kvHeads := make([]int32, len(tc.LayerTypes))
for i, lt := range tc.LayerTypes {
if lt == "sliding_attention" {
kvHeads[i] = int32(tc.NumKeyValueHeads)
} else {
kvHeads[i] = int32(*tc.NumGlobalKeyValueHeads)
}
}
kv["gemma4.attention.head_count_kv"] = kvHeads
} else {
kv["gemma4.attention.head_count_kv"] = tc.NumKeyValueHeads
}
// key_length = global head dim, key_length_swa = local (SWA) head dim
kv["gemma4.attention.key_length"] = tc.GlobalHeadDim
kv["gemma4.attention.value_length"] = tc.GlobalHeadDim
kv["gemma4.attention.key_length_swa"] = tc.HeadDim
kv["gemma4.attention.value_length_swa"] = tc.HeadDim
kv["gemma4.attention.layer_norm_rms_epsilon"] = tc.RMSNormEps
kv["gemma4.attention.sliding_window"] = tc.SlidingWindow
// Sliding window pattern from layer_types
if len(tc.LayerTypes) > 0 {
kv["gemma4.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
for _, lt := range tc.LayerTypes {
if !yield(lt == "sliding_attention") {
break
}
}
})
}
kv["gemma4.attention.shared_kv_layers"] = tc.NumKVSharedLayers
// RoPE: dimension_count is the full global head dim (freq_factors handle partial rotation)
if rp, ok := tc.RopeParameters["full_attention"]; ok && rp != nil {
kv["gemma4.rope.freq_base"] = rp.RopeTheta
kv["gemma4.rope.dimension_count"] = tc.GlobalHeadDim
}
if rp, ok := tc.RopeParameters["sliding_attention"]; ok && rp != nil {
kv["gemma4.rope.freq_base_swa"] = rp.RopeTheta
kv["gemma4.rope.dimension_count_swa"] = tc.HeadDim
}
if tc.FinalLogitSoftcapping > 0 {
kv["gemma4.final_logit_softcapping"] = tc.FinalLogitSoftcapping
}
// MoE
if tc.EnableMoeBlock && tc.NumExperts != nil {
kv["gemma4.expert_count"] = *tc.NumExperts
if tc.TopKExperts != nil {
kv["gemma4.expert_used_count"] = *tc.TopKExperts
}
if tc.ExpertIntermediateSize != nil {
kv["gemma4.expert_feed_forward_length"] = *tc.ExpertIntermediateSize
}
}
// PLE — always emit, even when 0
pleSize := uint32(0)
if tc.HiddenSizePerLayerInput != nil {
pleSize = *tc.HiddenSizePerLayerInput
}
kv["gemma4.embedding_length_per_layer_input"] = pleSize
// Vision model KV metadata
vc := p.VisionModel
if vc.NumHiddenLayers > 0 {
kv["gemma4.vision.block_count"] = vc.NumHiddenLayers
kv["gemma4.vision.embedding_length"] = vc.HiddenSize
kv["gemma4.vision.attention.head_count"] = vc.NumAttentionHeads
kv["gemma4.vision.feed_forward_length"] = vc.IntermediateSize
kv["gemma4.vision.patch_size"] = vc.PatchSize
numCh := vc.NumChannels
if numCh == 0 {
numCh = 3
}
kv["gemma4.vision.num_channels"] = numCh
nMerge := vc.PoolingKernelSize
if nMerge == 0 {
nMerge = 3
}
kv["gemma4.vision.projector.scale_factor"] = nMerge
eps := vc.LayerNormEps
if eps == 0 {
eps = 1e-6
}
kv["gemma4.vision.attention.layer_norm_epsilon"] = eps
}
// Audio model KV metadata
if p.AudioModel != nil && p.AudioModel.NumHiddenLayers > 0 {
ac := p.AudioModel
kv["gemma4.audio.block_count"] = ac.NumHiddenLayers
kv["gemma4.audio.embedding_length"] = ac.HiddenSize
kv["gemma4.audio.feed_forward_length"] = ac.HiddenSize * 4
kv["gemma4.audio.attention.head_count"] = ac.NumAttentionHeads
eps := ac.RMSNormEps
if eps == 0 {
eps = 1e-6
}
kv["gemma4.audio.attention.layer_norm_epsilon"] = eps
if ac.ConvKernelSize > 0 {
kv["gemma4.audio.conv_kernel_size"] = ac.ConvKernelSize
}
}
return kv
}
func (p *gemma4Model) Tensors(ts []Tensor) []*ggml.Tensor {
// First pass: collect vision clamp scalar values into a packed tensor.
// Layout: per vision layer (0..N-1), 7 linears (q,k,v,out,gate,up,down) × 4 values (inMin,inMax,outMin,outMax).
// Then 4 values for the projector (mm.input_projection).
clampSuffixes := []string{".input_min", ".input_max", ".output_min", ".output_max"}
clampMap := make(map[string]float32)
for _, t := range ts {
name := t.Name()
for _, sfx := range clampSuffixes {
if strings.HasSuffix(name, sfx) && (strings.Contains(name, "vision_tower") || strings.Contains(name, "embed_vision")) {
var buf bytes.Buffer
t.WriteTo(&buf)
data := buf.Bytes()
if len(data) >= 4 {
clampMap[name] = math.Float32frombits(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16 | uint32(data[3])<<24)
}
}
}
}
var out []*ggml.Tensor
for _, t := range ts {
name := t.Name()
// Skip embedding_post_projection_norm — used as weightless RMS norm in inference
if strings.Contains(name, "embedding_post_projection_norm") {
continue
}
// Vision tensor renaming: match published mmproj GGUF names
if strings.HasPrefix(name, "v.blk.") {
name = strings.Replace(name, ".attn_norm.", ".ln1.", 1)
name = strings.Replace(name, ".ffn_norm.", ".ln2.", 1)
name = strings.Replace(name, ".attn_output.", ".attn_out.", 1)
name = strings.Replace(name, ".post_attention_norm.", ".attn_post_norm.", 1)
name = strings.Replace(name, ".post_ffw_norm.", ".ffn_post_norm.", 1)
name = strings.Replace(name, ".layer_output_scale.", ".out_scale.", 1)
}
// per_dim_scale: apply softplus to weight data and add .weight suffix.
if strings.HasPrefix(name, "a.blk.") && strings.HasSuffix(name, "per_dim_scale") {
name = name + ".weight"
t.SetRepacker(softplusRepacker)
}
// Depthwise conv1d: squeeze middle dimension [C, 1, K] → [C, K].
if strings.HasPrefix(name, "a.blk.") && strings.Contains(name, "conv_dw") && strings.HasSuffix(name, ".weight") {
t.SetRepacker(squeezeMiddleDim)
}
shape := t.Shape()
// Convert scalar tensors (input_min/max, output_min/max) to 1D
if len(shape) == 0 {
shape = []uint64{1}
}
// Depthwise conv1d shape: safetensors [C, 1, K] → GGUF ne[K, C].
// Shape array here maps to GGUF ne[] directly, but safetensors reader
// stores shape in PyTorch order [C, 1, K] which the GGUF writer inverts.
// Published GGUF has ne[0]=K, ne[1]=C → shape array must be [K, C].
if strings.HasPrefix(name, "a.blk.") && strings.Contains(name, "conv_dw") && strings.HasSuffix(name, ".weight") && len(shape) == 3 {
shape = []uint64{shape[0], shape[2]}
}
// MoE expert weights: no transpose needed. Safetensors stores [experts, out, in]
// which the framework reverses to GGUF ne=[in, out, experts], matching ggml_mul_mat_id.
// (transposeExperts was incorrectly swapping dims — removed)
// Audio conv weights are forced to F32 via tensorBase.Kind() in reader.go
// (im2col doesn't support BF16). No kindOverride needed — the Kind() method
// controls both the GGUF header type AND the WriteTo data encoding path.
var kindOverride *uint32
// Vision patch embedding: reshape from [n_embd, ksize_sq_c] to [n_embd, 3, patch_size, patch_size]
// Must be stored as F16 (not BF16) because the Conv2D im2col kernel requires F16/F32.
if strings.Contains(name, "v.patch_embd.weight") && len(shape) == 2 {
nEmbd := shape[0]
patchSize := uint64(p.VisionModel.PatchSize)
if patchSize == 0 {
patchSize = 16
}
numCh := uint64(p.VisionModel.NumChannels)
if numCh == 0 {
numCh = 3
}
t.SetRepacker(p.reshapePatchEmbed)
shape = []uint64{nEmbd, numCh, patchSize, patchSize}
f16Kind := uint32(1) // tensorKindFP16
kindOverride = &f16Kind
}
// Vision position embedding: keep 3D [2, maxPos, nEmbd] — matching published mmproj format.
// The framework reverses shape to GGUF ne=[nEmbd, maxPos, 2]. No data repacking needed.
kind := t.Kind()
if kindOverride != nil {
kind = *kindOverride
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: kind,
Shape: shape,
WriterTo: t,
})
}
// Generate a single global rope_freqs.weight for proportional RoPE on global attention layers.
// This matches the published GGUF format: one global tensor shared by all layers.
// Global layers use partial_rotary_factor (0.25) — only rotate that fraction of dims.
// Dimensions beyond the rotated portion get freq_factor=1e30 (effectively no rotation).
tc := p.TextModel
if tc.GlobalHeadDim > 0 {
globalFreqsSize := tc.GlobalHeadDim / 2 // freq_factors are per dimension pair
// Compute number of rotated pairs for global layers
partialRotaryFactor := float32(0.25) // default
if rp, ok := tc.RopeParameters["full_attention"]; ok && rp != nil && rp.PartialRotaryFactor != nil {
partialRotaryFactor = *rp.PartialRotaryFactor
}
nRotFull := int(float32(tc.GlobalHeadDim) * partialRotaryFactor / 2)
freqs := make(ropeFactor, globalFreqsSize)
for j := range freqs {
if j < nRotFull {
freqs[j] = 1.0
} else {
freqs[j] = 1e30 // effectively disable rotation
}
}
out = append(out, &ggml.Tensor{
Name: "rope_freqs.weight",
Kind: 0, // F32
Shape: []uint64{uint64(len(freqs))},
WriterTo: freqs,
})
}
// Emit packed vision clamp data as a single F32 tensor.
// Layout: numLayers × 7 linears (q,k,v,out,gate,up,down) × 4 floats (inMin,inMax,outMin,outMax)
// then 4 floats for the projector. Total = (numLayers*7 + 1) * 4 floats.
if len(clampMap) > 0 {
numLayers := int(p.VisionModel.NumHiddenLayers)
linearNames := []string{"attn_q", "attn_k", "attn_v", "attn_out", "ffn_gate", "ffn_up", "ffn_down"}
suffixes := []string{".input_min", ".input_max", ".output_min", ".output_max"}
totalFloats := (numLayers*len(linearNames) + 1) * 4 // +1 for projector
clampData := make([]float32, totalFloats)
for layer := range numLayers {
for li, ln := range linearNames {
for si, sfx := range suffixes {
sfxMap := map[string]string{"attn_q": "q_proj", "attn_k": "k_proj", "attn_v": "v_proj", "attn_out": "o_proj", "ffn_gate": "gate_proj", "ffn_up": "up_proj", "ffn_down": "down_proj"}
for origName, val := range clampMap {
if strings.Contains(origName, fmt.Sprintf("layers.%d.", layer)) && strings.HasSuffix(origName, sfx) && strings.Contains(origName, sfxMap[ln]) {
idx := (layer*len(linearNames)+li)*4 + si
clampData[idx] = val
break
}
}
}
}
}
// Projector clamp values
projIdx := numLayers * len(linearNames) * 4
for si, sfx := range suffixes {
for origName, val := range clampMap {
if strings.Contains(origName, "input_projection") && strings.HasSuffix(origName, sfx) {
clampData[projIdx+si] = val
break
}
}
}
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, clampData)
out = append(out, &ggml.Tensor{
Name: "v.clamp_data",
Kind: 0, // F32
Shape: []uint64{uint64(totalFloats)},
WriterTo: &buf,
})
}
return out
}
// reshapePatchEmbed reshapes the vision patch embedding from HF layout [n_embd, ksize*ksize*channels]
// to GGUF layout [n_embd, channels, patch_size, patch_size].
func (*gemma4Model) reshapePatchEmbed(_ string, data []float32, shape []uint64) ([]float32, error) {
if len(shape) != 2 {
return data, nil
}
nEmbd := int(shape[0])
ksqC := int(shape[1])
nChannels := 3
patchSize := int(math.Sqrt(float64(ksqC / nChannels)))
// HF layout: [n_embd, patch_size * patch_size * channels] (row-major)
// Need: [n_embd, channels, patch_size, patch_size]
result := make([]float32, len(data))
for e := range nEmbd {
for c := range nChannels {
for h := range patchSize {
for w := range patchSize {
srcIdx := e*ksqC + h*patchSize*nChannels + w*nChannels + c
dstIdx := e*nChannels*patchSize*patchSize + c*patchSize*patchSize + h*patchSize + w
result[dstIdx] = data[srcIdx]
}
}
}
}
shape[0] = uint64(nEmbd)
shape[1] = uint64(nChannels * patchSize * patchSize)
return result, nil
}
// softplusRepacker applies softplus (ln(1 + exp(x))) to tensor data.
// Used for per_dim_scale tensors which the published GGUF stores pre-activated.
func softplusRepacker(_ string, data []float32, shape []uint64) ([]float32, error) {
result := make([]float32, len(data))
for i, x := range data {
result[i] = float32(math.Log(1 + math.Exp(float64(x))))
}
return result, nil
}
// squeezeMiddleDim squeezes the middle dimension from [C, 1, K] → [C, K] for depthwise conv1d weights.
// Data layout stays the same since the middle dim is 1 — just a shape change.
func squeezeMiddleDim(_ string, data []float32, _ []uint64) ([]float32, error) {
return data, nil
}
func (p *gemma4Model) Replacements() []string {
return []string{
// ClippableLinear wraps nn.Linear — strip .linear. from weight path
".linear.weight", ".weight",
".linear.bias", ".bias",
// Audio SSCP (Sub-Sample Convolution Projection)
"model.audio_tower.subsample_conv_projection.conv_0.conv", "a.conv1d.0",
"model.audio_tower.subsample_conv_projection.conv_0.norm", "a.conv1d.0.norm",
"model.audio_tower.subsample_conv_projection.conv_1.conv", "a.conv1d.1",
"model.audio_tower.subsample_conv_projection.conv_1.norm", "a.conv1d.1.norm",
"model.audio_tower.subsample_conv_projection.layer0.conv", "a.conv1d.0",
"model.audio_tower.subsample_conv_projection.layer0.norm", "a.conv1d.0.norm",
"model.audio_tower.subsample_conv_projection.layer1.conv", "a.conv1d.1",
"model.audio_tower.subsample_conv_projection.layer1.norm", "a.conv1d.1.norm",
"model.audio_tower.subsample_conv_projection.input_proj_linear", "a.pre_encode.out",
// Audio conformer blocks
"model.audio_tower.conformer", "a.blk",
"model.audio_tower.layers", "a.blk",
// Audio conformer attention
"attention.attn.relative_position_embedding.pos_proj", "linear_pos",
"self_attn.relative_k_proj", "linear_pos",
"attention.attn.per_dim_key_scale", "per_dim_k_scale",
"attention.attn.per_dim_scale", "per_dim_scale",
"self_attn.per_dim_scale", "per_dim_scale",
"attention.attn.q_proj", "attn_q",
"attention.attn.k_proj", "attn_k",
"attention.attn.v_proj", "attn_v",
"attention.pre_attn_norm", "ln1",
"attention.post_norm", "ln2",
"attention.post", "attn_out",
"self_attn.post", "attn_out",
"norm_pre_attn", "ln1",
"norm_post_attn", "ln2",
// Audio conformer feedforward
"ffw_layer_start.pre_layer_norm", "ffn_norm",
"ffw_layer_start.post_layer_norm", "ffn_post_norm",
"ffw_layer_start.ffw_layer_1", "ffn_up",
"ffw_layer_start.ffw_layer_2", "ffn_down",
"ffw_layer_end.pre_layer_norm", "ffn_norm_1",
"ffw_layer_end.post_layer_norm", "ffn_post_norm_1",
"ffw_layer_end.ffw_layer_1", "ffn_up_1",
"ffw_layer_end.ffw_layer_2", "ffn_down_1",
"feed_forward1.pre_layer_norm", "ffn_norm",
"feed_forward1.post_layer_norm", "ffn_post_norm",
"feed_forward1.ffw_layer_1", "ffn_up",
"feed_forward1.ffw_layer_2", "ffn_down",
"feed_forward2.pre_layer_norm", "ffn_norm_1",
"feed_forward2.post_layer_norm", "ffn_post_norm_1",
"feed_forward2.ffw_layer_1", "ffn_up_1",
"feed_forward2.ffw_layer_2", "ffn_down_1",
// Audio conformer lightweight conv1d
"lconv1d.depthwise_conv1d", "conv_dw",
"lconv1d.pre_layer_norm", "conv_norm",
"lconv1d.conv_norm", "norm_conv",
"lconv1d.linear_start", "conv_pw1",
"lconv1d.linear_end", "conv_pw2",
// Audio block final norm
"norm_out", "layer_pre_norm",
// Audio embedder and output projection
"model.embed_audio.embedding_projection", "mm.a.input_projection",
"model.audio_tower.output_proj", "mm.a.fc",
// Vision encoder
"model.vision_tower.encoder.layers", "v.blk",
"model.vision_tower.patch_embedder.input_proj", "v.patch_embd",
"model.vision_tower.patch_embedder.position_embedding_table", "v.position_embd.weight",
"model.vision_tower.std_bias", "v.std_bias",
"model.vision_tower.std_scale", "v.std_scale",
// Vision multimodal projector
"model.embed_vision.embedding_projection", "mm.input_projection",
// Text model
"model.language_model.embed_tokens_per_layer", "per_layer_token_embd",
"model.language_model.embed_tokens", "token_embd",
"model.language_model.per_layer_model_projection", "per_layer_model_proj",
"model.language_model.per_layer_projection_norm", "per_layer_proj_norm",
"model.language_model.norm", "output_norm",
"model.language_model.layers", "blk",
// Shared attention replacements (work for both text and vision tensors)
"input_layernorm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.q_norm", "attn_q_norm",
"self_attn.k_proj", "attn_k",
"self_attn.k_norm", "attn_k_norm",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
// Post norms
"post_attention_layernorm", "post_attention_norm",
"pre_feedforward_layernorm_2", "pre_ffw_norm_2",
"pre_feedforward_layernorm", "ffn_norm",
"post_feedforward_layernorm_1", "post_ffw_norm_1",
"post_feedforward_layernorm_2", "post_ffw_norm_2",
"post_feedforward_layernorm", "post_ffw_norm",
// PLE
"per_layer_input_gate", "inp_gate",
"per_layer_projection", "proj",
"post_per_layer_input_norm", "post_norm",
// MoE
"router.proj", "ffn_gate_inp",
"router.scale", "ffn_gate_inp.scale",
"router.per_expert_scale.weight", "ffn_down_exps.scale",
"router.per_expert_scale", "ffn_down_exps.scale",
"experts.gate_up_proj.weight", "ffn_gate_up_exps.weight",
"experts.gate_up_proj", "ffn_gate_up_exps.weight",
"experts.down_proj.weight", "ffn_down_exps.weight",
"experts.down_proj", "ffn_down_exps.weight",
"moe.gate_proj", "ffn_gate_exps.weight",
"moe.up_proj", "ffn_up_exps.weight",
"moe.gate_up_proj.weight", "ffn_gate_up_exps.weight",
"moe.gate_up_proj", "ffn_gate_up_exps.weight",
"moe.down_proj", "ffn_down_exps.weight",
"moe.per_expert_scale.weight", "ffn_down_exps.scale",
"moe.per_expert_scale", "ffn_down_exps.scale",
// Layer scalar
"layer_scalar", "layer_output_scale.weight",
}
}

View File

@@ -0,0 +1,318 @@
package convert
import (
"strings"
"testing"
)
func TestGemma4AudioReplacements(t *testing.T) {
p := gemma4Model{}
r := strings.NewReplacer(p.Replacements()...)
tests := []struct {
name string
in string
want string
}{
// SSCP convolution blocks
{
"sscp conv0 weight",
"model.audio_tower.subsample_conv_projection.conv_0.conv.weight",
"a.conv1d.0.weight",
},
{
"sscp conv0 norm",
"model.audio_tower.subsample_conv_projection.conv_0.norm.weight",
"a.conv1d.0.norm.weight",
},
{
"sscp conv1 weight",
"model.audio_tower.subsample_conv_projection.conv_1.conv.weight",
"a.conv1d.1.weight",
},
{
"sscp input proj weight",
"model.audio_tower.subsample_conv_projection.input_proj_linear.weight",
"a.pre_encode.out.weight",
},
{
"sscp input proj bias",
"model.audio_tower.subsample_conv_projection.input_proj_linear.bias",
"a.pre_encode.out.bias",
},
{
"sscp layer0 conv weight (new naming)",
"model.audio_tower.subsample_conv_projection.layer0.conv.weight",
"a.conv1d.0.weight",
},
{
"sscp layer1 norm weight (new naming)",
"model.audio_tower.subsample_conv_projection.layer1.norm.weight",
"a.conv1d.1.norm.weight",
},
// Conformer attention
{
"attn q weight",
"model.audio_tower.conformer.0.attention.attn.q_proj.linear.weight",
"a.blk.0.attn_q.weight",
},
{
"attn k weight",
"model.audio_tower.conformer.5.attention.attn.k_proj.linear.weight",
"a.blk.5.attn_k.weight",
},
{
"attn v clamp input_min",
"model.audio_tower.conformer.0.attention.attn.v_proj.input_min",
"a.blk.0.attn_v.input_min",
},
{
"attn out weight (ClippableLinear)",
"model.audio_tower.conformer.0.attention.post.linear.weight",
"a.blk.0.attn_out.weight",
},
{
"attn out clamp output_max",
"model.audio_tower.conformer.0.attention.post.output_max",
"a.blk.0.attn_out.output_max",
},
{
"attn pre norm",
"model.audio_tower.conformer.0.attention.pre_attn_norm.weight",
"a.blk.0.ln1.weight",
},
{
"attn post norm",
"model.audio_tower.conformer.0.attention.post_norm.weight",
"a.blk.0.ln2.weight",
},
{
"linear pos",
"model.audio_tower.conformer.0.attention.attn.relative_position_embedding.pos_proj.weight",
"a.blk.0.linear_pos.weight",
},
{
"per dim scale",
"model.audio_tower.conformer.0.attention.attn.per_dim_scale",
"a.blk.0.per_dim_scale",
},
{
"per dim key scale",
"model.audio_tower.conformer.0.attention.attn.per_dim_key_scale",
"a.blk.0.per_dim_k_scale",
},
{
"attn relative k proj (new naming)",
"model.audio_tower.layers.0.self_attn.relative_k_proj.weight",
"a.blk.0.linear_pos.weight",
},
{
"attn pre norm (new naming)",
"model.audio_tower.layers.0.norm_pre_attn.weight",
"a.blk.0.ln1.weight",
},
{
"attn post norm (new naming)",
"model.audio_tower.layers.0.norm_post_attn.weight",
"a.blk.0.ln2.weight",
},
{
"attn out clamp output_max (new naming)",
"model.audio_tower.layers.0.self_attn.post.output_max",
"a.blk.0.attn_out.output_max",
},
{
"per dim scale (new naming)",
"model.audio_tower.layers.0.self_attn.per_dim_scale",
"a.blk.0.per_dim_scale",
},
// Conformer feedforward start
{
"ffn up weight",
"model.audio_tower.conformer.0.ffw_layer_start.ffw_layer_1.linear.weight",
"a.blk.0.ffn_up.weight",
},
{
"ffn down weight",
"model.audio_tower.conformer.0.ffw_layer_start.ffw_layer_2.linear.weight",
"a.blk.0.ffn_down.weight",
},
{
"ffn norm",
"model.audio_tower.conformer.0.ffw_layer_start.pre_layer_norm.weight",
"a.blk.0.ffn_norm.weight",
},
{
"ffn post norm",
"model.audio_tower.conformer.0.ffw_layer_start.post_layer_norm.weight",
"a.blk.0.ffn_post_norm.weight",
},
// Conformer feedforward end
{
"ffn up 1 weight",
"model.audio_tower.conformer.0.ffw_layer_end.ffw_layer_1.linear.weight",
"a.blk.0.ffn_up_1.weight",
},
{
"ffn down 1 weight",
"model.audio_tower.conformer.0.ffw_layer_end.ffw_layer_2.linear.weight",
"a.blk.0.ffn_down_1.weight",
},
{
"ffn norm 1",
"model.audio_tower.conformer.0.ffw_layer_end.pre_layer_norm.weight",
"a.blk.0.ffn_norm_1.weight",
},
{
"ffn post norm 1",
"model.audio_tower.conformer.0.ffw_layer_end.post_layer_norm.weight",
"a.blk.0.ffn_post_norm_1.weight",
},
{
"ffn up output_max (new naming)",
"model.audio_tower.layers.10.feed_forward1.ffw_layer_1.output_max",
"a.blk.10.ffn_up.output_max",
},
{
"ffn down output_min (new naming)",
"model.audio_tower.layers.0.feed_forward1.ffw_layer_2.output_min",
"a.blk.0.ffn_down.output_min",
},
{
"ffn up 1 input_max (new naming)",
"model.audio_tower.layers.0.feed_forward2.ffw_layer_1.input_max",
"a.blk.0.ffn_up_1.input_max",
},
{
"ffn norm 1 (new naming)",
"model.audio_tower.layers.0.feed_forward2.pre_layer_norm.weight",
"a.blk.0.ffn_norm_1.weight",
},
// Conformer lightweight conv1d
{
"conv dw weight",
"model.audio_tower.conformer.0.lconv1d.depthwise_conv1d.weight",
"a.blk.0.conv_dw.weight",
},
{
"conv norm (pre_layer_norm)",
"model.audio_tower.conformer.0.lconv1d.pre_layer_norm.weight",
"a.blk.0.conv_norm.weight",
},
{
"norm conv (conv_norm)",
"model.audio_tower.conformer.0.lconv1d.conv_norm.weight",
"a.blk.0.norm_conv.weight",
},
{
"conv pw1 weight",
"model.audio_tower.conformer.0.lconv1d.linear_start.linear.weight",
"a.blk.0.conv_pw1.weight",
},
{
"conv pw2 weight",
"model.audio_tower.conformer.0.lconv1d.linear_end.linear.weight",
"a.blk.0.conv_pw2.weight",
},
// Audio embedder
{
"audio embedder projection weight",
"model.embed_audio.embedding_projection.linear.weight",
"mm.a.input_projection.weight",
},
{
"audio embedder projection bias",
"model.embed_audio.embedding_projection.linear.bias",
"mm.a.input_projection.bias",
},
// Audio output projection
{
"audio output proj weight",
"model.audio_tower.output_proj.weight",
"mm.a.fc.weight",
},
{
"audio output proj bias",
"model.audio_tower.output_proj.bias",
"mm.a.fc.bias",
},
// Verify vision tensors still work
{
"vision q weight",
"model.vision_tower.encoder.layers.0.self_attn.q_proj.linear.weight",
"v.blk.0.attn_q.weight",
},
{
"vision std bias",
"model.vision_tower.std_bias",
"v.std_bias",
},
{
"vision std scale",
"model.vision_tower.std_scale",
"v.std_scale",
},
{
"vision patch embd",
"model.vision_tower.patch_embedder.input_proj.weight",
"v.patch_embd.weight",
},
{
"vision projector",
"model.embed_vision.embedding_projection.linear.weight",
"mm.input_projection.weight",
},
// Verify text tensors still work
{
"text attn q",
"model.language_model.layers.0.self_attn.q_proj.weight",
"blk.0.attn_q.weight",
},
{
"text token embd",
"model.language_model.embed_tokens.weight",
"token_embd.weight",
},
{
"text moe gate up fused",
"model.language_model.layers.0.experts.gate_up_proj",
"blk.0.ffn_gate_up_exps.weight",
},
{
"text moe down",
"model.language_model.layers.0.experts.down_proj",
"blk.0.ffn_down_exps.weight",
},
{
"text moe down with weight suffix",
"model.language_model.layers.0.experts.down_proj.weight",
"blk.0.ffn_down_exps.weight",
},
{
"text moe per expert scale",
"model.language_model.layers.0.router.per_expert_scale",
"blk.0.ffn_down_exps.scale",
},
{
"text moe per expert scale with weight suffix",
"model.language_model.layers.0.router.per_expert_scale.weight",
"blk.0.ffn_down_exps.scale",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := r.Replace(tt.in); got != tt.want {
t.Errorf("Replace(%q) = %q, want %q", tt.in, got, tt.want)
}
})
}
}

View File

@@ -205,8 +205,8 @@ func TestConvertInvalidDatatype(t *testing.T) {
generateSafetensorTestData(t, tempDir, td) generateSafetensorTestData(t, tempDir, td)
err = ConvertModel(os.DirFS(tempDir), f) err = ConvertModel(os.DirFS(tempDir), f)
if err == nil || err.Error() != "unsupported safetensors model" { if err == nil || !strings.Contains(err.Error(), "unknown data type") {
t.Errorf("expected error but didn't get one") t.Errorf("expected 'unknown data type' error but got: %v", err)
} }
} }

View File

@@ -42,8 +42,11 @@ func (t tensorBase) Kind() uint32 {
strings.HasSuffix(t.name, ".bias") || strings.HasSuffix(t.name, ".bias") ||
strings.HasSuffix(t.name, ".shortconv.conv.weight") || strings.HasSuffix(t.name, ".shortconv.conv.weight") ||
strings.HasSuffix(t.name, ".ssm_conv1d.weight") || // SSM conv kernel must be F32 for Metal strings.HasSuffix(t.name, ".ssm_conv1d.weight") || // SSM conv kernel must be F32 for Metal
strings.HasPrefix(t.name, "a.conv1d.") || // audio SSCP conv weights must be F32 for im2col
strings.Contains(t.name, ".conv_dw.") || // audio depthwise conv weights must be F32
t.name == "token_types.weight" || t.name == "token_types.weight" ||
t.name == "v.positional_embedding_vlm" || t.name == "v.positional_embedding_vlm" ||
t.name == "v.position_embd.weight" ||
t.name == "v.tile_position_embd.weight" || t.name == "v.tile_position_embd.weight" ||
t.name == "v.pre_tile_position_embd.weight" || t.name == "v.pre_tile_position_embd.weight" ||
t.name == "v.post_tile_position_embd.weight" || t.name == "v.post_tile_position_embd.weight" ||

View File

@@ -5,7 +5,6 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"io/fs" "io/fs"
@@ -53,9 +52,10 @@ func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]T
for _, key := range keys { for _, key := range keys {
if value := headers[key]; value.Type != "" { if value := headers[key]; value.Type != "" {
// bitsandbytes quantized models are unsupported // Scalar tensors (e.g. clipped linear min/max) are 0-dim in safetensors.
// Promote them to 1-dim so they can be stored in GGUF.
if len(value.Shape) == 0 { if len(value.Shape) == 0 {
return nil, errors.New("unsupported safetensors model") value.Shape = []uint64{1}
} }
ggufName := replacer.Replace(key) ggufName := replacer.Replace(key)
if _, ok := names[ggufName]; ok { if _, ok := names[ggufName]; ok {

View File

@@ -12,7 +12,6 @@ To use Ollama with tools that expect the Anthropic API (like Claude Code), set t
```shell ```shell
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
export ANTHROPIC_API_KEY="" # required but ignored
export ANTHROPIC_BASE_URL=http://localhost:11434 export ANTHROPIC_BASE_URL=http://localhost:11434
``` ```
@@ -269,7 +268,7 @@ ollama launch claude --config
Set the environment variables and run Claude Code: Set the environment variables and run Claude Code:
```shell ```shell
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 claude --model qwen3-coder
``` ```
Or set the environment variables in your shell profile: Or set the environment variables in your shell profile:
@@ -277,7 +276,6 @@ Or set the environment variables in your shell profile:
```shell ```shell
export ANTHROPIC_AUTH_TOKEN=ollama export ANTHROPIC_AUTH_TOKEN=ollama
export ANTHROPIC_BASE_URL=http://localhost:11434 export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=""
``` ```
Then run Claude Code with any Ollama model: Then run Claude Code with any Ollama model:

View File

@@ -6,7 +6,7 @@ Ollama provides compatibility with parts of the [OpenAI API](https://platform.op
## Usage ## Usage
### Simple `v1/chat/completions` example ### Simple `/v1/chat/completions` example
<CodeGroup dropdown> <CodeGroup dropdown>
@@ -57,7 +57,7 @@ curl -X POST http://localhost:11434/v1/chat/completions \
</CodeGroup> </CodeGroup>
### Simple `v1/responses` example ### Simple `/v1/responses` example
<CodeGroup dropdown> <CodeGroup dropdown>
@@ -103,7 +103,7 @@ curl -X POST http://localhost:11434/v1/responses \
</CodeGroup> </CodeGroup>
### v1/chat/completions with vision example ### `/v1/chat/completions` with vision example
<CodeGroup dropdown> <CodeGroup dropdown>
@@ -184,6 +184,7 @@ curl -X POST http://localhost:11434/v1/chat/completions \
- [x] Reproducible outputs - [x] Reproducible outputs
- [x] Vision - [x] Vision
- [x] Tools - [x] Tools
- [x] Reasoning/thinking control (for thinking models)
- [ ] Logprobs - [ ] Logprobs
#### Supported request fields #### Supported request fields
@@ -207,6 +208,9 @@ curl -X POST http://localhost:11434/v1/chat/completions \
- [x] `top_p` - [x] `top_p`
- [x] `max_tokens` - [x] `max_tokens`
- [x] `tools` - [x] `tools`
- [x] `reasoning_effort` (`"high"`, `"medium"`, `"low"`, `"none"`)
- [x] `reasoning`
- [x] `effort` (`"high"`, `"medium"`, `"low"`, `"none"`)
- [ ] `tool_choice` - [ ] `tool_choice`
- [ ] `logit_bias` - [ ] `logit_bias`
- [ ] `user` - [ ] `user`

View File

@@ -21,6 +21,7 @@ Configure and launch external applications to use Ollama models. This provides a
- **OpenCode** - Open-source coding assistant - **OpenCode** - Open-source coding assistant
- **Claude Code** - Anthropic's agentic coding tool - **Claude Code** - Anthropic's agentic coding tool
- **Codex** - OpenAI's coding assistant - **Codex** - OpenAI's coding assistant
- **VS Code** - Microsoft's IDE with built-in AI chat
- **Droid** - Factory's AI coding agent - **Droid** - Factory's AI coding agent
#### Examples #### Examples
@@ -40,7 +41,7 @@ ollama launch claude
Launch with a specific model: Launch with a specific model:
``` ```
ollama launch claude --model qwen3-coder ollama launch claude --model qwen3.5
``` ```
Configure without launching: Configure without launching:

View File

@@ -51,6 +51,9 @@ Install prerequisites:
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network) - [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
- (Optional) VULKAN GPU support - (Optional) VULKAN GPU support
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs - [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
- (Optional) MLX engine support
- [CUDA 13+ SDK](https://developer.nvidia.com/cuda-downloads)
- [cuDNN 9+](https://developer.nvidia.com/cudnn)
Then, configure and build the project: Then, configure and build the project:
@@ -101,6 +104,10 @@ Install prerequisites:
- (Optional) VULKAN GPU support - (Optional) VULKAN GPU support
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs - [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
- Or install via package manager: `sudo apt install vulkan-sdk` (Ubuntu/Debian) or `sudo dnf install vulkan-sdk` (Fedora/CentOS) - Or install via package manager: `sudo apt install vulkan-sdk` (Ubuntu/Debian) or `sudo dnf install vulkan-sdk` (Fedora/CentOS)
- (Optional) MLX engine support
- [CUDA 13+ SDK](https://developer.nvidia.com/cuda-downloads)
- [cuDNN 9+](https://developer.nvidia.com/cudnn)
- OpenBLAS/LAPACK: `sudo apt install libopenblas-dev liblapack-dev liblapacke-dev` (Ubuntu/Debian)
> [!IMPORTANT] > [!IMPORTANT]
> Ensure prerequisites are in `PATH` before running CMake. > Ensure prerequisites are in `PATH` before running CMake.
@@ -118,6 +125,67 @@ Lastly, run Ollama:
go run . serve go run . serve
``` ```
## MLX Engine (Optional)
The MLX engine enables running safetensor based models. It requires building the [MLX](https://github.com/ml-explore/mlx) and [MLX-C](https://github.com/ml-explore/mlx-c) shared libraries separately via CMake. On MacOS, MLX leverages the Metal library to run on the GPU, and on Windows and Linux, runs on NVIDIA GPUs via CUDA v13.
### macOS (Apple Silicon)
Requires the Metal toolchain. Install [Xcode](https://developer.apple.com/xcode/) first, then:
```shell
xcodebuild -downloadComponent MetalToolchain
```
Verify it's installed correctly (should print "no input files"):
```shell
xcrun metal
```
Then build:
```shell
cmake -B build --preset MLX
cmake --build build --preset MLX --parallel
cmake --install build --component MLX
```
> [!NOTE]
> Without the Metal toolchain, cmake will silently complete with Metal disabled. Check the cmake output for `Setting MLX_BUILD_METAL=OFF` which indicates the toolchain is missing.
### Windows / Linux (CUDA)
Requires CUDA 13+ and [cuDNN](https://developer.nvidia.com/cudnn) 9+.
```shell
cmake -B build --preset "MLX CUDA 13"
cmake --build build --target mlx --target mlxc --config Release --parallel
cmake --install build --component MLX --strip
```
### Local MLX source overrides
To build against a local checkout of MLX and/or MLX-C (useful for development), set environment variables before running CMake:
```shell
export OLLAMA_MLX_SOURCE=/path/to/mlx
export OLLAMA_MLX_C_SOURCE=/path/to/mlx-c
```
For example, using the helper scripts with local mlx and mlx-c repos:
```shell
OLLAMA_MLX_SOURCE=../mlx OLLAMA_MLX_C_SOURCE=../mlx-c ./scripts/build_linux.sh
OLLAMA_MLX_SOURCE=../mlx OLLAMA_MLX_C_SOURCE=../mlx-c ./scripts/build_darwin.sh
```
```powershell
$env:OLLAMA_MLX_SOURCE="../mlx"
$env:OLLAMA_MLX_C_SOURCE="../mlx-c"
./scripts/build_darwin.ps1
```
## Docker ## Docker
```shell ```shell

View File

@@ -127,6 +127,7 @@
}, },
{ {
"group": "IDEs & Editors", "group": "IDEs & Editors",
"expanded": true,
"pages": [ "pages": [
"/integrations/cline", "/integrations/cline",
"/integrations/jetbrains", "/integrations/jetbrains",
@@ -160,6 +161,12 @@
"group": "More information", "group": "More information",
"pages": [ "pages": [
"/cli", "/cli",
{
"group": "Assistant Sandboxing",
"pages": [
"/integrations/nemoclaw"
]
},
"/modelfile", "/modelfile",
"/context-length", "/context-length",
"/linux", "/linux",

View File

@@ -61,11 +61,17 @@ Ollama supports the following AMD GPUs via the ROCm library:
### Linux Support ### Linux Support
| Family | Cards and accelerators | Ollama requires the AMD ROCm v7 driver on Linux. You can install or upgrade
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- | using the `amdgpu-install` utility from
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` | [AMD's ROCm documentation](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/).
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `SSG` |
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` | | Family | Cards and accelerators |
| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| AMD Radeon RX | `9070 XT` `9070 GRE` `9070` `9060 XT` `9060 XT LP` `9060` `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7700` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `5700 XT` `5700` `5600 XT` `5500 XT` |
| AMD Radeon AI PRO | `R9700` `R9600D` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` |
| AMD Ryzen AI | `Ryzen AI Max+ 395` `Ryzen AI Max 390` `Ryzen AI Max 385` `Ryzen AI 9 HX 475` `Ryzen AI 9 HX 470` `Ryzen AI 9 465` `Ryzen AI 9 HX 375` `Ryzen AI 9 HX 370` `Ryzen AI 9 365` |
| AMD Instinct | `MI350X` `MI300X` `MI300A` `MI250X` `MI250` `MI210` `MI100` |
### Windows Support ### Windows Support
@@ -97,17 +103,20 @@ This table shows some example GPUs that map to these LLVM targets:
| **LLVM Target** | **An Example GPU** | | **LLVM Target** | **An Example GPU** |
|-----------------|---------------------| |-----------------|---------------------|
| gfx908 | Radeon Instinct MI100 | | gfx908 | Radeon Instinct MI100 |
| gfx90a | Radeon Instinct MI210 | | gfx90a | Radeon Instinct MI210/MI250 |
| gfx940 | Radeon Instinct MI300 | | gfx942 | Radeon Instinct MI300X/MI300A |
| gfx941 | | | gfx950 | Radeon Instinct MI350X |
| gfx942 | | | gfx1010 | Radeon RX 5700 XT |
| gfx1012 | Radeon RX 5500 XT |
| gfx1030 | Radeon PRO V620 | | gfx1030 | Radeon PRO V620 |
| gfx1100 | Radeon PRO W7900 | | gfx1100 | Radeon PRO W7900 |
| gfx1101 | Radeon PRO W7700 | | gfx1101 | Radeon PRO W7700 |
| gfx1102 | Radeon RX 7600 | | gfx1102 | Radeon RX 7600 |
| gfx1103 | Radeon 780M |
AMD is working on enhancing ROCm v6 to broaden support for families of GPUs in a | gfx1150 | Ryzen AI 9 HX 375 |
future release which should increase support for more GPUs. | gfx1151 | Ryzen AI Max+ 395 |
| gfx1200 | Radeon RX 9070 |
| gfx1201 | Radeon RX 9070 XT |
Reach out on [Discord](https://discord.gg/ollama) or file an Reach out on [Discord](https://discord.gg/ollama) or file an
[issue](https://github.com/ollama/ollama/issues) for additional help. [issue](https://github.com/ollama/ollama/issues) for additional help.

Some files were not shown because too many files have changed in this diff Show More