Compare commits

...

104 Commits

Author SHA1 Message Date
Parth Sareen
676d9845ba launch: register websearch for openclaw (#14914) 2026-03-17 15:03:15 -07:00
Devon Rifkin
e37a9b4c01 cloud_proxy: for the web_search legacy path, flush on newlines (#14897)
`WebSearchAnthropicWriter` expects a single object per write. The new
transparent proxy will instead send it whatever bytes it sees. This
cloud-model + local-orchestration + cloud-search is a temporary code
path, so instead of making the web search code more robust to this, I
put an adapter in the middle that will flush line-by-line to preserve
the old behavior.
2026-03-17 13:30:17 -07:00
Patrick Devine
d727aacd04 mlx: quantized embeddings, fast SwiGLU, and runtime fixes (#14884)
Add QuantizedEmbedding and EmbeddingLayer interface so models can
use quantized embedding weights and expose tied output projections.
This change updates gemma3, glm4_moe_lite, llama, qwen3, and qwen3_5
to use the new interface.
2026-03-17 11:21:38 -07:00
Patrick Devine
fa69b833cd mlx: add prequantized tensor packing + changes for qwen35 (#14878)
This change adds a tensorImportTransform interface for model-specific
tensor transformations during safetensors import. This allows importing
and modifying the standard HF based weights as well as the mlx-community
derived pre-quantized safetensors repos to be directly
imported into `ollama create`. Right now this only works with Qwen3.5
importing which does tensor renaming, norm weight shifting (it
adds +1 to each value of the norm vectors), conv1d transposition,
and casts to BF16s for F32 based vectors.
2026-03-17 11:21:18 -07:00
Jesse Gross
bbbad97686 sched: Model eviction for MLX
MLX runners (image generation and LLM) previously bypassed the
scheduler's standard load path via a separate loadMLX method. This meant
they skipped VRAM fitting checks and couldn't participate in model
eviction.

Now all model types flow through the same load function. Model eviction
for MLX is based on weights as KV cache and compute graph are dynamic.
This means that eviction does not take into account the worst case
memory and models can still compete for memory but it is a significant
improvement.
2026-03-16 17:40:29 -07:00
Parth Sareen
bcf6d55b54 launch: fix web search, add web fetch, and enable both for local (#14886) 2026-03-16 16:26:19 -07:00
easonysliu
810d4f9c22 runner: fix swallowed error in allocModel graph reservation
In allocModel(), the first call to reserveWorstCaseGraph(true) had its
error silently discarded — `return nil` was used instead of `return err`.

This meant that if the prompt-sized graph reservation failed (e.g. due
to insufficient memory), the error was swallowed, allocModel reported
success, and the model appeared to load correctly. Subsequent inference
would then fail in unexpected ways because the worst-case graph was
never properly reserved.

Fix: return the actual error so the caller can handle the failure
(retry with reduced parallelism, report OOM, etc.).

Co-Authored-By: Claude (claude-opus-4-6) <noreply@anthropic.com>
2026-03-16 15:48:45 -07:00
Bruce MacDonald
856c047a6c cmd/launch: skip --install-daemon when systemd is unavailable (#14883)
In container environments without systemd, `openclaw onboard
--install-daemon` exits non-zero because it cannot create a systemd
user service. This causes `ollama launch openclaw` to abort even
though the gateway can be started as a foreground child process.

Only pass --install-daemon when systemd user services are reachable
(Linux with /run/systemd/system present and XDG_RUNTIME_DIR set).
On all other platforms the flag is still included by default.
2026-03-16 13:50:04 -07:00
Daniel Hiltgen
79c1e93c00 bench: improve benchmarking tool (#14240)
New features:
- Warmup phase to eliminate cold-start outliers
- time-to-first-token measured in each epoch
- VRAM/memory tracking to identify CPU spillover
- Controlled prompt length
- Defaults to 6 epochs and 200 tokens max

Benchstat fixes:
- ns/request instead of ns/op — non-standard unit created a separate group instead of grouping with timing metrics
- Token count as the N field — benchstat interprets N as iteration count for statistical weighting, not as a token count
2026-03-15 11:47:31 -07:00
Parth Sareen
f8b657c967 cmd/launch: add guards for headless mode (#14837) 2026-03-14 00:10:02 -07:00
Bruce MacDonald
10fefe0d57 config: use native OpenClaw Ollama onboarding (#14829)
OpenClaw now accepts the Ollama onboarding flags directly upstream, so rely on its wizard state instead of the legacy integration onboarding flag.

Update first-run setup to pass the Ollama auth and model flags during onboarding, perform a best-effort update before onboarding when needed, and drop the stale test that asserted persistence of the old onboarding flag.
2026-03-13 16:28:40 -07:00
Daniel Hiltgen
2f9a68f9e9 rocm: doc driver constraints (#14833) 2026-03-13 15:53:35 -07:00
Bruce MacDonald
3980c0217d server: decompress zstd request bodies in cloud passthrough middleware (#14827)
When a zstd-compressed request (e.g. from Codex CLI) hits /v1/responses
with a cloud model the request failed.

Fix by decompressing zstd bodies before
model extraction, so cloud models are detected and proxied directly
without the writer being wrapped.
2026-03-13 15:06:47 -07:00
Parth Sareen
870599f5da launch: remove warning for default policy (#14830) 2026-03-13 15:01:38 -07:00
Bruce MacDonald
abf8e8e9c8 middleware: handle non-JSON error responses gracefully (#14828)
writeError in both OpenAI and Anthropic middleware writers would return
a raw json.SyntaxError when the error payload wasn't valid JSON (e.g.
"invalid character 'e' looking for beginning of value"). Fall back to
using the raw bytes as the error message instead.

Also use the actual HTTP status code rather than hardcoding 500, so
error types map correctly
2026-03-13 14:50:49 -07:00
Shivam Tiwari
f3f31a8192 anthropic: close thinking block before tool_use when no text in between (#14825)
Root cause: StreamConverter.Process() only incremented contentIndex when
closing a thinking block if text content was present. When a model emitted
thinking followed directly by a tool_use block (no text in between),
thinkingDone was never set and contentIndex was not incremented, causing the
tool_use content_block_start to reuse index 0. Clients expecting sequential
indices would then fail to find the tool content block.

Fix: In the tool call loop, close any open thinking block (thinkingStarted &&
!thinkingDone) and increment contentIndex before opening the tool_use block,
mirroring the existing logic that closes an open text block.

Fixes #14816
2026-03-13 13:12:05 -07:00
Devon Rifkin
9e7ba835da cmd: still populate ollama ls when using ollama run <model:cloud> (#14824)
This is temporary until `api/tags` supports cloud natively
2026-03-13 12:24:45 -07:00
Parth Sareen
347f17b8d1 launch: add compact window for claude code (#14823) 2026-03-13 12:09:23 -07:00
Devon Rifkin
081b9eb423 api/create: always propagate :cloud source for cloud models (#14822)
Otherwise, using `/save` would try to run the local model instead
2026-03-13 11:58:00 -07:00
Parth Sareen
bb867c6fdb launch: fix headless --yes integration flow and policy scoping (#14815) 2026-03-13 11:45:36 -07:00
Cadu
81f4506a61 docs: document reasoning_effort support in OpenAI-compatible API (#14821)
Add reasoning_effort and reasoning to the supported features and
request fields for /v1/chat/completions. These fields control
thinking on thinking-capable models but were previously undocumented.

Closes #14820
2026-03-13 10:57:14 -07:00
Parth Sareen
76925f1284 cmd: TUI model ordering (#14814) 2026-03-13 10:19:22 -07:00
Devon Rifkin
f676231de9 server: remove experimental aliases support (#14810) 2026-03-12 20:27:24 -07:00
Parth Sareen
af5f7c0a9e cmd: refactor tui and launch (#14609) 2026-03-12 18:39:06 -07:00
Daniel Hiltgen
a6b27d776b ci: fix missing windows zip file (#14807)
Use 7z compression (better compression rate) if found in path.  That
alone isn't sufficient to get us under 2G, so MLX is now split out as a
discrete download.  Fix CI so it will fail if artifacts fail to upload.
2026-03-12 16:14:00 -07:00
Daniel Hiltgen
539741199e mlx: perf improvements (#14768)
* mlx: perf improvements

Fix nn.go to call mlx_fast_layer_norm instead of manually implementing (mean,
subtract, variance, rsqrt, multiply, add — 6 ops)

Fix llama.go, gemma3.go to remove RepeatKV to tile K/V tensors to match the Q
head count, since scaled_dot_product_attention natively handles GQA (it just
requires n_q_heads % n_kv_heads == 0)

* review comments
2026-03-12 12:01:28 -07:00
Eva H
8f45236d09 middleware: enable local tool model for web search (#14787) 2026-03-11 17:51:39 -04:00
Parth Sareen
97013a190c openai: split mixed thinking stream chunks via ToChunks (#14648) 2026-03-11 14:21:29 -07:00
Daniel Hiltgen
c222735c02 mlx: only log load errors when MLX is needed (#14764)
This suppresses irrelevant/noisy errors in the GGML runner.
2026-03-11 10:31:31 -07:00
Daniel Hiltgen
87d21c7fc0 MLX: harden for init failures (#14777)
The CLI now links to the lazy-load MLX code, but that still happens in
init functions.  On internal MLX errors, the CLI exits before it has a
chance to start.  This change re-wires the MLX error handling so it
doesn't exit by default.  The MLX based runners currently expect exits
on failure, so they re-initialize the default error handling.  We can
refine error handling for better go stack traces in the future.
2026-03-10 22:52:23 -07:00
Jeffrey Morgan
54e05172a0 Revert "runner: add token history sampling parameters to ollama runner (#14537)" (#14776)
This reverts commit 86513cb697.
2026-03-10 21:07:52 -07:00
Parth Sareen
464186e995 config: qwen3.5 recommendations (#14758) 2026-03-10 18:04:57 -07:00
Devon Rifkin
8c4d5d6c2f cloud_proxy: send ollama client version (#14769)
This was previously included in the user agent, and we've made use of it
in the past to hotpatch bugs server-side for particular Ollama versions.
2026-03-10 15:53:25 -07:00
Parth Sareen
bc72b14016 docs: update claude code docs (#14770) 2026-03-10 15:52:41 -07:00
Parth Sareen
61086083eb server: add experimental web search and web fetch routes (#14753) 2026-03-09 21:52:12 -07:00
Daniel Hiltgen
62d1f01ab4 ci: Fix windows build (#14754)
Instead of relying on sh for wildcard, do it in Go for better windows
compatibility.
2026-03-09 19:27:59 -07:00
Daniel Hiltgen
10e51c5177 MLX: add header vendoring and remove go build tag (#14642)
* prefer rocm v6 on windows

Avoid building with v7 - more changes are needed

* MLX: add header vendoring and remove go build tag

This switches to using a vendoring approach for the mlx-c headers so that Go
can build without requiring a cmake first.  This enables building the new MLX
based code by default.  Every time cmake runs, the headers are refreshed, so we
can easily keep them in sync when we bump mlx versions.  Basic Windows
and Linux support are verified.

* ci: harden for flaky choco repo servers

CI sometimes fails due to choco not actually installing cache.  Since it just speeds up the build, we can proceed without.

* review comments
2026-03-09 17:24:45 -07:00
Patrick Devine
3e06bde643 mlx: get parameters from modelfile during model creation (#14747) 2026-03-09 15:33:24 -07:00
Eva H
6be2de8214 app: auto update should be enabled when reset to defaults (#14741) 2026-03-09 15:02:36 -04:00
Daniel Hiltgen
ebb1b9ec14 rocm: update linux to v7.2 (#14391)
* rocm: update linux to v7.2

* review comments
2026-03-09 08:26:55 -07:00
Patrick Devine
d126467d5d x/mlxrunner: replace sampler interface chain with single stateful Sampler (#14652)
- Collapse MLX sampling state into a single sample.Sampler struct (options + history).
- Replace interface-based sampler chain (TopP, TopK, penalty, etc.) with function-based transforms.
- Update request/pipeline wiring to use *sample.Sampler, seed history from prompt tokens, and append generated tokens each step.
- Implement top_p, min_p, repeat_penalty, and frequency_penalty
2026-03-07 17:50:57 -08:00
Devon Rifkin
afb4c62fbf cloud_proxy: handle stream disconnects gracefully (#14685)
Previously we were printing out bad errors for expected cases like
clients disconnecting. Now we only debug log when that happens (which
still might help in cases where we're figuring out why an integration
isn't working). For other errors, we print out a proper warning now
2026-03-06 19:18:52 -08:00
Patrick Devine
e790dc435b mlx: int4 groupsize 64 (#14682)
Change affine 4bit integers to use groupsize 64
2026-03-06 16:39:47 -08:00
Daniel Hiltgen
288077c3a3 build: smarter docker parallelism (#14653)
Our Dockerfile leverages parallel stages for more efficient builds.  However,
our old parallel settings were naive and lead to under/over utilization
depending on the capabilities of your build system.

This change switches to using Ninja for all our docker cmake builds to leverage
its smarter parallel logic.  We tell Ninja to target a load of nproc so each of
the build stages will share the load on the system aiming for full CPU use
without oversaturation.

The GPU parallelism settings are also adjusted to 4 to avoid a long-tail for
the last few GPU targets as they work through the long list of GPU
architectures.

This also fixes the Dockerfile to move Vulkan install to just the stage that
needs it instead of blocking most other GPU installs.  This should speed up CI
which always has a clean build cache.
2026-03-06 16:36:22 -08:00
Daniel Hiltgen
4425c54eda create: fix localhost handling (#14681) 2026-03-06 16:35:58 -08:00
Michael Yang
778899a5d2 docs: format compat docs (#14678) 2026-03-06 14:53:17 -08:00
Jeffrey Morgan
4eab60c1e2 Reapply "don't require pulling stubs for cloud models" again (#14608)
* Revert "Revert "Reapply "don't require pulling stubs for cloud models"" (#14606)"

This reverts commit 39982a954e.

* fix test + do cloud lookup only when seeing cloud models

---------

Co-authored-by: ParthSareen <parth.sareen@ollama.com>
2026-03-06 14:27:47 -08:00
Bruce MacDonald
1af850e6e3 parsers: repair unclosed arg_value tags in GLM tool calls (#14656)
GLM models sometimes omits </arg_value> closing tags in tool call XML, causing xml.Unmarshal to fail with "element <arg_value> closed by </tool_call>".

This is a known issue across the GLM family.

Sanitize the input to fix closing arg_key values so encoding/xml can handle it.
2026-03-06 14:08:34 -08:00
Parth Sareen
9b0c7cc7b9 cmd: override stale entries for context window pi (#14655) 2026-03-05 16:30:24 -08:00
Daniel Hiltgen
6928630601 mlx: prevent remote creation mismatch (#14651)
If the user is pointing at a remote OLLAMA_HOST, fail experimental safetensor
based create operations as we only support local creation currently.
2026-03-05 14:59:00 -08:00
Parth Sareen
9896e3627f cmd/config: fix cloud model limit lookups in integrations (#14650) 2026-03-05 13:57:28 -08:00
Bruce MacDonald
15732f0ea7 cmd: use native Ollama API endpoint for OpenClaw (#14649)
Remove the /v1 suffix from the OpenClaw provider baseUrl so it uses
the native Ollama API instead of the OpenAI-compatible endpoint. The
/v1 endpoint my break tool calling in OpenClaw.
2026-03-05 13:29:17 -08:00
Parth Sareen
562c76d7cc cmd: add qwen3.5 context length for launch (#14626) 2026-03-04 14:10:52 -08:00
Parth Sareen
122c68c151 server: loosen thinking level constraint (#14625) 2026-03-04 13:42:18 -08:00
Jeffrey Morgan
82848a7806 model: fix renderer and parser for qwen3.5 (#14605) 2026-03-03 20:58:29 -08:00
Jeffrey Morgan
39982a954e Revert "Reapply "don't require pulling stubs for cloud models"" (#14606)
This reverts commit 799e51d419.
2026-03-03 20:56:10 -08:00
Patrick Devine
e9f6ea232f Add qwen3.5-next-moe support to MLX runner and models (#14417)
This change adds support for qwen3.5-next-moe models (qwen3-next/qwen3.5-next/qwen3-coder) to the MLX runner. It also:

* introduces recurrent cache support and related MLX ops
* updates pipeline/runner integration and adds tests
* properly quantizes stacked expert tensors
* a Gated Delta Metal kernel for fast SSM inference
* adds new MLX calls for Conv1d, DepthwideConv1d, Contiguous, Exp, Log, SoftmaxAxis
2026-03-03 16:39:22 -08:00
Patrick Devine
110eff01a9 chore: remove old imagegen LLMs models (#14597)
These models are implemented in the x/mlxrunner instead.
2026-03-03 13:23:40 -08:00
Jeffrey Morgan
799e51d419 Reapply "don't require pulling stubs for cloud models"
This reverts commit 97d2f05a6d.
2026-03-03 13:17:10 -08:00
Victor-Quqi
e8fcb29586 model/renderers: fix glm-ocr image tags in renderer prompts (#14584) 2026-03-03 12:51:34 -08:00
Jeffrey Morgan
97d2f05a6d Revert "don't require pulling stubs for cloud models (#14574)" (#14596)
This reverts commit 8207e55ec7.
2026-03-03 12:51:23 -08:00
Devon Rifkin
8207e55ec7 don't require pulling stubs for cloud models (#14574)
* don't require pulling stubs for cloud models

This is a first in a series of PRs that will better integrate Ollama's
cloud into the API and CLI. Previously we used to have a layer of
indirection where you'd first have to pull a "stub" model that contains
a reference to a cloud model. With this change, you don't have to pull
first, you can just use a cloud model in various routes like `/api/chat`
and `/api/show`. This change respects
<https://github.com/ollama/ollama/pull/14221>, so if cloud is disabled,
these models won't be accessible.

There's also a new, simpler pass-through proxy that doesn't convert the
requests ahead of hitting the cloud models, which they themselves
already support various formats (e.g., `v1/chat/completions` or Open
Responses, etc.). This will help prevent issues caused by double
converting (e.g., `v1/chat/completions` converted to `api/chat` on the
client, then calling cloud and converting back to a
`v1/chat/completions` response instead of the cloud model handling the
original `v1/chat/completions` request first).

There's now a notion of "source tags", which can be mixed with existing
tags. So instead of having different formats like`gpt-oss:20b-cloud` vs.
`kimi-k2.5:cloud` (`-cloud` suffix vs. `:cloud`), you can now specify
cloud by simply appending `:cloud`. This PR doesn't change model
resolution yet, but sets us up to allow for things like omitting the
non-source tag, which would make something like `ollama run
gpt-oss:cloud` work the same way that `ollama run gpt-oss` already works
today.

More detailed changes:

- Added a shared model selector parser in `types/modelselector`:
  - supports `:cloud` and `:local`
  - accepts source tags in any position
  - supports legacy `:<tag>-cloud`
  - rejects conflicting source tags
- Integrated selector handling across server inference/show routes:
  - `GenerateHandler`, `ChatHandler`, `EmbedHandler`,
    `EmbeddingsHandler`, `ShowHandler`
- Added explicit-cloud passthrough proxy for ollama.com:
  - same-endpoint forwarding for `/api/*`, `/v1/*`, and `/v1/messages`
  - normalizes `model` (and `name` for `/api/show`) before forwarding
  - forwards request headers except hop-by-hop/proxy-managed headers
  - uses bounded response-header timeout
  - handles auth failures in a friendly way
- Preserved cloud-disable behavior (`OLLAMA_NO_CLOUD`)
- Updated create flow to support `FROM ...:cloud` model sources (though
  this flow uses the legacy proxy still, supporting Modelfile overrides
  is more complicated with the direct proxy approach)
- Updated CLI/TUI/config cloud detection to use shared selector logic
- Updated CLI preflight behavior so explicit cloud requests do not
  auto-pull local stubs

What's next?

- Cloud discovery/listing and cache-backed `ollama ls` / `/api/tags`
- Modelfile overlay support for virtual cloud models on OpenAI/Anthropic
  request families
- Recommender/default-selection behavior for ambiguous model families
- Fully remove the legacy flow

Fixes: https://github.com/ollama/ollama/issues/13801

* consolidate pull logic into confirmAndPull helper

pullIfNeeded and ShowOrPull shared identical confirm-and-pull logic.
Extract confirmAndPull to eliminate the duplication.

* skip local existence checks for cloud models

ModelExists and the TUI's modelExists both check the local model list,
which causes cloud models to appear missing. Return true early for
explicit cloud models so the TUI displays them beside the integration
name and skips re-prompting the model picker on relaunch.

* support optionally pulling stubs for newly-style names

We now normalize names like `<family>:<size>:cloud` into legacy-style
names like `<family>:<size>-cloud` for pulling and deleting (this also
supports stripping `:local`). Support for pulling cloud models is
temporary, once we integrate properly into `/api/tags` we won't need
this anymore.

* Fix server alias syncing

* Update cmd/cmd.go

Co-authored-by: Parth Sareen <parth.sareen@ollama.com>

* address comments

* improve some naming

---------

Co-authored-by: ParthSareen <parth.sareen@ollama.com>
2026-03-03 10:46:33 -08:00
Jesse Gross
ad16bffc7d mlx: Remove peak memory from the API
This is still in flux so it is better to just log it for now.
2026-03-02 15:56:18 -08:00
Jesse Gross
c1e3ef4bcc mlxrunner: Refcount pinned tensors
Otherwise, it is error prone to manage multiple components working
with the same tensor.
2026-03-02 15:56:06 -08:00
Parth Sareen
a3093cd5e5 cmd/opencode: rename provider from "Ollama (local)" to "Ollama" (#14566)
The "(local)" qualifier is unnecessary since there's only one Ollama
provider. Existing configs with the old name are migrated automatically;
custom names are left unchanged.
2026-03-02 14:17:18 -08:00
Bruce MacDonald
23d4cad1a2 server: verify digest is not empty on create (#14555)
An empty digest is not a valid digest for an incoming create request. Reject empty digests at the api level.
2026-03-02 13:43:35 -08:00
Jeffrey Morgan
86513cb697 runner: add token history sampling parameters to ollama runner (#14537) 2026-03-01 19:16:07 -08:00
Jeffrey Morgan
3490e9590b model/qwen3next: avoid crash in in DeltaNet when offloading (#14541)
Co-authored-by: Yossi Ovadia <jabadia@gmail.com>
2026-03-01 18:44:04 -08:00
Jeffrey Morgan
8da09b1e7e qwen3next: add compatibility with imported GGUF models (#14517) 2026-02-28 14:21:42 -08:00
Jesse Gross
a60b9adcce mlxrunner: Fix prompt eval timing and count metrics
Only the last token's processing time is included in prompt processing,
giving an artificially high rate. In addition, the number of tokens
only included the tokens that miss the cache, instead of our historic
total tokens.
2026-02-27 17:29:47 -08:00
Jesse Gross
a16f96658b mlxrunner: Enforce model context limit
Currently, context length is unbounded - the cache will keep
growing forever independent of the model's trained context
length. This caps it and enforces semantics similar to most
cloud services:
 - Long prompts will result in an error, not truncation.
 - Generation that exceeds the context will be stopped
2026-02-27 17:29:47 -08:00
Jesse Gross
18ab09b431 mlxrunner: Propagate pipeline errors to client via api.StatusError
Errors that occur during pipeline processing are currently only
logged but not sent back to the client. Rather than using HTTP
status codes as we have historically done, this serializes errors
as messages to allow sending them at any time during the stream.
2026-02-27 17:29:47 -08:00
Jesse Gross
638faeac54 mlxrunner: Report actual memory usage from runner
The MLX runner previously reported a static VRAM estimate that was
computed at load time and consisted only of the weights. This is
strictly less than the actual memory usage, as it does not include
the KV cache or compute graph.
2026-02-27 17:29:47 -08:00
Jesse Gross
dd5eb6337d mlxrunner: Fix panic on full KV cache hit
When the entire prompt was already cached (e.g. repeated prompt),
findRemaining returned an empty slice, causing FromValues to panic
on an index-out-of-range accessing a zero-length byte slice.

Fix by always keeping at least one token to re-evaluate so the
pipeline can seed token generation. Also reject empty prompts
early rather than panicking.
2026-02-27 11:07:03 -08:00
Patrick Devine
79917cf80b show peak memory usage (#14485) 2026-02-26 18:38:27 -08:00
Parth Sareen
cc90a035a0 model/parsers: add stable tool call indexing for glm47 and qwen3 parsers (#14484) 2026-02-26 18:14:29 -08:00
Jeffrey Morgan
d98dda4676 model: fix qwen3 tool calling in thinking (#14477)
Align Qwen parser behavior with Transformers serve by allowing <tool_call> parsing while still in thinking collection.

Changes:

- qwen3vl: detect <tool_call> before </think> in thinking state and transition to tool parsing

- qwen3: same thinking-state tool detection and partial-tag overlap handling

- tests: update qwen3vl thinking/tool interleaving expectations

- tests: add qwen3 cases for tool call before </think> and split <tool_call> streaming
2026-02-26 16:13:18 -08:00
Eva H
d69ddc1edc fix: window app crash on startup when update is pending (#14451) 2026-02-26 16:47:12 -05:00
Eva H
9bf41969f0 app: fix first update check delayed by 1 hour (#14427) 2026-02-25 18:29:55 -05:00
Jesse Gross
0f23b7bff5 mlxrunner: Cancel in-flight requests when the client disconnects
Currently, a canceled request can result in computation continuing
in the background to completion. It can also trigger a deadlock
when there is nobody to read the output tokens and the pipeline
cannot continue to the next request.
2026-02-25 14:00:42 -08:00
Jesse Gross
4e57d2094e mlxrunner: Simplify pipeline memory and cache management
Particularly in error cases, it can be difficult to ensure that
all pinned memory is unpinned, MLX buffers are released and cache
state is consistent. This encapsulates those pieces and sets up
proper deferrals so that this happens automatically on exit.
2026-02-25 14:00:42 -08:00
Jeffrey Morgan
7f9efd53df model: add support for qwen3.5-27b model (#14415) 2026-02-25 01:09:58 -08:00
Jeffrey Morgan
da70c3222e model: support for qwen3.5 architecture (#14378) 2026-02-24 20:08:05 -08:00
Bruce MacDonald
9d902d63ce ggml: ensure tensor size is valid (#14406)
When quantizing tensors during model creation validate that the resulting sizes match what is expected based on the shape.
2026-02-24 21:52:44 -04:00
Daniel Hiltgen
f4f0a4a471 update mlx-c bindings to 0.5.0 (#14380)
* chore: update mlx-c bindings to 0.5.0 (#14303)

* linux: use gcc 11

---------

Co-authored-by: Patrick Devine <patrick@infrahq.com>
2026-02-23 16:44:29 -08:00
Eva H
3323c1d319 app: add upgrade configuration to settings page (#13512) 2026-02-23 18:08:52 -05:00
Jesse Gross
f20dc6b698 mlx: don't default to affine quantization for unquantized models
Otherwise the BF16 version of models trigger segfaults when they
call into quantized kernels.
2026-02-23 15:03:53 -08:00
Jeffrey Morgan
4b2ac1f369 model: improvements to LFM architectures (#14368) 2026-02-23 14:38:10 -08:00
Jesse Gross
8daf47fb3a mlxrunner: Fix duplicate log prefixes and reduce log noise
Pass subprocess stdout/stderr through to the parent's stderr directly
instead of re-wrapping each line with slog. The subprocess already
writes structured slog output, so the re-wrapping produced nested
timestamps, levels, and message fields that were hard to read.

Also downgrade verbose KV cache debug logs to trace level.
2026-02-23 14:09:20 -08:00
Eva H
6c980579cd ui: use capability-based detection for web search (#14336) 2026-02-23 15:00:09 -05:00
Jesse Gross
5c73c4e2ee mlxrunner: Simplify KV cache to single-entry prefix matching
The KV cache previously used a tree structure which could
store multiple divergent sequences, which is good for cache
reuse. However, this is typically used in conjunction with
paged attention so each node in the tree can store just a
chunk of the KV cache and they can be stitched together later.
We don't currently do this, so the cache was storing copies of
the full cache for each past sequence.

This redundancy plus the lack of resource limits, caused significant
memory use as a conversation grew. Instead, this changes to store
a single entry for the cache, which can be prefix matched. Although
it is less ideal for multiple users, it largely matches Ollama's
current behavior. It can be improved as additional pieces are fleshed
out.
2026-02-23 09:50:07 -08:00
Jesse Gross
5daf59cc66 mlxrunner: Fix memory leaks with pin/sweep lifecycle management
The previous approach tracked array lifecycles through reference
counting, where each array recorded its inputs and a reference count
that was decremented as dependents were freed. This is not really
necessary as MLX tracks references internally. It is also error
prone as it is easy to create new arrays and forget to free them
when the Go variable goes out of scope.

Instead, we can pin just the arrays we want (typically outputs and
specific intermediates, like the cache). All other arrays are freed
by default when we run sweep. This avoids most causes of memory leaks
while still giving the freedom to save what we want.
2026-02-23 09:50:07 -08:00
Jeffrey Morgan
0ade9205cc models: add nemotronh architecture support (#14356) 2026-02-22 15:09:14 -08:00
Parth Sareen
06edabdde1 cmd/config: install web search plugin to user-level extensions dir (#14362) 2026-02-22 02:17:03 -08:00
Jeffrey Morgan
8b4e5a82a8 mlx: remove noisy error output from dynamic library loading (#14346)
The recent change in #14322 added tryLoadByName() which attempts to
load libmlxc.dylib via rpath before searching directories. This is an
optimization for Homebrew installations where rpath is correctly set.

However, when rpath isn't set (which is the common case for app bundle
installations), dlopen fails and the CHECK macro prints an error to
stderr:

  ERROR - dynamic.c:21 - CHECK failed: handle->ctx != NULL

This error is misleading because it's an expected failure path - the
code correctly falls back to searching the executable directory and
loads the library successfully. The error message causes user confusion
and makes it appear that something is broken.

Replace the CHECK macro with a simple return code so the C code fails
silently. The Go code already handles error logging appropriately:
tryLoadByName() fails silently (intentional fallback), while
tryLoadFromDir() logs via slog.Error() when explicit path loading fails.
2026-02-20 23:46:07 -08:00
Parth Sareen
3445223311 cmd: openclaw onboarding (#14344) 2026-02-20 19:08:38 -08:00
Jeffrey Morgan
fa6c0127e6 app: expose server's default context length to UI (#14037)
Parse the default_num_ctx from the server's "vram-based default context"
log line and expose it through the inference compute API. This eliminates
duplicate VRAM tier calculation logic in the frontend.

- Add InferenceInfo struct with Computes and DefaultContextLength
- Rename GetInferenceComputer to GetInferenceInfo
- Handle missing default context line gracefully (older servers)
- Add DefaultContextLength to InferenceComputeResponse
- Update Settings UI to use server's default, disable slider while loading
- Add disabled prop to Slider component (grays out + hides handle)
- Migrate existing users with context_length=4096 to 0 (auto mode)
2026-02-20 18:56:30 -08:00
Patrick Devine
97323d1c68 consolidate the tokenizer (#14327)
This change adds a new x/tokenizer package which includes:
  * New BPE and SentencePiece tokenizers
  * Removing the dependency on the imagegen tokenizers
  * Fixes to multibyte decoding in the pipeline
  * Various correctness and benchmark tests

Not included in this PR is the WordPiece tokenizer for BERT models which will be
added when we add embedding models. The imagegen tokenizers will also be removed in
a follow-up PR.
2026-02-19 15:55:45 -08:00
natl-set
458dd1b9d9 mlx: try loading library via rpath before searching directories (#14322)
The existing code manually searches directories for libmlxc.* and passes
full paths to dlopen, bypassing the binary's rpath. This means MLX
libraries installed via package managers (e.g., Homebrew) aren't found
even when rpath is correctly set at link time.

This change adds a fallback that tries loading via rpath first (using
just the library name), before falling back to the existing directory
search. This follows standard Unix/macOS conventions and works with any
installation that sets rpath.

Fixes library loading on macOS with Homebrew-installed mlx-c without
requiring OLLAMA_LIBRARY_PATH environment variable.

Co-authored-by: Natl <nat@MacBook-Pro.local>
2026-02-19 10:55:02 -08:00
Bruce MacDonald
9d02d1d767 install: prevent partial download script execution (#14311)
Wrap script in main function so that a truncated partial download doesn't end up executing half a script.
2026-02-18 18:32:45 -08:00
Bruce MacDonald
1a636fb47a cmd: set codex env vars on launch and handle zstd request bodies (#14122)
The Codex runner was not setting OPENAI_BASE_URL or OPENAI_API_KEY, this prevents Codex from sending requests to api.openai.com instead of the local Ollama server. This mirrors the approach used by the Claude runner.

Codex v0.98.0 sends zstd-compressed request bodies to the /v1/responses endpoint. Add decompression support in ResponsesMiddleware with an 8MB max decompressed size limit to prevent resource exhaustion.
2026-02-18 17:19:36 -08:00
Patrick Devine
0759fface9 Revert "chore: update mlx-c bindings to 0.5.0 (#14303)" (#14316)
This reverts commit f01a9a7859.
2026-02-18 17:01:25 -08:00
Parth Sareen
325b72bc31 cmd/tui: default to single-select for editor integrations (#14302) 2026-02-17 18:17:27 -08:00
Patrick Devine
f01a9a7859 chore: update mlx-c bindings to 0.5.0 (#14303) 2026-02-17 16:48:16 -08:00
369 changed files with 39937 additions and 15345 deletions

View File

@@ -117,6 +117,25 @@ jobs:
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
flags: ''
runner_dir: 'vulkan'
- os: windows
arch: amd64
preset: 'MLX CUDA 13'
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip
cuda-components:
- '"cudart"'
- '"nvcc"'
- '"cublas"'
- '"cublas_dev"'
- '"cufft"'
- '"cufft_dev"'
- '"nvrtc"'
- '"nvrtc_dev"'
- '"crt"'
- '"nvvm"'
- '"nvptxcompiler"'
cuda-version: '13.0'
flags: ''
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
env:
@@ -125,8 +144,10 @@ jobs:
- name: Install system dependencies
run: |
choco install -y --no-progress ccache ninja
ccache -o cache_dir=${{ github.workspace }}\.ccache
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan')
if (Get-Command ccache -ErrorAction SilentlyContinue) {
ccache -o cache_dir=${{ github.workspace }}\.ccache
}
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan') || startsWith(matrix.preset, 'MLX ')
id: cache-install
uses: actions/cache/restore@v4
with:
@@ -134,8 +155,9 @@ jobs:
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
key: ${{ matrix.install }}
- if: startsWith(matrix.preset, 'CUDA ')
C:\Program Files\NVIDIA\CUDNN
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'MLX ')
name: Install CUDA ${{ matrix.cuda-version }}
run: |
$ErrorActionPreference = "Stop"
@@ -179,6 +201,23 @@ jobs:
run: |
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: startsWith(matrix.preset, 'MLX ')
name: Install cuDNN for MLX
run: |
$ErrorActionPreference = "Stop"
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip"
Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted
$cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName
New-Item -ItemType Directory -Force -Path $cudnnRoot
Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse
}
echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
uses: actions/cache/save@v4
with:
@@ -186,7 +225,8 @@ jobs:
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
key: ${{ matrix.install }}
C:\Program Files\NVIDIA\CUDNN
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
- uses: actions/checkout@v4
- uses: actions/cache@v4
with:
@@ -198,7 +238,7 @@ jobs:
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}"
cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}"
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
cmake --install build --component "${{ startsWith(matrix.preset, 'MLX ') && 'MLX' || startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
env:
CMAKE_GENERATOR: Ninja
@@ -543,11 +583,19 @@ jobs:
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg dist/*.ps1 dist/*.sh ; do
echo "Uploading $payload"
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
pids[$!]=$!
pids+=($!)
sleep 1
done
echo "Waiting for uploads to complete"
for pid in "${pids[*]}"; do
wait $pid
failed=0
for pid in "${pids[@]}"; do
if ! wait $pid; then
echo "::error::Upload failed (pid $pid)"
failed=1
fi
done
if [ $failed -ne 0 ]; then
echo "One or more uploads failed"
exit 1
fi
echo "done"

View File

@@ -37,7 +37,7 @@ jobs:
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
}
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*' '.github/**/*') | tee -a $GITHUB_OUTPUT
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
linux:
@@ -51,7 +51,7 @@ jobs:
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
- preset: ROCm
container: rocm/dev-ubuntu-22.04:6.1.2
container: rocm/dev-ubuntu-22.04:7.2
extra-packages: rocm-libs
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm'
- preset: Vulkan
@@ -60,6 +60,10 @@ jobs:
mesa-vulkan-drivers vulkan-tools
libvulkan1 libvulkan-dev
vulkan-sdk cmake ccache g++ make
- preset: 'MLX CUDA 13'
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
extra-packages: libcudnn9-dev-cuda-13 libopenblas-dev liblapack-dev liblapacke-dev git curl
flags: '-DCMAKE_CUDA_ARCHITECTURES=87 -DBLAS_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu -DLAPACK_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu'
runs-on: linux
container: ${{ matrix.container }}
steps:
@@ -76,6 +80,10 @@ jobs:
$sudo apt-get update
fi
$sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }}
# MLX requires CMake 3.25+, install from official releases
if [ "${{ matrix.preset }}" = "MLX CUDA 13" ]; then
curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.31.2/cmake-3.31.2-linux-$(uname -m).tar.gz | $sudo tar xz -C /usr/local --strip-components 1
fi
# Export VULKAN_SDK if provided by LunarG package (defensive)
if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then
echo "VULKAN_SDK=/usr" >> $GITHUB_ENV
@@ -87,8 +95,8 @@ jobs:
path: /github/home/.cache/ccache
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
- run: |
cmake --preset ${{ matrix.preset }} ${{ matrix.flags }}
cmake --build --preset ${{ matrix.preset }} --parallel
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
cmake --build --preset "${{ matrix.preset }}" --parallel
windows:
needs: [changes]
@@ -114,12 +122,31 @@ jobs:
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
- preset: Vulkan
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
- preset: 'MLX CUDA 13'
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
cuda-components:
- '"cudart"'
- '"nvcc"'
- '"cublas"'
- '"cublas_dev"'
- '"cufft"'
- '"cufft_dev"'
- '"nvrtc"'
- '"nvrtc_dev"'
- '"crt"'
- '"nvvm"'
- '"nvptxcompiler"'
cuda-version: '13.0'
runs-on: windows
steps:
- run: |
choco install -y --no-progress ccache ninja
ccache -o cache_dir=${{ github.workspace }}\.ccache
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan'
if (Get-Command ccache -ErrorAction SilentlyContinue) {
ccache -o cache_dir=${{ github.workspace }}\.ccache
}
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan' || matrix.preset == 'MLX CUDA 13'
id: cache-install
uses: actions/cache/restore@v4
with:
@@ -127,8 +154,9 @@ jobs:
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
key: ${{ matrix.install }}
- if: matrix.preset == 'CUDA'
C:\Program Files\NVIDIA\CUDNN
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
- if: matrix.preset == 'CUDA' || matrix.preset == 'MLX CUDA 13'
name: Install CUDA ${{ matrix.cuda-version }}
run: |
$ErrorActionPreference = "Stop"
@@ -164,10 +192,27 @@ jobs:
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
Start-Process -FilePath .\install.exe -ArgumentList "-c","--am","--al","in" -NoNewWindow -Wait
}
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
- if: matrix.preset == 'MLX CUDA 13'
name: Install cuDNN for MLX
run: |
$ErrorActionPreference = "Stop"
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip"
Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted
$cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName
New-Item -ItemType Directory -Force -Path $cudnnRoot
Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse
}
echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
uses: actions/cache/save@v4
with:
@@ -175,7 +220,8 @@ jobs:
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
key: ${{ matrix.install }}
C:\Program Files\NVIDIA\CUDNN
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
- uses: actions/checkout@v4
- uses: actions/cache@v4
with:

View File

@@ -64,10 +64,15 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR})
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${OLLAMA_BUILD_DIR})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
# Store ggml include paths for use with target_include_directories later.
# We avoid global include_directories() to prevent polluting the include path
# for other projects like MLX (whose openblas dependency has its own common.h).
set(GGML_INCLUDE_DIRS
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx
)
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
@@ -87,6 +92,14 @@ if(NOT CPU_VARIANTS)
set(CPU_VARIANTS "ggml-cpu")
endif()
# Apply ggml include directories to ggml targets only (not globally)
target_include_directories(ggml-base PRIVATE ${GGML_INCLUDE_DIRS})
foreach(variant ${CPU_VARIANTS})
if(TARGET ${variant})
target_include_directories(${variant} PRIVATE ${GGML_INCLUDE_DIRS})
endif()
endforeach()
install(TARGETS ggml-base ${CPU_VARIANTS}
RUNTIME_DEPENDENCIES
PRE_EXCLUDE_REGEXES ".*"
@@ -103,6 +116,7 @@ if(CMAKE_CUDA_COMPILER)
find_package(CUDAToolkit)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
target_include_directories(ggml-cuda PRIVATE ${GGML_INCLUDE_DIRS})
install(TARGETS ggml-cuda
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
@@ -134,6 +148,7 @@ if(CMAKE_HIP_COMPILER)
if(AMDGPU_TARGETS)
find_package(hip REQUIRED)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
target_include_directories(ggml-hip PRIVATE ${GGML_INCLUDE_DIRS})
if (WIN32)
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
@@ -148,7 +163,7 @@ if(CMAKE_HIP_COMPILER)
)
install(RUNTIME_DEPENDENCY_SET rocm
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register roctx64 rocroller drm drm_amdgpu numa elf
PRE_EXCLUDE_REGEXES ".*"
POST_EXCLUDE_REGEXES "system32"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
@@ -168,6 +183,7 @@ if(NOT APPLE)
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
target_include_directories(ggml-vulkan PRIVATE ${GGML_INCLUDE_DIRS})
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
PRE_INCLUDE_REGEXES vulkan
@@ -179,7 +195,6 @@ if(NOT APPLE)
endif()
option(MLX_ENGINE "Enable MLX backend" OFF)
if(MLX_ENGINE)
message(STATUS "Setting up MLX (this takes a while...)")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx)
@@ -187,10 +202,36 @@ if(MLX_ENGINE)
# Find CUDA toolkit if MLX is built with CUDA support
find_package(CUDAToolkit)
# Build list of directories for runtime dependency resolution
set(MLX_RUNTIME_DIRS ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR})
# Add cuDNN bin paths for DLLs (Windows MLX CUDA builds)
# CUDNN_ROOT_DIR is the standard CMake variable for cuDNN location
if(DEFINED ENV{CUDNN_ROOT_DIR})
# cuDNN 9.x has versioned subdirectories under bin/ (e.g., bin/13.0/)
file(GLOB CUDNN_BIN_SUBDIRS "$ENV{CUDNN_ROOT_DIR}/bin/*")
list(APPEND MLX_RUNTIME_DIRS ${CUDNN_BIN_SUBDIRS})
endif()
# Add build output directory and MLX dependency build directories
list(APPEND MLX_RUNTIME_DIRS ${OLLAMA_BUILD_DIR})
# OpenBLAS DLL location (pre-built zip extracts into openblas-src/bin/)
list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/openblas-src/bin)
# NCCL: on Linux, if real NCCL is found, cmake bundles libnccl.so via the
# regex below. If NCCL is not found, MLX links a static stub (OBJECT lib)
# so there is no runtime dependency. This path covers the stub build dir
# for windows so we include the DLL in our dependencies.
list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/distributed/nccl/nccl_stub-prefix/src/nccl_stub-build/Release)
# Base regexes for runtime dependencies (cross-platform)
set(MLX_INCLUDE_REGEXES cublas cublasLt cudart cufft nvrtc nvrtc-builtins cudnn nccl openblas gfortran)
# On Windows, also include dl.dll (dlfcn-win32 POSIX emulation layer)
if(WIN32)
list(APPEND MLX_INCLUDE_REGEXES "^dl\\.dll$")
endif()
install(TARGETS mlx mlxc
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
DIRECTORIES ${MLX_RUNTIME_DIRS}
PRE_INCLUDE_REGEXES ${MLX_INCLUDE_REGEXES}
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
@@ -205,13 +246,54 @@ if(MLX_ENGINE)
COMPONENT MLX)
endif()
# Manually install cudart and cublas since they might not be picked up as direct dependencies
# Install CCCL headers for NVRTC JIT compilation at runtime.
# MLX's own install rules use the default component so they get skipped by
# --component MLX. Headers are installed alongside libmlx in OLLAMA_INSTALL_DIR.
# On Linux, MLX's jit_module.cpp resolves CCCL via
# current_binary_dir().parent_path() / "include" / "cccl", so we create a
# symlink from lib/ollama/include -> ${OLLAMA_RUNNER_DIR}/include
# This will need refinement if we add multiple CUDA versions for MLX in the future.
if(EXISTS ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda)
install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda
DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl
COMPONENT MLX)
install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/nv
DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl
COMPONENT MLX)
if(NOT WIN32 AND NOT APPLE)
install(CODE "
set(_link \"${CMAKE_INSTALL_PREFIX}/lib/ollama/include\")
set(_target \"${OLLAMA_RUNNER_DIR}/include\")
if(NOT EXISTS \${_link})
execute_process(COMMAND \${CMAKE_COMMAND} -E create_symlink \${_target} \${_link})
endif()
" COMPONENT MLX)
endif()
endif()
# On Windows, explicitly install dl.dll (dlfcn-win32 POSIX dlopen emulation)
# RUNTIME_DEPENDENCIES auto-excludes it via POST_EXCLUDE_FILES_STRICT because
# dlfcn-win32 is a known CMake target with its own install rules (which install
# to the wrong destination). We must install it explicitly here.
if(WIN32)
install(FILES ${OLLAMA_BUILD_DIR}/dl.dll
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()
# Manually install CUDA runtime libraries that MLX loads via dlopen
# (not detected by RUNTIME_DEPENDENCIES since they aren't link-time deps)
if(CUDAToolkit_FOUND)
file(GLOB CUDART_LIBS
file(GLOB MLX_CUDA_LIBS
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
if(CUDART_LIBS)
install(FILES ${CUDART_LIBS}
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*"
"${CUDAToolkit_LIBRARY_DIR}/libcublasLt.so*"
"${CUDAToolkit_LIBRARY_DIR}/libnvrtc.so*"
"${CUDAToolkit_LIBRARY_DIR}/libnvrtc-builtins.so*"
"${CUDAToolkit_LIBRARY_DIR}/libcufft.so*"
"${CUDAToolkit_LIBRARY_DIR}/libcudnn.so*")
if(MLX_CUDA_LIBS)
install(FILES ${MLX_CUDA_LIBS}
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()

View File

@@ -77,6 +77,15 @@
"OLLAMA_RUNNER_DIR": "rocm"
}
},
{
"name": "ROCm 7",
"inherits": [ "ROCm" ],
"cacheVariables": {
"CMAKE_HIP_FLAGS": "-parallel-jobs=4",
"AMDGPU_TARGETS": "gfx942;gfx950;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1151;gfx1200;gfx1201;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-",
"OLLAMA_RUNNER_DIR": "rocm"
}
},
{
"name": "Vulkan",
"inherits": [ "Default" ],
@@ -103,6 +112,7 @@
"name": "MLX CUDA 13",
"inherits": [ "MLX", "CUDA 13" ],
"cacheVariables": {
"MLX_CUDA_ARCHITECTURES": "86;89;90;90a;100;103;75-virtual;80-virtual;110-virtual;120-virtual;121-virtual",
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
}
}
@@ -158,6 +168,11 @@
"inherits": [ "ROCm" ],
"configurePreset": "ROCm 6"
},
{
"name": "ROCm 7",
"inherits": [ "ROCm" ],
"configurePreset": "ROCm 7"
},
{
"name": "Vulkan",
"targets": [ "ggml-vulkan" ],

View File

@@ -1,33 +1,23 @@
# vim: filetype=dockerfile
ARG FLAVOR=${TARGETARCH}
ARG PARALLEL=8
ARG ROCMVERSION=6.3.3
ARG ROCMVERSION=7.2
ARG JETPACK5VERSION=r35.4.1
ARG JETPACK6VERSION=r36.4.0
ARG CMAKEVERSION=3.31.2
ARG NINJAVERSION=1.12.1
ARG VULKANVERSION=1.4.321.1
# We require gcc v10 minimum. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
# Default empty stages for local MLX source overrides.
# Override with: docker build --build-context local-mlx=../mlx --build-context local-mlx-c=../mlx-c
FROM scratch AS local-mlx
FROM scratch AS local-mlx-c
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
RUN yum install -y yum-utils \
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \
&& dnf install -y ccache \
RUN dnf install -y yum-utils ccache gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ gcc-toolset-11-binutils \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
ARG VULKANVERSION
RUN wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
&& tar xvf /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
&& dnf -y install ninja-build \
&& ln -s /usr/bin/python3 /usr/bin/python \
&& /${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \
&& /${VULKANVERSION}/vulkansdk -j 8 shaderc
RUN cp -r /${VULKANVERSION}/x86_64/include/* /usr/local/include/ \
&& cp -r /${VULKANVERSION}/x86_64/lib/* /usr/local/lib
ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
# install epel-release for ccache
@@ -38,100 +28,119 @@ ENV CC=clang CXX=clang++
FROM base-${TARGETARCH} AS base
ARG CMAKEVERSION
ARG NINJAVERSION
RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
RUN dnf install -y unzip \
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux$([ "$(uname -m)" = "aarch64" ] && echo "-aarch64").zip \
&& unzip /tmp/ninja.zip -d /usr/local/bin \
&& rm /tmp/ninja.zip
ENV CMAKE_GENERATOR=Ninja
ENV LDFLAGS=-s
FROM base AS cpu
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CPU' \
&& cmake --build --parallel ${PARALLEL} --preset 'CPU' \
&& cmake --install build --component CPU --strip --parallel ${PARALLEL}
&& cmake --build --preset 'CPU' -- -l $(nproc) \
&& cmake --install build --component CPU --strip
FROM base AS cuda-11
ARG CUDA11VERSION=11.8
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
ENV PATH=/usr/local/cuda-11/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 11' \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
&& cmake --build --preset 'CUDA 11' -- -l $(nproc) \
&& cmake --install build --component CUDA --strip
FROM base AS cuda-12
ARG CUDA12VERSION=12.8
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
ENV PATH=/usr/local/cuda-12/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 12' \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
&& cmake --build --preset 'CUDA 12' -- -l $(nproc) \
&& cmake --install build --component CUDA --strip
FROM base AS cuda-13
ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
ENV PATH=/usr/local/cuda-13/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 13' \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
&& cmake --build --preset 'CUDA 13' -- -l $(nproc) \
&& cmake --install build --component CUDA --strip
FROM base AS rocm-6
FROM base AS rocm-7
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'ROCm 6' \
&& cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
&& cmake --install build --component HIP --strip --parallel ${PARALLEL}
cmake --preset 'ROCm 7' \
&& cmake --build --preset 'ROCm 7' -- -l $(nproc) \
&& cmake --install build --component HIP --strip
RUN rm -f dist/lib/ollama/rocm/rocblas/library/*gfx90[06]*
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
ARG CMAKEVERSION
RUN apt-get update && apt-get install -y curl ccache \
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
ARG NINJAVERSION
RUN apt-get update && apt-get install -y curl ccache unzip \
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 \
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux-aarch64.zip \
&& unzip /tmp/ninja.zip -d /usr/local/bin \
&& rm /tmp/ninja.zip
ENV CMAKE_GENERATOR=Ninja
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'JetPack 5' \
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 5' \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
&& cmake --build --preset 'JetPack 5' -- -l $(nproc) \
&& cmake --install build --component CUDA --strip
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
ARG CMAKEVERSION
RUN apt-get update && apt-get install -y curl ccache \
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
ARG NINJAVERSION
RUN apt-get update && apt-get install -y curl ccache unzip \
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 \
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux-aarch64.zip \
&& unzip /tmp/ninja.zip -d /usr/local/bin \
&& rm /tmp/ninja.zip
ENV CMAKE_GENERATOR=Ninja
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'JetPack 6' \
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
&& cmake --build --preset 'JetPack 6' -- -l $(nproc) \
&& cmake --install build --component CUDA --strip
FROM base AS vulkan
ARG VULKANVERSION
RUN ln -s /usr/bin/python3 /usr/bin/python \
&& wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk.tar.xz \
&& tar xvf /tmp/vulkansdk.tar.xz -C /tmp \
&& /tmp/${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \
&& /tmp/${VULKANVERSION}/vulkansdk -j 8 shaderc \
&& cp -r /tmp/${VULKANVERSION}/x86_64/include/* /usr/local/include/ \
&& cp -r /tmp/${VULKANVERSION}/x86_64/lib/* /usr/local/lib \
&& cp -r /tmp/${VULKANVERSION}/x86_64/bin/* /usr/local/bin/ \
&& rm -rf /tmp/${VULKANVERSION} /tmp/vulkansdk.tar.xz
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'Vulkan' \
&& cmake --build --parallel --preset 'Vulkan' \
&& cmake --install build --component Vulkan --strip --parallel 8
&& cmake --build --preset 'Vulkan' -- -l $(nproc) \
&& cmake --install build --component Vulkan --strip
FROM base AS mlx
ARG CUDA13VERSION=13.0
@@ -143,20 +152,27 @@ ENV PATH=/usr/local/cuda-13/bin:$PATH
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
ARG PARALLEL
WORKDIR /go/src/github.com/ollama/ollama
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/imagegen/mlx x/imagegen/mlx
COPY go.mod go.sum .
COPY MLX_VERSION .
COPY MLX_VERSION MLX_CORE_VERSION .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
--mount=type=bind,from=local-mlx,target=/tmp/local-mlx \
--mount=type=bind,from=local-mlx-c,target=/tmp/local-mlx-c \
if [ -f /tmp/local-mlx/CMakeLists.txt ]; then \
export OLLAMA_MLX_SOURCE=/tmp/local-mlx; \
fi \
&& if [ -f /tmp/local-mlx-c/CMakeLists.txt ]; then \
export OLLAMA_MLX_C_SOURCE=/tmp/local-mlx-c; \
fi \
&& cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
&& cmake --build --preset 'MLX CUDA 13' -- -l $(nproc) \
&& cmake --install build --component MLX --strip
FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama
@@ -165,16 +181,14 @@ RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-
ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download
COPY . .
# Clone mlx-c headers for CGO (version from MLX_VERSION file)
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
ARG CGO_CFLAGS
ARG CGO_CXXFLAGS
ENV CGO_CFLAGS="${CGO_CFLAGS} -I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
ENV CGO_CFLAGS="${CGO_CFLAGS}"
ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}"
RUN --mount=type=cache,target=/root/.cache/go-build \
go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .
go build -trimpath -buildmode=pie -o /bin/ollama .
FROM --platform=linux/amd64 scratch AS amd64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
@@ -191,10 +205,9 @@ COPY --from=jetpack-5 dist/lib/ollama/ /lib/ollama/
COPY --from=jetpack-6 dist/lib/ollama/ /lib/ollama/
FROM scratch AS rocm
COPY --from=rocm-6 dist/lib/ollama /lib/ollama
COPY --from=rocm-7 dist/lib/ollama /lib/ollama
FROM ${FLAVOR} AS archive
ARG VULKANVERSION
COPY --from=cpu dist/lib/ollama /lib/ollama
COPY --from=build /bin/ollama /bin/ollama

1
MLX_CORE_VERSION Normal file
View File

@@ -0,0 +1 @@
v0.30.6

View File

@@ -1 +1 @@
v0.4.1
v0.5.0

View File

@@ -852,6 +852,19 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
continue
}
// Close thinking block if still open (thinking → tool_use without text in between)
if c.thinkingStarted && !c.thinkingDone {
c.thinkingDone = true
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
}
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",

View File

@@ -799,6 +799,107 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
}
}
// TestStreamConverter_ThinkingDirectlyFollowedByToolCall verifies that when a
// model emits a thinking block followed directly by a tool_use block (with no
// text block in between), the streaming converter correctly closes the thinking
// block and increments the content index before opening the tool_use block.
// Previously, the converter reused contentIndex=0 for the tool_use block,
// which caused "Content block not found" errors in clients. See #14816.
func TestStreamConverter_ThinkingDirectlyFollowedByToolCall(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model", 0)
// First chunk: thinking content (no text)
resp1 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Thinking: "I should call the tool.",
},
}
events1 := conv.Process(resp1)
// Should have: message_start, content_block_start(thinking), content_block_delta(thinking)
if len(events1) < 3 {
t.Fatalf("expected at least 3 events for thinking chunk, got %d", len(events1))
}
if events1[0].Event != "message_start" {
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
}
thinkingStart, ok := events1[1].Data.(ContentBlockStartEvent)
if !ok || thinkingStart.ContentBlock.Type != "thinking" {
t.Errorf("expected content_block_start(thinking) as second event, got %+v", events1[1])
}
if thinkingStart.Index != 0 {
t.Errorf("expected thinking block at index 0, got %d", thinkingStart.Index)
}
// Second chunk: tool call (no text between thinking and tool)
resp2 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_abc",
Function: api.ToolCallFunction{
Name: "ask_user",
Arguments: testArgs(map[string]any{"question": "cats or dogs?"}),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
}
events2 := conv.Process(resp2)
// Expect: content_block_stop(index=0), content_block_start(tool_use, index=1),
// content_block_delta(input_json_delta, index=1), content_block_stop(index=1),
// message_delta, message_stop
var thinkingStop, toolStart, toolDelta, toolStop *StreamEvent
for i := range events2 {
e := &events2[i]
switch e.Event {
case "content_block_stop":
if stop, ok := e.Data.(ContentBlockStopEvent); ok {
if stop.Index == 0 && thinkingStop == nil {
thinkingStop = e
} else if stop.Index == 1 {
toolStop = e
}
}
case "content_block_start":
if start, ok := e.Data.(ContentBlockStartEvent); ok && start.ContentBlock.Type == "tool_use" {
toolStart = e
}
case "content_block_delta":
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok && delta.Delta.Type == "input_json_delta" {
toolDelta = e
}
}
}
if thinkingStop == nil {
t.Error("expected content_block_stop for thinking block (index 0)")
}
if toolStart == nil {
t.Fatal("expected content_block_start for tool_use block")
}
if start, ok := toolStart.Data.(ContentBlockStartEvent); !ok || start.Index != 1 {
t.Errorf("expected tool_use block at index 1, got %+v", toolStart.Data)
}
if toolDelta == nil {
t.Fatal("expected input_json_delta event for tool call")
}
if delta, ok := toolDelta.Data.(ContentBlockDeltaEvent); !ok || delta.Index != 1 {
t.Errorf("expected tool delta at index 1, got %+v", toolDelta.Data)
}
if toolStop == nil {
t.Error("expected content_block_stop for tool_use block (index 1)")
}
}
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
// Test that unmarshalable arguments (like channels) are handled gracefully
// and don't cause a panic or corrupt stream

View File

@@ -476,25 +476,3 @@ func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
}
return &resp, nil
}
// AliasRequest is the request body for creating or updating a model alias.
type AliasRequest struct {
Alias string `json:"alias"`
Target string `json:"target"`
PrefixMatching bool `json:"prefix_matching,omitempty"`
}
// SetAliasExperimental creates or updates a model alias via the experimental aliases API.
func (c *Client) SetAliasExperimental(ctx context.Context, req *AliasRequest) error {
return c.do(ctx, http.MethodPost, "/api/experimental/aliases", req, nil)
}
// AliasDeleteRequest is the request body for deleting a model alias.
type AliasDeleteRequest struct {
Alias string `json:"alias"`
}
// DeleteAliasExperimental deletes a model alias via the experimental aliases API.
func (c *Client) DeleteAliasExperimental(ctx context.Context, req *AliasDeleteRequest) error {
return c.do(ctx, http.MethodDelete, "/api/experimental/aliases", req, nil)
}

View File

@@ -35,6 +35,7 @@ import (
var (
wv = &Webview{}
uiServerPort int
appStore *store.Store
)
var debug = strings.EqualFold(os.Getenv("OLLAMA_DEBUG"), "true") || os.Getenv("OLLAMA_DEBUG") == "1"
@@ -208,6 +209,7 @@ func main() {
uiServerPort = port
st := &store.Store{}
appStore = st
// Enable CORS in development mode
if devMode {
@@ -253,6 +255,8 @@ func main() {
done <- osrv.Run(octx)
}()
upd := &updater.Updater{Store: st}
uiServer := ui.Server{
Token: token,
Restart: func() {
@@ -267,6 +271,10 @@ func main() {
ToolRegistry: toolRegistry,
Dev: devMode,
Logger: slog.Default(),
Updater: upd,
UpdateAvailableFunc: func() {
UpdateAvailable("")
},
}
srv := &http.Server{
@@ -284,8 +292,20 @@ func main() {
slog.Debug("background desktop server done")
}()
updater := &updater.Updater{Store: st}
updater.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
upd.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
// Check for pending updates on startup (show tray notification if update is ready)
if updater.IsUpdatePending() {
// On Windows, the tray is initialized in osRun(). Calling UpdateAvailable
// before that would dereference a nil tray callback.
// TODO: refactor so the update check runs after platform init on all platforms.
if runtime.GOOS == "windows" {
slog.Debug("update pending on startup, deferring tray notification until tray initialization")
} else {
slog.Debug("update pending on startup, showing tray notification")
UpdateAvailable("")
}
}
hasCompletedFirstRun, err := st.HasCompletedFirstRun()
if err != nil {
@@ -348,6 +368,17 @@ func startHiddenTasks() {
// CLI triggered app startup use-case
slog.Info("deferring pending update for fast startup")
} else {
// Check if auto-update is enabled before automatically upgrading
settings, err := appStore.Settings()
if err != nil {
slog.Warn("failed to load settings for upgrade check", "error", err)
} else if !settings.AutoUpdateEnabled {
slog.Info("auto-update disabled, skipping automatic upgrade at startup")
// Still show tray notification so user knows update is ready
UpdateAvailable("")
return
}
if err := updater.DoUpgradeAtStartup(); err != nil {
slog.Info("unable to perform upgrade at startup", "error", err)
// Make sure the restart to upgrade menu shows so we can attempt an interactive upgrade to get authorization

View File

@@ -154,6 +154,10 @@ func handleURLSchemeRequest(urlScheme string) {
}
func UpdateAvailable(ver string) error {
if app.t == nil {
slog.Debug("tray not yet initialized, skipping update notification")
return nil
}
return app.t.UpdateAvailable(ver)
}
@@ -165,6 +169,14 @@ func osRun(shutdown func(), hasCompletedFirstRun, startHidden bool) {
log.Fatalf("Failed to start: %s", err)
}
// Check for pending updates now that the tray is initialized.
// The platform-independent check in app.go fires before osRun,
// when app.t is still nil, so we must re-check here.
if updater.IsUpdatePending() {
slog.Debug("update pending on startup, showing tray notification")
UpdateAvailable("")
}
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)

View File

@@ -41,6 +41,11 @@ type InferenceCompute struct {
VRAM string
}
type InferenceInfo struct {
Computes []InferenceCompute
DefaultContextLength int
}
func New(s *store.Store, devMode bool) *Server {
p := resolvePath("ollama")
return &Server{store: s, bin: p, dev: devMode}
@@ -272,9 +277,12 @@ func openRotatingLog() (io.WriteCloser, error) {
// Attempt to retrieve inference compute information from the server
// log. Set ctx to timeout to control how long to wait for the logs to appear
func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
inference := []InferenceCompute{}
marker := regexp.MustCompile(`inference compute.*library=`)
func GetInferenceInfo(ctx context.Context) (*InferenceInfo, error) {
info := &InferenceInfo{}
computeMarker := regexp.MustCompile(`inference compute.*library=`)
defaultCtxMarker := regexp.MustCompile(`vram-based default context`)
defaultCtxRegex := regexp.MustCompile(`default_num_ctx=(\d+)`)
q := `inference compute.*%s=["]([^"]*)["]`
nq := `inference compute.*%s=(\S+)\s`
type regex struct {
@@ -340,8 +348,8 @@ func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
match := marker.FindStringSubmatch(line)
if len(match) > 0 {
// Check for inference compute lines
if computeMarker.MatchString(line) {
ic := InferenceCompute{
Library: get("library", line),
Variant: get("variant", line),
@@ -352,12 +360,25 @@ func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
}
slog.Info("Matched", "inference compute", ic)
inference = append(inference, ic)
} else {
// Break out on first non matching line after we start matching
if len(inference) > 0 {
return inference, nil
info.Computes = append(info.Computes, ic)
continue
}
// Check for default context length line
if defaultCtxMarker.MatchString(line) {
match := defaultCtxRegex.FindStringSubmatch(line)
if len(match) > 1 {
numCtx, err := strconv.Atoi(match[1])
if err == nil {
info.DefaultContextLength = numCtx
slog.Info("Matched default context length", "default_num_ctx", numCtx)
}
}
return info, nil
}
// If we've found compute info but hit a non-matching line, return what we have
// This handles older server versions that don't log the default context line
if len(info.Computes) > 0 {
return info, nil
}
}
time.Sleep(100 * time.Millisecond)

View File

@@ -205,44 +205,50 @@ func TestServerCmdCloudSettingEnv(t *testing.T) {
}
}
func TestGetInferenceComputer(t *testing.T) {
func TestGetInferenceInfo(t *testing.T) {
tests := []struct {
name string
log string
exp []InferenceCompute
name string
log string
expComputes []InferenceCompute
expDefaultCtxLen int
}{
{
name: "metal",
log: `time=2025-06-30T09:23:07.374-07:00 level=DEBUG source=sched.go:108 msg="starting llm scheduler"
time=2025-06-30T09:23:07.416-07:00 level=INFO source=types.go:130 msg="inference compute" id=0 library=metal variant="" compute="" driver=0.0 name="" total="96.0 GiB" available="96.0 GiB"
time=2025-06-30T09:23:07.417-07:00 level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="96.0 GiB" default_num_ctx=262144
time=2025-06-30T09:25:56.197-07:00 level=DEBUG source=ggml.go:155 msg="key not found" key=general.alignment default=32
`,
exp: []InferenceCompute{{
expComputes: []InferenceCompute{{
Library: "metal",
Driver: "0.0",
VRAM: "96.0 GiB",
}},
expDefaultCtxLen: 262144,
},
{
name: "cpu",
log: `time=2025-07-01T17:59:51.470Z level=INFO source=gpu.go:377 msg="no compatible GPUs were discovered"
time=2025-07-01T17:59:51.470Z level=INFO source=types.go:130 msg="inference compute" id=0 library=cpu variant="" compute="" driver=0.0 name="" total="31.3 GiB" available="30.4 GiB"
time=2025-07-01T17:59:51.471Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="31.3 GiB" default_num_ctx=32768
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
`,
exp: []InferenceCompute{{
expComputes: []InferenceCompute{{
Library: "cpu",
Driver: "0.0",
VRAM: "31.3 GiB",
}},
expDefaultCtxLen: 32768,
},
{
name: "cuda1",
log: `time=2025-07-01T19:33:43.162Z level=DEBUG source=amd_linux.go:419 msg="amdgpu driver not detected /sys/module/amdgpu"
releasing cuda driver library
time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference compute" id=GPU-452cac9f-6960-839c-4fb3-0cec83699196 library=cuda variant=v12 compute=6.1 driver=12.7 name="NVIDIA GeForce GT 1030" total="3.9 GiB" available="3.9 GiB"
time=2025-07-01T19:33:43.163Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="3.9 GiB" default_num_ctx=4096
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
`,
exp: []InferenceCompute{{
expComputes: []InferenceCompute{{
Library: "cuda",
Variant: "v12",
Compute: "6.1",
@@ -250,6 +256,7 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
Name: "NVIDIA GeForce GT 1030",
VRAM: "3.9 GiB",
}},
expDefaultCtxLen: 4096,
},
{
name: "frank",
@@ -257,9 +264,10 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
releasing cuda driver library
time=2025-07-01T19:36:13.315Z level=INFO source=types.go:130 msg="inference compute" id=GPU-d6de3398-9932-6902-11ec-fee8e424c8a2 library=cuda variant=v12 compute=7.5 driver=12.8 name="NVIDIA GeForce RTX 2080 Ti" total="10.6 GiB" available="10.4 GiB"
time=2025-07-01T19:36:13.315Z level=INFO source=types.go:130 msg="inference compute" id=GPU-9abb57639fa80c50 library=rocm variant="" compute=gfx1030 driver=6.3 name=1002:73bf total="16.0 GiB" available="1.3 GiB"
time=2025-07-01T19:36:13.316Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="26.6 GiB" default_num_ctx=32768
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
`,
exp: []InferenceCompute{
expComputes: []InferenceCompute{
{
Library: "cuda",
Variant: "v12",
@@ -276,6 +284,20 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
VRAM: "16.0 GiB",
},
},
expDefaultCtxLen: 32768,
},
{
name: "missing_default_context",
log: `time=2025-06-30T09:23:07.374-07:00 level=DEBUG source=sched.go:108 msg="starting llm scheduler"
time=2025-06-30T09:23:07.416-07:00 level=INFO source=types.go:130 msg="inference compute" id=0 library=metal variant="" compute="" driver=0.0 name="" total="96.0 GiB" available="96.0 GiB"
time=2025-06-30T09:25:56.197-07:00 level=DEBUG source=ggml.go:155 msg="key not found" key=general.alignment default=32
`,
expComputes: []InferenceCompute{{
Library: "metal",
Driver: "0.0",
VRAM: "96.0 GiB",
}},
expDefaultCtxLen: 0, // No default context line, should return 0
},
}
for _, tt := range tests {
@@ -288,18 +310,21 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
}
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond)
defer cancel()
ics, err := GetInferenceComputer(ctx)
info, err := GetInferenceInfo(ctx)
if err != nil {
t.Fatalf(" failed to get inference compute: %v", err)
t.Fatalf("failed to get inference info: %v", err)
}
if !reflect.DeepEqual(ics, tt.exp) {
t.Fatalf("got:\n%#v\nwant:\n%#v", ics, tt.exp)
if !reflect.DeepEqual(info.Computes, tt.expComputes) {
t.Fatalf("computes mismatch\ngot:\n%#v\nwant:\n%#v", info.Computes, tt.expComputes)
}
if info.DefaultContextLength != tt.expDefaultCtxLen {
t.Fatalf("default context length mismatch: got %d, want %d", info.DefaultContextLength, tt.expDefaultCtxLen)
}
})
}
}
func TestGetInferenceComputerTimeout(t *testing.T) {
func TestGetInferenceInfoTimeout(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond)
defer cancel()
tmpDir := t.TempDir()
@@ -308,7 +333,7 @@ func TestGetInferenceComputerTimeout(t *testing.T) {
if err != nil {
t.Fatalf("failed to write log file %s: %s", serverLogPath, err)
}
_, err = GetInferenceComputer(ctx)
_, err = GetInferenceInfo(ctx)
if err == nil {
t.Fatal("expected timeout")
}

View File

@@ -9,12 +9,12 @@ import (
"strings"
"time"
sqlite3 "github.com/mattn/go-sqlite3"
_ "github.com/mattn/go-sqlite3"
)
// currentSchemaVersion defines the current database schema version.
// Increment this when making schema changes that require migrations.
const currentSchemaVersion = 13
const currentSchemaVersion = 15
// database wraps the SQLite connection.
// SQLite handles its own locking for concurrent access:
@@ -73,7 +73,7 @@ func (db *database) init() error {
agent BOOLEAN NOT NULL DEFAULT 0,
tools BOOLEAN NOT NULL DEFAULT 0,
working_dir TEXT NOT NULL DEFAULT '',
context_length INTEGER NOT NULL DEFAULT 4096,
context_length INTEGER NOT NULL DEFAULT 0,
window_width INTEGER NOT NULL DEFAULT 0,
window_height INTEGER NOT NULL DEFAULT 0,
config_migrated BOOLEAN NOT NULL DEFAULT 0,
@@ -86,6 +86,7 @@ func (db *database) init() error {
think_level TEXT NOT NULL DEFAULT '',
cloud_setting_migrated BOOLEAN NOT NULL DEFAULT 0,
remote TEXT NOT NULL DEFAULT '', -- deprecated
auto_update_enabled BOOLEAN NOT NULL DEFAULT 1,
schema_version INTEGER NOT NULL DEFAULT %d
);
@@ -251,6 +252,18 @@ func (db *database) migrate() error {
return fmt.Errorf("migrate v12 to v13: %w", err)
}
version = 13
case 13:
// change default context_length from 4096 to 0 (VRAM-based tiered defaults)
if err := db.migrateV13ToV14(); err != nil {
return fmt.Errorf("migrate v13 to v14: %w", err)
}
version = 14
case 14:
// add auto_update_enabled column to settings table
if err := db.migrateV14ToV15(); err != nil {
return fmt.Errorf("migrate v14 to v15: %w", err)
}
version = 15
default:
// If we have a version we don't recognize, just set it to current
// This might happen during development
@@ -474,6 +487,37 @@ func (db *database) migrateV12ToV13() error {
return nil
}
// migrateV13ToV14 changes the default context_length from 4096 to 0.
// When context_length is 0, the ollama server uses VRAM-based tiered defaults.
func (db *database) migrateV13ToV14() error {
_, err := db.conn.Exec(`UPDATE settings SET context_length = 0 WHERE context_length = 4096`)
if err != nil {
return fmt.Errorf("update context_length default: %w", err)
}
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 14`)
if err != nil {
return fmt.Errorf("update schema version: %w", err)
}
return nil
}
// migrateV14ToV15 adds the auto_update_enabled column to the settings table
func (db *database) migrateV14ToV15() error {
_, err := db.conn.Exec(`ALTER TABLE settings ADD COLUMN auto_update_enabled BOOLEAN NOT NULL DEFAULT 1`)
if err != nil && !duplicateColumnError(err) {
return fmt.Errorf("add auto_update_enabled column: %w", err)
}
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 15`)
if err != nil {
return fmt.Errorf("update schema version: %w", err)
}
return nil
}
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
func (db *database) cleanupOrphanedData() error {
_, err := db.conn.Exec(`
@@ -504,19 +548,11 @@ func (db *database) cleanupOrphanedData() error {
}
func duplicateColumnError(err error) bool {
if sqlite3Err, ok := err.(sqlite3.Error); ok {
return sqlite3Err.Code == sqlite3.ErrError &&
strings.Contains(sqlite3Err.Error(), "duplicate column name")
}
return false
return err != nil && strings.Contains(err.Error(), "duplicate column name")
}
func columnNotExists(err error) bool {
if sqlite3Err, ok := err.(sqlite3.Error); ok {
return sqlite3Err.Code == sqlite3.ErrError &&
strings.Contains(sqlite3Err.Error(), "no such column")
}
return false
return err != nil && strings.Contains(err.Error(), "no such column")
}
func (db *database) getAllChats() ([]Chat, error) {
@@ -1130,9 +1166,9 @@ func (db *database) getSettings() (Settings, error) {
var s Settings
err := db.conn.QueryRow(`
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level, auto_update_enabled
FROM settings
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel)
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel, &s.AutoUpdateEnabled)
if err != nil {
return Settings{}, fmt.Errorf("get settings: %w", err)
}
@@ -1142,9 +1178,9 @@ func (db *database) getSettings() (Settings, error) {
func (db *database) setSettings(s Settings) error {
_, err := db.conn.Exec(`
UPDATE settings
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel)
UPDATE settings
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?, auto_update_enabled = ?
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel, s.AutoUpdateEnabled)
if err != nil {
return fmt.Errorf("set settings: %w", err)
}

View File

@@ -98,6 +98,43 @@ func TestSchemaMigrations(t *testing.T) {
})
}
func TestMigrationV13ToV14ContextLength(t *testing.T) {
tmpDir := t.TempDir()
dbPath := filepath.Join(tmpDir, "test.db")
db, err := newDatabase(dbPath)
if err != nil {
t.Fatalf("failed to create database: %v", err)
}
defer db.Close()
_, err = db.conn.Exec("UPDATE settings SET context_length = 4096, schema_version = 13")
if err != nil {
t.Fatalf("failed to seed v13 settings row: %v", err)
}
if err := db.migrate(); err != nil {
t.Fatalf("migration from v13 to v14 failed: %v", err)
}
var contextLength int
if err := db.conn.QueryRow("SELECT context_length FROM settings").Scan(&contextLength); err != nil {
t.Fatalf("failed to read context_length: %v", err)
}
if contextLength != 0 {
t.Fatalf("expected context_length to migrate to 0, got %d", contextLength)
}
version, err := db.getSchemaVersion()
if err != nil {
t.Fatalf("failed to get schema version: %v", err)
}
if version != currentSchemaVersion {
t.Fatalf("expected schema version %d, got %d", currentSchemaVersion, version)
}
}
func TestChatDeletionWithCascade(t *testing.T) {
t.Run("chat deletion cascades to related messages", func(t *testing.T) {
tmpDir := t.TempDir()

View File

@@ -166,6 +166,9 @@ type Settings struct {
// SidebarOpen indicates if the chat sidebar is open
SidebarOpen bool
// AutoUpdateEnabled indicates if automatic updates should be downloaded
AutoUpdateEnabled bool
}
type Store struct {

View File

@@ -13,7 +13,7 @@ CREATE TABLE IF NOT EXISTS settings (
agent BOOLEAN NOT NULL DEFAULT 0,
tools BOOLEAN NOT NULL DEFAULT 0,
working_dir TEXT NOT NULL DEFAULT '',
context_length INTEGER NOT NULL DEFAULT 4096,
context_length INTEGER NOT NULL DEFAULT 0,
window_width INTEGER NOT NULL DEFAULT 0,
window_height INTEGER NOT NULL DEFAULT 0,
config_migrated BOOLEAN NOT NULL DEFAULT 0,

View File

@@ -289,10 +289,12 @@ export class InferenceCompute {
}
export class InferenceComputeResponse {
inferenceComputes: InferenceCompute[];
defaultContextLength: number;
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.inferenceComputes = this.convertValues(source["inferenceComputes"], InferenceCompute);
this.defaultContextLength = source["defaultContextLength"];
}
convertValues(a: any, classs: any, asMap: boolean = false): any {
@@ -412,6 +414,7 @@ export class Settings {
ThinkLevel: string;
SelectedModel: string;
SidebarOpen: boolean;
AutoUpdateEnabled: boolean;
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
@@ -429,6 +432,7 @@ export class Settings {
this.ThinkLevel = source["ThinkLevel"];
this.SelectedModel = source["SelectedModel"];
this.SidebarOpen = source["SidebarOpen"];
this.AutoUpdateEnabled = source["AutoUpdateEnabled"];
}
}
export class SettingsResponse {

View File

@@ -4,7 +4,6 @@ import {
ChatEvent,
DownloadEvent,
ErrorEvent,
InferenceCompute,
InferenceComputeResponse,
ModelCapabilitiesResponse,
Model,
@@ -407,7 +406,7 @@ export async function* pullModel(
}
}
export async function getInferenceCompute(): Promise<InferenceCompute[]> {
export async function getInferenceCompute(): Promise<InferenceComputeResponse> {
const response = await fetch(`${API_BASE}/api/v1/inference-compute`);
if (!response.ok) {
throw new Error(
@@ -416,8 +415,7 @@ export async function getInferenceCompute(): Promise<InferenceCompute[]> {
}
const data = await response.json();
const inferenceComputeResponse = new InferenceComputeResponse(data);
return inferenceComputeResponse.inferenceComputes || [];
return new InferenceComputeResponse(data);
}
export async function fetchHealth(): Promise<boolean> {

View File

@@ -17,7 +17,10 @@ import {
} from "@/hooks/useChats";
import { useNavigate } from "@tanstack/react-router";
import { useSelectedModel } from "@/hooks/useSelectedModel";
import { useHasVisionCapability } from "@/hooks/useModelCapabilities";
import {
useHasVisionCapability,
useHasToolsCapability,
} from "@/hooks/useModelCapabilities";
import { useUser } from "@/hooks/useUser";
import { DisplayLogin } from "@/components/DisplayLogin";
import { ErrorEvent, Message } from "@/gotypes";
@@ -149,12 +152,7 @@ function ChatForm({
} = useSettings();
const { cloudDisabled } = useCloudStatus();
// current supported models for web search
const modelLower = selectedModel?.model.toLowerCase() || "";
const supportsWebSearch =
modelLower.startsWith("gpt-oss") ||
modelLower.startsWith("qwen3") ||
modelLower.startsWith("deepseek-v3");
const supportsWebSearch = useHasToolsCapability(selectedModel?.model);
// Use per-chat thinking level instead of global
const thinkLevel: ThinkingLevel =
settingsThinkLevel === "none" || !settingsThinkLevel

View File

@@ -15,6 +15,7 @@ import {
XMarkIcon,
CogIcon,
ArrowLeftIcon,
ArrowDownTrayIcon,
} from "@heroicons/react/20/solid";
import { Settings as SettingsType } from "@/gotypes";
import { useNavigate } from "@tanstack/react-router";
@@ -26,6 +27,7 @@ import {
type CloudStatusResponse,
updateCloudSetting,
updateSettings,
getInferenceCompute,
} from "@/api";
function AnimatedDots() {
@@ -77,6 +79,13 @@ export default function Settings() {
const settings = settingsData?.settings || null;
const { data: inferenceComputeResponse } = useQuery({
queryKey: ["inferenceCompute"],
queryFn: getInferenceCompute,
});
const defaultContextLength = inferenceComputeResponse?.defaultContextLength;
const updateSettingsMutation = useMutation({
mutationFn: updateSettings,
onSuccess: () => {
@@ -204,7 +213,8 @@ export default function Settings() {
Models: "",
Agent: false,
Tools: false,
ContextLength: 4096,
ContextLength: 0,
AutoUpdateEnabled: true,
});
updateSettingsMutation.mutate(defaultSettings);
}
@@ -432,6 +442,29 @@ export default function Settings() {
</div>
</Field>
{/* Auto Update */}
<Field>
<div className="flex items-start justify-between gap-4">
<div className="flex items-start space-x-3 flex-1">
<ArrowDownTrayIcon className="mt-1 h-5 w-5 flex-shrink-0 text-black dark:text-neutral-100" />
<div>
<Label>Auto-download updates</Label>
<Description>
{settings.AutoUpdateEnabled
? "Automatically download updates when available."
: "Updates will not be downloaded automatically."}
</Description>
</div>
</div>
<div className="flex-shrink-0">
<Switch
checked={settings.AutoUpdateEnabled}
onChange={(checked) => handleChange("AutoUpdateEnabled", checked)}
/>
</div>
</div>
</Field>
{/* Expose Ollama */}
<Field>
<div className="flex items-start justify-between gap-4">
@@ -507,13 +540,11 @@ export default function Settings() {
</Description>
<div className="mt-3">
<Slider
value={(() => {
// Otherwise use the settings value
return settings.ContextLength || 4096;
})()}
value={settings.ContextLength || defaultContextLength || 0}
onChange={(value) => {
handleChange("ContextLength", value);
}}
disabled={!defaultContextLength}
options={[
{ value: 4096, label: "4k" },
{ value: 8192, label: "8k" },

View File

@@ -6,10 +6,11 @@ export interface SliderProps {
value?: number;
onChange?: (value: number) => void;
className?: string;
disabled?: boolean;
}
const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
({ label, options, value = 0, onChange }, ref) => {
({ label, options, value = 0, onChange, disabled = false }, ref) => {
const [selectedValue, setSelectedValue] = React.useState(value);
const [isDragging, setIsDragging] = React.useState(false);
const containerRef = React.useRef<HTMLDivElement>(null);
@@ -20,6 +21,7 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
}, [value]);
const handleClick = (optionValue: number) => {
if (disabled) return;
setSelectedValue(optionValue);
onChange?.(optionValue);
};
@@ -39,6 +41,7 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
};
const handleMouseDown = (e: React.MouseEvent) => {
if (disabled) return;
setIsDragging(true);
e.preventDefault();
};
@@ -77,7 +80,7 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
}
return (
<div className="space-y-2" ref={ref}>
<div className={`space-y-2 ${disabled ? "opacity-50" : ""}`} ref={ref}>
{label && <label className="text-sm font-medium">{label}</label>}
<div className="relative">
<div className="absolute top-[9px] left-2 right-2 h-1 bg-neutral-200 dark:bg-neutral-700 pointer-events-none rounded-full" />
@@ -88,10 +91,11 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
<button
onClick={() => handleClick(option.value)}
onMouseDown={handleMouseDown}
className="relative px-3 py-6 -mx-3 -my-6 z-10 cursor-pointer"
disabled={disabled}
className={`relative px-3 py-6 -mx-3 -my-6 z-10 ${disabled ? "cursor-not-allowed" : "cursor-pointer"}`}
>
<div className="relative w-5 h-5 flex items-center justify-center">
{selectedValue === option.value && (
{selectedValue === option.value && !disabled && (
<div className="w-4 h-4 bg-white dark:bg-white border border-neutral-400 dark:border-neutral-500 rounded-full cursor-grab active:cursor-grabbing" />
)}
</div>

View File

@@ -20,3 +20,8 @@ export function useHasVisionCapability(modelName: string | undefined) {
const { data: capabilitiesResponse } = useModelCapabilities(modelName);
return capabilitiesResponse?.capabilities?.includes("vision") ?? false;
}
export function useHasToolsCapability(modelName: string | undefined) {
const { data: capabilitiesResponse } = useModelCapabilities(modelName);
return capabilitiesResponse?.capabilities?.includes("tools") ?? false;
}

View File

@@ -28,12 +28,14 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
currentChatId && currentChatId !== "new" ? currentChatId : "",
);
const { data: inferenceComputes = [] } = useQuery({
queryKey: ["inference-compute"],
const { data: inferenceComputeResponse } = useQuery({
queryKey: ["inferenceCompute"],
queryFn: getInferenceCompute,
enabled: !settings.selectedModel, // Only fetch if no model is selected
});
const inferenceComputes = inferenceComputeResponse?.inferenceComputes || [];
const totalVRAM = useMemo(
() => getTotalVRAM(inferenceComputes),
[inferenceComputes],

View File

@@ -45,7 +45,8 @@ type InferenceCompute struct {
}
type InferenceComputeResponse struct {
InferenceComputes []InferenceCompute `json:"inferenceComputes"`
InferenceComputes []InferenceCompute `json:"inferenceComputes"`
DefaultContextLength int `json:"defaultContextLength"`
}
type ModelCapabilitiesResponse struct {

View File

@@ -28,6 +28,7 @@ import (
"github.com/ollama/ollama/app/tools"
"github.com/ollama/ollama/app/types/not"
"github.com/ollama/ollama/app/ui/responses"
"github.com/ollama/ollama/app/updater"
"github.com/ollama/ollama/app/version"
ollamaAuth "github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
@@ -106,6 +107,10 @@ type Server struct {
// Dev is true if the server is running in development mode
Dev bool
// Updater for checking and downloading updates
Updater *updater.Updater
UpdateAvailableFunc func()
}
func (s *Server) log() *slog.Logger {
@@ -829,8 +834,9 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
if !hasAttachments {
WebSearchEnabled := req.WebSearch != nil && *req.WebSearch
hasToolsCapability := slices.Contains(details.Capabilities, model.CapabilityTools)
if WebSearchEnabled {
if WebSearchEnabled && hasToolsCapability {
if supportsBrowserTools(req.Model) {
browserState, ok := s.browserState(chat)
if !ok {
@@ -840,7 +846,7 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
registry.Register(tools.NewBrowserSearch(browser))
registry.Register(tools.NewBrowserOpen(browser))
registry.Register(tools.NewBrowserFind(browser))
} else if supportsWebSearchTools(req.Model) {
} else {
registry.Register(&tools.WebSearch{})
registry.Register(&tools.WebFetch{})
}
@@ -1420,11 +1426,6 @@ func (s *Server) getSettings(w http.ResponseWriter, r *http.Request) error {
settings.Models = envconfig.Models()
}
// set default context length if not set
if settings.ContextLength == 0 {
settings.ContextLength = 4096
}
// Include current runtime settings
settings.Agent = s.Agent
settings.Tools = s.Tools
@@ -1451,6 +1452,24 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
return fmt.Errorf("failed to save settings: %w", err)
}
// Handle auto-update toggle changes
if old.AutoUpdateEnabled != settings.AutoUpdateEnabled {
if !settings.AutoUpdateEnabled {
// Auto-update disabled: cancel any ongoing download
if s.Updater != nil {
s.Updater.CancelOngoingDownload()
}
} else {
// Auto-update re-enabled: show notification if update is already staged, or trigger immediate check
if (updater.IsUpdatePending() || updater.UpdateDownloaded) && s.UpdateAvailableFunc != nil {
s.UpdateAvailableFunc()
} else if s.Updater != nil {
// Trigger the background checker to run immediately
s.Updater.TriggerImmediateCheck()
}
}
}
if old.ContextLength != settings.ContextLength ||
old.Models != settings.Models ||
old.Expose != settings.Expose {
@@ -1500,14 +1519,14 @@ func (s *Server) writeCloudStatus(w http.ResponseWriter) error {
func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error {
ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond)
defer cancel()
serverInferenceComputes, err := server.GetInferenceComputer(ctx)
info, err := server.GetInferenceInfo(ctx)
if err != nil {
s.log().Error("failed to get inference compute", "error", err)
return fmt.Errorf("failed to get inference compute: %w", err)
s.log().Error("failed to get inference info", "error", err)
return fmt.Errorf("failed to get inference info: %w", err)
}
inferenceComputes := make([]responses.InferenceCompute, len(serverInferenceComputes))
for i, ic := range serverInferenceComputes {
inferenceComputes := make([]responses.InferenceCompute, len(info.Computes))
for i, ic := range info.Computes {
inferenceComputes[i] = responses.InferenceCompute{
Library: ic.Library,
Variant: ic.Variant,
@@ -1519,7 +1538,8 @@ func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) err
}
response := responses.InferenceComputeResponse{
InferenceComputes: inferenceComputes,
InferenceComputes: inferenceComputes,
DefaultContextLength: info.DefaultContextLength,
}
w.Header().Set("Content-Type", "application/json")
@@ -1652,17 +1672,6 @@ func supportsBrowserTools(model string) bool {
return strings.HasPrefix(strings.ToLower(model), "gpt-oss")
}
// Web search tools are simpler, providing only basic web search and fetch capabilities (e.g., "web_search", "web_fetch") without simulating a browser. Currently only qwen3 and deepseek-v3 support web search tools.
func supportsWebSearchTools(model string) bool {
model = strings.ToLower(model)
prefixes := []string{"qwen3", "deepseek-v3"}
for _, p := range prefixes {
if strings.HasPrefix(model, p) {
return true
}
}
return false
}
// buildChatRequest converts store.Chat to api.ChatRequest
func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, availableTools []map[string]any) (*api.ChatRequest, error) {

View File

@@ -4,6 +4,7 @@ package ui
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
@@ -11,9 +12,11 @@ import (
"path/filepath"
"runtime"
"strings"
"sync/atomic"
"testing"
"github.com/ollama/ollama/app/store"
"github.com/ollama/ollama/app/updater"
)
func TestHandlePostApiSettings(t *testing.T) {
@@ -522,3 +525,290 @@ func TestUserAgentTransport(t *testing.T) {
t.Logf("User-Agent transport successfully set: %s", receivedUA)
}
func TestSupportsBrowserTools(t *testing.T) {
tests := []struct {
model string
want bool
}{
{"gpt-oss", true},
{"gpt-oss-latest", true},
{"GPT-OSS", true},
{"Gpt-Oss-v2", true},
{"qwen3", false},
{"deepseek-v3", false},
{"llama3.3", false},
{"", false},
}
for _, tt := range tests {
t.Run(tt.model, func(t *testing.T) {
if got := supportsBrowserTools(tt.model); got != tt.want {
t.Errorf("supportsBrowserTools(%q) = %v, want %v", tt.model, got, tt.want)
}
})
}
}
func TestWebSearchToolRegistration(t *testing.T) {
// Validates that the capability-gating logic in chat() correctly
// decides which tools to register based on model capabilities and
// the web search flag.
tests := []struct {
name string
webSearchEnabled bool
hasToolsCap bool
model string
wantBrowser bool // expects browser tools (gpt-oss)
wantWebSearch bool // expects basic web search/fetch tools
wantNone bool // expects no tools registered
}{
{
name: "web search enabled with tools capability - browser model",
webSearchEnabled: true,
hasToolsCap: true,
model: "gpt-oss-latest",
wantBrowser: true,
},
{
name: "web search enabled with tools capability - non-browser model",
webSearchEnabled: true,
hasToolsCap: true,
model: "qwen3",
wantWebSearch: true,
},
{
name: "web search enabled without tools capability",
webSearchEnabled: true,
hasToolsCap: false,
model: "llama3.3",
wantNone: true,
},
{
name: "web search disabled with tools capability",
webSearchEnabled: false,
hasToolsCap: true,
model: "qwen3",
wantNone: true,
},
{
name: "web search disabled without tools capability",
webSearchEnabled: false,
hasToolsCap: false,
model: "llama3.3",
wantNone: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Replicate the decision logic from chat() handler
gotBrowser := false
gotWebSearch := false
if tt.webSearchEnabled && tt.hasToolsCap {
if supportsBrowserTools(tt.model) {
gotBrowser = true
} else {
gotWebSearch = true
}
}
if tt.wantBrowser && !gotBrowser {
t.Error("expected browser tools to be registered")
}
if tt.wantWebSearch && !gotWebSearch {
t.Error("expected web search tools to be registered")
}
if tt.wantNone && (gotBrowser || gotWebSearch) {
t.Error("expected no tools to be registered")
}
if !tt.wantBrowser && gotBrowser {
t.Error("unexpected browser tools registered")
}
if !tt.wantWebSearch && gotWebSearch {
t.Error("unexpected web search tools registered")
}
})
}
}
func TestSettingsToggleAutoUpdateOff_CancelsDownload(t *testing.T) {
testStore := &store.Store{
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
}
defer testStore.Close()
// Start with auto-update enabled
settings, err := testStore.Settings()
if err != nil {
t.Fatal(err)
}
settings.AutoUpdateEnabled = true
if err := testStore.SetSettings(settings); err != nil {
t.Fatal(err)
}
upd := &updater.Updater{Store: &store.Store{
DBPath: filepath.Join(t.TempDir(), "db2.sqlite"),
}}
defer upd.Store.Close()
// We can't easily mock CancelOngoingDownload, but we can verify
// the full settings handler flow works without error
server := &Server{
Store: testStore,
Restart: func() {},
Updater: upd,
}
// Disable auto-update via settings API
settings.AutoUpdateEnabled = false
body, err := json.Marshal(settings)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
if err := server.settings(rr, req); err != nil {
t.Fatalf("settings() error = %v", err)
}
if rr.Code != http.StatusOK {
t.Fatalf("settings() status = %d, want %d", rr.Code, http.StatusOK)
}
// Verify settings were saved with auto-update disabled
saved, err := testStore.Settings()
if err != nil {
t.Fatal(err)
}
if saved.AutoUpdateEnabled {
t.Fatal("expected AutoUpdateEnabled to be false after toggle off")
}
}
func TestSettingsToggleAutoUpdateOn_WithPendingUpdate_ShowsNotification(t *testing.T) {
testStore := &store.Store{
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
}
defer testStore.Close()
// Start with auto-update disabled
settings, err := testStore.Settings()
if err != nil {
t.Fatal(err)
}
settings.AutoUpdateEnabled = false
if err := testStore.SetSettings(settings); err != nil {
t.Fatal(err)
}
// Simulate that an update was previously downloaded
oldVal := updater.UpdateDownloaded
updater.UpdateDownloaded = true
defer func() { updater.UpdateDownloaded = oldVal }()
var notificationCalled atomic.Bool
server := &Server{
Store: testStore,
Restart: func() {},
UpdateAvailableFunc: func() {
notificationCalled.Store(true)
},
}
// Re-enable auto-update via settings API
settings.AutoUpdateEnabled = true
body, err := json.Marshal(settings)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
if err := server.settings(rr, req); err != nil {
t.Fatalf("settings() error = %v", err)
}
if rr.Code != http.StatusOK {
t.Fatalf("settings() status = %d, want %d", rr.Code, http.StatusOK)
}
if !notificationCalled.Load() {
t.Fatal("expected UpdateAvailableFunc to be called when re-enabling with a downloaded update")
}
}
func TestSettingsToggleAutoUpdateOn_NoPendingUpdate_TriggersCheck(t *testing.T) {
testStore := &store.Store{
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
}
defer testStore.Close()
// Start with auto-update disabled
settings, err := testStore.Settings()
if err != nil {
t.Fatal(err)
}
settings.AutoUpdateEnabled = false
if err := testStore.SetSettings(settings); err != nil {
t.Fatal(err)
}
// Ensure no pending update - clear both the downloaded flag and the stage dir
oldVal := updater.UpdateDownloaded
updater.UpdateDownloaded = false
defer func() { updater.UpdateDownloaded = oldVal }()
oldStageDir := updater.UpdateStageDir
updater.UpdateStageDir = t.TempDir() // empty dir means IsUpdatePending() returns false
defer func() { updater.UpdateStageDir = oldStageDir }()
upd := &updater.Updater{Store: &store.Store{
DBPath: filepath.Join(t.TempDir(), "db2.sqlite"),
}}
defer upd.Store.Close()
// Initialize the checkNow channel by starting (and immediately stopping) the checker
// so TriggerImmediateCheck doesn't panic on nil channel
ctx, cancel := context.WithCancel(t.Context())
upd.StartBackgroundUpdaterChecker(ctx, func(string) error { return nil })
defer cancel()
var notificationCalled atomic.Bool
server := &Server{
Store: testStore,
Restart: func() {},
Updater: upd,
UpdateAvailableFunc: func() {
notificationCalled.Store(true)
},
}
// Re-enable auto-update via settings API
settings.AutoUpdateEnabled = true
body, err := json.Marshal(settings)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
if err := server.settings(rr, req); err != nil {
t.Fatalf("settings() error = %v", err)
}
if rr.Code != http.StatusOK {
t.Fatalf("settings() status = %d, want %d", rr.Code, http.StatusOK)
}
// UpdateAvailableFunc should NOT be called since there's no pending update
if notificationCalled.Load() {
t.Fatal("UpdateAvailableFunc should not be called when there is no pending update")
}
}

View File

@@ -19,6 +19,7 @@ import (
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/ollama/ollama/app/store"
@@ -58,7 +59,8 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
query := requestURL.Query()
query.Add("os", runtime.GOOS)
query.Add("arch", runtime.GOARCH)
query.Add("version", version.Version)
currentVersion := version.Version
query.Add("version", currentVersion)
query.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
// The original macOS app used to use the device ID
@@ -131,15 +133,27 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
}
func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
// Create a cancellable context for this download
downloadCtx, cancel := context.WithCancel(ctx)
u.cancelDownloadLock.Lock()
u.cancelDownload = cancel
u.cancelDownloadLock.Unlock()
defer func() {
u.cancelDownloadLock.Lock()
u.cancelDownload = nil
u.cancelDownloadLock.Unlock()
cancel()
}()
// Do a head first to check etag info
req, err := http.NewRequestWithContext(ctx, http.MethodHead, updateResp.UpdateURL, nil)
req, err := http.NewRequestWithContext(downloadCtx, http.MethodHead, updateResp.UpdateURL, nil)
if err != nil {
return err
}
// In case of slow downloads, continue the update check in the background
bgctx, cancel := context.WithCancel(ctx)
defer cancel()
bgctx, bgcancel := context.WithCancel(downloadCtx)
defer bgcancel()
go func() {
for {
select {
@@ -176,6 +190,7 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
_, err = os.Stat(stageFilename)
if err == nil {
slog.Info("update already downloaded", "bundle", stageFilename)
UpdateDownloaded = true
return nil
}
@@ -244,33 +259,85 @@ func cleanupOldDownloads(stageDir string) {
}
type Updater struct {
Store *store.Store
Store *store.Store
cancelDownload context.CancelFunc
cancelDownloadLock sync.Mutex
checkNow chan struct{}
}
// CancelOngoingDownload cancels any currently running download
func (u *Updater) CancelOngoingDownload() {
u.cancelDownloadLock.Lock()
defer u.cancelDownloadLock.Unlock()
if u.cancelDownload != nil {
slog.Info("cancelling ongoing update download")
u.cancelDownload()
u.cancelDownload = nil
}
}
// TriggerImmediateCheck signals the background checker to check for updates immediately
func (u *Updater) TriggerImmediateCheck() {
if u.checkNow != nil {
select {
case u.checkNow <- struct{}{}:
default:
// Check already pending, no need to queue another
}
}
}
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
u.checkNow = make(chan struct{}, 1)
u.checkNow <- struct{}{} // Trigger first check after initial delay
go func() {
// Don't blast an update message immediately after startup
time.Sleep(UpdateCheckInitialDelay)
slog.Info("beginning update checker", "interval", UpdateCheckInterval)
ticker := time.NewTicker(UpdateCheckInterval)
defer ticker.Stop()
for {
available, resp := u.checkForUpdate(ctx)
if available {
err := u.DownloadNewRelease(ctx, resp)
if err != nil {
slog.Error(fmt.Sprintf("failed to download new release: %s", err))
} else {
err = cb(resp.UpdateVersion)
if err != nil {
slog.Warn(fmt.Sprintf("failed to register update available with tray: %s", err))
}
}
}
select {
case <-ctx.Done():
slog.Debug("stopping background update checker")
return
default:
time.Sleep(UpdateCheckInterval)
case <-u.checkNow:
// Immediate check triggered
case <-ticker.C:
// Regular interval check
}
// Always check for updates
available, resp := u.checkForUpdate(ctx)
if !available {
continue
}
// Update is available - check if auto-update is enabled for downloading
settings, err := u.Store.Settings()
if err != nil {
slog.Error("failed to load settings", "error", err)
continue
}
if !settings.AutoUpdateEnabled {
// Auto-update disabled - don't download, just log
slog.Debug("update available but auto-update disabled", "version", resp.UpdateVersion)
continue
}
// Auto-update is enabled - download
err = u.DownloadNewRelease(ctx, resp)
if err != nil {
slog.Error("failed to download new release", "error", err)
continue
}
// Download successful - show tray notification
err = cb(resp.UpdateVersion)
if err != nil {
slog.Warn("failed to register update available with tray", "error", err)
}
}
}()

View File

@@ -11,6 +11,8 @@ import (
"log/slog"
"net/http"
"net/http/httptest"
"path/filepath"
"sync/atomic"
"testing"
"time"
@@ -33,7 +35,7 @@ func TestIsNewReleaseAvailable(t *testing.T) {
defer server.Close()
slog.Debug("server", "url", server.URL)
updater := &Updater{Store: &store.Store{}}
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
defer updater.Store.Close() // Ensure database is closed
UpdateCheckURLBase = server.URL + "/update.json"
updatePresent, resp := updater.checkForUpdate(t.Context())
@@ -84,8 +86,18 @@ func TestBackgoundChecker(t *testing.T) {
defer server.Close()
UpdateCheckURLBase = server.URL + "/update.json"
updater := &Updater{Store: &store.Store{}}
defer updater.Store.Close() // Ensure database is closed
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
defer updater.Store.Close()
settings, err := updater.Store.Settings()
if err != nil {
t.Fatal(err)
}
settings.AutoUpdateEnabled = true
if err := updater.Store.SetSettings(settings); err != nil {
t.Fatal(err)
}
updater.StartBackgroundUpdaterChecker(ctx, cb)
select {
case <-stallTimer.C:
@@ -99,3 +111,267 @@ func TestBackgoundChecker(t *testing.T) {
}
}
}
func TestAutoUpdateDisabledSkipsDownload(t *testing.T) {
UpdateStageDir = t.TempDir()
var downloadAttempted atomic.Bool
done := make(chan struct{})
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
UpdateCheckInitialDelay = 5 * time.Millisecond
UpdateCheckInterval = 5 * time.Millisecond
VerifyDownload = func() error {
return nil
}
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/update.json" {
w.Write([]byte(
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
server.URL+"/9.9.9/"+Installer)))
} else if r.URL.Path == "/9.9.9/"+Installer {
downloadAttempted.Store(true)
buf := &bytes.Buffer{}
zw := zip.NewWriter(buf)
zw.Close()
io.Copy(w, buf)
}
}))
defer server.Close()
UpdateCheckURLBase = server.URL + "/update.json"
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
defer updater.Store.Close()
// Ensure auto-update is disabled
settings, err := updater.Store.Settings()
if err != nil {
t.Fatal(err)
}
settings.AutoUpdateEnabled = false
if err := updater.Store.SetSettings(settings); err != nil {
t.Fatal(err)
}
cb := func(ver string) error {
t.Fatal("callback should not be called when auto-update is disabled")
return nil
}
updater.StartBackgroundUpdaterChecker(ctx, cb)
// Wait enough time for multiple check cycles
time.Sleep(50 * time.Millisecond)
close(done)
if downloadAttempted.Load() {
t.Fatal("download should not be attempted when auto-update is disabled")
}
}
func TestAutoUpdateReenabledDownloadsUpdate(t *testing.T) {
UpdateStageDir = t.TempDir()
var downloadAttempted atomic.Bool
callbackCalled := make(chan struct{}, 1)
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
UpdateCheckInitialDelay = 5 * time.Millisecond
UpdateCheckInterval = 5 * time.Millisecond
VerifyDownload = func() error {
return nil
}
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/update.json" {
w.Write([]byte(
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
server.URL+"/9.9.9/"+Installer)))
} else if r.URL.Path == "/9.9.9/"+Installer {
downloadAttempted.Store(true)
buf := &bytes.Buffer{}
zw := zip.NewWriter(buf)
zw.Close()
io.Copy(w, buf)
}
}))
defer server.Close()
UpdateCheckURLBase = server.URL + "/update.json"
upd := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
defer upd.Store.Close()
// Start with auto-update disabled
settings, err := upd.Store.Settings()
if err != nil {
t.Fatal(err)
}
settings.AutoUpdateEnabled = false
if err := upd.Store.SetSettings(settings); err != nil {
t.Fatal(err)
}
cb := func(ver string) error {
select {
case callbackCalled <- struct{}{}:
default:
}
return nil
}
upd.StartBackgroundUpdaterChecker(ctx, cb)
// Wait for a few cycles with auto-update disabled - no download should happen
time.Sleep(50 * time.Millisecond)
if downloadAttempted.Load() {
t.Fatal("download should not happen while auto-update is disabled")
}
// Re-enable auto-update
settings.AutoUpdateEnabled = true
if err := upd.Store.SetSettings(settings); err != nil {
t.Fatal(err)
}
// Wait for the checker to pick it up and download
select {
case <-callbackCalled:
// Success: download happened and callback was called after re-enabling
if !downloadAttempted.Load() {
t.Fatal("expected download to be attempted after re-enabling")
}
case <-time.After(5 * time.Second):
t.Fatal("expected download and callback after re-enabling auto-update")
}
}
func TestCancelOngoingDownload(t *testing.T) {
UpdateStageDir = t.TempDir()
downloadStarted := make(chan struct{})
downloadCancelled := make(chan struct{})
ctx := t.Context()
VerifyDownload = func() error {
return nil
}
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/update.json" {
w.Write([]byte(
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
server.URL+"/9.9.9/"+Installer)))
} else if r.URL.Path == "/9.9.9/"+Installer {
if r.Method == http.MethodHead {
w.Header().Set("Content-Length", "1000000")
w.WriteHeader(http.StatusOK)
return
}
// Signal that download has started
close(downloadStarted)
// Wait for cancellation or timeout
select {
case <-r.Context().Done():
close(downloadCancelled)
return
case <-time.After(5 * time.Second):
t.Error("download was not cancelled in time")
}
}
}))
defer server.Close()
UpdateCheckURLBase = server.URL + "/update.json"
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
defer updater.Store.Close()
_, resp := updater.checkForUpdate(ctx)
// Start download in goroutine
go func() {
_ = updater.DownloadNewRelease(ctx, resp)
}()
// Wait for download to start
select {
case <-downloadStarted:
case <-time.After(2 * time.Second):
t.Fatal("download did not start in time")
}
// Cancel the download
updater.CancelOngoingDownload()
// Verify cancellation was received
select {
case <-downloadCancelled:
// Success
case <-time.After(2 * time.Second):
t.Fatal("download cancellation was not received by server")
}
}
func TestTriggerImmediateCheck(t *testing.T) {
UpdateStageDir = t.TempDir()
checkCount := atomic.Int32{}
checkDone := make(chan struct{}, 10)
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
// Set a very long interval so only TriggerImmediateCheck causes checks
UpdateCheckInitialDelay = 1 * time.Millisecond
UpdateCheckInterval = 1 * time.Hour
VerifyDownload = func() error {
return nil
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/update.json" {
checkCount.Add(1)
select {
case checkDone <- struct{}{}:
default:
}
// Return no update available
w.WriteHeader(http.StatusNoContent)
}
}))
defer server.Close()
UpdateCheckURLBase = server.URL + "/update.json"
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
defer updater.Store.Close()
cb := func(ver string) error {
return nil
}
updater.StartBackgroundUpdaterChecker(ctx, cb)
// Wait for the initial check that fires after the initial delay
select {
case <-checkDone:
case <-time.After(2 * time.Second):
t.Fatal("initial check did not happen")
}
initialCount := checkCount.Load()
// Trigger immediate check
updater.TriggerImmediateCheck()
// Wait for the triggered check
select {
case <-checkDone:
case <-time.After(2 * time.Second):
t.Fatal("triggered check did not happen")
}
finalCount := checkCount.Load()
if finalCount <= initialCount {
t.Fatalf("TriggerImmediateCheck did not cause additional check: initial=%d, final=%d", initialCount, finalCount)
}
}

View File

@@ -369,25 +369,6 @@ func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error {
return nil
}
// func (t *winTray) hideMenuItem(menuItemId, parentId uint32) error {
// const ERROR_SUCCESS syscall.Errno = 0
// t.muMenus.RLock()
// menu := uintptr(t.menus[parentId])
// t.muMenus.RUnlock()
// res, _, err := pRemoveMenu.Call(
// menu,
// uintptr(menuItemId),
// MF_BYCOMMAND,
// )
// if res == 0 && err.(syscall.Errno) != ERROR_SUCCESS {
// return err
// }
// t.delFromVisibleItems(parentId, menuItemId)
// return nil
// }
func (t *winTray) showMenu() error {
p := point{}
boolRet, _, err := pGetCursorPos.Call(uintptr(unsafe.Pointer(&p)))

View File

@@ -51,7 +51,6 @@ const (
IMAGE_ICON = 1 // Loads an icon
LR_DEFAULTSIZE = 0x00000040 // Loads default-size icon for windows(SM_CXICON x SM_CYICON) if cx, cy are set to zero
LR_LOADFROMFILE = 0x00000010 // Loads the stand-alone image from the file
MF_BYCOMMAND = 0x00000000
MFS_DISABLED = 0x00000003
MFT_SEPARATOR = 0x00000800
MFT_STRING = 0x00000000

View File

@@ -1,27 +1,31 @@
Ollama Benchmark Tool
---------------------
A Go-based command-line tool for benchmarking Ollama models with configurable parameters and multiple output formats.
A Go-based command-line tool for benchmarking Ollama models with configurable parameters, warmup phases, TTFT tracking, VRAM monitoring, and benchstat/CSV output.
## Features
* Benchmark multiple models in a single run
* Support for both text and image prompts
* Configurable generation parameters (temperature, max tokens, seed, etc.)
* Supports benchstat and CSV output formats
* Detailed performance metrics (prefill, generate, load, total durations)
* Warmup phase before timed epochs to stabilize measurements
* Time-to-first-token (TTFT) tracking per epoch
* Model metadata display (parameter size, quantization level, family)
* VRAM and CPU memory usage tracking via running process info
* Controlled prompt token length for reproducible benchmarks
* Benchstat and CSV output formats
## Building from Source
```
go build -o ollama-bench bench.go
./ollama-bench -model gpt-oss:20b -epochs 6 -format csv
go build -o ollama-bench ./cmd/bench
./ollama-bench -model gemma3 -epochs 6 -format csv
```
Using Go Run (without building)
```
go run bench.go -model gpt-oss:20b -epochs 3
go run ./cmd/bench -model gemma3 -epochs 3
```
## Usage
@@ -45,10 +49,16 @@ benchstat -col /name gemma.bench
./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
```
### Controlled Prompt Length
```
./ollama-bench -model gemma3 -epochs 6 -prompt-tokens 512
```
### Advanced Example
```
./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv
./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -warmup 2 -format csv -output results.csv
```
## Command Line Options
@@ -56,41 +66,48 @@ benchstat -col /name gemma.bench
| Option | Description | Default |
|----------|-------------|---------|
| -model | Comma-separated list of models to benchmark | (required) |
| -epochs | Number of iterations per model | 1 |
| -max-tokens | Maximum tokens for model response | 0 (unlimited) |
| -epochs | Number of iterations per model | 6 |
| -max-tokens | Maximum tokens for model response | 200 |
| -temperature | Temperature parameter | 0.0 |
| -seed | Random seed | 0 (random) |
| -timeout | Timeout in seconds | 300 |
| -p | Prompt text | "Write a long story." |
| -p | Prompt text | (default story prompt) |
| -image | Image file to include in prompt | |
| -k | Keep-alive duration in seconds | 0 |
| -format | Output format (benchstat, csv) | benchstat |
| -output | Output file for results | "" (stdout) |
| -warmup | Number of warmup requests before timing | 1 |
| -prompt-tokens | Generate prompt targeting ~N tokens (0 = use -p) | 0 |
| -v | Verbose mode | false |
| -debug | Show debug information | false |
## Output Formats
### Markdown Format
### Benchstat Format (default)
The default markdown format is suitable for copying and pasting into a GitHub issue and will look like:
```
Model | Step | Count | Duration | nsPerToken | tokensPerSec |
|-------|------|-------|----------|------------|--------------|
| gpt-oss:20b | prefill | 124 | 30.006458ms | 241987.56 | 4132.44 |
| gpt-oss:20b | generate | 200 | 2.646843954s | 13234219.77 | 75.56 |
| gpt-oss:20b | load | 1 | 121.674208ms | - | - |
| gpt-oss:20b | total | 1 | 2.861047625s | - | - |
```
### Benchstat Format
Compatible with Go's benchstat tool for statistical analysis:
Compatible with Go's benchstat tool for statistical analysis. Uses one value/unit pair per line, standard `ns/op` for timing metrics, and `ns/token` for throughput. Each epoch produces one set of lines -- benchstat aggregates across repeated runs to compute statistics.
```
BenchmarkModel/name=gpt-oss:20b/step=prefill 128 78125.00 ns/token 12800.00 token/sec
BenchmarkModel/name=gpt-oss:20b/step=generate 512 19531.25 ns/token 51200.00 token/sec
BenchmarkModel/name=gpt-oss:20b/step=load 1 1500000000 ns/request
# Model: gemma3 | Params: 4.3B | Quant: Q4_K_M | Family: gemma3 | Size: 4080218931 | VRAM: 4080218931
BenchmarkModel/name=gemma3/step=prefill 1 78125.00 ns/token 12800.00 token/sec
BenchmarkModel/name=gemma3/step=generate 1 19531.25 ns/token 51200.00 token/sec
BenchmarkModel/name=gemma3/step=ttft 1 45123000 ns/op
BenchmarkModel/name=gemma3/step=load 1 1500000000 ns/op
BenchmarkModel/name=gemma3/step=total 1 2861047625 ns/op
```
Use with benchstat:
```
./ollama-bench -model gemma3 -epochs 6 > gemma3.bench
benchstat -col /step gemma3.bench
```
Compare two runs:
```
./ollama-bench -model gemma3 -epochs 6 > before.bench
# ... make changes ...
./ollama-bench -model gemma3 -epochs 6 > after.bench
benchstat before.bench after.bench
```
### CSV Format
@@ -99,17 +116,28 @@ Machine-readable comma-separated values:
```
NAME,STEP,COUNT,NS_PER_COUNT,TOKEN_PER_SEC
gpt-oss:20b,prefill,128,78125.00,12800.00
gpt-oss:20b,generate,512,19531.25,51200.00
gpt-oss:20b,load,1,1500000000,0
# Model: gemma3 | Params: 4.3B | Quant: Q4_K_M | Family: gemma3 | Size: 4080218931 | VRAM: 4080218931
gemma3,prefill,128,78125.00,12800.00
gemma3,generate,512,19531.25,51200.00
gemma3,ttft,1,45123000,0
gemma3,load,1,1500000000,0
gemma3,total,1,2861047625,0
```
## Metrics Explained
The tool reports four types of metrics for each model:
The tool reports the following metrics for each epoch:
* prefill: Time spent processing the prompt
* generate: Time spent generating the response
* load: Model loading time (one-time cost)
* total: Total request duration
* **prefill**: Time spent processing the prompt (ns/token)
* **generate**: Time spent generating the response (ns/token)
* **ttft**: Time to first token -- latency from request start to first response content
* **load**: Model loading time (one-time cost)
* **total**: Total request duration
Additionally, the model info comment line (displayed once per model before epochs) includes:
* **Params**: Model parameter count (e.g., 4.3B)
* **Quant**: Quantization level (e.g., Q4_K_M)
* **Family**: Model family (e.g., gemma3)
* **Size**: Total model memory in bytes
* **VRAM**: GPU memory used by the loaded model (when Size > VRAM, the difference is CPU spill)

View File

@@ -17,19 +17,21 @@ import (
)
type flagOptions struct {
models *string
epochs *int
maxTokens *int
temperature *float64
seed *int
timeout *int
prompt *string
imageFile *string
keepAlive *float64
format *string
outputFile *string
debug *bool
verbose *bool
models *string
epochs *int
maxTokens *int
temperature *float64
seed *int
timeout *int
prompt *string
imageFile *string
keepAlive *float64
format *string
outputFile *string
debug *bool
verbose *bool
warmup *int
promptTokens *int
}
type Metrics struct {
@@ -39,48 +41,169 @@ type Metrics struct {
Duration time.Duration
}
var once sync.Once
type ModelInfo struct {
Name string
ParameterSize string
QuantizationLevel string
Family string
SizeBytes int64
VRAMBytes int64
}
const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.`
// Word list for generating prompts targeting a specific token count.
var promptWordList = []string{
"the", "quick", "brown", "fox", "jumps", "over", "lazy", "dog",
"a", "bright", "sunny", "day", "in", "the", "meadow", "where",
"flowers", "bloom", "and", "birds", "sing", "their", "morning",
"songs", "while", "gentle", "breeze", "carries", "sweet", "scent",
"of", "pine", "trees", "across", "rolling", "hills", "toward",
"distant", "mountains", "covered", "with", "fresh", "snow",
"beneath", "clear", "blue", "sky", "children", "play", "near",
"old", "stone", "bridge", "that", "crosses", "winding", "river",
}
func generatePromptForTokenCount(targetTokens int, epoch int) string {
// ~1.3 tokens per word heuristic
targetWords := int(float64(targetTokens) / 1.3)
if targetWords < 1 {
targetWords = 1
}
// Vary the starting offset by epoch to defeat KV cache prefix matching
offset := epoch * 7 // stride by a prime to get good distribution
n := len(promptWordList)
words := make([]string, targetWords)
for i := range words {
words[i] = promptWordList[((i+offset)%n+n)%n]
}
return strings.Join(words, " ")
}
func buildGenerateRequest(model string, fOpt flagOptions, imgData api.ImageData, epoch int) *api.GenerateRequest {
options := make(map[string]interface{})
if *fOpt.maxTokens > 0 {
options["num_predict"] = *fOpt.maxTokens
}
options["temperature"] = *fOpt.temperature
if fOpt.seed != nil && *fOpt.seed > 0 {
options["seed"] = *fOpt.seed
}
var keepAliveDuration *api.Duration
if *fOpt.keepAlive > 0 {
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
keepAliveDuration = &duration
}
prompt := *fOpt.prompt
if *fOpt.promptTokens > 0 {
prompt = generatePromptForTokenCount(*fOpt.promptTokens, epoch)
} else {
// Vary the prompt per epoch to defeat KV cache prefix matching
prompt = fmt.Sprintf("[%d] %s", epoch, prompt)
}
req := &api.GenerateRequest{
Model: model,
Prompt: prompt,
Raw: true,
Options: options,
KeepAlive: keepAliveDuration,
}
if imgData != nil {
req.Images = []api.ImageData{imgData}
}
return req
}
func fetchModelInfo(ctx context.Context, client *api.Client, model string) ModelInfo {
info := ModelInfo{Name: model}
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
if err != nil {
fmt.Fprintf(os.Stderr, "WARNING: Could not fetch model info for '%s': %v\n", model, err)
return info
}
info.ParameterSize = resp.Details.ParameterSize
info.QuantizationLevel = resp.Details.QuantizationLevel
info.Family = resp.Details.Family
return info
}
func fetchMemoryUsage(ctx context.Context, client *api.Client, model string) (size, vram int64) {
resp, err := client.ListRunning(ctx)
if err != nil {
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
fmt.Fprintf(os.Stderr, "WARNING: Could not fetch memory usage: %v\n", err)
}
return 0, 0
}
for _, m := range resp.Models {
if m.Name == model || m.Model == model {
return m.Size, m.SizeVRAM
}
}
// Try prefix match (model names may include :latest or tags)
for _, m := range resp.Models {
if strings.HasPrefix(m.Name, model) || strings.HasPrefix(m.Model, model) {
return m.Size, m.SizeVRAM
}
}
return 0, 0
}
func outputFormatHeader(w io.Writer, format string, verbose bool) {
switch format {
case "benchstat":
if verbose {
fmt.Fprintf(w, "goos: %s\n", runtime.GOOS)
fmt.Fprintf(w, "goarch: %s\n", runtime.GOARCH)
}
case "csv":
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
fmt.Fprintln(w, strings.Join(headings, ","))
}
}
func outputModelInfo(w io.Writer, format string, info ModelInfo) {
params := cmp.Or(info.ParameterSize, "unknown")
quant := cmp.Or(info.QuantizationLevel, "unknown")
family := cmp.Or(info.Family, "unknown")
memStr := ""
if info.SizeBytes > 0 {
memStr = fmt.Sprintf(" | Size: %d | VRAM: %d", info.SizeBytes, info.VRAMBytes)
}
fmt.Fprintf(w, "# Model: %s | Params: %s | Quant: %s | Family: %s%s\n",
info.Name, params, quant, family, memStr)
}
func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) {
switch format {
case "benchstat":
if verbose {
printHeader := func() {
fmt.Fprintf(w, "sysname: %s\n", runtime.GOOS)
fmt.Fprintf(w, "machine: %s\n", runtime.GOARCH)
}
once.Do(printHeader)
}
for _, m := range metrics {
if m.Step == "generate" || m.Step == "prefill" {
if m.Count > 0 {
nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count)
tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d %.2f ns/token %.2f token/sec\n",
m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 %.2f ns/token %.2f token/sec\n",
m.Model, m.Step, nsPerToken, tokensPerSec)
} else {
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d 0 ns/token 0 token/sec\n",
m.Model, m.Step, m.Count)
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 0 ns/token 0 token/sec\n",
m.Model, m.Step)
}
} else if m.Step == "ttft" {
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=ttft 1 %d ns/op\n",
m.Model, m.Duration.Nanoseconds())
} else {
var suffix string
if m.Step == "load" {
suffix = "/step=load"
}
fmt.Fprintf(w, "BenchmarkModel/name=%s%s 1 %d ns/request\n",
m.Model, suffix, m.Duration.Nanoseconds())
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 %d ns/op\n",
m.Model, m.Step, m.Duration.Nanoseconds())
}
}
case "csv":
printHeader := func() {
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
fmt.Fprintln(w, strings.Join(headings, ","))
}
once.Do(printHeader)
for _, m := range metrics {
if m.Step == "generate" || m.Step == "prefill" {
var nsPerToken float64
@@ -94,39 +217,14 @@ func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool)
fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds())
}
}
case "markdown":
printHeader := func() {
fmt.Fprintln(w, "| Model | Step | Count | Duration | nsPerToken | tokensPerSec |")
fmt.Fprintln(w, "|-------|------|-------|----------|------------|--------------|")
}
once.Do(printHeader)
for _, m := range metrics {
var nsPerToken, tokensPerSec float64
var nsPerTokenStr, tokensPerSecStr string
if m.Step == "generate" || m.Step == "prefill" {
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
nsPerTokenStr = fmt.Sprintf("%.2f", nsPerToken)
tokensPerSecStr = fmt.Sprintf("%.2f", tokensPerSec)
} else {
nsPerTokenStr = "-"
tokensPerSecStr = "-"
}
fmt.Fprintf(w, "| %s | %s | %d | %v | %s | %s |\n",
m.Model, m.Step, m.Count, m.Duration, nsPerTokenStr, tokensPerSecStr)
}
default:
fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format)
}
}
func BenchmarkChat(fOpt flagOptions) error {
func BenchmarkModel(fOpt flagOptions) error {
models := strings.Split(*fOpt.models, ",")
// todo - add multi-image support
var imgData api.ImageData
var err error
if *fOpt.imageFile != "" {
@@ -158,71 +256,124 @@ func BenchmarkChat(fOpt flagOptions) error {
out = f
}
outputFormatHeader(out, *fOpt.format, *fOpt.verbose)
// Log prompt-tokens info in debug mode
if *fOpt.debug && *fOpt.promptTokens > 0 {
prompt := generatePromptForTokenCount(*fOpt.promptTokens, 0)
wordCount := len(strings.Fields(prompt))
fmt.Fprintf(os.Stderr, "Generated prompt targeting ~%d tokens (%d words, varied per epoch)\n", *fOpt.promptTokens, wordCount)
}
for _, model := range models {
for range *fOpt.epochs {
options := make(map[string]interface{})
if *fOpt.maxTokens > 0 {
options["num_predict"] = *fOpt.maxTokens
}
options["temperature"] = *fOpt.temperature
if fOpt.seed != nil && *fOpt.seed > 0 {
options["seed"] = *fOpt.seed
}
var keepAliveDuration *api.Duration
if *fOpt.keepAlive > 0 {
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
keepAliveDuration = &duration
}
req := &api.ChatRequest{
Model: model,
Messages: []api.Message{
{
Role: "user",
Content: *fOpt.prompt,
},
},
Options: options,
KeepAlive: keepAliveDuration,
}
if imgData != nil {
req.Messages[0].Images = []api.ImageData{imgData}
}
var responseMetrics *api.Metrics
// Fetch model info
infoCtx, infoCancel := context.WithTimeout(context.Background(), 10*time.Second)
info := fetchModelInfo(infoCtx, client, model)
infoCancel()
// Warmup phase (uses negative epoch numbers to avoid colliding with timed epochs)
for i := range *fOpt.warmup {
req := buildGenerateRequest(model, fOpt, imgData, -(i + 1))
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
defer cancel()
err = client.Chat(ctx, req, func(resp api.ChatResponse) error {
if *fOpt.debug {
fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Message.Thinking, resp.Message.Content))
}
if resp.Done {
responseMetrics = &resp.Metrics
}
err = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
return nil
})
if *fOpt.debug {
fmt.Fprintln(os.Stderr)
}
cancel()
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
fmt.Fprintf(os.Stderr, "ERROR: Chat request timed out with model '%s' after %vs\n", model, 1)
continue
fmt.Fprintf(os.Stderr, "WARNING: Warmup %d/%d for %s failed: %v\n", i+1, *fOpt.warmup, model, err)
} else if *fOpt.debug {
fmt.Fprintf(os.Stderr, "Warmup %d/%d for %s complete\n", i+1, *fOpt.warmup, model)
}
}
// Fetch memory usage once after warmup (model is loaded and stable)
memCtx, memCancel := context.WithTimeout(context.Background(), 5*time.Second)
info.SizeBytes, info.VRAMBytes = fetchMemoryUsage(memCtx, client, model)
memCancel()
outputModelInfo(out, *fOpt.format, info)
// Timed epoch loop
shortCount := 0
for epoch := range *fOpt.epochs {
var responseMetrics *api.Metrics
var ttft time.Duration
short := false
// Retry loop: if the model hits a stop token before max-tokens,
// retry with a different prompt (up to maxRetries times).
const maxRetries = 3
for attempt := range maxRetries + 1 {
responseMetrics = nil
ttft = 0
var ttftOnce sync.Once
req := buildGenerateRequest(model, fOpt, imgData, epoch+attempt*1000)
requestStart := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
err = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
if *fOpt.debug {
fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Thinking, resp.Response))
}
// Capture TTFT on first content
ttftOnce.Do(func() {
if resp.Response != "" || resp.Thinking != "" {
ttft = time.Since(requestStart)
}
})
if resp.Done {
responseMetrics = &resp.Metrics
}
return nil
})
cancel()
if *fOpt.debug {
fmt.Fprintln(os.Stderr)
}
fmt.Fprintf(os.Stderr, "ERROR: Couldn't chat with model '%s': %v\n", model, err)
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
fmt.Fprintf(os.Stderr, "ERROR: Request timed out with model '%s' after %vs\n", model, *fOpt.timeout)
} else {
fmt.Fprintf(os.Stderr, "ERROR: Couldn't generate with model '%s': %v\n", model, err)
}
break
}
if responseMetrics == nil {
fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model)
break
}
// Check if the response was shorter than requested
short = *fOpt.maxTokens > 0 && responseMetrics.EvalCount < *fOpt.maxTokens
if !short || attempt == maxRetries {
break
}
if *fOpt.debug {
fmt.Fprintf(os.Stderr, "Short response (%d/%d tokens), retrying with different prompt (attempt %d/%d)\n",
responseMetrics.EvalCount, *fOpt.maxTokens, attempt+1, maxRetries)
}
}
if err != nil || responseMetrics == nil {
continue
}
if responseMetrics == nil {
fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model)
continue
if short {
shortCount++
if *fOpt.debug {
fmt.Fprintf(os.Stderr, "WARNING: Short response (%d/%d tokens) after %d retries for epoch %d\n",
responseMetrics.EvalCount, *fOpt.maxTokens, maxRetries, epoch+1)
}
}
metrics := []Metrics{
@@ -238,6 +389,12 @@ func BenchmarkChat(fOpt flagOptions) error {
Count: responseMetrics.EvalCount,
Duration: responseMetrics.EvalDuration,
},
{
Model: model,
Step: "ttft",
Count: 1,
Duration: ttft,
},
{
Model: model,
Step: "load",
@@ -254,15 +411,42 @@ func BenchmarkChat(fOpt flagOptions) error {
OutputMetrics(out, *fOpt.format, metrics, *fOpt.verbose)
if *fOpt.debug && *fOpt.promptTokens > 0 {
fmt.Fprintf(os.Stderr, "Generated prompt targeting ~%d tokens (actual: %d)\n",
*fOpt.promptTokens, responseMetrics.PromptEvalCount)
}
if *fOpt.keepAlive > 0 {
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
}
}
if shortCount > 0 {
fmt.Fprintf(os.Stderr, "WARNING: %d/%d epochs for '%s' had short responses (<%d tokens). Generation metrics may be unreliable.\n",
shortCount, *fOpt.epochs, model, *fOpt.maxTokens)
}
// Unload model before moving to the next one
unloadModel(client, model, *fOpt.timeout)
}
return nil
}
func unloadModel(client *api.Client, model string, timeout int) {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
defer cancel()
zero := api.Duration{Duration: 0}
req := &api.GenerateRequest{
Model: model,
KeepAlive: &zero,
}
_ = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
return nil
})
}
func readImage(filePath string) (api.ImageData, error) {
file, err := os.Open(filePath)
if err != nil {
@@ -280,19 +464,21 @@ func readImage(filePath string) (api.ImageData, error) {
func main() {
fOpt := flagOptions{
models: flag.String("model", "", "Model to benchmark"),
epochs: flag.Int("epochs", 6, "Number of epochs (iterations) per model"),
maxTokens: flag.Int("max-tokens", 200, "Maximum tokens for model response"),
temperature: flag.Float64("temperature", 0, "Temperature parameter"),
seed: flag.Int("seed", 0, "Random seed"),
timeout: flag.Int("timeout", 60*5, "Timeout in seconds (default 300s)"),
prompt: flag.String("p", DefaultPrompt, "Prompt to use"),
imageFile: flag.String("image", "", "Filename for an image to include"),
keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"),
format: flag.String("format", "markdown", "Output format [benchstat|csv] (default benchstat)"),
outputFile: flag.String("output", "", "Output file for results (stdout if empty)"),
verbose: flag.Bool("v", false, "Show system information"),
debug: flag.Bool("debug", false, "Show debug information"),
models: flag.String("model", "", "Model to benchmark"),
epochs: flag.Int("epochs", 6, "Number of epochs (iterations) per model"),
maxTokens: flag.Int("max-tokens", 200, "Maximum tokens for model response"),
temperature: flag.Float64("temperature", 0, "Temperature parameter"),
seed: flag.Int("seed", 0, "Random seed"),
timeout: flag.Int("timeout", 60*5, "Timeout in seconds (default 300s)"),
prompt: flag.String("p", DefaultPrompt, "Prompt to use"),
imageFile: flag.String("image", "", "Filename for an image to include"),
keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"),
format: flag.String("format", "benchstat", "Output format [benchstat|csv]"),
outputFile: flag.String("output", "", "Output file for results (stdout if empty)"),
verbose: flag.Bool("v", false, "Show system information"),
debug: flag.Bool("debug", false, "Show debug information"),
warmup: flag.Int("warmup", 1, "Number of warmup requests before timing"),
promptTokens: flag.Int("prompt-tokens", 0, "Generate prompt targeting ~N tokens (0 = use -p prompt)"),
}
flag.Usage = func() {
@@ -302,11 +488,12 @@ func main() {
fmt.Fprintf(os.Stderr, "Options:\n")
flag.PrintDefaults()
fmt.Fprintf(os.Stderr, "\nExamples:\n")
fmt.Fprintf(os.Stderr, " bench -model gpt-oss:20b -epochs 3 -temperature 0.7\n")
fmt.Fprintf(os.Stderr, " bench -model gemma3,llama3 -epochs 6\n")
fmt.Fprintf(os.Stderr, " bench -model gemma3 -epochs 6 -prompt-tokens 512 -format csv\n")
}
flag.Parse()
if !slices.Contains([]string{"markdown", "benchstat", "csv"}, *fOpt.format) {
if !slices.Contains([]string{"benchstat", "csv"}, *fOpt.format) {
fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format)
os.Exit(1)
}
@@ -317,5 +504,5 @@ func main() {
return
}
BenchmarkChat(fOpt)
BenchmarkModel(fOpt)
}

File diff suppressed because it is too large Load Diff

View File

@@ -11,6 +11,7 @@ import (
"fmt"
"io"
"log"
"log/slog"
"math"
"net"
"net/http"
@@ -38,9 +39,12 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/cmd/launch"
"github.com/ollama/ollama/cmd/tui"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
@@ -57,36 +61,42 @@ import (
func init() {
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
config.DefaultSingleSelector = func(title string, items []config.ModelItem, current string) (string, error) {
launch.DefaultSingleSelector = func(title string, items []launch.ModelItem, current string) (string, error) {
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
return "", fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode")
}
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
result, err := tui.SelectSingle(title, tuiItems, current)
if errors.Is(err, tui.ErrCancelled) {
return "", config.ErrCancelled
return "", launch.ErrCancelled
}
return result, err
}
config.DefaultMultiSelector = func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
launch.DefaultMultiSelector = func(title string, items []launch.ModelItem, preChecked []string) ([]string, error) {
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
return nil, fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode")
}
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
if errors.Is(err, tui.ErrCancelled) {
return nil, config.ErrCancelled
return nil, launch.ErrCancelled
}
return result, err
}
config.DefaultSignIn = func(modelName, signInURL string) (string, error) {
launch.DefaultSignIn = func(modelName, signInURL string) (string, error) {
userName, err := tui.RunSignIn(modelName, signInURL)
if errors.Is(err, tui.ErrCancelled) {
return "", config.ErrCancelled
return "", launch.ErrCancelled
}
return userName, err
}
config.DefaultConfirmPrompt = func(prompt string) (bool, error) {
launch.DefaultConfirmPrompt = func(prompt string) (bool, error) {
ok, err := tui.RunConfirm(prompt)
if errors.Is(err, tui.ErrCancelled) {
return false, config.ErrCancelled
return false, launch.ErrCancelled
}
return ok, err
}
@@ -131,6 +141,17 @@ func getModelfileName(cmd *cobra.Command) (string, error) {
return absName, nil
}
// isLocalhost returns true if the configured Ollama host is a loopback or unspecified address.
func isLocalhost() bool {
host := envconfig.Host()
h, _, _ := net.SplitHostPort(host.Host)
if h == "localhost" {
return true
}
ip := net.ParseIP(h)
return ip != nil && (ip.IsLoopback() || ip.IsUnspecified())
}
func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()
@@ -145,6 +166,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
// Check for --experimental flag for safetensors model creation
experimental, _ := cmd.Flags().GetBool("experimental")
if experimental {
if !isLocalhost() {
return errors.New("remote safetensor model creation not yet supported")
}
// Get Modelfile content - either from -f flag or default to "FROM ."
var reader io.Reader
filename, err := getModelfileName(cmd)
@@ -168,29 +192,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to parse Modelfile: %w", err)
}
// Extract FROM path and configuration
var modelDir string
mfConfig := &xcreateclient.ModelfileConfig{}
for _, cmd := range modelfile.Commands {
switch cmd.Name {
case "model":
modelDir = cmd.Args
case "template":
mfConfig.Template = cmd.Args
case "system":
mfConfig.System = cmd.Args
case "license":
mfConfig.License = cmd.Args
case "parser":
mfConfig.Parser = cmd.Args
case "renderer":
mfConfig.Renderer = cmd.Args
}
}
if modelDir == "" {
modelDir = "."
modelDir, mfConfig, err := xcreateclient.ConfigFromModelfile(modelfile)
if err != nil {
return err
}
// Resolve relative paths based on Modelfile location
@@ -214,6 +218,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
if filename == "" {
// No Modelfile found - check if current directory is an image gen model
if create.IsTensorModelDir(".") {
if !isLocalhost() {
return errors.New("remote safetensor model creation not yet supported")
}
quantize, _ := cmd.Flags().GetString("quantize")
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
ModelName: modelName,
@@ -406,12 +413,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
return err
}
requestedCloud := modelref.HasExplicitCloudSource(opts.Model)
if info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model}); err != nil {
return err
} else if info.RemoteHost != "" {
} else if info.RemoteHost != "" || requestedCloud {
// Cloud model, no need to load/unload
isCloud := strings.HasPrefix(info.RemoteHost, "https://ollama.com")
isCloud := requestedCloud || strings.HasPrefix(info.RemoteHost, "https://ollama.com")
// Check if user is signed in for ollama.com cloud models
if isCloud {
@@ -422,10 +431,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
if opts.ShowConnect {
p.StopAndClear()
remoteModel := info.RemoteModel
if remoteModel == "" {
remoteModel = opts.Model
}
if isCloud {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", remoteModel)
} else {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", remoteModel, info.RemoteHost)
}
}
@@ -497,6 +510,64 @@ func generateEmbedding(cmd *cobra.Command, modelName, input string, keepAlive *a
return nil
}
// TODO(parthsareen): consolidate with TUI signin flow
func handleCloudAuthorizationError(err error) bool {
var authErr api.AuthorizationError
if errors.As(err, &authErr) && authErr.StatusCode == http.StatusUnauthorized {
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
if authErr.SigninURL != "" {
fmt.Printf(ConnectInstructions, authErr.SigninURL)
}
return true
}
return false
}
// TEMP(drifkin): To match legacy `ollama run some-model:cloud` behavior, we
// best-effort pull cloud stub files for any explicit cloud source models.
// Remove this once `/api/tags` is cloud-aware.
func ensureCloudStub(ctx context.Context, client *api.Client, modelName string) {
if !modelref.HasExplicitCloudSource(modelName) {
return
}
normalizedName, _, err := modelref.NormalizePullName(modelName)
if err != nil {
slog.Warn("failed to normalize pull name", "model", modelName, "error", err, "normalizedName", normalizedName)
return
}
listResp, err := client.List(ctx)
if err != nil {
slog.Warn("failed to list models", "error", err)
return
}
if hasListedModelName(listResp.Models, modelName) || hasListedModelName(listResp.Models, normalizedName) {
return
}
logutil.Trace("pulling cloud stub", "model", modelName, "normalizedName", normalizedName)
err = client.Pull(ctx, &api.PullRequest{
Model: normalizedName,
}, func(api.ProgressResponse) error {
return nil
})
if err != nil {
slog.Warn("failed to pull cloud stub", "model", modelName, "error", err)
}
}
func hasListedModelName(models []api.ListModelResponse, name string) bool {
for _, m := range models {
if strings.EqualFold(m.Name, name) || strings.EqualFold(m.Model, name) {
return true
}
}
return false
}
func RunHandler(cmd *cobra.Command, args []string) error {
interactive := true
@@ -585,17 +656,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
opts.WordWrap = !nowrap
useImagegen := false
if cmd.Flags().Lookup("imagegen") != nil {
useImagegen, err = cmd.Flags().GetBool("imagegen")
if err != nil {
return err
}
}
if useImagegen {
opts.Options["use_imagegen_runner"] = true
}
// Fill out the rest of the options based on information about the
// model.
client, err := api.ClientFromEnvironment()
@@ -604,12 +664,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
name := args[0]
requestedCloud := modelref.HasExplicitCloudSource(name)
info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: name}
info, err := client.Show(cmd.Context(), showReq)
var se api.StatusError
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
if requestedCloud {
return nil, err
}
if err := PullHandler(cmd, []string{name}); err != nil {
return nil, err
}
@@ -618,9 +682,14 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return info, err
}()
if err != nil {
if handleCloudAuthorizationError(err) {
return nil
}
return err
}
ensureCloudStub(cmd.Context(), client, name)
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkFlag.Changed)
if err != nil {
return err
@@ -712,7 +781,13 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generateInteractive(cmd, opts)
}
return generate(cmd, opts)
if err := generate(cmd, opts); err != nil {
if handleCloudAuthorizationError(err) {
return nil
}
return err
}
return nil
}
func SigninHandler(cmd *cobra.Command, args []string) error {
@@ -1892,6 +1967,24 @@ func ensureServerRunning(ctx context.Context) error {
}
}
func launchInteractiveModel(cmd *cobra.Command, modelName string) error {
opts := runOptions{
Model: modelName,
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]any{},
ShowConnect: true,
}
// loadOrUnloadModel is cloud-safe here: remote/cloud models skip local preload
// and only validate auth/connectivity before interactive chat starts.
if err := loadOrUnloadModel(cmd, &opts); err != nil {
return fmt.Errorf("error loading model: %w", err)
}
if err := generateInteractive(cmd, opts); err != nil {
return fmt.Errorf("error running model: %w", err)
}
return nil
}
// runInteractiveTUI runs the main interactive TUI menu.
func runInteractiveTUI(cmd *cobra.Command) {
// Ensure the server is running before showing the TUI
@@ -1900,171 +1993,81 @@ func runInteractiveTUI(cmd *cobra.Command) {
return
}
// Selector adapters for tui
singleSelector := func(title string, items []config.ModelItem, current string) (string, error) {
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
result, err := tui.SelectSingle(title, tuiItems, current)
if errors.Is(err, tui.ErrCancelled) {
return "", config.ErrCancelled
}
return result, err
}
multiSelector := func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
if errors.Is(err, tui.ErrCancelled) {
return nil, config.ErrCancelled
}
return result, err
deps := launcherDeps{
buildState: launch.BuildLauncherState,
runMenu: tui.RunMenu,
resolveRunModel: launch.ResolveRunModel,
launchIntegration: launch.LaunchIntegration,
runModel: launchInteractiveModel,
}
for {
result, err := tui.Run()
continueLoop, err := runInteractiveTUIStep(cmd, deps)
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
}
if !continueLoop {
return
}
}
}
runModel := func(modelName string) {
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
return
}
if err := config.ShowOrPull(cmd.Context(), client, modelName); err != nil {
if errors.Is(err, config.ErrCancelled) {
return
}
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
return
}
_ = config.SetLastModel(modelName)
opts := runOptions{
Model: modelName,
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]any{},
ShowConnect: true,
}
if err := loadOrUnloadModel(cmd, &opts); err != nil {
fmt.Fprintf(os.Stderr, "Error loading model: %v\n", err)
return
}
if err := generateInteractive(cmd, opts); err != nil {
fmt.Fprintf(os.Stderr, "Error running model: %v\n", err)
}
}
type launcherDeps struct {
buildState func(context.Context) (*launch.LauncherState, error)
runMenu func(*launch.LauncherState) (tui.TUIAction, error)
resolveRunModel func(context.Context, launch.RunModelRequest) (string, error)
launchIntegration func(context.Context, launch.IntegrationLaunchRequest) error
runModel func(*cobra.Command, string) error
}
launchIntegration := func(name string) bool {
// If not configured or model no longer exists, prompt for model selection
configuredModel := config.IntegrationModel(name)
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) || config.IsCloudModelDisabled(cmd.Context(), configuredModel) {
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), name, singleSelector, multiSelector)
if errors.Is(err, config.ErrCancelled) {
return false // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", name, err)
return true
}
}
if err := config.LaunchIntegration(name); err != nil {
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", name, err)
}
return true
}
func runInteractiveTUIStep(cmd *cobra.Command, deps launcherDeps) (bool, error) {
state, err := deps.buildState(cmd.Context())
if err != nil {
return false, fmt.Errorf("build launcher state: %w", err)
}
switch result.Selection {
case tui.SelectionNone:
// User quit
return
case tui.SelectionRunModel:
_ = config.SetLastSelection("run")
if modelName := config.LastModel(); modelName != "" && !config.IsCloudModelDisabled(cmd.Context(), modelName) {
runModel(modelName)
} else {
modelName, err := config.SelectModelWithSelector(cmd.Context(), singleSelector)
if errors.Is(err, config.ErrCancelled) {
continue // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
continue
}
runModel(modelName)
}
case tui.SelectionChangeRunModel:
_ = config.SetLastSelection("run")
// Use model from modal if selected, otherwise show picker
modelName := result.Model
if modelName == "" {
var err error
modelName, err = config.SelectModelWithSelector(cmd.Context(), singleSelector)
if errors.Is(err, config.ErrCancelled) {
continue // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
continue
}
}
if config.IsCloudModelDisabled(cmd.Context(), modelName) {
continue // Return to main menu
}
runModel(modelName)
case tui.SelectionIntegration:
_ = config.SetLastSelection(result.Integration)
if !launchIntegration(result.Integration) {
continue // Return to main menu
}
case tui.SelectionChangeIntegration:
_ = config.SetLastSelection(result.Integration)
if len(result.Models) > 0 {
// Filter out cloud-disabled models
var filtered []string
for _, m := range result.Models {
if !config.IsCloudModelDisabled(cmd.Context(), m) {
filtered = append(filtered, m)
}
}
if len(filtered) == 0 {
continue
}
result.Models = filtered
// Multi-select from modal (Editor integrations)
if err := config.SaveAndEditIntegration(result.Integration, result.Models); err != nil {
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
continue
}
if err := config.LaunchIntegrationWithModel(result.Integration, result.Models[0]); err != nil {
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
}
} else if result.Model != "" {
if config.IsCloudModelDisabled(cmd.Context(), result.Model) {
continue
}
// Single-select from modal - save and launch
if err := config.SaveIntegration(result.Integration, []string{result.Model}); err != nil {
fmt.Fprintf(os.Stderr, "Error saving config: %v\n", err)
continue
}
if err := config.LaunchIntegrationWithModel(result.Integration, result.Model); err != nil {
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
}
} else {
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), result.Integration, singleSelector, multiSelector)
if errors.Is(err, config.ErrCancelled) {
continue // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
continue
}
if err := config.LaunchIntegration(result.Integration); err != nil {
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
}
}
action, err := deps.runMenu(state)
if err != nil {
return false, fmt.Errorf("run launcher menu: %w", err)
}
return runLauncherAction(cmd, action, deps)
}
func saveLauncherSelection(action tui.TUIAction) {
// Best effort only: this affects menu recall, not launch correctness.
_ = config.SetLastSelection(action.LastSelection())
}
func runLauncherAction(cmd *cobra.Command, action tui.TUIAction, deps launcherDeps) (bool, error) {
switch action.Kind {
case tui.TUIActionNone:
return false, nil
case tui.TUIActionRunModel:
saveLauncherSelection(action)
modelName, err := deps.resolveRunModel(cmd.Context(), action.RunModelRequest())
if errors.Is(err, launch.ErrCancelled) {
return true, nil
}
if err != nil {
return true, fmt.Errorf("selecting model: %w", err)
}
if err := deps.runModel(cmd, modelName); err != nil {
return true, err
}
return true, nil
case tui.TUIActionLaunchIntegration:
saveLauncherSelection(action)
err := deps.launchIntegration(cmd.Context(), action.IntegrationLaunchRequest())
if errors.Is(err, launch.ErrCancelled) {
return true, nil
}
if err != nil {
return true, fmt.Errorf("launching %s: %w", action.Integration, err)
}
return true, nil
default:
return false, fmt.Errorf("unknown launcher action: %d", action.Kind)
}
}
@@ -2334,7 +2337,7 @@ func NewCLI() *cobra.Command {
copyCmd,
deleteCmd,
runnerCmd,
config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
launch.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
)
return rootCmd

233
cmd/cmd_launcher_test.go Normal file
View File

@@ -0,0 +1,233 @@
package cmd
import (
"context"
"testing"
"github.com/spf13/cobra"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/cmd/launch"
"github.com/ollama/ollama/cmd/tui"
)
func setCmdTestHome(t *testing.T, dir string) {
t.Helper()
t.Setenv("HOME", dir)
t.Setenv("USERPROFILE", dir)
}
func unexpectedRunModelResolution(t *testing.T) func(context.Context, launch.RunModelRequest) (string, error) {
t.Helper()
return func(ctx context.Context, req launch.RunModelRequest) (string, error) {
t.Fatalf("did not expect run-model resolution: %+v", req)
return "", nil
}
}
func unexpectedIntegrationLaunch(t *testing.T) func(context.Context, launch.IntegrationLaunchRequest) error {
t.Helper()
return func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
t.Fatalf("did not expect integration launch: %+v", req)
return nil
}
}
func unexpectedModelLaunch(t *testing.T) func(*cobra.Command, string) error {
t.Helper()
return func(cmd *cobra.Command, model string) error {
t.Fatalf("did not expect chat launch: %s", model)
return nil
}
}
func TestRunInteractiveTUI_RunModelActionsUseResolveRunModel(t *testing.T) {
tests := []struct {
name string
action tui.TUIAction
wantForce bool
wantModel string
}{
{
name: "enter uses saved model flow",
action: tui.TUIAction{Kind: tui.TUIActionRunModel},
wantModel: "qwen3:8b",
},
{
name: "right forces picker",
action: tui.TUIAction{Kind: tui.TUIActionRunModel, ForceConfigure: true},
wantForce: true,
wantModel: "glm-5:cloud",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setCmdTestHome(t, t.TempDir())
var menuCalls int
runMenu := func(state *launch.LauncherState) (tui.TUIAction, error) {
menuCalls++
if menuCalls == 1 {
return tt.action, nil
}
return tui.TUIAction{Kind: tui.TUIActionNone}, nil
}
var gotReq launch.RunModelRequest
var launched string
deps := launcherDeps{
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
return &launch.LauncherState{}, nil
},
runMenu: runMenu,
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
gotReq = req
return tt.wantModel, nil
},
launchIntegration: unexpectedIntegrationLaunch(t),
runModel: func(cmd *cobra.Command, model string) error {
launched = model
return nil
},
}
cmd := &cobra.Command{}
cmd.SetContext(context.Background())
for {
continueLoop, err := runInteractiveTUIStep(cmd, deps)
if err != nil {
t.Fatalf("unexpected step error: %v", err)
}
if !continueLoop {
break
}
}
if gotReq.ForcePicker != tt.wantForce {
t.Fatalf("expected ForcePicker=%v, got %v", tt.wantForce, gotReq.ForcePicker)
}
if launched != tt.wantModel {
t.Fatalf("expected interactive launcher to run %q, got %q", tt.wantModel, launched)
}
if got := config.LastSelection(); got != "run" {
t.Fatalf("expected last selection to be run, got %q", got)
}
})
}
}
func TestRunInteractiveTUI_IntegrationActionsUseLaunchIntegration(t *testing.T) {
tests := []struct {
name string
action tui.TUIAction
wantForce bool
}{
{
name: "enter launches integration",
action: tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"},
},
{
name: "right forces configure",
action: tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude", ForceConfigure: true},
wantForce: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setCmdTestHome(t, t.TempDir())
var menuCalls int
runMenu := func(state *launch.LauncherState) (tui.TUIAction, error) {
menuCalls++
if menuCalls == 1 {
return tt.action, nil
}
return tui.TUIAction{Kind: tui.TUIActionNone}, nil
}
var gotReq launch.IntegrationLaunchRequest
deps := launcherDeps{
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
return &launch.LauncherState{}, nil
},
runMenu: runMenu,
resolveRunModel: unexpectedRunModelResolution(t),
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
gotReq = req
return nil
},
runModel: unexpectedModelLaunch(t),
}
cmd := &cobra.Command{}
cmd.SetContext(context.Background())
for {
continueLoop, err := runInteractiveTUIStep(cmd, deps)
if err != nil {
t.Fatalf("unexpected step error: %v", err)
}
if !continueLoop {
break
}
}
if gotReq.Name != "claude" {
t.Fatalf("expected integration name to be passed through, got %q", gotReq.Name)
}
if gotReq.ForceConfigure != tt.wantForce {
t.Fatalf("expected ForceConfigure=%v, got %v", tt.wantForce, gotReq.ForceConfigure)
}
if got := config.LastSelection(); got != "claude" {
t.Fatalf("expected last selection to be claude, got %q", got)
}
})
}
}
func TestRunLauncherAction_RunModelContinuesAfterCancellation(t *testing.T) {
setCmdTestHome(t, t.TempDir())
cmd := &cobra.Command{}
cmd.SetContext(context.Background())
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionRunModel}, launcherDeps{
buildState: nil,
runMenu: nil,
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
return "", launch.ErrCancelled
},
launchIntegration: unexpectedIntegrationLaunch(t),
runModel: unexpectedModelLaunch(t),
})
if err != nil {
t.Fatalf("expected nil error on cancellation, got %v", err)
}
if !continueLoop {
t.Fatal("expected cancellation to continue the menu loop")
}
}
func TestRunLauncherAction_IntegrationContinuesAfterCancellation(t *testing.T) {
setCmdTestHome(t, t.TempDir())
cmd := &cobra.Command{}
cmd.SetContext(context.Background())
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"}, launcherDeps{
buildState: nil,
runMenu: nil,
resolveRunModel: unexpectedRunModelResolution(t),
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
return launch.ErrCancelled
},
runModel: unexpectedModelLaunch(t),
})
if err != nil {
t.Fatalf("expected nil error on cancellation, got %v", err)
}
if !continueLoop {
t.Fatal("expected cancellation to continue the menu loop")
}
}

View File

@@ -705,6 +705,347 @@ func TestRunEmbeddingModelNoInput(t *testing.T) {
}
}
func TestRunHandler_CloudAuthErrorOnShow_PrintsSigninMessage(t *testing.T) {
var generateCalled bool
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusUnauthorized)
if err := json.NewEncoder(w).Encode(map[string]string{
"error": "unauthorized",
"signin_url": "https://ollama.com/signin",
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
generateCalled = true
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
default:
http.NotFound(w, r)
}
}))
t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close)
cmd := &cobra.Command{}
cmd.SetContext(t.Context())
cmd.Flags().String("keepalive", "", "")
cmd.Flags().Bool("truncate", false, "")
cmd.Flags().Int("dimensions", 0, "")
cmd.Flags().Bool("verbose", false, "")
cmd.Flags().Bool("insecure", false, "")
cmd.Flags().Bool("nowordwrap", false, "")
cmd.Flags().String("format", "", "")
cmd.Flags().String("think", "", "")
cmd.Flags().Bool("hidethinking", false, "")
oldStdout := os.Stdout
readOut, writeOut, _ := os.Pipe()
os.Stdout = writeOut
t.Cleanup(func() { os.Stdout = oldStdout })
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
_ = writeOut.Close()
var out bytes.Buffer
_, _ = io.Copy(&out, readOut)
if err != nil {
t.Fatalf("RunHandler returned error: %v", err)
}
if generateCalled {
t.Fatal("expected run to stop before /api/generate after unauthorized /api/show")
}
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
t.Fatalf("expected sign-in guidance message, got %q", out.String())
}
if !strings.Contains(out.String(), "https://ollama.com/signin") {
t.Fatalf("expected signin_url in output, got %q", out.String())
}
}
func TestRunHandler_CloudAuthErrorOnGenerate_PrintsSigninMessage(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ShowResponse{
Capabilities: []model.Capability{model.CapabilityCompletion},
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusUnauthorized)
if err := json.NewEncoder(w).Encode(map[string]string{
"error": "unauthorized",
"signin_url": "https://ollama.com/signin",
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
default:
http.NotFound(w, r)
}
}))
t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close)
cmd := &cobra.Command{}
cmd.SetContext(t.Context())
cmd.Flags().String("keepalive", "", "")
cmd.Flags().Bool("truncate", false, "")
cmd.Flags().Int("dimensions", 0, "")
cmd.Flags().Bool("verbose", false, "")
cmd.Flags().Bool("insecure", false, "")
cmd.Flags().Bool("nowordwrap", false, "")
cmd.Flags().String("format", "", "")
cmd.Flags().String("think", "", "")
cmd.Flags().Bool("hidethinking", false, "")
oldStdout := os.Stdout
readOut, writeOut, _ := os.Pipe()
os.Stdout = writeOut
t.Cleanup(func() { os.Stdout = oldStdout })
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
_ = writeOut.Close()
var out bytes.Buffer
_, _ = io.Copy(&out, readOut)
if err != nil {
t.Fatalf("RunHandler returned error: %v", err)
}
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
t.Fatalf("expected sign-in guidance message, got %q", out.String())
}
if !strings.Contains(out.String(), "https://ollama.com/signin") {
t.Fatalf("expected signin_url in output, got %q", out.String())
}
}
func TestRunHandler_ExplicitCloudStubMissing_PullsNormalizedNameTEMP(t *testing.T) {
var pulledModel string
var generateCalled bool
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ShowResponse{
Capabilities: []model.Capability{model.CapabilityCompletion},
RemoteModel: "gpt-oss:20b",
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/tags" && r.Method == http.MethodGet:
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ListResponse{Models: nil}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/pull" && r.Method == http.MethodPost:
var req api.PullRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
pulledModel = req.Model
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ProgressResponse{Status: "success"}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
generateCalled = true
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
default:
http.NotFound(w, r)
}
}))
t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close)
cmd := &cobra.Command{}
cmd.SetContext(t.Context())
cmd.Flags().String("keepalive", "", "")
cmd.Flags().Bool("truncate", false, "")
cmd.Flags().Int("dimensions", 0, "")
cmd.Flags().Bool("verbose", false, "")
cmd.Flags().Bool("insecure", false, "")
cmd.Flags().Bool("nowordwrap", false, "")
cmd.Flags().String("format", "", "")
cmd.Flags().String("think", "", "")
cmd.Flags().Bool("hidethinking", false, "")
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
if err != nil {
t.Fatalf("RunHandler returned error: %v", err)
}
if pulledModel != "gpt-oss:20b-cloud" {
t.Fatalf("expected normalized pull model %q, got %q", "gpt-oss:20b-cloud", pulledModel)
}
if !generateCalled {
t.Fatal("expected /api/generate to be called")
}
}
func TestRunHandler_ExplicitCloudStubPresent_SkipsPullTEMP(t *testing.T) {
var pullCalled bool
var generateCalled bool
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ShowResponse{
Capabilities: []model.Capability{model.CapabilityCompletion},
RemoteModel: "gpt-oss:20b",
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/tags" && r.Method == http.MethodGet:
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ListResponse{
Models: []api.ListModelResponse{{Name: "gpt-oss:20b-cloud"}},
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/pull" && r.Method == http.MethodPost:
pullCalled = true
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ProgressResponse{Status: "success"}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
generateCalled = true
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
default:
http.NotFound(w, r)
}
}))
t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close)
cmd := &cobra.Command{}
cmd.SetContext(t.Context())
cmd.Flags().String("keepalive", "", "")
cmd.Flags().Bool("truncate", false, "")
cmd.Flags().Int("dimensions", 0, "")
cmd.Flags().Bool("verbose", false, "")
cmd.Flags().Bool("insecure", false, "")
cmd.Flags().Bool("nowordwrap", false, "")
cmd.Flags().String("format", "", "")
cmd.Flags().String("think", "", "")
cmd.Flags().Bool("hidethinking", false, "")
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
if err != nil {
t.Fatalf("RunHandler returned error: %v", err)
}
if pullCalled {
t.Fatal("expected /api/pull not to be called when cloud stub already exists")
}
if !generateCalled {
t.Fatal("expected /api/generate to be called")
}
}
func TestRunHandler_ExplicitCloudStubPullFailure_IsBestEffortTEMP(t *testing.T) {
var generateCalled bool
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ShowResponse{
Capabilities: []model.Capability{model.CapabilityCompletion},
RemoteModel: "gpt-oss:20b",
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/tags" && r.Method == http.MethodGet:
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ListResponse{Models: nil}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/pull" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusInternalServerError)
if err := json.NewEncoder(w).Encode(map[string]string{"error": "pull failed"}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
generateCalled = true
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
default:
http.NotFound(w, r)
}
}))
t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close)
cmd := &cobra.Command{}
cmd.SetContext(t.Context())
cmd.Flags().String("keepalive", "", "")
cmd.Flags().Bool("truncate", false, "")
cmd.Flags().Int("dimensions", 0, "")
cmd.Flags().Bool("verbose", false, "")
cmd.Flags().Bool("insecure", false, "")
cmd.Flags().Bool("nowordwrap", false, "")
cmd.Flags().String("format", "", "")
cmd.Flags().String("think", "", "")
cmd.Flags().Bool("hidethinking", false, "")
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
if err != nil {
t.Fatalf("RunHandler returned error: %v", err)
}
if !generateCalled {
t.Fatal("expected /api/generate to be called despite pull failure")
}
}
func TestGetModelfileName(t *testing.T) {
tests := []struct {
name string
@@ -1212,6 +1553,20 @@ func TestNewCreateRequest(t *testing.T) {
Model: "newmodel",
},
},
{
"explicit cloud model preserves source when parent lacks it",
"newmodel",
runOptions{
Model: "qwen3.5:cloud",
ParentModel: "qwen3.5",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "qwen3.5:cloud",
Model: "newmodel",
},
},
{
"parent model as filepath test",
"newmodel",
@@ -1663,31 +2018,81 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
tests := []struct {
name string
remoteHost string
whoamiStatus int
whoamiResp any
expectedError string
name string
model string
showStatus int
remoteHost string
remoteModel string
whoamiStatus int
whoamiResp any
expectWhoami bool
expectedError string
expectAuthError bool
}{
{
name: "ollama.com cloud model - user signed in",
model: "test-cloud-model",
remoteHost: "https://ollama.com",
remoteModel: "test-model",
whoamiStatus: http.StatusOK,
whoamiResp: api.UserResponse{Name: "testuser"},
expectWhoami: true,
},
{
name: "ollama.com cloud model - user not signed in",
model: "test-cloud-model",
remoteHost: "https://ollama.com",
remoteModel: "test-model",
whoamiStatus: http.StatusUnauthorized,
whoamiResp: map[string]string{
"error": "unauthorized",
"signin_url": "https://ollama.com/signin",
},
expectedError: "unauthorized",
expectWhoami: true,
expectedError: "unauthorized",
expectAuthError: true,
},
{
name: "non-ollama.com remote - no auth check",
model: "test-cloud-model",
remoteHost: "https://other-remote.com",
remoteModel: "test-model",
whoamiStatus: http.StatusUnauthorized, // should not be called
whoamiResp: nil,
},
{
name: "explicit :cloud model - auth check without remote metadata",
model: "kimi-k2.5:cloud",
remoteHost: "",
remoteModel: "",
whoamiStatus: http.StatusOK,
whoamiResp: api.UserResponse{Name: "testuser"},
expectWhoami: true,
},
{
name: "explicit :cloud model without local stub returns not found by default",
model: "minimax-m2.5:cloud",
showStatus: http.StatusNotFound,
whoamiStatus: http.StatusOK,
whoamiResp: api.UserResponse{Name: "testuser"},
expectedError: "not found",
expectWhoami: false,
expectAuthError: false,
},
{
name: "explicit -cloud model - auth check without remote metadata",
model: "kimi-k2.5:latest-cloud",
remoteHost: "",
remoteModel: "",
whoamiStatus: http.StatusOK,
whoamiResp: api.UserResponse{Name: "testuser"},
expectWhoami: true,
},
{
name: "dash cloud-like name without explicit source does not require auth",
model: "test-cloud-model",
remoteHost: "",
remoteModel: "",
whoamiStatus: http.StatusUnauthorized, // should not be called
whoamiResp: nil,
},
@@ -1699,10 +2104,15 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
if tt.showStatus != 0 && tt.showStatus != http.StatusOK {
w.WriteHeader(tt.showStatus)
_ = json.NewEncoder(w).Encode(map[string]string{"error": "not found"})
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(api.ShowResponse{
RemoteHost: tt.remoteHost,
RemoteModel: "test-model",
RemoteModel: tt.remoteModel,
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
@@ -1715,6 +2125,8 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
case "/api/generate":
w.WriteHeader(http.StatusOK)
default:
http.NotFound(w, r)
}
@@ -1727,29 +2139,28 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
cmd.SetContext(t.Context())
opts := &runOptions{
Model: "test-cloud-model",
Model: tt.model,
ShowConnect: false,
}
err := loadOrUnloadModel(cmd, opts)
if strings.HasPrefix(tt.remoteHost, "https://ollama.com") {
if !whoamiCalled {
t.Error("expected whoami to be called for ollama.com cloud model")
}
} else {
if whoamiCalled {
t.Error("whoami should not be called for non-ollama.com remote")
}
if whoamiCalled != tt.expectWhoami {
t.Errorf("whoami called = %v, want %v", whoamiCalled, tt.expectWhoami)
}
if tt.expectedError != "" {
if err == nil {
t.Errorf("expected error containing %q, got nil", tt.expectedError)
} else {
var authErr api.AuthorizationError
if !errors.As(err, &authErr) {
t.Errorf("expected AuthorizationError, got %T: %v", err, err)
if !tt.expectAuthError && !strings.Contains(strings.ToLower(err.Error()), strings.ToLower(tt.expectedError)) {
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
}
if tt.expectAuthError {
var authErr api.AuthorizationError
if !errors.As(err, &authErr) {
t.Errorf("expected AuthorizationError, got %T: %v", err, err)
}
}
}
} else {
@@ -1760,3 +2171,38 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
})
}
}
func TestIsLocalhost(t *testing.T) {
tests := []struct {
name string
host string
expected bool
}{
{"default empty", "", true},
{"localhost no port", "localhost", true},
{"localhost with port", "localhost:11435", true},
{"127.0.0.1 no port", "127.0.0.1", true},
{"127.0.0.1 with port", "127.0.0.1:11434", true},
{"0.0.0.0 no port", "0.0.0.0", true},
{"0.0.0.0 with port", "0.0.0.0:11434", true},
{"::1 no port", "::1", true},
{"[::1] with port", "[::1]:11434", true},
{"loopback with scheme", "http://localhost:11434", true},
{"remote hostname", "example.com", false},
{"remote hostname with port", "example.com:11434", false},
{"remote IP", "192.168.1.1", false},
{"remote IP with port", "192.168.1.1:11434", false},
{"remote with scheme", "http://example.com:11434", false},
{"https remote", "https://example.com:443", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", tt.host)
got := isLocalhost()
if got != tt.expected {
t.Errorf("isLocalhost() with OLLAMA_HOST=%q = %v, want %v", tt.host, got, tt.expected)
}
})
}
}

View File

@@ -1,192 +0,0 @@
package config
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
// Claude implements Runner and AliasConfigurer for Claude Code integration
type Claude struct{}
// Compile-time check that Claude implements AliasConfigurer
var _ AliasConfigurer = (*Claude)(nil)
func (c *Claude) String() string { return "Claude Code" }
func (c *Claude) args(model string, extra []string) []string {
var args []string
if model != "" {
args = append(args, "--model", model)
}
args = append(args, extra...)
return args
}
func (c *Claude) findPath() (string, error) {
if p, err := exec.LookPath("claude"); err == nil {
return p, nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
name := "claude"
if runtime.GOOS == "windows" {
name = "claude.exe"
}
fallback := filepath.Join(home, ".claude", "local", name)
if _, err := os.Stat(fallback); err != nil {
return "", err
}
return fallback, nil
}
func (c *Claude) Run(model string, args []string) error {
claudePath, err := c.findPath()
if err != nil {
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
}
cmd := exec.Command(claudePath, c.args(model, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
env := append(os.Environ(),
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
"ANTHROPIC_API_KEY=",
"ANTHROPIC_AUTH_TOKEN=ollama",
)
env = append(env, c.modelEnvVars(model)...)
cmd.Env = env
return cmd.Run()
}
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
func (c *Claude) modelEnvVars(model string) []string {
primary := model
fast := model
if cfg, err := loadIntegration("claude"); err == nil && cfg.Aliases != nil {
if p := cfg.Aliases["primary"]; p != "" {
primary = p
}
if f := cfg.Aliases["fast"]; f != "" {
fast = f
}
}
return []string{
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + primary,
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + primary,
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + fast,
"CLAUDE_CODE_SUBAGENT_MODEL=" + primary,
}
}
// ConfigureAliases sets up model aliases for Claude Code.
// model: the model to use (if empty, user will be prompted to select)
// aliases: existing alias configuration to preserve/update
// Cloud-only: subagent routing (fast model) is gated to cloud models only until
// there is a better strategy for prompt caching on local models.
func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAliases map[string]string, force bool) (map[string]string, bool, error) {
aliases := make(map[string]string)
for k, v := range existingAliases {
aliases[k] = v
}
if model != "" {
aliases["primary"] = model
}
if !force && aliases["primary"] != "" {
client, _ := api.ClientFromEnvironment()
if isCloudModel(ctx, client, aliases["primary"]) {
if isCloudModel(ctx, client, aliases["fast"]) {
return aliases, false, nil
}
} else {
delete(aliases, "fast")
return aliases, false, nil
}
}
items, existingModels, cloudModels, client, err := listModels(ctx)
if err != nil {
return nil, false, err
}
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
if aliases["primary"] == "" || force {
primary, err := DefaultSingleSelector("Select model:", items, aliases["primary"])
if err != nil {
return nil, false, err
}
if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil {
return nil, false, err
}
if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil {
return nil, false, err
}
aliases["primary"] = primary
}
if isCloudModel(ctx, client, aliases["primary"]) {
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
aliases["fast"] = aliases["primary"]
}
} else {
delete(aliases, "fast")
}
return aliases, true, nil
}
// SetAliases syncs the configured aliases to the Ollama server using prefix matching.
// Cloud-only: for local models (fast is empty), we delete any existing aliases to
// prevent stale routing to a previous cloud model.
func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
prefixes := []string{"claude-sonnet-", "claude-haiku-"}
if aliases["fast"] == "" {
for _, prefix := range prefixes {
_ = client.DeleteAliasExperimental(ctx, &api.AliasDeleteRequest{Alias: prefix})
}
return nil
}
prefixAliases := map[string]string{
"claude-sonnet-": aliases["primary"],
"claude-haiku-": aliases["fast"],
}
var errs []string
for prefix, target := range prefixAliases {
req := &api.AliasRequest{
Alias: prefix,
Target: target,
PrefixMatching: true,
}
if err := client.SetAliasExperimental(ctx, req); err != nil {
errs = append(errs, prefix)
}
}
if len(errs) > 0 {
return fmt.Errorf("failed to set aliases: %v", errs)
}
return nil
}

View File

@@ -3,7 +3,6 @@
package config
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -11,14 +10,18 @@ import (
"path/filepath"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/internal/fileutil"
)
type integration struct {
Models []string `json:"models"`
Aliases map[string]string `json:"aliases,omitempty"`
Models []string `json:"models"`
Aliases map[string]string `json:"aliases,omitempty"`
Onboarded bool `json:"onboarded,omitempty"`
}
// IntegrationConfig is the persisted config for one integration.
type IntegrationConfig = integration
type config struct {
Integrations map[string]*integration `json:"integrations"`
LastModel string `json:"last_model,omitempty"`
@@ -123,7 +126,7 @@ func save(cfg *config) error {
return err
}
return writeWithBackup(path, data)
return fileutil.WriteWithBackup(path, data)
}
func SaveIntegration(appName string, models []string) error {
@@ -139,34 +142,54 @@ func SaveIntegration(appName string, models []string) error {
key := strings.ToLower(appName)
existing := cfg.Integrations[key]
var aliases map[string]string
if existing != nil && existing.Aliases != nil {
var onboarded bool
if existing != nil {
aliases = existing.Aliases
onboarded = existing.Onboarded
}
cfg.Integrations[key] = &integration{
Models: models,
Aliases: aliases,
Models: models,
Aliases: aliases,
Onboarded: onboarded,
}
return save(cfg)
}
// MarkIntegrationOnboarded marks an integration as onboarded in Ollama's config.
func MarkIntegrationOnboarded(appName string) error {
cfg, err := load()
if err != nil {
return err
}
key := strings.ToLower(appName)
existing := cfg.Integrations[key]
if existing == nil {
existing = &integration{}
}
existing.Onboarded = true
cfg.Integrations[key] = existing
return save(cfg)
}
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
func IntegrationModel(appName string) string {
ic, err := loadIntegration(appName)
if err != nil || len(ic.Models) == 0 {
integrationConfig, err := LoadIntegration(appName)
if err != nil || len(integrationConfig.Models) == 0 {
return ""
}
return ic.Models[0]
return integrationConfig.Models[0]
}
// IntegrationModels returns all configured models for an integration, or nil.
func IntegrationModels(appName string) []string {
ic, err := loadIntegration(appName)
if err != nil || len(ic.Models) == 0 {
integrationConfig, err := LoadIntegration(appName)
if err != nil || len(integrationConfig.Models) == 0 {
return nil
}
return ic.Models
return integrationConfig.Models
}
// LastModel returns the last model that was run, or empty string if none.
@@ -207,42 +230,23 @@ func SetLastSelection(selection string) error {
return save(cfg)
}
// ModelExists checks if a model exists on the Ollama server.
func ModelExists(ctx context.Context, name string) bool {
if name == "" {
return false
}
client, err := api.ClientFromEnvironment()
if err != nil {
return false
}
models, err := client.List(ctx)
if err != nil {
return false
}
for _, m := range models.Models {
if m.Name == name || strings.HasPrefix(m.Name, name+":") {
return true
}
}
return false
}
func loadIntegration(appName string) (*integration, error) {
// LoadIntegration returns the saved config for one integration.
func LoadIntegration(appName string) (*integration, error) {
cfg, err := load()
if err != nil {
return nil, err
}
ic, ok := cfg.Integrations[strings.ToLower(appName)]
integrationConfig, ok := cfg.Integrations[strings.ToLower(appName)]
if !ok {
return nil, os.ErrNotExist
}
return ic, nil
return integrationConfig, nil
}
func saveAliases(appName string, aliases map[string]string) error {
// SaveAliases replaces the saved aliases for one integration.
func SaveAliases(appName string, aliases map[string]string) error {
if appName == "" {
return errors.New("app name cannot be empty")
}
@@ -272,8 +276,8 @@ func listIntegrations() ([]integration, error) {
}
result := make([]integration, 0, len(cfg.Integrations))
for _, ic := range cfg.Integrations {
result = append(result, *ic)
for _, integrationConfig := range cfg.Integrations {
result = append(result, *integrationConfig)
}
return result, nil

View File

@@ -1,7 +1,6 @@
package config
import (
"context"
"errors"
"os"
"path/filepath"
@@ -45,12 +44,12 @@ func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
"primary": "cloud-model",
"fast": "cloud-model",
}
if err := saveAliases("claude", initial); err != nil {
if err := SaveAliases("claude", initial); err != nil {
t.Fatalf("failed to save initial aliases: %v", err)
}
// Verify both are saved
loaded, err := loadIntegration("claude")
loaded, err := LoadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
@@ -63,12 +62,12 @@ func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
"primary": "local-model",
// fast intentionally missing
}
if err := saveAliases("claude", updated); err != nil {
if err := SaveAliases("claude", updated); err != nil {
t.Fatalf("failed to save updated aliases: %v", err)
}
// Verify fast is GONE (not merged/preserved)
loaded, err = loadIntegration("claude")
loaded, err = LoadIntegration("claude")
if err != nil {
t.Fatalf("failed to load after update: %v", err)
}
@@ -91,12 +90,12 @@ func TestSaveAliases_PreservesModels(t *testing.T) {
// Then update aliases
aliases := map[string]string{"primary": "new-model"}
if err := saveAliases("claude", aliases); err != nil {
if err := SaveAliases("claude", aliases); err != nil {
t.Fatalf("failed to save aliases: %v", err)
}
// Verify models are preserved
loaded, err := loadIntegration("claude")
loaded, err := LoadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
@@ -111,16 +110,16 @@ func TestSaveAliases_EmptyMap(t *testing.T) {
setTestHome(t, tmpDir)
// Save with aliases
if err := saveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
if err := SaveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Save empty map
if err := saveAliases("claude", map[string]string{}); err != nil {
if err := SaveAliases("claude", map[string]string{}); err != nil {
t.Fatalf("failed to save empty: %v", err)
}
loaded, err := loadIntegration("claude")
loaded, err := LoadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
@@ -135,16 +134,16 @@ func TestSaveAliases_NilMap(t *testing.T) {
setTestHome(t, tmpDir)
// Save with aliases first
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Save nil map - should clear aliases
if err := saveAliases("claude", nil); err != nil {
if err := SaveAliases("claude", nil); err != nil {
t.Fatalf("failed to save nil: %v", err)
}
loaded, err := loadIntegration("claude")
loaded, err := LoadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
@@ -155,7 +154,7 @@ func TestSaveAliases_NilMap(t *testing.T) {
// TestSaveAliases_EmptyAppName returns error
func TestSaveAliases_EmptyAppName(t *testing.T) {
err := saveAliases("", map[string]string{"primary": "model"})
err := SaveAliases("", map[string]string{"primary": "model"})
if err == nil {
t.Error("expected error for empty app name")
}
@@ -165,12 +164,12 @@ func TestSaveAliases_CaseInsensitive(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if err := saveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
if err := SaveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Load with different case
loaded, err := loadIntegration("claude")
loaded, err := LoadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
@@ -179,11 +178,11 @@ func TestSaveAliases_CaseInsensitive(t *testing.T) {
}
// Update with different case
if err := saveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
if err := SaveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
t.Fatalf("failed to update: %v", err)
}
loaded, err = loadIntegration("claude")
loaded, err = LoadIntegration("claude")
if err != nil {
t.Fatalf("failed to load after update: %v", err)
}
@@ -198,11 +197,11 @@ func TestSaveAliases_CreatesIntegration(t *testing.T) {
setTestHome(t, tmpDir)
// Save aliases for non-existent integration
if err := saveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
if err := SaveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
loaded, err := loadIntegration("newintegration")
loaded, err := LoadIntegration("newintegration")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
@@ -371,12 +370,12 @@ func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) {
t.Fatal("server should succeed")
}
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("saveAliases failed: %v", err)
}
// Verify it was actually saved
loaded, err := loadIntegration("claude")
loaded, err := LoadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
@@ -408,7 +407,7 @@ func TestConfigFile_PreservesUnknownFields(t *testing.T) {
os.WriteFile(configPath, []byte(initialConfig), 0o644)
// Update aliases
if err := saveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
if err := SaveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
@@ -440,11 +439,6 @@ func containsHelper(s, substr string) bool {
return false
}
func TestClaudeImplementsAliasConfigurer(t *testing.T) {
c := &Claude{}
var _ AliasConfigurer = c // Compile-time check
}
func TestModelNameEdgeCases(t *testing.T) {
testCases := []struct {
name string
@@ -464,11 +458,11 @@ func TestModelNameEdgeCases(t *testing.T) {
setTestHome(t, tmpDir)
aliases := map[string]string{"primary": tc.model}
if err := saveAliases("claude", aliases); err != nil {
if err := SaveAliases("claude", aliases); err != nil {
t.Fatalf("failed to save model %q: %v", tc.model, err)
}
loaded, err := loadIntegration("claude")
loaded, err := LoadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
@@ -485,7 +479,7 @@ func TestSwitchingScenarios(t *testing.T) {
setTestHome(t, tmpDir)
// Initial cloud config
if err := saveAliases("claude", map[string]string{
if err := SaveAliases("claude", map[string]string{
"primary": "cloud-model",
"fast": "cloud-model",
}); err != nil {
@@ -493,13 +487,13 @@ func TestSwitchingScenarios(t *testing.T) {
}
// Switch to local (no fast)
if err := saveAliases("claude", map[string]string{
if err := SaveAliases("claude", map[string]string{
"primary": "local-model",
}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
loaded, _ := LoadIntegration("claude")
if loaded.Aliases["fast"] != "" {
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
}
@@ -513,21 +507,21 @@ func TestSwitchingScenarios(t *testing.T) {
setTestHome(t, tmpDir)
// Initial local config
if err := saveAliases("claude", map[string]string{
if err := SaveAliases("claude", map[string]string{
"primary": "local-model",
}); err != nil {
t.Fatal(err)
}
// Switch to cloud (with fast)
if err := saveAliases("claude", map[string]string{
if err := SaveAliases("claude", map[string]string{
"primary": "cloud-model",
"fast": "cloud-model",
}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
loaded, _ := LoadIntegration("claude")
if loaded.Aliases["fast"] != "cloud-model" {
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
}
@@ -538,7 +532,7 @@ func TestSwitchingScenarios(t *testing.T) {
setTestHome(t, tmpDir)
// Initial cloud config
if err := saveAliases("claude", map[string]string{
if err := SaveAliases("claude", map[string]string{
"primary": "cloud-model-1",
"fast": "cloud-model-1",
}); err != nil {
@@ -546,14 +540,14 @@ func TestSwitchingScenarios(t *testing.T) {
}
// Switch to different cloud
if err := saveAliases("claude", map[string]string{
if err := SaveAliases("claude", map[string]string{
"primary": "cloud-model-2",
"fast": "cloud-model-2",
}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
loaded, _ := LoadIntegration("claude")
if loaded.Aliases["primary"] != "cloud-model-2" {
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
}
@@ -563,43 +557,13 @@ func TestSwitchingScenarios(t *testing.T) {
})
}
func TestToolCapabilityFiltering(t *testing.T) {
t.Run("all models checked for tool capability", func(t *testing.T) {
// Both cloud and local models are checked for tool capability via Show API
// Only models with "tools" in capabilities are included
m := modelInfo{Name: "tool-model", Remote: false, ToolCapable: true}
if !m.ToolCapable {
t.Error("tool capable model should be marked as such")
}
})
t.Run("modelInfo includes ToolCapable field", func(t *testing.T) {
m := modelInfo{Name: "test", Remote: true, ToolCapable: true}
if !m.ToolCapable {
t.Error("ToolCapable field should be accessible")
}
})
}
func TestIsCloudModel_RequiresClient(t *testing.T) {
t.Run("nil client always returns false", func(t *testing.T) {
// isCloudModel now only uses Show API, no suffix detection
if isCloudModel(context.Background(), nil, "model:cloud") {
t.Error("nil client should return false regardless of suffix")
}
if isCloudModel(context.Background(), nil, "local-model") {
t.Error("nil client should return false")
}
})
}
func TestModelsAndAliasesMustStayInSync(t *testing.T) {
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save aliases with one model
if err := saveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
if err := SaveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
t.Fatal(err)
}
@@ -608,7 +572,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
loaded, _ := LoadIntegration("claude")
if loaded.Aliases["primary"] != loaded.Models[0] {
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
}
@@ -622,11 +586,11 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
if err := SaveIntegration("claude", []string{"old-model"}); err != nil {
t.Fatal(err)
}
if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
if err := SaveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
loaded, _ := LoadIntegration("claude")
// They should be different (this is the bug state)
if loaded.Models[0] == loaded.Aliases["primary"] {
@@ -638,7 +602,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
t.Fatal(err)
}
loaded, _ = loadIntegration("claude")
loaded, _ = LoadIntegration("claude")
if loaded.Models[0] != loaded.Aliases["primary"] {
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
loaded.Models[0], loaded.Aliases["primary"])
@@ -653,20 +617,20 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
if err := SaveIntegration("claude", []string{"initial-model"}); err != nil {
t.Fatal(err)
}
if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
if err := SaveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
t.Fatal(err)
}
// Update aliases AND models together
newAliases := map[string]string{"primary": "updated-model"}
if err := saveAliases("claude", newAliases); err != nil {
if err := SaveAliases("claude", newAliases); err != nil {
t.Fatal(err)
}
if err := SaveIntegration("claude", []string{newAliases["primary"]}); err != nil {
t.Fatal(err)
}
loaded, _ := loadIntegration("claude")
loaded, _ := LoadIntegration("claude")
if loaded.Models[0] != "updated-model" {
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
}

View File

@@ -10,17 +10,10 @@ import (
// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests
func setTestHome(t *testing.T, dir string) {
t.Setenv("HOME", dir)
t.Setenv("TMPDIR", dir)
t.Setenv("USERPROFILE", dir)
}
// editorPaths is a test helper that safely calls Paths if the runner implements Editor
func editorPaths(r Runner) []string {
if editor, ok := r.(Editor); ok {
return editor.Paths()
}
return nil
}
func TestIntegrationConfig(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
@@ -31,7 +24,7 @@ func TestIntegrationConfig(t *testing.T) {
t.Fatal(err)
}
config, err := loadIntegration("claude")
config, err := LoadIntegration("claude")
if err != nil {
t.Fatal(err)
}
@@ -55,11 +48,11 @@ func TestIntegrationConfig(t *testing.T) {
"primary": "llama3.2:70b",
"fast": "llama3.2:8b",
}
if err := saveAliases("claude", aliases); err != nil {
if err := SaveAliases("claude", aliases); err != nil {
t.Fatal(err)
}
config, err := loadIntegration("claude")
config, err := LoadIntegration("claude")
if err != nil {
t.Fatal(err)
}
@@ -77,14 +70,14 @@ func TestIntegrationConfig(t *testing.T) {
if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
t.Fatal(err)
}
if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
if err := SaveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
t.Fatal(err)
}
if err := SaveIntegration("claude", []string{"model-b"}); err != nil {
t.Fatal(err)
}
config, err := loadIntegration("claude")
config, err := LoadIntegration("claude")
if err != nil {
t.Fatal(err)
}
@@ -96,7 +89,7 @@ func TestIntegrationConfig(t *testing.T) {
t.Run("defaultModel returns first model", func(t *testing.T) {
SaveIntegration("codex", []string{"model-a", "model-b"})
config, _ := loadIntegration("codex")
config, _ := LoadIntegration("codex")
defaultModel := ""
if len(config.Models) > 0 {
defaultModel = config.Models[0]
@@ -120,7 +113,7 @@ func TestIntegrationConfig(t *testing.T) {
t.Run("app name is case-insensitive", func(t *testing.T) {
SaveIntegration("Claude", []string{"model-x"})
config, err := loadIntegration("claude")
config, err := LoadIntegration("claude")
if err != nil {
t.Fatal(err)
}
@@ -137,8 +130,8 @@ func TestIntegrationConfig(t *testing.T) {
SaveIntegration("app1", []string{"model-1"})
SaveIntegration("app2", []string{"model-2"})
config1, _ := loadIntegration("app1")
config2, _ := loadIntegration("app2")
config1, _ := LoadIntegration("app1")
config2, _ := LoadIntegration("app2")
defaultModel1 := ""
if len(config1.Models) > 0 {
@@ -185,64 +178,6 @@ func TestListIntegrations(t *testing.T) {
})
}
func TestEditorPaths(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("returns empty for claude (no Editor)", func(t *testing.T) {
r := integrations["claude"]
paths := editorPaths(r)
if len(paths) != 0 {
t.Errorf("expected no paths for claude, got %v", paths)
}
})
t.Run("returns empty for codex (no Editor)", func(t *testing.T) {
r := integrations["codex"]
paths := editorPaths(r)
if len(paths) != 0 {
t.Errorf("expected no paths for codex, got %v", paths)
}
})
t.Run("returns empty for droid when no config exists", func(t *testing.T) {
r := integrations["droid"]
paths := editorPaths(r)
if len(paths) != 0 {
t.Errorf("expected no paths, got %v", paths)
}
})
t.Run("returns path for droid when config exists", func(t *testing.T) {
settingsDir, _ := os.UserHomeDir()
settingsDir = filepath.Join(settingsDir, ".factory")
os.MkdirAll(settingsDir, 0o755)
os.WriteFile(filepath.Join(settingsDir, "settings.json"), []byte(`{}`), 0o644)
r := integrations["droid"]
paths := editorPaths(r)
if len(paths) != 1 {
t.Errorf("expected 1 path, got %d", len(paths))
}
})
t.Run("returns paths for opencode when configs exist", func(t *testing.T) {
home, _ := os.UserHomeDir()
configDir := filepath.Join(home, ".config", "opencode")
stateDir := filepath.Join(home, ".local", "state", "opencode")
os.MkdirAll(configDir, 0o755)
os.MkdirAll(stateDir, 0o755)
os.WriteFile(filepath.Join(configDir, "opencode.json"), []byte(`{}`), 0o644)
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644)
r := integrations["opencode"]
paths := editorPaths(r)
if len(paths) != 2 {
t.Errorf("expected 2 paths, got %d: %v", len(paths), paths)
}
})
}
func TestLoadIntegration_CorruptedJSON(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
@@ -251,7 +186,7 @@ func TestLoadIntegration_CorruptedJSON(t *testing.T) {
os.MkdirAll(dir, 0o755)
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
_, err := loadIntegration("test")
_, err := LoadIntegration("test")
if err == nil {
t.Error("expected error for nonexistent integration in corrupted file")
}
@@ -265,7 +200,7 @@ func TestSaveIntegration_NilModels(t *testing.T) {
t.Fatalf("saveIntegration with nil models failed: %v", err)
}
config, err := loadIntegration("test")
config, err := LoadIntegration("test")
if err != nil {
t.Fatalf("loadIntegration failed: %v", err)
}
@@ -294,7 +229,7 @@ func TestLoadIntegration_NonexistentIntegration(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
_, err := loadIntegration("nonexistent")
_, err := LoadIntegration("nonexistent")
if err == nil {
t.Error("expected error for nonexistent integration, got nil")
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,264 +0,0 @@
package config
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"strings"
"github.com/ollama/ollama/envconfig"
)
type Openclaw struct{}
func (c *Openclaw) String() string { return "OpenClaw" }
func (c *Openclaw) Run(model string, args []string) error {
bin := "openclaw"
if _, err := exec.LookPath(bin); err != nil {
bin = "clawdbot"
if _, err := exec.LookPath(bin); err != nil {
return fmt.Errorf("openclaw is not installed, install from https://docs.openclaw.ai")
}
}
models := []string{model}
if config, err := loadIntegration("openclaw"); err == nil && len(config.Models) > 0 {
models = config.Models
} else if config, err := loadIntegration("clawdbot"); err == nil && len(config.Models) > 0 {
models = config.Models
}
var err error
models, err = resolveEditorModels("openclaw", models, func() ([]string, error) {
return selectModels(context.Background(), "openclaw", "")
})
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
if err := c.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
if !c.onboarded() {
// Onboarding not completed: run it (model already set via Edit)
// Use "ollama" as gateway token for simple local access
cmd := exec.Command(bin, "onboard",
"--auth-choice", "skip",
"--gateway-token", "ollama",
)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
// Onboarding completed: run gateway
cmd := exec.Command(bin, append([]string{"gateway"}, args...)...)
cmd.Stdin = os.Stdin
// Capture output to detect "already running" message
var outputBuf bytes.Buffer
cmd.Stdout = io.MultiWriter(os.Stdout, &outputBuf)
cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf)
err = cmd.Run()
if err != nil && strings.Contains(outputBuf.String(), "Gateway already running") {
fmt.Fprintf(os.Stderr, "%sOpenClaw has been configured with Ollama. Gateway is already running.%s\n", ansiGreen, ansiReset)
return nil
}
return err
}
// onboarded checks if OpenClaw onboarding wizard was completed
// by looking for the wizard.lastRunAt marker in the config
func (c *Openclaw) onboarded() bool {
home, err := os.UserHomeDir()
if err != nil {
return false
}
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config)
} else if data, err := os.ReadFile(legacyPath); err == nil {
_ = json.Unmarshal(data, &config)
} else {
return false
}
// Check for wizard.lastRunAt marker (set when onboarding completes)
wizard, _ := config["wizard"].(map[string]any)
if wizard == nil {
return false
}
lastRunAt, _ := wizard["lastRunAt"].(string)
return lastRunAt != ""
}
func (c *Openclaw) Paths() []string {
home, _ := os.UserHomeDir()
p := filepath.Join(home, ".openclaw", "openclaw.json")
if _, err := os.Stat(p); err == nil {
return []string{p}
}
legacy := filepath.Join(home, ".clawdbot", "clawdbot.json")
if _, err := os.Stat(legacy); err == nil {
return []string{legacy}
}
return nil
}
func (c *Openclaw) Edit(models []string) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
// Read into map[string]any to preserve unknown fields
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config)
} else if data, err := os.ReadFile(legacyPath); err == nil {
_ = json.Unmarshal(data, &config)
}
// Navigate/create: models.providers.ollama (preserving other providers)
modelsSection, _ := config["models"].(map[string]any)
if modelsSection == nil {
modelsSection = make(map[string]any)
}
providers, _ := modelsSection["providers"].(map[string]any)
if providers == nil {
providers = make(map[string]any)
}
ollama, _ := providers["ollama"].(map[string]any)
if ollama == nil {
ollama = make(map[string]any)
}
ollama["baseUrl"] = envconfig.Host().String() + "/v1"
// needed to register provider
ollama["apiKey"] = "ollama-local"
// TODO(parthsareen): potentially move to responses
ollama["api"] = "openai-completions"
// Build map of existing models to preserve user customizations
existingModels, _ := ollama["models"].([]any)
existingByID := make(map[string]map[string]any)
for _, m := range existingModels {
if entry, ok := m.(map[string]any); ok {
if id, ok := entry["id"].(string); ok {
existingByID[id] = entry
}
}
}
var newModels []any
for _, model := range models {
entry := map[string]any{
"id": model,
"name": model,
"reasoning": false,
"input": []any{"text"},
"cost": map[string]any{
"input": 0,
"output": 0,
"cacheRead": 0,
"cacheWrite": 0,
},
// TODO(parthsareen): get these values from API
"contextWindow": 131072,
"maxTokens": 16384,
}
// Merge existing fields (user customizations)
if existing, ok := existingByID[model]; ok {
for k, v := range existing {
if _, isNew := entry[k]; !isNew {
entry[k] = v
}
}
}
newModels = append(newModels, entry)
}
ollama["models"] = newModels
providers["ollama"] = ollama
modelsSection["providers"] = providers
config["models"] = modelsSection
// Update agents.defaults.model.primary (preserving other agent settings)
agents, _ := config["agents"].(map[string]any)
if agents == nil {
agents = make(map[string]any)
}
defaults, _ := agents["defaults"].(map[string]any)
if defaults == nil {
defaults = make(map[string]any)
}
modelConfig, _ := defaults["model"].(map[string]any)
if modelConfig == nil {
modelConfig = make(map[string]any)
}
modelConfig["primary"] = "ollama/" + models[0]
defaults["model"] = modelConfig
agents["defaults"] = defaults
config["agents"] = agents
data, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
return writeWithBackup(configPath, data)
}
func (c *Openclaw) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
config, err := readJSONFile(filepath.Join(home, ".openclaw", "openclaw.json"))
if err != nil {
config, err = readJSONFile(filepath.Join(home, ".clawdbot", "clawdbot.json"))
if err != nil {
return nil
}
}
modelsSection, _ := config["models"].(map[string]any)
providers, _ := modelsSection["providers"].(map[string]any)
ollama, _ := providers["ollama"].(map[string]any)
modelList, _ := ollama["models"].([]any)
var result []string
for _, m := range modelList {
if entry, ok := m.(map[string]any); ok {
if id, ok := entry["id"].(string); ok {
result = append(result, id)
}
}
}
return result
}

View File

@@ -1,878 +0,0 @@
package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
)
func TestOpenclawIntegration(t *testing.T) {
c := &Openclaw{}
t.Run("String", func(t *testing.T) {
if got := c.String(); got != "OpenClaw" {
t.Errorf("String() = %q, want %q", got, "OpenClaw")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = c
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = c
})
}
func TestOpenclawEdit(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
cleanup := func() { os.RemoveAll(configDir) }
t.Run("fresh install", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertOpenclawModelExists(t, configPath, "llama3.2")
assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2")
})
t.Run("multiple models - first is primary", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
t.Fatal(err)
}
assertOpenclawModelExists(t, configPath, "llama3.2")
assertOpenclawModelExists(t, configPath, "mistral")
assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2")
})
t.Run("preserve other providers", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":{"anthropic":{"apiKey":"xxx"}}}}`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
if providers["anthropic"] == nil {
t.Error("anthropic provider was removed")
}
})
t.Run("preserve top-level keys", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"theme":"dark","mcp":{"servers":{}}}`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
if cfg["theme"] != "dark" {
t.Error("theme was removed")
}
if cfg["mcp"] == nil {
t.Error("mcp was removed")
}
})
t.Run("preserve user customizations on models", func(t *testing.T) {
cleanup()
c.Edit([]string{"llama3.2"})
// User adds custom field
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelList := ollama["models"].([]any)
entry := modelList[0].(map[string]any)
entry["customField"] = "user-value"
configData, _ := json.MarshalIndent(cfg, "", " ")
os.WriteFile(configPath, configData, 0o644)
// Re-run Edit
c.Edit([]string{"llama3.2"})
data, _ = os.ReadFile(configPath)
json.Unmarshal(data, &cfg)
models = cfg["models"].(map[string]any)
providers = models["providers"].(map[string]any)
ollama = providers["ollama"].(map[string]any)
modelList = ollama["models"].([]any)
entry = modelList[0].(map[string]any)
if entry["customField"] != "user-value" {
t.Error("custom field was lost")
}
})
t.Run("edit replaces models list", func(t *testing.T) {
cleanup()
c.Edit([]string{"llama3.2", "mistral"})
c.Edit([]string{"llama3.2"})
assertOpenclawModelExists(t, configPath, "llama3.2")
assertOpenclawModelNotExists(t, configPath, "mistral")
})
t.Run("empty models is no-op", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
original := `{"existing":"data"}`
os.WriteFile(configPath, []byte(original), 0o644)
c.Edit([]string{})
data, _ := os.ReadFile(configPath)
if string(data) != original {
t.Error("empty models should not modify file")
}
})
t.Run("corrupted JSON treated as empty", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Error("result should be valid JSON")
}
})
t.Run("wrong type models section", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":"not a map"}`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertOpenclawModelExists(t, configPath, "llama3.2")
})
}
func TestOpenclawModels(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("no config returns nil", func(t *testing.T) {
if models := c.Models(); len(models) > 0 {
t.Errorf("expected nil/empty, got %v", models)
}
})
t.Run("returns all ollama models", func(t *testing.T) {
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
"models":{"providers":{"ollama":{"models":[
{"id":"llama3.2"},
{"id":"mistral"}
]}}}
}`), 0o644)
models := c.Models()
if len(models) != 2 {
t.Errorf("expected 2 models, got %v", models)
}
})
}
// Helper functions
func assertOpenclawModelExists(t *testing.T, path, model string) {
t.Helper()
data, _ := os.ReadFile(path)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelList := ollama["models"].([]any)
for _, m := range modelList {
if entry, ok := m.(map[string]any); ok {
if entry["id"] == model {
return
}
}
}
t.Errorf("model %s not found", model)
}
func assertOpenclawModelNotExists(t *testing.T, path, model string) {
t.Helper()
data, _ := os.ReadFile(path)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models, _ := cfg["models"].(map[string]any)
providers, _ := models["providers"].(map[string]any)
ollama, _ := providers["ollama"].(map[string]any)
modelList, _ := ollama["models"].([]any)
for _, m := range modelList {
if entry, ok := m.(map[string]any); ok {
if entry["id"] == model {
t.Errorf("model %s should not exist", model)
}
}
}
}
func assertOpenclawPrimaryModel(t *testing.T, path, expected string) {
t.Helper()
data, _ := os.ReadFile(path)
var cfg map[string]any
json.Unmarshal(data, &cfg)
agents := cfg["agents"].(map[string]any)
defaults := agents["defaults"].(map[string]any)
model := defaults["model"].(map[string]any)
if model["primary"] != expected {
t.Errorf("primary model = %v, want %v", model["primary"], expected)
}
}
func TestOpenclawPaths(t *testing.T) {
c := &Openclaw{}
t.Run("returns path when config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{}`), 0o644)
paths := c.Paths()
if len(paths) != 1 {
t.Errorf("expected 1 path, got %d", len(paths))
}
})
t.Run("returns nil when config missing", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if paths := c.Paths(); paths != nil {
t.Errorf("expected nil, got %v", paths)
}
})
}
func TestOpenclawModelsEdgeCases(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
cleanup := func() { os.RemoveAll(configDir) }
t.Run("corrupted JSON returns nil", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
if models := c.Models(); models != nil {
t.Errorf("expected nil, got %v", models)
}
})
t.Run("wrong type at models level", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":"string"}`), 0o644)
if models := c.Models(); models != nil {
t.Errorf("expected nil, got %v", models)
}
})
t.Run("wrong type at providers level", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":"string"}}`), 0o644)
if models := c.Models(); models != nil {
t.Errorf("expected nil, got %v", models)
}
})
t.Run("wrong type at ollama level", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":"string"}}}`), 0o644)
if models := c.Models(); models != nil {
t.Errorf("expected nil, got %v", models)
}
})
t.Run("model entry missing id", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"name":"test"}]}}}}`), 0o644)
if len(c.Models()) != 0 {
t.Error("expected empty for missing id")
}
})
t.Run("model id is not string", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"id":123}]}}}}`), 0o644)
if len(c.Models()) != 0 {
t.Error("expected empty for non-string id")
}
})
}
func TestOpenclawEditSchemaFields(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelList := ollama["models"].([]any)
entry := modelList[0].(map[string]any)
// Verify required schema fields
if entry["reasoning"] != false {
t.Error("reasoning should be false")
}
if entry["input"] == nil {
t.Error("input should be set")
}
if entry["contextWindow"] == nil {
t.Error("contextWindow should be set")
}
if entry["maxTokens"] == nil {
t.Error("maxTokens should be set")
}
cost := entry["cost"].(map[string]any)
if cost["cacheRead"] == nil {
t.Error("cost.cacheRead should be set")
}
if cost["cacheWrite"] == nil {
t.Error("cost.cacheWrite should be set")
}
}
func TestOpenclawEditModelNames(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
cleanup := func() { os.RemoveAll(filepath.Join(tmpDir, ".openclaw")) }
t.Run("model with colon tag", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"llama3.2:70b"}); err != nil {
t.Fatal(err)
}
assertOpenclawModelExists(t, configPath, "llama3.2:70b")
assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2:70b")
})
t.Run("model with slash", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"library/model:tag"}); err != nil {
t.Fatal(err)
}
assertOpenclawModelExists(t, configPath, "library/model:tag")
assertOpenclawPrimaryModel(t, configPath, "ollama/library/model:tag")
})
t.Run("model with hyphen", func(t *testing.T) {
cleanup()
if err := c.Edit([]string{"test-model"}); err != nil {
t.Fatal(err)
}
assertOpenclawModelExists(t, configPath, "test-model")
})
}
func TestOpenclawEditAgentsPreservation(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
cleanup := func() { os.RemoveAll(configDir) }
t.Run("preserve other agent defaults", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"agents":{"defaults":{"model":{"primary":"old"},"temperature":0.7}}}`), 0o644)
c.Edit([]string{"llama3.2"})
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
agents := cfg["agents"].(map[string]any)
defaults := agents["defaults"].(map[string]any)
if defaults["temperature"] != 0.7 {
t.Error("temperature setting was lost")
}
})
t.Run("preserve other agents besides defaults", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"agents":{"defaults":{},"custom-agent":{"foo":"bar"}}}`), 0o644)
c.Edit([]string{"llama3.2"})
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
agents := cfg["agents"].(map[string]any)
if agents["custom-agent"] == nil {
t.Error("custom-agent was lost")
}
})
}
const testOpenclawFixture = `{
"theme": "dark",
"mcp": {"servers": {"custom": {"enabled": true}}},
"models": {
"providers": {
"anthropic": {"apiKey": "xxx"},
"ollama": {
"baseUrl": "http://127.0.0.1:11434/v1",
"models": [{"id": "old-model", "customField": "preserved"}]
}
}
},
"agents": {
"defaults": {"model": {"primary": "old"}, "temperature": 0.7},
"custom-agent": {"foo": "bar"}
}
}`
func TestOpenclawEdit_RoundTrip(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
// Verify top-level preserved
if cfg["theme"] != "dark" {
t.Error("theme not preserved")
}
mcp := cfg["mcp"].(map[string]any)
servers := mcp["servers"].(map[string]any)
if servers["custom"] == nil {
t.Error("mcp.servers.custom not preserved")
}
// Verify other providers preserved
models := cfg["models"].(map[string]any)
providers := models["providers"].(map[string]any)
if providers["anthropic"] == nil {
t.Error("anthropic provider not preserved")
}
// Verify agents preserved
agents := cfg["agents"].(map[string]any)
if agents["custom-agent"] == nil {
t.Error("custom-agent not preserved")
}
defaults := agents["defaults"].(map[string]any)
if defaults["temperature"] != 0.7 {
t.Error("temperature not preserved")
}
}
func TestOpenclawEdit_Idempotent(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
c.Edit([]string{"llama3.2", "mistral"})
firstData, _ := os.ReadFile(configPath)
c.Edit([]string{"llama3.2", "mistral"})
secondData, _ := os.ReadFile(configPath)
if string(firstData) != string(secondData) {
t.Error("repeated edits with same models produced different results")
}
}
func TestOpenclawEdit_MultipleConsecutiveEdits(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
for i := range 10 {
models := []string{"model-a", "model-b"}
if i%2 == 0 {
models = []string{"model-x", "model-y", "model-z"}
}
if err := c.Edit(models); err != nil {
t.Fatalf("edit %d failed: %v", i, err)
}
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatalf("file is not valid JSON after multiple edits: %v", err)
}
if cfg["theme"] != "dark" {
t.Error("theme lost after multiple edits")
}
}
func TestOpenclawEdit_BackupCreated(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
configPath := filepath.Join(configDir, "openclaw.json")
backupDir := filepath.Join(os.TempDir(), "ollama-backups")
os.MkdirAll(configDir, 0o755)
uniqueMarker := fmt.Sprintf("test-marker-%d", os.Getpid())
original := fmt.Sprintf(`{"theme": "%s"}`, uniqueMarker)
os.WriteFile(configPath, []byte(original), 0o644)
if err := c.Edit([]string{"model-a"}); err != nil {
t.Fatal(err)
}
backups, _ := filepath.Glob(filepath.Join(backupDir, "openclaw.json.*"))
foundBackup := false
for _, backup := range backups {
data, _ := os.ReadFile(backup)
if string(data) == original {
foundBackup = true
break
}
}
if !foundBackup {
t.Error("backup with original content not found")
}
}
func TestOpenclawClawdbotAlias(t *testing.T) {
for _, alias := range []string{"clawdbot", "moltbot"} {
t.Run(alias+" alias resolves to Openclaw runner", func(t *testing.T) {
r, ok := integrations[alias]
if !ok {
t.Fatalf("%s not found in integrations", alias)
}
if _, ok := r.(*Openclaw); !ok {
t.Errorf("%s integration is %T, want *Openclaw", alias, r)
}
})
t.Run(alias+" is hidden from selector", func(t *testing.T) {
if !integrationAliases[alias] {
t.Errorf("%s should be in integrationAliases", alias)
}
})
}
}
func TestOpenclawLegacyPaths(t *testing.T) {
c := &Openclaw{}
t.Run("falls back to legacy clawdbot path", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{}`), 0o644)
paths := c.Paths()
if len(paths) != 1 {
t.Fatalf("expected 1 path, got %d", len(paths))
}
if paths[0] != filepath.Join(legacyDir, "clawdbot.json") {
t.Errorf("expected legacy path, got %s", paths[0])
}
})
t.Run("prefers new path over legacy", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
newDir := filepath.Join(tmpDir, ".openclaw")
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(newDir, 0o755)
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{}`), 0o644)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{}`), 0o644)
paths := c.Paths()
if len(paths) != 1 {
t.Fatalf("expected 1 path, got %d", len(paths))
}
if paths[0] != filepath.Join(newDir, "openclaw.json") {
t.Errorf("expected new path, got %s", paths[0])
}
})
t.Run("Models reads from legacy path", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{
"models":{"providers":{"ollama":{"models":[{"id":"llama3.2"}]}}}
}`), 0o644)
models := c.Models()
if len(models) != 1 || models[0] != "llama3.2" {
t.Errorf("expected [llama3.2], got %v", models)
}
})
t.Run("Models prefers new path over legacy", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
newDir := filepath.Join(tmpDir, ".openclaw")
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(newDir, 0o755)
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{
"models":{"providers":{"ollama":{"models":[{"id":"new-model"}]}}}
}`), 0o644)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{
"models":{"providers":{"ollama":{"models":[{"id":"legacy-model"}]}}}
}`), 0o644)
models := c.Models()
if len(models) != 1 || models[0] != "new-model" {
t.Errorf("expected [new-model], got %v", models)
}
})
t.Run("Edit reads new path over legacy when both exist", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
newDir := filepath.Join(tmpDir, ".openclaw")
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(newDir, 0o755)
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{"theme":"new"}`), 0o644)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"theme":"legacy"}`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(filepath.Join(newDir, "openclaw.json"))
var cfg map[string]any
json.Unmarshal(data, &cfg)
if cfg["theme"] != "new" {
t.Errorf("expected theme from new config, got %v", cfg["theme"])
}
})
t.Run("Edit migrates from legacy config", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"theme":"dark"}`), 0o644)
if err := c.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
// Should write to new path
newPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
data, err := os.ReadFile(newPath)
if err != nil {
t.Fatal("expected new config file to be created")
}
var cfg map[string]any
json.Unmarshal(data, &cfg)
if cfg["theme"] != "dark" {
t.Error("legacy theme setting was not migrated")
}
})
}
func TestOpenclawEdit_CreatesDirectoryIfMissing(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
if _, err := os.Stat(configDir); !os.IsNotExist(err) {
t.Fatal("directory should not exist before test")
}
if err := c.Edit([]string{"model-a"}); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(configDir); os.IsNotExist(err) {
t.Fatal("directory was not created")
}
}
func TestOpenclawOnboarded(t *testing.T) {
c := &Openclaw{}
t.Run("returns false when no config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if c.onboarded() {
t.Error("expected false when no config exists")
}
})
t.Run("returns false when config exists but no wizard section", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"theme":"dark"}`), 0o644)
if c.onboarded() {
t.Error("expected false when no wizard section")
}
})
t.Run("returns false when wizard section exists but no lastRunAt", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{}}`), 0o644)
if c.onboarded() {
t.Error("expected false when wizard.lastRunAt is missing")
}
})
t.Run("returns false when wizard.lastRunAt is empty string", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{"lastRunAt":""}}`), 0o644)
if c.onboarded() {
t.Error("expected false when wizard.lastRunAt is empty")
}
})
t.Run("returns true when wizard.lastRunAt is set", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
if !c.onboarded() {
t.Error("expected true when wizard.lastRunAt is set")
}
})
t.Run("checks legacy clawdbot path", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
if !c.onboarded() {
t.Error("expected true when legacy config has wizard.lastRunAt")
}
})
t.Run("prefers new path over legacy", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
newDir := filepath.Join(tmpDir, ".openclaw")
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(newDir, 0o755)
os.MkdirAll(legacyDir, 0o755)
// New path has no wizard marker
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{}`), 0o644)
// Legacy has wizard marker
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
if c.onboarded() {
t.Error("expected false - should prefer new path which has no wizard marker")
}
})
t.Run("handles corrupted JSON gracefully", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{corrupted`), 0o644)
if c.onboarded() {
t.Error("expected false for corrupted JSON")
}
})
t.Run("handles wrong type for wizard section", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":"not a map"}`), 0o644)
if c.onboarded() {
t.Error("expected false when wizard is wrong type")
}
})
}

View File

@@ -1,58 +0,0 @@
package config
import (
"errors"
"fmt"
"os"
"golang.org/x/term"
)
// ANSI escape sequences for terminal formatting.
const (
ansiBold = "\033[1m"
ansiReset = "\033[0m"
ansiGray = "\033[37m"
ansiGreen = "\033[32m"
)
// ErrCancelled is returned when the user cancels a selection.
var ErrCancelled = errors.New("cancelled")
// errCancelled is kept as an alias for backward compatibility within the package.
var errCancelled = ErrCancelled
// DefaultConfirmPrompt provides a TUI-based confirmation prompt.
// When set, confirmPrompt delegates to it instead of using raw terminal I/O.
var DefaultConfirmPrompt func(prompt string) (bool, error)
func confirmPrompt(prompt string) (bool, error) {
if DefaultConfirmPrompt != nil {
return DefaultConfirmPrompt(prompt)
}
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
return false, err
}
defer term.Restore(fd, oldState)
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
buf := make([]byte, 1)
for {
if _, err := os.Stdin.Read(buf); err != nil {
return false, err
}
switch buf[0] {
case 'Y', 'y', 13:
fmt.Fprintf(os.Stderr, "yes\r\n")
return true, nil
case 'N', 'n', 27, 3:
fmt.Fprintf(os.Stderr, "no\r\n")
return false, nil
}
}
}

View File

@@ -1,19 +0,0 @@
package config
import (
"testing"
)
func TestErrCancelled(t *testing.T) {
t.Run("NotNil", func(t *testing.T) {
if errCancelled == nil {
t.Error("errCancelled should not be nil")
}
})
t.Run("Message", func(t *testing.T) {
if errCancelled.Error() != "cancelled" {
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
}
})
}

View File

@@ -17,6 +17,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
@@ -540,6 +541,13 @@ func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
parentModel = ""
}
// Preserve explicit cloud intent for sessions started with `:cloud`.
// Cloud model metadata can return a source-less parent_model (for example
// "qwen3.5"), which would otherwise make `/save` create a local derivative.
if modelref.HasExplicitCloudSource(opts.Model) && !modelref.HasExplicitCloudSource(parentModel) {
parentModel = ""
}
req := &api.CreateRequest{
Model: name,
From: cmp.Or(parentModel, opts.Model),

View File

@@ -1,4 +1,6 @@
package config
// Package fileutil provides small shared helpers for reading JSON files
// and writing config files with backup-on-overwrite semantics.
package fileutil
import (
"bytes"
@@ -9,7 +11,8 @@ import (
"time"
)
func readJSONFile(path string) (map[string]any, error) {
// ReadJSON reads a JSON object file into a generic map.
func ReadJSON(path string) (map[string]any, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
@@ -33,12 +36,13 @@ func copyFile(src, dst string) error {
return os.WriteFile(dst, data, info.Mode().Perm())
}
func backupDir() string {
// BackupDir returns the shared backup directory used before overwriting files.
func BackupDir() string {
return filepath.Join(os.TempDir(), "ollama-backups")
}
func backupToTmp(srcPath string) (string, error) {
dir := backupDir()
dir := BackupDir()
if err := os.MkdirAll(dir, 0o755); err != nil {
return "", err
}
@@ -50,8 +54,8 @@ func backupToTmp(srcPath string) (string, error) {
return backupPath, nil
}
// writeWithBackup writes data to path via temp file + rename, backing up any existing file first
func writeWithBackup(path string, data []byte) error {
// WriteWithBackup writes data to path via temp file + rename, backing up any existing file first.
func WriteWithBackup(path string, data []byte) error {
var backupPath string
// backup must be created before any writes to the target file
if existingContent, err := os.ReadFile(path); err == nil {

View File

@@ -1,4 +1,4 @@
package config
package fileutil
import (
"encoding/json"
@@ -9,6 +9,21 @@ import (
"testing"
)
func TestMain(m *testing.M) {
tmpRoot, err := os.MkdirTemp("", "fileutil-test-*")
if err != nil {
panic(err)
}
if err := os.Setenv("TMPDIR", tmpRoot); err != nil {
panic(err)
}
code := m.Run()
_ = os.RemoveAll(tmpRoot)
os.Exit(code)
}
func mustMarshal(t *testing.T, v any) []byte {
t.Helper()
data, err := json.MarshalIndent(v, "", " ")
@@ -18,14 +33,19 @@ func mustMarshal(t *testing.T, v any) []byte {
return data
}
func isolatedTempDir(t *testing.T) string {
t.Helper()
return t.TempDir()
}
func TestWriteWithBackup(t *testing.T) {
tmpDir := t.TempDir()
tmpDir := isolatedTempDir(t)
t.Run("creates file", func(t *testing.T) {
path := filepath.Join(tmpDir, "new.json")
data := mustMarshal(t, map[string]string{"key": "value"})
if err := writeWithBackup(path, data); err != nil {
if err := WriteWithBackup(path, data); err != nil {
t.Fatal(err)
}
@@ -43,17 +63,17 @@ func TestWriteWithBackup(t *testing.T) {
}
})
t.Run("creates backup in /tmp/ollama-backups", func(t *testing.T) {
t.Run("creates backup in the temp backup directory", func(t *testing.T) {
path := filepath.Join(tmpDir, "backup.json")
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
data := mustMarshal(t, map[string]bool{"updated": true})
if err := writeWithBackup(path, data); err != nil {
if err := WriteWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries, err := os.ReadDir(backupDir())
entries, err := os.ReadDir(BackupDir())
if err != nil {
t.Fatal("backup directory not created")
}
@@ -63,7 +83,7 @@ func TestWriteWithBackup(t *testing.T) {
if filepath.Ext(entry.Name()) != ".json" {
name := entry.Name()
if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." {
backupPath := filepath.Join(backupDir(), name)
backupPath := filepath.Join(BackupDir(), name)
backup, err := os.ReadFile(backupPath)
if err == nil {
var backupData map[string]bool
@@ -79,7 +99,7 @@ func TestWriteWithBackup(t *testing.T) {
}
if !foundBackup {
t.Error("backup file not created in /tmp/ollama-backups")
t.Error("backup file not created in backup directory")
}
current, _ := os.ReadFile(path)
@@ -94,11 +114,11 @@ func TestWriteWithBackup(t *testing.T) {
path := filepath.Join(tmpDir, "nobak.json")
data := mustMarshal(t, map[string]string{"new": "file"})
if err := writeWithBackup(path, data); err != nil {
if err := WriteWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries, _ := os.ReadDir(backupDir())
entries, _ := os.ReadDir(BackupDir())
for _, entry := range entries {
if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." {
t.Error("backup should not exist for new file")
@@ -111,11 +131,11 @@ func TestWriteWithBackup(t *testing.T) {
data := mustMarshal(t, map[string]string{"key": "value"})
if err := writeWithBackup(path, data); err != nil {
if err := WriteWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries1, _ := os.ReadDir(backupDir())
entries1, _ := os.ReadDir(BackupDir())
countBefore := 0
for _, e := range entries1 {
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
@@ -123,11 +143,11 @@ func TestWriteWithBackup(t *testing.T) {
}
}
if err := writeWithBackup(path, data); err != nil {
if err := WriteWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries2, _ := os.ReadDir(backupDir())
entries2, _ := os.ReadDir(BackupDir())
countAfter := 0
for _, e := range entries2 {
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
@@ -145,11 +165,11 @@ func TestWriteWithBackup(t *testing.T) {
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
data := mustMarshal(t, map[string]int{"v": 2})
if err := writeWithBackup(path, data); err != nil {
if err := WriteWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries, _ := os.ReadDir(backupDir())
entries, _ := os.ReadDir(BackupDir())
var found bool
for _, entry := range entries {
name := entry.Name()
@@ -161,7 +181,7 @@ func TestWriteWithBackup(t *testing.T) {
}
}
found = true
os.Remove(filepath.Join(backupDir(), name))
os.Remove(filepath.Join(BackupDir(), name))
break
}
}
@@ -180,7 +200,7 @@ func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
t.Skip("permission tests unreliable on Windows")
}
tmpDir := t.TempDir()
tmpDir := isolatedTempDir(t)
path := filepath.Join(tmpDir, "config.json")
// Create original file
@@ -188,13 +208,13 @@ func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
os.WriteFile(path, originalContent, 0o644)
// Make backup directory read-only to force backup failure
backupDir := backupDir()
backupDir := BackupDir()
os.MkdirAll(backupDir, 0o755)
os.Chmod(backupDir, 0o444) // Read-only
defer os.Chmod(backupDir, 0o755)
newContent := []byte(`{"updated": true}`)
err := writeWithBackup(path, newContent)
err := WriteWithBackup(path, newContent)
// Should fail because backup couldn't be created
if err == nil {
@@ -215,7 +235,7 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
t.Skip("permission tests unreliable on Windows")
}
tmpDir := t.TempDir()
tmpDir := isolatedTempDir(t)
// Create a read-only directory
readOnlyDir := filepath.Join(tmpDir, "readonly")
@@ -224,7 +244,7 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
defer os.Chmod(readOnlyDir, 0o755)
path := filepath.Join(readOnlyDir, "config.json")
err := writeWithBackup(path, []byte(`{"test": true}`))
err := WriteWithBackup(path, []byte(`{"test": true}`))
if err == nil {
t.Error("expected permission error, got nil")
@@ -234,10 +254,10 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist.
// writeWithBackup doesn't create directories - caller is responsible.
func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) {
tmpDir := t.TempDir()
tmpDir := isolatedTempDir(t)
path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json")
err := writeWithBackup(path, []byte(`{"test": true}`))
err := WriteWithBackup(path, []byte(`{"test": true}`))
// Should fail because directory doesn't exist
if err == nil {
@@ -252,7 +272,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
t.Skip("symlink tests may require admin on Windows")
}
tmpDir := t.TempDir()
tmpDir := isolatedTempDir(t)
realFile := filepath.Join(tmpDir, "real.json")
symlink := filepath.Join(tmpDir, "link.json")
@@ -261,7 +281,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
os.Symlink(realFile, symlink)
// Write through symlink
err := writeWithBackup(symlink, []byte(`{"v": 2}`))
err := WriteWithBackup(symlink, []byte(`{"v": 2}`))
if err != nil {
t.Fatalf("writeWithBackup through symlink failed: %v", err)
}
@@ -276,7 +296,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters.
// User may have config files with unusual names.
func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) {
tmpDir := t.TempDir()
tmpDir := isolatedTempDir(t)
// File with spaces and special chars
path := filepath.Join(tmpDir, "my config (backup).json")
@@ -305,7 +325,7 @@ func TestCopyFile_PreservesPermissions(t *testing.T) {
t.Skip("permission preservation tests unreliable on Windows")
}
tmpDir := t.TempDir()
tmpDir := isolatedTempDir(t)
src := filepath.Join(tmpDir, "src.json")
dst := filepath.Join(tmpDir, "dst.json")
@@ -327,7 +347,7 @@ func TestCopyFile_PreservesPermissions(t *testing.T) {
// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist.
func TestCopyFile_SourceNotFound(t *testing.T) {
tmpDir := t.TempDir()
tmpDir := isolatedTempDir(t)
src := filepath.Join(tmpDir, "nonexistent.json")
dst := filepath.Join(tmpDir, "dst.json")
@@ -339,11 +359,11 @@ func TestCopyFile_SourceNotFound(t *testing.T) {
// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory.
func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
tmpDir := t.TempDir()
tmpDir := isolatedTempDir(t)
dirPath := filepath.Join(tmpDir, "actualdir")
os.MkdirAll(dirPath, 0o755)
err := writeWithBackup(dirPath, []byte(`{"test": true}`))
err := WriteWithBackup(dirPath, []byte(`{"test": true}`))
if err == nil {
t.Error("expected error when target is a directory, got nil")
}
@@ -351,10 +371,10 @@ func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly.
func TestWriteWithBackup_EmptyData(t *testing.T) {
tmpDir := t.TempDir()
tmpDir := isolatedTempDir(t)
path := filepath.Join(tmpDir, "empty.json")
err := writeWithBackup(path, []byte{})
err := WriteWithBackup(path, []byte{})
if err != nil {
t.Fatalf("writeWithBackup with empty data failed: %v", err)
}
@@ -375,7 +395,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
t.Skip("permission tests unreliable on Windows")
}
tmpDir := t.TempDir()
tmpDir := isolatedTempDir(t)
path := filepath.Join(tmpDir, "unreadable.json")
// Create file and make it unreadable
@@ -384,7 +404,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
defer os.Chmod(path, 0o644)
// Should fail because we can't read the file to compare/backup
err := writeWithBackup(path, []byte(`{"updated": true}`))
err := WriteWithBackup(path, []byte(`{"updated": true}`))
if err == nil {
t.Error("expected error when file is unreadable, got nil")
}
@@ -393,7 +413,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes
// within the same second (timestamp collision scenario).
func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
tmpDir := t.TempDir()
tmpDir := isolatedTempDir(t)
path := filepath.Join(tmpDir, "rapid.json")
// Create initial file
@@ -402,7 +422,7 @@ func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
// Rapid successive writes
for i := 1; i <= 3; i++ {
data := []byte(fmt.Sprintf(`{"v": %d}`, i))
if err := writeWithBackup(path, data); err != nil {
if err := WriteWithBackup(path, data); err != nil {
t.Fatalf("write %d failed: %v", i, err)
}
}
@@ -414,7 +434,7 @@ func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
}
// Verify at least one backup exists
entries, _ := os.ReadDir(backupDir())
entries, _ := os.ReadDir(BackupDir())
var backupCount int
for _, e := range entries {
if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." {
@@ -432,8 +452,9 @@ func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
t.Skip("test modifies system temp directory")
}
tmpDir := isolatedTempDir(t)
// Create a file at the backup directory path
backupPath := backupDir()
backupPath := BackupDir()
// Clean up any existing directory first
os.RemoveAll(backupPath)
// Create a file instead of directory
@@ -443,11 +464,10 @@ func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
os.MkdirAll(backupPath, 0o755)
}()
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "test.json")
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
err := writeWithBackup(path, []byte(`{"updated": true}`))
err := WriteWithBackup(path, []byte(`{"updated": true}`))
if err == nil {
t.Error("expected error when backup dir is a file, got nil")
}
@@ -459,7 +479,7 @@ func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
t.Skip("permission tests unreliable on Windows")
}
tmpDir := t.TempDir()
tmpDir := isolatedTempDir(t)
// Count existing temp files
countTempFiles := func() int {
@@ -493,7 +513,7 @@ func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
badPath := filepath.Join(tmpDir, "isdir")
os.MkdirAll(badPath, 0o755)
_ = writeWithBackup(badPath, []byte(`{"test": true}`))
_ = WriteWithBackup(badPath, []byte(`{"test": true}`))
after := countTempFiles()
if after > before {

86
cmd/launch/claude.go Normal file
View File

@@ -0,0 +1,86 @@
package launch
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"github.com/ollama/ollama/envconfig"
)
// Claude implements Runner for Claude Code integration.
type Claude struct{}
func (c *Claude) String() string { return "Claude Code" }
func (c *Claude) args(model string, extra []string) []string {
var args []string
if model != "" {
args = append(args, "--model", model)
}
args = append(args, extra...)
return args
}
func (c *Claude) findPath() (string, error) {
if p, err := exec.LookPath("claude"); err == nil {
return p, nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
name := "claude"
if runtime.GOOS == "windows" {
name = "claude.exe"
}
fallback := filepath.Join(home, ".claude", "local", name)
if _, err := os.Stat(fallback); err != nil {
return "", err
}
return fallback, nil
}
func (c *Claude) Run(model string, args []string) error {
claudePath, err := c.findPath()
if err != nil {
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
}
cmd := exec.Command(claudePath, c.args(model, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
env := append(os.Environ(),
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
"ANTHROPIC_API_KEY=",
"ANTHROPIC_AUTH_TOKEN=ollama",
)
env = append(env, c.modelEnvVars(model)...)
cmd.Env = env
return cmd.Run()
}
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
func (c *Claude) modelEnvVars(model string) []string {
env := []string{
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + model,
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + model,
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + model,
"CLAUDE_CODE_SUBAGENT_MODEL=" + model,
}
if isCloudModelName(model) {
if l, ok := lookupCloudModelLimit(model); ok {
env = append(env, "CLAUDE_CODE_AUTO_COMPACT_WINDOW="+strconv.Itoa(l.Context))
}
}
return env
}

View File

@@ -1,4 +1,4 @@
package config
package launch
import (
"os"
@@ -117,10 +117,7 @@ func TestClaudeModelEnvVars(t *testing.T) {
return m
}
t.Run("falls back to model param when no aliases saved", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("maps all Claude model env vars to the provided model", func(t *testing.T) {
got := envMap(c.modelEnvVars("llama3.2"))
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2" {
t.Errorf("OPUS = %q, want llama3.2", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
@@ -134,65 +131,41 @@ func TestClaudeModelEnvVars(t *testing.T) {
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2" {
t.Errorf("SUBAGENT = %q, want llama3.2", got["CLAUDE_CODE_SUBAGENT_MODEL"])
}
})
t.Run("uses primary alias for opus sonnet and subagent", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
SaveIntegration("claude", []string{"qwen3:8b"})
saveAliases("claude", map[string]string{"primary": "qwen3:8b"})
got := envMap(c.modelEnvVars("qwen3:8b"))
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "qwen3:8b" {
t.Errorf("OPUS = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
}
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "qwen3:8b" {
t.Errorf("SONNET = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
}
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "qwen3:8b" {
t.Errorf("HAIKU = %q, want qwen3:8b (no fast alias)", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
}
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "qwen3:8b" {
t.Errorf("SUBAGENT = %q, want qwen3:8b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty for local models", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
}
})
t.Run("uses fast alias for haiku", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
SaveIntegration("claude", []string{"llama3.2:70b"})
saveAliases("claude", map[string]string{
"primary": "llama3.2:70b",
"fast": "llama3.2:8b",
})
got := envMap(c.modelEnvVars("llama3.2:70b"))
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2:70b" {
t.Errorf("OPUS = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
t.Run("supports empty model", func(t *testing.T) {
got := envMap(c.modelEnvVars(""))
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "" {
t.Errorf("OPUS = %q, want empty", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
}
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2:70b" {
t.Errorf("SONNET = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "" {
t.Errorf("SONNET = %q, want empty", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
}
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2:8b" {
t.Errorf("HAIKU = %q, want llama3.2:8b", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "" {
t.Errorf("HAIKU = %q, want empty", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
}
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2:70b" {
t.Errorf("SUBAGENT = %q, want llama3.2:70b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "" {
t.Errorf("SUBAGENT = %q, want empty", got["CLAUDE_CODE_SUBAGENT_MODEL"])
}
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
}
})
t.Run("alias primary overrides model param", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("sets auto compact window for known cloud models", func(t *testing.T) {
got := envMap(c.modelEnvVars("glm-5:cloud"))
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "202752" {
t.Errorf("AUTO_COMPACT_WINDOW = %q, want 202752", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
}
})
SaveIntegration("claude", []string{"saved-model"})
saveAliases("claude", map[string]string{"primary": "saved-model"})
got := envMap(c.modelEnvVars("different-model"))
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "saved-model" {
t.Errorf("OPUS = %q, want saved-model", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
t.Run("does not set auto compact window for unknown cloud models", func(t *testing.T) {
got := envMap(c.modelEnvVars("unknown-model:cloud"))
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
}
})
}

View File

@@ -1,14 +1,13 @@
package config
package launch
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"os/exec"
"path/filepath"
"github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig"
)
@@ -22,24 +21,6 @@ func (c *Cline) Run(model string, args []string) error {
return fmt.Errorf("cline is not installed, install with: npm install -g cline")
}
models := []string{model}
if config, err := loadIntegration("cline"); err == nil && len(config.Models) > 0 {
models = config.Models
}
var err error
models, err = resolveEditorModels("cline", models, func() ([]string, error) {
return selectModels(context.Background(), "cline", "")
})
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
if err := c.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("cline", args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
@@ -97,7 +78,7 @@ func (c *Cline) Edit(models []string) error {
if err != nil {
return err
}
return writeWithBackup(configPath, data)
return fileutil.WriteWithBackup(configPath, data)
}
func (c *Cline) Models() []string {
@@ -106,7 +87,7 @@ func (c *Cline) Models() []string {
return nil
}
config, err := readJSONFile(filepath.Join(home, ".cline", "data", "globalState.json"))
config, err := fileutil.ReadJSON(filepath.Join(home, ".cline", "data", "globalState.json"))
if err != nil {
return nil
}

View File

@@ -1,4 +1,4 @@
package config
package launch
import (
"encoding/json"

View File

@@ -1,4 +1,4 @@
package config
package launch
import (
"fmt"
@@ -6,6 +6,7 @@ import (
"os/exec"
"strings"
"github.com/ollama/ollama/envconfig"
"golang.org/x/mod/semver"
)
@@ -32,6 +33,10 @@ func (c *Codex) Run(model string, args []string) error {
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(),
"OPENAI_BASE_URL="+envconfig.Host().String()+"/v1/",
"OPENAI_API_KEY=ollama",
)
return cmd.Run()
}

View File

@@ -1,4 +1,4 @@
package config
package launch
import (
"slices"
@@ -16,7 +16,7 @@ func TestCodexArgs(t *testing.T) {
}{
{"with model", "llama3.2", nil, []string{"--oss", "-m", "llama3.2"}},
{"empty model", "", nil, []string{"--oss"}},
{"with model and profile", "qwen3-coder", []string{"-p", "myprofile"}, []string{"--oss", "-m", "qwen3-coder", "-p", "myprofile"}},
{"with model and profile", "qwen3.5", []string{"-p", "myprofile"}, []string{"--oss", "-m", "qwen3.5", "-p", "myprofile"}},
{"with sandbox flag", "llama3.2", []string{"--sandbox", "workspace-write"}, []string{"--oss", "-m", "llama3.2", "--sandbox", "workspace-write"}},
}

598
cmd/launch/command_test.go Normal file
View File

@@ -0,0 +1,598 @@
package launch
import (
"bytes"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/cmd/config"
"github.com/spf13/cobra"
)
func captureStderr(t *testing.T, fn func()) string {
t.Helper()
oldStderr := os.Stderr
r, w, err := os.Pipe()
if err != nil {
t.Fatalf("failed to create stderr pipe: %v", err)
}
os.Stderr = w
defer func() {
os.Stderr = oldStderr
}()
done := make(chan string, 1)
go func() {
var buf bytes.Buffer
_, _ = io.Copy(&buf, r)
done <- buf.String()
}()
fn()
_ = w.Close()
return <-done
}
func TestLaunchCmd(t *testing.T) {
mockCheck := func(cmd *cobra.Command, args []string) error {
return nil
}
mockTUI := func(cmd *cobra.Command) {}
cmd := LaunchCmd(mockCheck, mockTUI)
t.Run("command structure", func(t *testing.T) {
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]")
}
if cmd.Short == "" {
t.Error("Short description should not be empty")
}
if cmd.Long == "" {
t.Error("Long description should not be empty")
}
})
t.Run("flags exist", func(t *testing.T) {
if cmd.Flags().Lookup("model") == nil {
t.Error("--model flag should exist")
}
if cmd.Flags().Lookup("config") == nil {
t.Error("--config flag should exist")
}
if cmd.Flags().Lookup("yes") == nil {
t.Error("--yes flag should exist")
}
})
t.Run("PreRunE is set", func(t *testing.T) {
if cmd.PreRunE == nil {
t.Error("PreRunE should be set to checkServerHeartbeat")
}
})
}
func TestLaunchCmdTUICallback(t *testing.T) {
mockCheck := func(cmd *cobra.Command, args []string) error {
return nil
}
t.Run("no args calls TUI", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{})
_ = cmd.Execute()
if !tuiCalled {
t.Error("TUI callback should be called when no args provided")
}
})
t.Run("integration arg bypasses TUI", func(t *testing.T) {
srv := httptest.NewServer(http.NotFoundHandler())
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"claude"})
_ = cmd.Execute()
if tuiCalled {
t.Error("TUI callback should NOT be called when integration arg provided")
}
})
t.Run("--model flag without integration returns error", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"--model", "test-model"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected --model without an integration to fail")
}
if !strings.Contains(err.Error(), "require an integration name") {
t.Fatalf("expected integration-name guidance, got %v", err)
}
if tuiCalled {
t.Error("TUI callback should NOT be called when --model is provided without an integration")
}
})
t.Run("--config flag without integration returns error", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"--config"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected --config without an integration to fail")
}
if !strings.Contains(err.Error(), "require an integration name") {
t.Fatalf("expected integration-name guidance, got %v", err)
}
if tuiCalled {
t.Error("TUI callback should NOT be called when --config is provided without an integration")
}
})
t.Run("--yes flag without integration returns error", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"--yes"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected --yes without an integration to fail")
}
if !strings.Contains(err.Error(), "require an integration name") {
t.Fatalf("expected integration-name guidance, got %v", err)
}
if tuiCalled {
t.Error("TUI callback should NOT be called when --yes is provided without an integration")
}
})
t.Run("extra args without integration return error", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"--model", "test-model", "--", "--sandbox", "workspace-write"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected flags and extra args without an integration to fail")
}
if !strings.Contains(err.Error(), "require an integration name") {
t.Fatalf("expected integration-name guidance, got %v", err)
}
if tuiCalled {
t.Error("TUI callback should NOT be called when flags or extra args are provided without an integration")
}
})
}
func TestLaunchCmdNilHeartbeat(t *testing.T) {
cmd := LaunchCmd(nil, nil)
if cmd == nil {
t.Fatal("LaunchCmd returned nil")
}
if cmd.PreRunE != nil {
t.Log("Note: PreRunE is set even when nil is passed (acceptable)")
}
}
func TestLaunchCmdModelFlagFiltersDisabledCloudFromSavedConfig(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
if err := config.SaveIntegration("stubeditor", []string{"glm-5:cloud"}); err != nil {
t.Fatalf("failed to seed saved config: %v", err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/status":
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
case "/api/show":
fmt.Fprintf(w, `{"model":"llama3.2"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherEditorRunner{}
restore := OverrideIntegration("stubeditor", stub)
defer restore()
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
if err := cmd.Execute(); err != nil {
t.Fatalf("launch command failed: %v", err)
}
saved, err := config.LoadIntegration("stubeditor")
if err != nil {
t.Fatalf("failed to reload integration config: %v", err)
}
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
}
if stub.ranModel != "llama3.2" {
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
}
}
func TestLaunchCmdModelFlagClearsDisabledCloudOverride(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/status":
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
case "/api/show":
fmt.Fprint(w, `{"model":"llama3.2"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherSingleRunner{}
restore := OverrideIntegration("stubapp", stub)
defer restore()
oldSelector := DefaultSingleSelector
defer func() { DefaultSingleSelector = oldSelector }()
var selectorCalls int
var gotCurrent string
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
selectorCalls++
gotCurrent = current
return "llama3.2", nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubapp", "--model", "glm-5:cloud"})
stderr := captureStderr(t, func() {
if err := cmd.Execute(); err != nil {
t.Fatalf("launch command failed: %v", err)
}
})
if selectorCalls != 1 {
t.Fatalf("expected disabled cloud override to fall back to selector, got %d calls", selectorCalls)
}
if gotCurrent != "" {
t.Fatalf("expected disabled override to be cleared before selection, got current %q", gotCurrent)
}
if stub.ranModel != "llama3.2" {
t.Fatalf("expected launch to run with replacement local model, got %q", stub.ranModel)
}
if !strings.Contains(stderr, "Warning: ignoring --model glm-5:cloud because cloud is disabled") {
t.Fatalf("expected disabled-cloud warning, got stderr: %q", stderr)
}
saved, err := config.LoadIntegration("stubapp")
if err != nil {
t.Fatalf("failed to reload integration config: %v", err)
}
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
}
}
func TestLaunchCmdYes_AutoConfirmsLaunchPromptPath(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
withInteractiveSession(t, false)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
fmt.Fprint(w, `{"model":"llama3.2"}`)
case "/api/status":
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, `{"error":"not found"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherEditorRunner{paths: []string{"/tmp/stubeditor.json"}}
restore := OverrideIntegration("stubeditor", stub)
defer restore()
DefaultConfirmPrompt = func(prompt string) (bool, error) {
t.Fatalf("unexpected prompt with --yes: %q", prompt)
return false, nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2", "--yes"})
if err := cmd.Execute(); err != nil {
t.Fatalf("launch command with --yes failed: %v", err)
}
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
}
if stub.ranModel != "llama3.2" {
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
}
}
func TestLaunchCmdHeadlessWithYes_AutoPullsMissingLocalModel(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
withInteractiveSession(t, false)
var pullCalled bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, `{"error":"model not found"}`)
case "/api/pull":
pullCalled = true
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, `{"status":"success"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherSingleRunner{}
restore := OverrideIntegration("stubapp", stub)
defer restore()
DefaultConfirmPrompt = func(prompt string) (bool, error) {
t.Fatalf("unexpected prompt with --yes in headless autopull path: %q", prompt)
return false, nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubapp", "--model", "missing-model", "--yes"})
if err := cmd.Execute(); err != nil {
t.Fatalf("launch command with --yes failed: %v", err)
}
if !pullCalled {
t.Fatal("expected missing local model to be auto-pulled with --yes in headless mode")
}
if stub.ranModel != "missing-model" {
t.Fatalf("expected launch to run with pulled model, got %q", stub.ranModel)
}
}
func TestLaunchCmdHeadlessWithoutYes_ReturnsActionableConfirmError(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
withInteractiveSession(t, false)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
fmt.Fprint(w, `{"model":"llama3.2"}`)
case "/api/status":
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, `{"error":"not found"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherEditorRunner{paths: []string{"/tmp/stubeditor.json"}}
restore := OverrideIntegration("stubeditor", stub)
defer restore()
DefaultConfirmPrompt = func(prompt string) (bool, error) {
t.Fatalf("unexpected prompt in headless non-yes mode: %q", prompt)
return false, nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected launch command to fail without --yes in headless mode")
}
if !strings.Contains(err.Error(), "re-run with --yes") {
t.Fatalf("expected actionable --yes guidance, got %v", err)
}
if len(stub.edited) != 0 {
t.Fatalf("expected no editor writes when confirmation is blocked, got %v", stub.edited)
}
if stub.ranModel != "" {
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
}
}
func TestLaunchCmdIntegrationArgPromptsForModelWithSavedSelection(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
if err := config.SaveIntegration("stubapp", []string{"llama3.2"}); err != nil {
t.Fatalf("failed to seed saved config: %v", err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"},{"name":"qwen3:8b"}]}`)
case "/api/show":
fmt.Fprint(w, `{"model":"qwen3:8b"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherSingleRunner{}
restore := OverrideIntegration("stubapp", stub)
defer restore()
oldSelector := DefaultSingleSelector
defer func() { DefaultSingleSelector = oldSelector }()
var gotCurrent string
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
gotCurrent = current
return "qwen3:8b", nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubapp"})
if err := cmd.Execute(); err != nil {
t.Fatalf("launch command failed: %v", err)
}
if gotCurrent != "llama3.2" {
t.Fatalf("expected selector current model to be saved model llama3.2, got %q", gotCurrent)
}
if stub.ranModel != "qwen3:8b" {
t.Fatalf("expected launch to run selected model qwen3:8b, got %q", stub.ranModel)
}
saved, err := config.LoadIntegration("stubapp")
if err != nil {
t.Fatalf("failed to reload integration config: %v", err)
}
if diff := cmp.Diff([]string{"qwen3:8b"}, saved.Models); diff != "" {
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
}
}
func TestLaunchCmdHeadlessYes_IntegrationRequiresModelEvenWhenSaved(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
withInteractiveSession(t, false)
if err := config.SaveIntegration("stubapp", []string{"llama3.2"}); err != nil {
t.Fatalf("failed to seed saved config: %v", err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
fmt.Fprint(w, `{"model":"llama3.2"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherSingleRunner{}
restore := OverrideIntegration("stubapp", stub)
defer restore()
oldSelector := DefaultSingleSelector
defer func() { DefaultSingleSelector = oldSelector }()
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
t.Fatal("selector should not be called for headless --yes saved-model launch")
return "", nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubapp", "--yes"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected launch command to fail when --yes is used headlessly without --model")
}
if !strings.Contains(err.Error(), "requires --model <model>") {
t.Fatalf("expected actionable --model guidance, got %v", err)
}
if stub.ranModel != "" {
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
}
}
func TestLaunchCmdHeadlessYes_IntegrationWithoutSavedModelReturnsError(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
withInteractiveSession(t, false)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherSingleRunner{}
restore := OverrideIntegration("stubapp", stub)
defer restore()
oldSelector := DefaultSingleSelector
defer func() { DefaultSingleSelector = oldSelector }()
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
t.Fatal("selector should not be called for headless --yes without saved model")
return "", nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubapp", "--yes"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected launch command to fail when --yes is used headlessly without --model")
}
if !strings.Contains(err.Error(), "requires --model <model>") {
t.Fatalf("expected actionable --model guidance, got %v", err)
}
if stub.ranModel != "" {
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
}
}

View File

@@ -1,16 +1,14 @@
package config
package launch
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"os/exec"
"path/filepath"
"slices"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig"
)
@@ -47,25 +45,6 @@ func (d *Droid) Run(model string, args []string) error {
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
}
// Call Edit() to ensure config is up-to-date before launch
models := []string{model}
if config, err := loadIntegration("droid"); err == nil && len(config.Models) > 0 {
models = config.Models
}
var err error
models, err = resolveEditorModels("droid", models, func() ([]string, error) {
return selectModels(context.Background(), "droid", "")
})
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
if err := d.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("droid", args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
@@ -111,6 +90,16 @@ func (d *Droid) Edit(models []string) error {
json.Unmarshal(data, &settings) // ignore error, zero values are fine
}
settingsMap = updateDroidSettings(settingsMap, settings, models)
data, err := json.MarshalIndent(settingsMap, "", " ")
if err != nil {
return err
}
return fileutil.WriteWithBackup(settingsPath, data)
}
func updateDroidSettings(settingsMap map[string]any, settings droidSettings, models []string) map[string]any {
// Keep only non-Ollama models from the raw map (preserves extra fields)
// Rebuild Ollama models
var nonOllamaModels []any
@@ -125,13 +114,12 @@ func (d *Droid) Edit(models []string) error {
}
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
client, _ := api.ClientFromEnvironment()
var newModels []any
var defaultModelID string
for i, model := range models {
maxOutput := 64000
if isCloudModel(context.Background(), client, model) {
if isCloudModelName(model) {
if l, ok := lookupCloudModelLimit(model); ok {
maxOutput = l.Output
}
@@ -167,12 +155,7 @@ func (d *Droid) Edit(models []string) error {
}
settingsMap["sessionDefaultSettings"] = sessionSettings
data, err := json.MarshalIndent(settingsMap, "", " ")
if err != nil {
return err
}
return writeWithBackup(settingsPath, data)
return settingsMap
}
func (d *Droid) Models() []string {

View File

@@ -1,4 +1,4 @@
package config
package launch
import (
"encoding/json"
@@ -6,6 +6,8 @@ import (
"os"
"path/filepath"
"testing"
"github.com/ollama/ollama/cmd/internal/fileutil"
)
func TestDroidIntegration(t *testing.T) {
@@ -362,7 +364,7 @@ func TestDroidEdit_DuplicateModels(t *testing.T) {
t.Fatalf("Edit with duplicates failed: %v", err)
}
settings, err := readJSONFile(settingsPath)
settings, err := fileutil.ReadJSON(settingsPath)
if err != nil {
t.Fatalf("readJSONFile failed: %v", err)
}
@@ -392,7 +394,7 @@ func TestDroidEdit_MalformedModelEntry(t *testing.T) {
}
// Malformed entries (non-object) are dropped - only valid model objects are preserved
settings, _ := readJSONFile(settingsPath)
settings, _ := fileutil.ReadJSON(settingsPath)
customModels, _ := settings["customModels"].([]any)
// Should have: 1 new Ollama model only (malformed entries dropped)
@@ -419,7 +421,7 @@ func TestDroidEdit_WrongTypeSessionSettings(t *testing.T) {
}
// Should create proper sessionDefaultSettings
settings, _ := readJSONFile(settingsPath)
settings, _ := fileutil.ReadJSON(settingsPath)
session, ok := settings["sessionDefaultSettings"].(map[string]any)
if !ok {
t.Fatalf("sessionDefaultSettings should be map after setup, got %T", settings["sessionDefaultSettings"])
@@ -1008,34 +1010,34 @@ func TestDroidEdit_ModelNamesWithSpecialCharacters(t *testing.T) {
}
func TestDroidEdit_MissingCustomModelsKey(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
os.MkdirAll(settingsDir, 0o755)
// No customModels key at all
original := `{
"diffMode": "github",
"sessionDefaultSettings": {"autonomyMode": "auto-high"}
}`
os.WriteFile(settingsPath, []byte(original), 0o644)
if err := d.Edit([]string{"model-a"}); err != nil {
var settingsStruct droidSettings
var settings map[string]any
if err := json.Unmarshal([]byte(original), &settings); err != nil {
t.Fatal(err)
}
if err := json.Unmarshal([]byte(original), &settingsStruct); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(settingsPath)
var settings map[string]any
json.Unmarshal(data, &settings)
settings = updateDroidSettings(settings, settingsStruct, []string{"model-a"})
// Original fields preserved
if settings["diffMode"] != "github" {
t.Error("diffMode not preserved")
}
session, ok := settings["sessionDefaultSettings"].(map[string]any)
if !ok {
t.Fatal("sessionDefaultSettings not preserved")
}
if session["autonomyMode"] != "auto-high" {
t.Error("sessionDefaultSettings.autonomyMode not preserved")
}
// customModels created
models, ok := settings["customModels"].([]any)
@@ -1276,25 +1278,17 @@ func TestDroidEdit_LocalModelDefaultMaxOutput(t *testing.T) {
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
// Verify that every cloud model in cloudModelLimits has a valid output
// value that would be used for maxOutputTokens when isCloudModel returns true.
// :cloud suffix stripping must also work since that's how users specify them.
// value that would be used for maxOutputTokens when the selected model uses
// the explicit :cloud source tag.
for name, expected := range cloudModelLimits {
t.Run(name, func(t *testing.T) {
l, ok := lookupCloudModelLimit(name)
if !ok {
t.Fatalf("lookupCloudModelLimit(%q) returned false", name)
}
if l.Output != expected.Output {
t.Errorf("output = %d, want %d", l.Output, expected.Output)
}
// Also verify :cloud suffix lookup
cloudName := name + ":cloud"
l2, ok := lookupCloudModelLimit(cloudName)
l, ok := lookupCloudModelLimit(cloudName)
if !ok {
t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName)
}
if l2.Output != expected.Output {
t.Errorf(":cloud output = %d, want %d", l2.Output, expected.Output)
if l.Output != expected.Output {
t.Errorf("output = %d, want %d", l.Output, expected.Output)
}
})
}

842
cmd/launch/launch.go Normal file
View File

@@ -0,0 +1,842 @@
package launch
import (
"context"
"errors"
"fmt"
"net/http"
"os"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
"github.com/spf13/cobra"
"golang.org/x/term"
)
// LauncherState is the launch-owned snapshot used to render the root launcher menu.
type LauncherState struct {
LastSelection string
RunModel string
RunModelUsable bool
Integrations map[string]LauncherIntegrationState
}
// LauncherIntegrationState is the launch-owned status for one launcher integration.
type LauncherIntegrationState struct {
Name string
DisplayName string
Description string
Installed bool
AutoInstallable bool
Selectable bool
Changeable bool
CurrentModel string
ModelUsable bool
InstallHint string
Editor bool
}
// RunModelRequest controls how the root launcher resolves the chat model.
type RunModelRequest struct {
ForcePicker bool
Policy *LaunchPolicy
}
// LaunchConfirmMode controls confirmation behavior across launch flows.
type LaunchConfirmMode int
const (
// LaunchConfirmPrompt prompts the user for confirmation.
LaunchConfirmPrompt LaunchConfirmMode = iota
// LaunchConfirmAutoApprove skips prompts and treats confirmation as accepted.
LaunchConfirmAutoApprove
// LaunchConfirmRequireYes rejects confirmation requests with a --yes hint.
LaunchConfirmRequireYes
)
// LaunchMissingModelMode controls local missing-model handling in launch flows.
type LaunchMissingModelMode int
const (
// LaunchMissingModelPromptToPull prompts to pull a missing local model.
LaunchMissingModelPromptToPull LaunchMissingModelMode = iota
// LaunchMissingModelAutoPull pulls a missing local model without prompting.
LaunchMissingModelAutoPull
// LaunchMissingModelFail fails immediately when a local model is missing.
LaunchMissingModelFail
)
// LaunchPolicy controls launch behavior that may vary by caller context.
type LaunchPolicy struct {
Confirm LaunchConfirmMode
MissingModel LaunchMissingModelMode
}
func defaultLaunchPolicy(interactive bool, yes bool) LaunchPolicy {
policy := LaunchPolicy{
Confirm: LaunchConfirmPrompt,
MissingModel: LaunchMissingModelPromptToPull,
}
switch {
case yes:
// if yes flag is set, auto approve and auto pull
policy.Confirm = LaunchConfirmAutoApprove
policy.MissingModel = LaunchMissingModelAutoPull
case !interactive:
// otherwise make sure to stop when needed
policy.Confirm = LaunchConfirmRequireYes
policy.MissingModel = LaunchMissingModelFail
}
return policy
}
func (p LaunchPolicy) confirmPolicy() launchConfirmPolicy {
switch p.Confirm {
case LaunchConfirmAutoApprove:
return launchConfirmPolicy{yes: true}
case LaunchConfirmRequireYes:
return launchConfirmPolicy{requireYesMessage: true}
default:
return launchConfirmPolicy{}
}
}
func (p LaunchPolicy) missingModelPolicy() missingModelPolicy {
switch p.MissingModel {
case LaunchMissingModelAutoPull:
return missingModelAutoPull
case LaunchMissingModelFail:
return missingModelFail
default:
return missingModelPromptPull
}
}
// IntegrationLaunchRequest controls the canonical integration launcher flow.
type IntegrationLaunchRequest struct {
Name string
ModelOverride string
ForceConfigure bool
ConfigureOnly bool
ExtraArgs []string
Policy *LaunchPolicy
}
var isInteractiveSession = func() bool {
return term.IsTerminal(int(os.Stdin.Fd())) && term.IsTerminal(int(os.Stdout.Fd()))
}
// Runner executes a model with an integration.
type Runner interface {
Run(model string, args []string) error
String() string
}
// Editor can edit config files for integrations that support model configuration.
type Editor interface {
Paths() []string
Edit(models []string) error
Models() []string
}
type modelInfo struct {
Name string
Remote bool
ToolCapable bool
}
// ModelInfo re-exports launcher model inventory details for callers.
type ModelInfo = modelInfo
// ModelItem represents a model for selection UIs.
type ModelItem struct {
Name string
Description string
Recommended bool
}
// LaunchCmd returns the cobra command for launching integrations.
// The runTUI callback is called when the root launcher UI should be shown.
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error, runTUI func(cmd *cobra.Command)) *cobra.Command {
var modelFlag string
var configFlag bool
var yesFlag bool
cmd := &cobra.Command{
Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
Short: "Launch the Ollama menu or an integration",
Long: `Launch the Ollama interactive menu, or directly launch a specific integration.
Without arguments, this is equivalent to running 'ollama' directly.
Flags and extra arguments require an integration name.
Supported integrations:
claude Claude Code
cline Cline
codex Codex
droid Droid
opencode OpenCode
openclaw OpenClaw (aliases: clawdbot, moltbot)
pi Pi
Examples:
ollama launch
ollama launch claude
ollama launch claude --model <model>
ollama launch droid --config (does not auto-launch)
ollama launch codex -- -p myprofile (pass extra args to integration)
ollama launch codex -- --sandbox workspace-write`,
Args: cobra.ArbitraryArgs,
PreRunE: checkServerHeartbeat,
RunE: func(cmd *cobra.Command, args []string) error {
policy := defaultLaunchPolicy(isInteractiveSession(), yesFlag)
// reset when done to make sure state doens't leak between launches
restoreConfirmPolicy := withLaunchConfirmPolicy(policy.confirmPolicy())
defer restoreConfirmPolicy()
var name string
var passArgs []string
dashIdx := cmd.ArgsLenAtDash()
if dashIdx == -1 {
if len(args) > 1 {
return fmt.Errorf("unexpected arguments: %v\nUse '--' to pass extra arguments to the integration", args[1:])
}
if len(args) == 1 {
name = args[0]
}
} else {
if dashIdx > 1 {
return fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx)
}
if dashIdx == 1 {
name = args[0]
}
passArgs = args[dashIdx:]
}
if name == "" {
if cmd.Flags().Changed("model") || cmd.Flags().Changed("config") || cmd.Flags().Changed("yes") || len(passArgs) > 0 {
return fmt.Errorf("flags and extra args require an integration name, for example: 'ollama launch claude --model qwen3.5'")
}
runTUI(cmd)
return nil
}
if modelFlag != "" && isCloudModelName(modelFlag) {
if client, err := api.ClientFromEnvironment(); err == nil {
if disabled, _ := cloudStatusDisabled(cmd.Context(), client); disabled {
fmt.Fprintf(os.Stderr, "Warning: ignoring --model %s because cloud is disabled\n", modelFlag)
modelFlag = ""
}
}
}
headlessYes := yesFlag && !isInteractiveSession()
err := LaunchIntegration(cmd.Context(), IntegrationLaunchRequest{
Name: name,
ModelOverride: modelFlag,
ForceConfigure: configFlag || (modelFlag == "" && !headlessYes),
ConfigureOnly: configFlag,
ExtraArgs: passArgs,
Policy: &policy,
})
if errors.Is(err, ErrCancelled) {
return nil
}
return err
},
}
cmd.Flags().StringVar(&modelFlag, "model", "", "Model to use")
cmd.Flags().BoolVar(&configFlag, "config", false, "Configure without launching")
cmd.Flags().BoolVarP(&yesFlag, "yes", "y", false, "Automatically answer yes to confirmation prompts")
return cmd
}
type launcherClient struct {
apiClient *api.Client
modelInventory []ModelInfo
inventoryLoaded bool
policy LaunchPolicy
}
func newLauncherClient(policy LaunchPolicy) (*launcherClient, error) {
apiClient, err := api.ClientFromEnvironment()
if err != nil {
return nil, err
}
return &launcherClient{
apiClient: apiClient,
policy: policy,
}, nil
}
// BuildLauncherState returns the launch-owned root launcher menu snapshot.
func BuildLauncherState(ctx context.Context) (*LauncherState, error) {
launchClient, err := newLauncherClient(defaultLaunchPolicy(isInteractiveSession(), false))
if err != nil {
return nil, err
}
return launchClient.buildLauncherState(ctx)
}
// ResolveRunModel returns the model that should be used for interactive chat.
func ResolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
// Called by the launcher TUI "Run a model" action (cmd/runLauncherAction),
// which resolves models separately from LaunchIntegration. Callers can pass
// Policy directly; otherwise we fall back to ambient --yes/session defaults.
policy := defaultLaunchPolicy(isInteractiveSession(), currentLaunchConfirmPolicy.yes)
if req.Policy != nil {
policy = *req.Policy
}
launchClient, err := newLauncherClient(policy)
if err != nil {
return "", err
}
return launchClient.resolveRunModel(ctx, req)
}
// LaunchIntegration runs the canonical launcher flow for one integration.
func LaunchIntegration(ctx context.Context, req IntegrationLaunchRequest) error {
name, runner, err := LookupIntegration(req.Name)
if err != nil {
return err
}
if !req.ConfigureOnly {
if err := EnsureIntegrationInstalled(name, runner); err != nil {
return err
}
}
var policy LaunchPolicy
// TUI does not set a policy, whereas ollama launch <app> does as it can have flags which change the behavior
if req.Policy == nil {
policy = defaultLaunchPolicy(isInteractiveSession(), false)
} else {
policy = *req.Policy
}
launchClient, err := newLauncherClient(policy)
if err != nil {
return err
}
saved, _ := loadStoredIntegrationConfig(name)
// In headless --yes mode we cannot prompt, so require an explicit --model.
if policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() && req.ModelOverride == "" {
return fmt.Errorf("headless --yes launch for %s requires --model <model>", name)
}
if editor, ok := runner.(Editor); ok {
return launchClient.launchEditorIntegration(ctx, name, runner, editor, saved, req)
}
return launchClient.launchSingleIntegration(ctx, name, runner, saved, req)
}
func (c *launcherClient) buildLauncherState(ctx context.Context) (*LauncherState, error) {
_ = c.loadModelInventoryOnce(ctx)
state := &LauncherState{
LastSelection: config.LastSelection(),
RunModel: config.LastModel(),
Integrations: make(map[string]LauncherIntegrationState),
}
runModelUsable, err := c.savedModelUsable(ctx, state.RunModel)
if err != nil {
runModelUsable = false
}
state.RunModelUsable = runModelUsable
for _, info := range ListIntegrationInfos() {
integrationState, err := c.buildLauncherIntegrationState(ctx, info)
if err != nil {
return nil, err
}
state.Integrations[info.Name] = integrationState
}
return state, nil
}
func (c *launcherClient) buildLauncherIntegrationState(ctx context.Context, info IntegrationInfo) (LauncherIntegrationState, error) {
integration, err := integrationFor(info.Name)
if err != nil {
return LauncherIntegrationState{}, err
}
currentModel, usable, err := c.launcherModelState(ctx, info.Name, integration.editor)
if err != nil {
return LauncherIntegrationState{}, err
}
return LauncherIntegrationState{
Name: info.Name,
DisplayName: info.DisplayName,
Description: info.Description,
Installed: integration.installed,
AutoInstallable: integration.autoInstallable,
Selectable: integration.installed || integration.autoInstallable,
Changeable: integration.installed || integration.autoInstallable,
CurrentModel: currentModel,
ModelUsable: usable,
InstallHint: integration.installHint,
Editor: integration.editor,
}, nil
}
func (c *launcherClient) launcherModelState(ctx context.Context, name string, isEditor bool) (string, bool, error) {
cfg, loadErr := loadStoredIntegrationConfig(name)
hasModels := loadErr == nil && len(cfg.Models) > 0
if !hasModels {
return "", false, nil
}
if isEditor {
filtered := c.filterDisabledCloudModels(ctx, cfg.Models)
if len(filtered) > 0 {
return filtered[0], true, nil
}
return cfg.Models[0], false, nil
}
model := cfg.Models[0]
usable, usableErr := c.savedModelUsable(ctx, model)
return model, usableErr == nil && usable, nil
}
func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
current := config.LastModel()
if !req.ForcePicker && current != "" && c.policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() {
if err := c.ensureModelsReady(ctx, []string{current}); err != nil {
return "", err
}
fmt.Fprintf(os.Stderr, "Headless mode: auto-selected last used model %q\n", current)
if err := config.SetLastModel(current); err != nil {
return "", err
}
return current, nil
}
if !req.ForcePicker {
usable, err := c.savedModelUsable(ctx, current)
if err != nil {
return "", err
}
if usable {
if err := c.ensureModelsReady(ctx, []string{current}); err != nil {
return "", err
}
if err := config.SetLastModel(current); err != nil {
return "", err
}
return current, nil
}
}
model, err := c.selectSingleModelWithSelector(ctx, "Select model to run:", current, DefaultSingleSelector)
if err != nil {
return "", err
}
if err := config.SetLastModel(model); err != nil {
return "", err
}
return model, nil
}
func (c *launcherClient) launchSingleIntegration(ctx context.Context, name string, runner Runner, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
current := primaryModelFromConfig(saved)
target := req.ModelOverride
needsConfigure := req.ForceConfigure
if target == "" {
target = current
usable, err := c.savedModelUsable(ctx, target)
if err != nil {
return err
}
if !usable {
needsConfigure = true
}
}
if needsConfigure {
selected, err := c.selectSingleModelWithSelector(ctx, fmt.Sprintf("Select model for %s:", runner), target, DefaultSingleSelector)
if err != nil {
return err
}
target = selected
} else if err := c.ensureModelsReady(ctx, []string{target}); err != nil {
return err
}
if target == "" {
return nil
}
if err := config.SaveIntegration(name, []string{target}); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
return launchAfterConfiguration(name, runner, target, req)
}
func (c *launcherClient) launchEditorIntegration(ctx context.Context, name string, runner Runner, editor Editor, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
models, needsConfigure := c.resolveEditorLaunchModels(ctx, saved, req)
if needsConfigure {
selected, err := c.selectMultiModelsForIntegration(ctx, runner, models)
if err != nil {
return err
}
models = selected
} else if err := c.ensureModelsReady(ctx, models); err != nil {
return err
}
if len(models) == 0 {
return nil
}
if needsConfigure || req.ModelOverride != "" {
if err := prepareEditorIntegration(name, runner, editor, models); err != nil {
return err
}
}
return launchAfterConfiguration(name, runner, models[0], req)
}
func (c *launcherClient) selectSingleModelWithSelector(ctx context.Context, title, current string, selector SingleSelector) (string, error) {
if selector == nil {
return "", fmt.Errorf("no selector configured")
}
items, _, err := c.loadSelectableModels(ctx, nil, current, "no models available, run 'ollama pull <model>' first")
if err != nil {
return "", err
}
selected, err := selector(title, items, current)
if err != nil {
return "", err
}
if err := c.ensureModelsReady(ctx, []string{selected}); err != nil {
return "", err
}
return selected, nil
}
func (c *launcherClient) selectMultiModelsForIntegration(ctx context.Context, runner Runner, preChecked []string) ([]string, error) {
if DefaultMultiSelector == nil {
return nil, fmt.Errorf("no selector configured")
}
current := firstModel(preChecked)
items, orderedChecked, err := c.loadSelectableModels(ctx, preChecked, current, "no models available")
if err != nil {
return nil, err
}
if len(preChecked) > 0 {
// Keep list order stable in multi-select even when there are existing checks.
// checked/default state still comes from orderedChecked.
stableItems, _, stableErr := c.loadSelectableModels(ctx, nil, current, "no models available")
if stableErr != nil {
return nil, stableErr
}
items = stableItems
}
selected, err := DefaultMultiSelector(fmt.Sprintf("Select models for %s:", runner), items, orderedChecked)
if err != nil {
return nil, err
}
if err := c.ensureModelsReady(ctx, selected); err != nil {
return nil, err
}
return selected, nil
}
func (c *launcherClient) loadSelectableModels(ctx context.Context, preChecked []string, current, emptyMessage string) ([]ModelItem, []string, error) {
if err := c.loadModelInventoryOnce(ctx); err != nil {
return nil, nil, err
}
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
items, orderedChecked, _, _ := buildModelList(c.modelInventory, preChecked, current)
if cloudDisabled {
items = filterCloudItems(items)
orderedChecked = c.filterDisabledCloudModels(ctx, orderedChecked)
}
if len(items) == 0 {
return nil, nil, errors.New(emptyMessage)
}
return items, orderedChecked, nil
}
func (c *launcherClient) ensureModelsReady(ctx context.Context, models []string) error {
var deduped []string
seen := make(map[string]bool, len(models))
for _, model := range models {
if model == "" || seen[model] {
continue
}
seen[model] = true
deduped = append(deduped, model)
}
models = deduped
if len(models) == 0 {
return nil
}
cloudModels := make(map[string]bool, len(models))
for _, model := range models {
isCloudModel := isCloudModelName(model)
if isCloudModel {
cloudModels[model] = true
}
if err := showOrPullWithPolicy(ctx, c.apiClient, model, c.policy.missingModelPolicy(), isCloudModel); err != nil {
return err
}
}
return ensureAuth(ctx, c.apiClient, cloudModels, models)
}
func (c *launcherClient) resolveEditorLaunchModels(ctx context.Context, saved *config.IntegrationConfig, req IntegrationLaunchRequest) ([]string, bool) {
if req.ForceConfigure {
return editorPreCheckedModels(saved, req.ModelOverride), true
}
if req.ModelOverride != "" {
models := append([]string{req.ModelOverride}, additionalSavedModels(saved, req.ModelOverride)...)
models = c.filterDisabledCloudModels(ctx, models)
return models, len(models) == 0
}
if saved == nil || len(saved.Models) == 0 {
return nil, true
}
models := c.filterDisabledCloudModels(ctx, saved.Models)
return models, len(models) == 0
}
func (c *launcherClient) filterDisabledCloudModels(ctx context.Context, models []string) []string {
// if connection cannot be established or there is a 404, cloud models will continue to be displayed
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
if !cloudDisabled {
return append([]string(nil), models...)
}
filtered := make([]string, 0, len(models))
for _, model := range models {
if !isCloudModelName(model) {
filtered = append(filtered, model)
}
}
return filtered
}
func (c *launcherClient) savedModelUsable(ctx context.Context, name string) (bool, error) {
if err := c.loadModelInventoryOnce(ctx); err != nil {
return c.showBasedModelUsable(ctx, name)
}
return c.singleModelUsable(ctx, name), nil
}
func (c *launcherClient) showBasedModelUsable(ctx context.Context, name string) (bool, error) {
if name == "" {
return false, nil
}
info, err := c.apiClient.Show(ctx, &api.ShowRequest{Model: name})
if err != nil {
var statusErr api.StatusError
if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
return false, nil
}
return false, err
}
if isCloudModelName(name) || info.RemoteModel != "" {
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
return !cloudDisabled, nil
}
return true, nil
}
func (c *launcherClient) singleModelUsable(ctx context.Context, name string) bool {
if name == "" {
return false
}
if isCloudModelName(name) {
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
return !cloudDisabled
}
return c.hasLocalModel(name)
}
func (c *launcherClient) hasLocalModel(name string) bool {
for _, model := range c.modelInventory {
if model.Remote {
continue
}
if model.Name == name || strings.HasPrefix(model.Name, name+":") {
return true
}
}
return false
}
func (c *launcherClient) loadModelInventoryOnce(ctx context.Context) error {
if c.inventoryLoaded {
return nil
}
resp, err := c.apiClient.List(ctx)
if err != nil {
return err
}
c.modelInventory = c.modelInventory[:0]
for _, model := range resp.Models {
c.modelInventory = append(c.modelInventory, ModelInfo{
Name: model.Name,
Remote: model.RemoteModel != "",
})
}
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
if cloudDisabled {
c.modelInventory = filterCloudModels(c.modelInventory)
}
c.inventoryLoaded = true
return nil
}
func runIntegration(runner Runner, modelName string, args []string) error {
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", runner, modelName)
return runner.Run(modelName, args)
}
func launchAfterConfiguration(name string, runner Runner, model string, req IntegrationLaunchRequest) error {
if req.ConfigureOnly {
launch, err := ConfirmPrompt(fmt.Sprintf("Launch %s now?", runner))
if err != nil {
return err
}
if !launch {
return nil
}
}
if err := EnsureIntegrationInstalled(name, runner); err != nil {
return err
}
return runIntegration(runner, model, req.ExtraArgs)
}
func loadStoredIntegrationConfig(name string) (*config.IntegrationConfig, error) {
cfg, err := config.LoadIntegration(name)
if err == nil {
return cfg, nil
}
if !errors.Is(err, os.ErrNotExist) {
return nil, err
}
spec, specErr := LookupIntegrationSpec(name)
if specErr != nil {
return nil, err
}
for _, alias := range spec.Aliases {
legacy, legacyErr := config.LoadIntegration(alias)
if legacyErr == nil {
migrateLegacyIntegrationConfig(spec.Name, legacy)
if migrated, migratedErr := config.LoadIntegration(spec.Name); migratedErr == nil {
return migrated, nil
}
return legacy, nil
}
if legacyErr != nil && !errors.Is(legacyErr, os.ErrNotExist) {
return nil, legacyErr
}
}
return nil, err
}
func migrateLegacyIntegrationConfig(canonical string, legacy *config.IntegrationConfig) {
if legacy == nil {
return
}
_ = config.SaveIntegration(canonical, append([]string(nil), legacy.Models...))
if len(legacy.Aliases) > 0 {
_ = config.SaveAliases(canonical, cloneAliases(legacy.Aliases))
}
if legacy.Onboarded {
_ = config.MarkIntegrationOnboarded(canonical)
}
}
func primaryModelFromConfig(cfg *config.IntegrationConfig) string {
if cfg == nil || len(cfg.Models) == 0 {
return ""
}
return cfg.Models[0]
}
func cloneAliases(aliases map[string]string) map[string]string {
if len(aliases) == 0 {
return make(map[string]string)
}
cloned := make(map[string]string, len(aliases))
for key, value := range aliases {
cloned[key] = value
}
return cloned
}
func singleModelPrechecked(current string) []string {
if current == "" {
return nil
}
return []string{current}
}
func firstModel(models []string) string {
if len(models) == 0 {
return ""
}
return models[0]
}
func editorPreCheckedModels(saved *config.IntegrationConfig, override string) []string {
if override == "" {
if saved == nil {
return nil
}
return append([]string(nil), saved.Models...)
}
return append([]string{override}, additionalSavedModels(saved, override)...)
}
func additionalSavedModels(saved *config.IntegrationConfig, exclude string) []string {
if saved == nil {
return nil
}
var models []string
for _, model := range saved.Models {
if model != exclude {
models = append(models, model)
}
}
return models
}

1498
cmd/launch/launch_test.go Normal file

File diff suppressed because it is too large Load Diff

490
cmd/launch/models.go Normal file
View File

@@ -0,0 +1,490 @@
package launch
import (
"context"
"errors"
"fmt"
"net/http"
"os"
"os/exec"
"runtime"
"slices"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/cmd/internal/fileutil"
internalcloud "github.com/ollama/ollama/internal/cloud"
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/progress"
)
var recommendedModels = []ModelItem{
{Name: "kimi-k2.5:cloud", Description: "Multimodal reasoning with subagents", Recommended: true},
{Name: "qwen3.5:cloud", Description: "Reasoning, coding, and agentic tool use with vision", Recommended: true},
{Name: "glm-5:cloud", Description: "Reasoning and code generation", Recommended: true},
{Name: "minimax-m2.5:cloud", Description: "Fast, efficient coding and real-world productivity", Recommended: true},
{Name: "glm-4.7-flash", Description: "Reasoning and code generation locally", Recommended: true},
{Name: "qwen3.5", Description: "Reasoning, coding, and visual understanding locally", Recommended: true},
}
var recommendedVRAM = map[string]string{
"glm-4.7-flash": "~25GB",
"qwen3.5": "~11GB",
}
// cloudModelLimit holds context and output token limits for a cloud model.
type cloudModelLimit struct {
Context int
Output int
}
// cloudModelLimits maps cloud model base names to their token limits.
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
var cloudModelLimits = map[string]cloudModelLimit{
"minimax-m2.5": {Context: 204_800, Output: 128_000},
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
"glm-4.6": {Context: 202_752, Output: 131_072},
"glm-4.7": {Context: 202_752, Output: 131_072},
"glm-5": {Context: 202_752, Output: 131_072},
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
"kimi-k2.5": {Context: 262_144, Output: 262_144},
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
"qwen3.5": {Context: 262_144, Output: 32_768},
}
// lookupCloudModelLimit returns the token limits for a cloud model.
// It normalizes explicit cloud source suffixes before checking the shared limit map.
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
base, stripped := modelref.StripCloudSourceTag(name)
if stripped {
if l, ok := cloudModelLimits[base]; ok {
return l, true
}
}
return cloudModelLimit{}, false
}
// missingModelPolicy controls how model-not-found errors should be handled.
type missingModelPolicy int
const (
// missingModelPromptPull prompts the user to download missing local models.
missingModelPromptPull missingModelPolicy = iota
// missingModelAutoPull downloads missing local models without prompting.
missingModelAutoPull
// missingModelFail returns an error for missing local models without prompting.
missingModelFail
)
// OpenBrowser opens the URL in the user's browser.
func OpenBrowser(url string) {
switch runtime.GOOS {
case "darwin":
_ = exec.Command("open", url).Start()
case "linux":
_ = exec.Command("xdg-open", url).Start()
case "windows":
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
}
}
// ensureAuth ensures the user is signed in before cloud-backed models run.
func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error {
var selectedCloudModels []string
for _, m := range selected {
if cloudModels[m] {
selectedCloudModels = append(selectedCloudModels, m)
}
}
if len(selectedCloudModels) == 0 {
return nil
}
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
}
user, err := client.Whoami(ctx)
if err == nil && user != nil && user.Name != "" {
return nil
}
var aErr api.AuthorizationError
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
return err
}
modelList := strings.Join(selectedCloudModels, ", ")
if DefaultSignIn != nil {
_, err := DefaultSignIn(modelList, aErr.SigninURL)
if errors.Is(err, ErrCancelled) {
return ErrCancelled
}
if err != nil {
return fmt.Errorf("%s requires sign in", modelList)
}
return nil
}
yes, err := ConfirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
if errors.Is(err, ErrCancelled) {
return ErrCancelled
}
if err != nil {
return err
}
if !yes {
return ErrCancelled
}
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
OpenBrowser(aErr.SigninURL)
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\r\033[K")
return ctx.Err()
case <-ticker.C:
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
if frame%10 == 0 {
u, err := client.Whoami(ctx)
if err == nil && u != nil && u.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
return nil
}
}
}
}
}
// showOrPullWithPolicy checks if a model exists and applies the provided missing-model policy.
func showOrPullWithPolicy(ctx context.Context, client *api.Client, model string, policy missingModelPolicy, isCloudModel bool) error {
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
return nil
} else {
var statusErr api.StatusError
if !errors.As(err, &statusErr) || statusErr.StatusCode != http.StatusNotFound {
return err
}
}
if isCloudModel {
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
}
return fmt.Errorf("model %q not found", model)
}
switch policy {
case missingModelAutoPull:
return pullMissingModel(ctx, client, model)
case missingModelFail:
return fmt.Errorf("model %q not found; run 'ollama pull %s' first, or use --yes to auto-pull", model, model)
default:
return confirmAndPull(ctx, client, model)
}
}
func confirmAndPull(ctx context.Context, client *api.Client, model string) error {
if ok, err := ConfirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
return err
} else if !ok {
return errCancelled
}
fmt.Fprintf(os.Stderr, "\n")
return pullMissingModel(ctx, client, model)
}
func pullMissingModel(ctx context.Context, client *api.Client, model string) error {
if err := pullModel(ctx, client, model, false); err != nil {
return fmt.Errorf("failed to pull %s: %w", model, err)
}
return nil
}
// prepareEditorIntegration persists models and applies editor-managed config files.
func prepareEditorIntegration(name string, runner Runner, editor Editor, models []string) error {
if ok, err := confirmEditorEdit(runner, editor); err != nil {
return err
} else if !ok {
return errCancelled
}
if err := editor.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
if err := config.SaveIntegration(name, models); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
return nil
}
func confirmEditorEdit(runner Runner, editor Editor) (bool, error) {
paths := editor.Paths()
if len(paths) == 0 {
return true, nil
}
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", runner)
for _, path := range paths {
fmt.Fprintf(os.Stderr, " %s\n", path)
}
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", fileutil.BackupDir())
return ConfirmPrompt("Proceed?")
}
// buildModelList merges existing models with recommendations for selection UIs.
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
existingModels = make(map[string]bool)
cloudModels = make(map[string]bool)
recommended := make(map[string]bool)
var hasLocalModel, hasCloudModel bool
recDesc := make(map[string]string)
for _, rec := range recommendedModels {
recommended[rec.Name] = true
recDesc[rec.Name] = rec.Description
}
for _, m := range existing {
existingModels[m.Name] = true
if m.Remote {
cloudModels[m.Name] = true
hasCloudModel = true
} else {
hasLocalModel = true
}
displayName := strings.TrimSuffix(m.Name, ":latest")
existingModels[displayName] = true
item := ModelItem{Name: displayName, Recommended: recommended[displayName], Description: recDesc[displayName]}
items = append(items, item)
}
for _, rec := range recommendedModels {
if existingModels[rec.Name] || existingModels[rec.Name+":latest"] {
continue
}
items = append(items, rec)
if isCloudModelName(rec.Name) {
cloudModels[rec.Name] = true
}
}
checked := make(map[string]bool, len(preChecked))
for _, n := range preChecked {
checked[n] = true
}
if current != "" {
matchedCurrent := false
for _, item := range items {
if item.Name == current {
current = item.Name
matchedCurrent = true
break
}
}
if !matchedCurrent {
for _, item := range items {
if strings.HasPrefix(item.Name, current+":") {
current = item.Name
break
}
}
}
}
if checked[current] {
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
}
notInstalled := make(map[string]bool)
for i := range items {
if !existingModels[items[i].Name] && !cloudModels[items[i].Name] {
notInstalled[items[i].Name] = true
var parts []string
if items[i].Description != "" {
parts = append(parts, items[i].Description)
}
if vram := recommendedVRAM[items[i].Name]; vram != "" {
parts = append(parts, vram)
}
parts = append(parts, "(not downloaded)")
items[i].Description = strings.Join(parts, ", ")
}
}
recRank := make(map[string]int)
for i, rec := range recommendedModels {
recRank[rec.Name] = i + 1
}
onlyLocal := hasLocalModel && !hasCloudModel
if hasLocalModel || hasCloudModel {
slices.SortStableFunc(items, func(a, b ModelItem) int {
ac, bc := checked[a.Name], checked[b.Name]
aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
aRec, bRec := recRank[a.Name] > 0, recRank[b.Name] > 0
aCloud, bCloud := cloudModels[a.Name], cloudModels[b.Name]
if ac != bc {
if ac {
return -1
}
return 1
}
if aRec != bRec {
if aRec {
return -1
}
return 1
}
if aRec && bRec {
if aCloud != bCloud {
if onlyLocal {
if aCloud {
return 1
}
return -1
}
if aCloud {
return -1
}
return 1
}
return recRank[a.Name] - recRank[b.Name]
}
if aNew != bNew {
if aNew {
return 1
}
return -1
}
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
})
}
return items, preChecked, existingModels, cloudModels
}
// isCloudModelName reports whether the model name has an explicit cloud source.
func isCloudModelName(name string) bool {
return modelref.HasExplicitCloudSource(name)
}
// filterCloudModels drops remote-only models from the given inventory.
func filterCloudModels(existing []modelInfo) []modelInfo {
filtered := existing[:0]
for _, m := range existing {
if !m.Remote {
filtered = append(filtered, m)
}
}
return filtered
}
// filterCloudItems removes cloud models from selection items.
func filterCloudItems(items []ModelItem) []ModelItem {
filtered := items[:0]
for _, item := range items {
if !isCloudModelName(item.Name) {
filtered = append(filtered, item)
}
}
return filtered
}
func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
if client == nil {
return false
}
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
if err != nil {
return false
}
return resp.RemoteModel != ""
}
// cloudStatusDisabled returns whether cloud usage is currently disabled.
func cloudStatusDisabled(ctx context.Context, client *api.Client) (disabled bool, known bool) {
status, err := client.CloudStatusExperimental(ctx)
if err != nil {
var statusErr api.StatusError
if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
return false, false
}
return false, false
}
return status.Cloud.Disabled, true
}
// TODO(parthsareen): this duplicates the pull progress UI in cmd.PullHandler.
// Move the shared pull rendering to a small utility once the package boundary settles.
func pullModel(ctx context.Context, client *api.Client, model string, insecure bool) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()
bars := make(map[string]*progress.Bar)
var status string
var spinner *progress.Spinner
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
if resp.Completed == 0 {
return nil
}
if spinner != nil {
spinner.Stop()
}
bar, ok := bars[resp.Digest]
if !ok {
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
name = strings.TrimSpace(name)
if isDigest {
name = name[:min(12, len(name))]
}
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
bar.Set(resp.Completed)
} else if status != resp.Status {
if spinner != nil {
spinner.Stop()
}
status = resp.Status
spinner = progress.NewSpinner(status)
p.Add(status, spinner)
}
return nil
}
request := api.PullRequest{Name: model, Insecure: insecure}
return client.Pull(ctx, &request, fn)
}

890
cmd/launch/openclaw.go Normal file
View File

@@ -0,0 +1,890 @@
package launch
import (
"context"
"encoding/json"
"fmt"
"net"
"net/url"
"os"
"os/exec"
"path/filepath"
"runtime"
"slices"
"strings"
"time"
"golang.org/x/mod/semver"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
const defaultGatewayPort = 18789
// Bound model capability probing so launch/config cannot hang on slow/unreachable API calls.
var openclawModelShowTimeout = 5 * time.Second
// openclawFreshInstall is set to true when ensureOpenclawInstalled performs an install
var openclawFreshInstall bool
type Openclaw struct{}
func (c *Openclaw) String() string { return "OpenClaw" }
func (c *Openclaw) Run(model string, args []string) error {
bin, err := ensureOpenclawInstalled()
if err != nil {
return err
}
firstLaunch := !c.onboarded()
if firstLaunch {
fmt.Fprintf(os.Stderr, "\n%sSecurity%s\n\n", ansiBold, ansiReset)
fmt.Fprintf(os.Stderr, " OpenClaw can read files and run actions when tools are enabled.\n")
fmt.Fprintf(os.Stderr, " A bad prompt can trick it into doing unsafe things.\n\n")
fmt.Fprintf(os.Stderr, "%s Learn more: https://docs.openclaw.ai/gateway/security%s\n\n", ansiGray, ansiReset)
ok, err := ConfirmPrompt("I understand the risks. Continue?")
if err != nil {
return err
}
if !ok {
return nil
}
// Ensure the latest version is installed before onboarding so we get
// the newest wizard flags (e.g. --auth-choice ollama).
if !openclawFreshInstall {
update := exec.Command(bin, "update")
update.Stdout = os.Stdout
update.Stderr = os.Stderr
_ = update.Run() // best-effort; continue even if update fails
}
fmt.Fprintf(os.Stderr, "\n%sSetting up OpenClaw with Ollama...%s\n", ansiGreen, ansiReset)
fmt.Fprintf(os.Stderr, "%s Model: %s%s\n\n", ansiGray, model, ansiReset)
onboardArgs := []string{
"onboard",
"--non-interactive",
"--accept-risk",
"--auth-choice", "ollama",
"--custom-base-url", envconfig.Host().String(),
"--custom-model-id", model,
"--skip-channels",
"--skip-skills",
}
if canInstallDaemon() {
onboardArgs = append(onboardArgs, "--install-daemon")
}
cmd := exec.Command(bin, onboardArgs...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return windowsHint(fmt.Errorf("openclaw onboarding failed: %w\n\nTry running: openclaw onboard", err))
}
patchDeviceScopes()
}
if ensureWebSearchPlugin() {
registerWebSearchPlugin()
}
fmt.Fprintf(os.Stderr, "\n%sStarting your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
// When extra args are passed through, run exactly what the user asked for
// after setup and skip the built-in gateway+TUI convenience flow.
if len(args) > 0 {
cmd := exec.Command(bin, args...)
cmd.Env = openclawEnv()
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return windowsHint(err)
}
return nil
}
token, port := c.gatewayInfo()
addr := fmt.Sprintf("localhost:%d", port)
// If the gateway is already running (e.g. via the daemon), restart it
// so it picks up any config changes (model, provider, etc.).
if portOpen(addr) {
restart := exec.Command(bin, "daemon", "restart")
restart.Env = openclawEnv()
if err := restart.Run(); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: daemon restart failed: %v%s\n", ansiYellow, err, ansiReset)
}
if !waitForPort(addr, 10*time.Second) {
fmt.Fprintf(os.Stderr, "%s Warning: gateway did not come back after restart%s\n", ansiYellow, ansiReset)
}
}
// If the gateway isn't running, start it as a background child process.
if !portOpen(addr) {
gw := exec.Command(bin, "gateway", "run", "--force")
gw.Env = openclawEnv()
if err := gw.Start(); err != nil {
return windowsHint(fmt.Errorf("failed to start gateway: %w", err))
}
defer func() {
if gw.Process != nil {
_ = gw.Process.Kill()
_ = gw.Wait()
}
}()
}
fmt.Fprintf(os.Stderr, "%sStarting gateway...%s\n", ansiGray, ansiReset)
if !waitForPort(addr, 30*time.Second) {
return windowsHint(fmt.Errorf("gateway did not start on %s", addr))
}
printOpenclawReady(bin, token, port, firstLaunch)
tuiArgs := []string{"tui"}
if firstLaunch {
tuiArgs = append(tuiArgs, "--message", "Wake up, my friend!")
}
tui := exec.Command(bin, tuiArgs...)
tui.Env = openclawEnv()
tui.Stdin = os.Stdin
tui.Stdout = os.Stdout
tui.Stderr = os.Stderr
if err := tui.Run(); err != nil {
return windowsHint(err)
}
return nil
}
// gatewayInfo reads the gateway auth token and port from the OpenClaw config.
func (c *Openclaw) gatewayInfo() (token string, port int) {
port = defaultGatewayPort
home, err := os.UserHomeDir()
if err != nil {
return "", port
}
for _, path := range []string{
filepath.Join(home, ".openclaw", "openclaw.json"),
filepath.Join(home, ".clawdbot", "clawdbot.json"),
} {
data, err := os.ReadFile(path)
if err != nil {
continue
}
var config map[string]any
if json.Unmarshal(data, &config) != nil {
continue
}
gw, _ := config["gateway"].(map[string]any)
if p, ok := gw["port"].(float64); ok && p > 0 {
port = int(p)
}
auth, _ := gw["auth"].(map[string]any)
if t, _ := auth["token"].(string); t != "" {
token = t
}
return token, port
}
return "", port
}
func printOpenclawReady(bin, token string, port int, firstLaunch bool) {
u := fmt.Sprintf("http://localhost:%d", port)
if token != "" {
u += "/#token=" + url.QueryEscape(token)
}
fmt.Fprintf(os.Stderr, "\n%s✓ OpenClaw is running%s\n\n", ansiGreen, ansiReset)
fmt.Fprintf(os.Stderr, " Open the Web UI:\n")
fmt.Fprintf(os.Stderr, " %s\n\n", hyperlink(u, u))
if firstLaunch {
fmt.Fprintf(os.Stderr, "%s Quick start:%s\n", ansiBold, ansiReset)
fmt.Fprintf(os.Stderr, "%s /help see all commands%s\n", ansiGray, ansiReset)
fmt.Fprintf(os.Stderr, "%s %s configure --section channels connect WhatsApp, Telegram, etc.%s\n", ansiGray, bin, ansiReset)
fmt.Fprintf(os.Stderr, "%s %s skills browse and install skills%s\n\n", ansiGray, bin, ansiReset)
fmt.Fprintf(os.Stderr, "%s The OpenClaw gateway is running in the background.%s\n", ansiYellow, ansiReset)
fmt.Fprintf(os.Stderr, "%s Stop it with: %s gateway stop%s\n\n", ansiYellow, bin, ansiReset)
} else {
fmt.Fprintf(os.Stderr, "%sTip: connect WhatsApp, Telegram, and more with: %s configure --section channels%s\n", ansiGray, bin, ansiReset)
}
}
// openclawEnv returns the current environment with provider API keys cleared
// so openclaw only uses the Ollama gateway, not keys from the user's shell.
func openclawEnv() []string {
clear := map[string]bool{
"ANTHROPIC_API_KEY": true,
"ANTHROPIC_OAUTH_TOKEN": true,
"OPENAI_API_KEY": true,
"GEMINI_API_KEY": true,
"MISTRAL_API_KEY": true,
"GROQ_API_KEY": true,
"XAI_API_KEY": true,
"OPENROUTER_API_KEY": true,
}
var env []string
for _, e := range os.Environ() {
key, _, _ := strings.Cut(e, "=")
if !clear[key] {
env = append(env, e)
}
}
return env
}
// portOpen checks if a TCP port is currently accepting connections.
func portOpen(addr string) bool {
conn, err := net.DialTimeout("tcp", addr, 500*time.Millisecond)
if err != nil {
return false
}
conn.Close()
return true
}
func waitForPort(addr string, timeout time.Duration) bool {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
conn, err := net.DialTimeout("tcp", addr, 500*time.Millisecond)
if err == nil {
conn.Close()
return true
}
time.Sleep(250 * time.Millisecond)
}
return false
}
func windowsHint(err error) error {
if runtime.GOOS != "windows" {
return err
}
return fmt.Errorf("%w\n\n"+
"OpenClaw runs best on WSL2.\n"+
"Quick setup: wsl --install\n"+
"Guide: https://docs.openclaw.ai/windows", err)
}
// onboarded checks if OpenClaw onboarding wizard was completed
// by looking for the wizard.lastRunAt marker in the config
func (c *Openclaw) onboarded() bool {
home, err := os.UserHomeDir()
if err != nil {
return false
}
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config)
} else if data, err := os.ReadFile(legacyPath); err == nil {
_ = json.Unmarshal(data, &config)
} else {
return false
}
// Check for wizard.lastRunAt marker (set when onboarding completes)
wizard, _ := config["wizard"].(map[string]any)
if wizard == nil {
return false
}
lastRunAt, _ := wizard["lastRunAt"].(string)
return lastRunAt != ""
}
// patchDeviceScopes upgrades the local CLI device's paired scopes to include
// operator.admin. Only patches the local device, not remote ones.
// Best-effort: silently returns on any error.
func patchDeviceScopes() {
home, err := os.UserHomeDir()
if err != nil {
return
}
deviceID := readLocalDeviceID(home)
if deviceID == "" {
return
}
path := filepath.Join(home, ".openclaw", "devices", "paired.json")
data, err := os.ReadFile(path)
if err != nil {
return
}
var devices map[string]map[string]any
if err := json.Unmarshal(data, &devices); err != nil {
return
}
dev, ok := devices[deviceID]
if !ok {
return
}
required := []string{
"operator.read",
"operator.admin",
"operator.approvals",
"operator.pairing",
}
changed := patchScopes(dev, "scopes", required)
if tokens, ok := dev["tokens"].(map[string]any); ok {
for _, tok := range tokens {
if tokenMap, ok := tok.(map[string]any); ok {
if patchScopes(tokenMap, "scopes", required) {
changed = true
}
}
}
}
if !changed {
return
}
out, err := json.MarshalIndent(devices, "", " ")
if err != nil {
return
}
_ = os.WriteFile(path, out, 0o600)
}
// readLocalDeviceID reads the local device ID from openclaw's identity file.
func readLocalDeviceID(home string) string {
data, err := os.ReadFile(filepath.Join(home, ".openclaw", "identity", "device-auth.json"))
if err != nil {
return ""
}
var auth map[string]any
if err := json.Unmarshal(data, &auth); err != nil {
return ""
}
id, _ := auth["deviceId"].(string)
return id
}
// patchScopes ensures obj[key] contains all required scopes. Returns true if
// any scopes were added.
func patchScopes(obj map[string]any, key string, required []string) bool {
existing, _ := obj[key].([]any)
have := make(map[string]bool, len(existing))
for _, s := range existing {
if str, ok := s.(string); ok {
have[str] = true
}
}
added := false
for _, s := range required {
if !have[s] {
existing = append(existing, s)
added = true
}
}
if added {
obj[key] = existing
}
return added
}
// canInstallDaemon reports whether the openclaw daemon can be installed as a
// background service. Returns false on Linux when systemd is absent (e.g.
// containers) so that --install-daemon is omitted and the gateway is started
// as a foreground child process instead. Returns true in all other cases.
func canInstallDaemon() bool {
if runtime.GOOS != "linux" {
return true
}
// /run/systemd/system exists as a directory when systemd is the init system.
// This is absent in most containers.
fi, err := os.Stat("/run/systemd/system")
if err != nil || !fi.IsDir() {
return false
}
// Even when systemd is the init system, user services require a user
// manager instance. XDG_RUNTIME_DIR being set is a prerequisite.
return os.Getenv("XDG_RUNTIME_DIR") != ""
}
func ensureOpenclawInstalled() (string, error) {
if _, err := exec.LookPath("openclaw"); err == nil {
return "openclaw", nil
}
if _, err := exec.LookPath("clawdbot"); err == nil {
return "clawdbot", nil
}
if _, err := exec.LookPath("npm"); err != nil {
return "", fmt.Errorf("openclaw is not installed and npm was not found\n\n" +
"Install Node.js first:\n" +
" https://nodejs.org/\n\n" +
"Then rerun:\n" +
" ollama launch\n" +
"and select OpenClaw")
}
ok, err := ConfirmPrompt("OpenClaw is not installed. Install with npm?")
if err != nil {
return "", err
}
if !ok {
return "", fmt.Errorf("openclaw installation cancelled")
}
fmt.Fprintf(os.Stderr, "\nInstalling OpenClaw...\n")
cmd := exec.Command("npm", "install", "-g", "openclaw@latest")
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("failed to install openclaw: %w", err)
}
if _, err := exec.LookPath("openclaw"); err != nil {
return "", fmt.Errorf("openclaw was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
}
fmt.Fprintf(os.Stderr, "%sOpenClaw installed successfully%s\n\n", ansiGreen, ansiReset)
openclawFreshInstall = true
return "openclaw", nil
}
func (c *Openclaw) Paths() []string {
home, _ := os.UserHomeDir()
p := filepath.Join(home, ".openclaw", "openclaw.json")
if _, err := os.Stat(p); err == nil {
return []string{p}
}
legacy := filepath.Join(home, ".clawdbot", "clawdbot.json")
if _, err := os.Stat(legacy); err == nil {
return []string{legacy}
}
return nil
}
func (c *Openclaw) Edit(models []string) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
// Read into map[string]any to preserve unknown fields
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config)
} else if data, err := os.ReadFile(legacyPath); err == nil {
_ = json.Unmarshal(data, &config)
}
// Navigate/create: models.providers.ollama (preserving other providers)
modelsSection, _ := config["models"].(map[string]any)
if modelsSection == nil {
modelsSection = make(map[string]any)
}
providers, _ := modelsSection["providers"].(map[string]any)
if providers == nil {
providers = make(map[string]any)
}
ollama, _ := providers["ollama"].(map[string]any)
if ollama == nil {
ollama = make(map[string]any)
}
ollama["baseUrl"] = envconfig.Host().String()
// needed to register provider
ollama["apiKey"] = "ollama-local"
ollama["api"] = "ollama"
// Build map of existing models to preserve user customizations
existingModels, _ := ollama["models"].([]any)
existingByID := make(map[string]map[string]any)
for _, m := range existingModels {
if entry, ok := m.(map[string]any); ok {
if id, ok := entry["id"].(string); ok {
existingByID[id] = entry
}
}
}
client, _ := api.ClientFromEnvironment()
var newModels []any
for _, m := range models {
entry, _ := openclawModelConfig(context.Background(), client, m)
// Merge existing fields (user customizations)
if existing, ok := existingByID[m]; ok {
for k, v := range existing {
if _, isNew := entry[k]; !isNew {
entry[k] = v
}
}
}
newModels = append(newModels, entry)
}
ollama["models"] = newModels
providers["ollama"] = ollama
modelsSection["providers"] = providers
config["models"] = modelsSection
// Update agents.defaults.model.primary (preserving other agent settings)
agents, _ := config["agents"].(map[string]any)
if agents == nil {
agents = make(map[string]any)
}
defaults, _ := agents["defaults"].(map[string]any)
if defaults == nil {
defaults = make(map[string]any)
}
modelConfig, _ := defaults["model"].(map[string]any)
if modelConfig == nil {
modelConfig = make(map[string]any)
}
modelConfig["primary"] = "ollama/" + models[0]
defaults["model"] = modelConfig
agents["defaults"] = defaults
config["agents"] = agents
data, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
if err := fileutil.WriteWithBackup(configPath, data); err != nil {
return err
}
// Clear any per-session model overrides so the new primary takes effect
// immediately rather than being shadowed by a cached modelOverride.
clearSessionModelOverride(models[0])
return nil
}
// clearSessionModelOverride removes per-session model overrides from the main
// agent session so the global primary model takes effect on the next TUI launch.
func clearSessionModelOverride(primary string) {
home, err := os.UserHomeDir()
if err != nil {
return
}
path := filepath.Join(home, ".openclaw", "agents", "main", "sessions", "sessions.json")
data, err := os.ReadFile(path)
if err != nil {
return
}
var sessions map[string]map[string]any
if json.Unmarshal(data, &sessions) != nil {
return
}
changed := false
for _, sess := range sessions {
if override, _ := sess["modelOverride"].(string); override != "" && override != primary {
delete(sess, "modelOverride")
delete(sess, "providerOverride")
sess["model"] = primary
changed = true
}
}
if !changed {
return
}
out, err := json.MarshalIndent(sessions, "", " ")
if err != nil {
return
}
_ = os.WriteFile(path, out, 0o600)
}
const (
webSearchNpmPackage = "@ollama/openclaw-web-search"
webSearchMinVersion = "0.2.1"
)
// ensureWebSearchPlugin installs the openclaw-web-search extension into the
// user-level extensions directory (~/.openclaw/extensions/) if it isn't already
// present, or re-installs if the installed version is older than webSearchMinVersion.
// Returns true if the extension is available.
func ensureWebSearchPlugin() bool {
home, err := os.UserHomeDir()
if err != nil {
return false
}
pluginDir := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
if webSearchPluginUpToDate(pluginDir) {
return true
}
npmBin, err := exec.LookPath("npm")
if err != nil {
return false
}
if err := os.MkdirAll(pluginDir, 0o755); err != nil {
return false
}
// Download the tarball via `npm pack`, extract it flat into the plugin dir.
pack := exec.Command(npmBin, "pack", webSearchNpmPackage, "--pack-destination", pluginDir)
out, err := pack.Output()
if err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not download web search plugin: %v%s\n", ansiYellow, err, ansiReset)
return false
}
tgzName := strings.TrimSpace(string(out))
tgzPath := filepath.Join(pluginDir, tgzName)
defer os.Remove(tgzPath)
tar := exec.Command("tar", "xzf", tgzPath, "--strip-components=1", "-C", pluginDir)
if err := tar.Run(); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not extract web search plugin: %v%s\n", ansiYellow, err, ansiReset)
return false
}
fmt.Fprintf(os.Stderr, "%s ✓ Installed web search plugin%s\n", ansiGreen, ansiReset)
return true
}
// webSearchPluginUpToDate returns true if the plugin is installed and its
// package.json version is >= webSearchMinVersion.
func webSearchPluginUpToDate(pluginDir string) bool {
data, err := os.ReadFile(filepath.Join(pluginDir, "package.json"))
if err != nil {
return false
}
var pkg struct {
Version string `json:"version"`
}
if json.Unmarshal(data, &pkg) != nil || pkg.Version == "" {
return false
}
return !versionLessThan(pkg.Version, webSearchMinVersion)
}
// versionLessThan compares two semver version strings (major.minor.patch).
// Inputs may omit the "v" prefix; it is added automatically for semver.Compare.
func versionLessThan(a, b string) bool {
if !strings.HasPrefix(a, "v") {
a = "v" + a
}
if !strings.HasPrefix(b, "v") {
b = "v" + b
}
return semver.Compare(a, b) < 0
}
// registerWebSearchPlugin adds plugins.entries.openclaw-web-search to the OpenClaw
// config so the gateway activates it on next start. Best-effort; silently returns
// on any error.
func registerWebSearchPlugin() {
home, err := os.UserHomeDir()
if err != nil {
return
}
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
data, err := os.ReadFile(configPath)
if err != nil {
return
}
var config map[string]any
if json.Unmarshal(data, &config) != nil {
return
}
plugins, _ := config["plugins"].(map[string]any)
if plugins == nil {
plugins = make(map[string]any)
}
entries, _ := plugins["entries"].(map[string]any)
if entries == nil {
entries = make(map[string]any)
}
entries["openclaw-web-search"] = map[string]any{"enabled": true}
plugins["entries"] = entries
// Pin trust so the gateway doesn't warn about untracked plugins.
allow, _ := plugins["allow"].([]any)
hasAllow := false
for _, v := range allow {
if s, ok := v.(string); ok && s == "openclaw-web-search" {
hasAllow = true
break
}
}
if !hasAllow {
allow = append(allow, "openclaw-web-search")
}
plugins["allow"] = allow
// Record install provenance so the loader can verify the plugin origin.
installs, _ := plugins["installs"].(map[string]any)
if installs == nil {
installs = make(map[string]any)
}
pluginDir := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
installs["openclaw-web-search"] = map[string]any{
"source": "npm",
"spec": webSearchNpmPackage,
"installPath": pluginDir,
}
plugins["installs"] = installs
config["plugins"] = plugins
// Add plugin tools to tools.alsoAllow so they survive the coding profile's
// policy pipeline (which has an explicit allow list of core tools only).
tools, _ := config["tools"].(map[string]any)
if tools == nil {
tools = make(map[string]any)
}
alsoAllow, _ := tools["alsoAllow"].([]any)
needed := []string{"ollama_web_search", "ollama_web_fetch"}
have := make(map[string]bool, len(alsoAllow))
for _, v := range alsoAllow {
if s, ok := v.(string); ok {
have[s] = true
}
}
for _, name := range needed {
if !have[name] {
alsoAllow = append(alsoAllow, name)
}
}
tools["alsoAllow"] = alsoAllow
// Disable built-in web search/fetch since our plugin replaces them.
web, _ := tools["web"].(map[string]any)
if web == nil {
web = make(map[string]any)
}
web["search"] = map[string]any{"enabled": false}
web["fetch"] = map[string]any{"enabled": false}
tools["web"] = web
config["tools"] = tools
out, err := json.MarshalIndent(config, "", " ")
if err != nil {
return
}
_ = os.WriteFile(configPath, out, 0o600)
}
// openclawModelConfig builds an OpenClaw model config entry with capability detection.
// The second return value indicates whether the model is a cloud (remote) model.
func openclawModelConfig(ctx context.Context, client *api.Client, modelID string) (map[string]any, bool) {
entry := map[string]any{
"id": modelID,
"name": modelID,
"input": []any{"text"},
"cost": map[string]any{
"input": 0,
"output": 0,
"cacheRead": 0,
"cacheWrite": 0,
},
}
if client == nil {
return entry, false
}
showCtx := ctx
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
var cancel context.CancelFunc
showCtx, cancel = context.WithTimeout(ctx, openclawModelShowTimeout)
defer cancel()
}
resp, err := client.Show(showCtx, &api.ShowRequest{Model: modelID})
if err != nil {
return entry, false
}
// Set input types based on vision capability
if slices.Contains(resp.Capabilities, model.CapabilityVision) {
entry["input"] = []any{"text", "image"}
}
// Set reasoning based on thinking capability
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
entry["reasoning"] = true
}
// Cloud models: use hardcoded limits for context/output tokens.
// Capability detection above still applies (vision, thinking).
if resp.RemoteModel != "" {
if l, ok := lookupCloudModelLimit(modelID); ok {
entry["contextWindow"] = l.Context
entry["maxTokens"] = l.Output
}
return entry, true
}
// Extract context window from ModelInfo (local models only)
for key, val := range resp.ModelInfo {
if strings.HasSuffix(key, ".context_length") {
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
entry["contextWindow"] = int(ctxLen)
}
break
}
}
return entry, false
}
func (c *Openclaw) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
config, err := fileutil.ReadJSON(filepath.Join(home, ".openclaw", "openclaw.json"))
if err != nil {
config, err = fileutil.ReadJSON(filepath.Join(home, ".clawdbot", "clawdbot.json"))
if err != nil {
return nil
}
}
modelsSection, _ := config["models"].(map[string]any)
providers, _ := modelsSection["providers"].(map[string]any)
ollama, _ := providers["ollama"].(map[string]any)
modelList, _ := ollama["models"].([]any)
var result []string
for _, m := range modelList {
if entry, ok := m.(map[string]any); ok {
if id, ok := entry["id"].(string); ok {
result = append(result, id)
}
}
}
return result
}

1770
cmd/launch/openclaw_test.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,9 +1,7 @@
package config
package launch
import (
"context"
"encoding/json"
"errors"
"fmt"
"maps"
"os"
@@ -12,34 +10,13 @@ import (
"slices"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig"
)
// OpenCode implements Runner and Editor for OpenCode integration
type OpenCode struct{}
// cloudModelLimit holds context and output token limits for a cloud model.
type cloudModelLimit struct {
Context int
Output int
}
// lookupCloudModelLimit returns the token limits for a cloud model.
// It tries the exact name first, then strips the ":cloud" suffix.
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
if l, ok := cloudModelLimits[name]; ok {
return l, true
}
base := strings.TrimSuffix(name, ":cloud")
if base != name {
if l, ok := cloudModelLimits[base]; ok {
return l, true
}
}
return cloudModelLimit{}, false
}
func (o *OpenCode) String() string { return "OpenCode" }
func (o *OpenCode) Run(model string, args []string) error {
@@ -47,25 +24,6 @@ func (o *OpenCode) Run(model string, args []string) error {
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
}
// Call Edit() to ensure config is up-to-date before launch
models := []string{model}
if config, err := loadIntegration("opencode"); err == nil && len(config.Models) > 0 {
models = config.Models
}
var err error
models, err = resolveEditorModels("opencode", models, func() ([]string, error) {
return selectModels(context.Background(), "opencode", "")
})
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
if err := o.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("opencode", args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
@@ -122,13 +80,18 @@ func (o *OpenCode) Edit(modelList []string) error {
if !ok {
ollama = map[string]any{
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama (local)",
"name": "Ollama",
"options": map[string]any{
"baseURL": envconfig.Host().String() + "/v1",
},
}
}
// Migrate legacy provider name
if name, _ := ollama["name"].(string); name == "Ollama (local)" {
ollama["name"] = "Ollama"
}
models, ok := ollama["models"].(map[string]any)
if !ok {
models = make(map[string]any)
@@ -147,8 +110,6 @@ func (o *OpenCode) Edit(modelList []string) error {
}
}
client, _ := api.ClientFromEnvironment()
for _, model := range modelList {
if existing, ok := models[model].(map[string]any); ok {
// migrate existing models without _launch marker
@@ -158,7 +119,7 @@ func (o *OpenCode) Edit(modelList []string) error {
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
}
}
if isCloudModel(context.Background(), client, model) {
if isCloudModelName(model) {
if l, ok := lookupCloudModelLimit(model); ok {
existing["limit"] = map[string]any{
"context": l.Context,
@@ -172,7 +133,7 @@ func (o *OpenCode) Edit(modelList []string) error {
"name": model,
"_launch": true,
}
if isCloudModel(context.Background(), client, model) {
if isCloudModelName(model) {
if l, ok := lookupCloudModelLimit(model); ok {
entry["limit"] = map[string]any{
"context": l.Context,
@@ -191,7 +152,7 @@ func (o *OpenCode) Edit(modelList []string) error {
if err != nil {
return err
}
if err := writeWithBackup(configPath, configData); err != nil {
if err := fileutil.WriteWithBackup(configPath, configData); err != nil {
return err
}
@@ -243,7 +204,7 @@ func (o *OpenCode) Edit(modelList []string) error {
if err != nil {
return err
}
return writeWithBackup(statePath, stateData)
return fileutil.WriteWithBackup(statePath, stateData)
}
func (o *OpenCode) Models() []string {
@@ -251,7 +212,7 @@ func (o *OpenCode) Models() []string {
if err != nil {
return nil
}
config, err := readJSONFile(filepath.Join(home, ".config", "opencode", "opencode.json"))
config, err := fileutil.ReadJSON(filepath.Join(home, ".config", "opencode", "opencode.json"))
if err != nil {
return nil
}

View File

@@ -1,8 +1,10 @@
package config
package launch
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
@@ -232,6 +234,44 @@ func TestOpenCodeEdit(t *testing.T) {
}
})
t.Run("migrate Ollama (local) provider name", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"name":"Ollama (local)","npm":"@ai-sdk/openai-compatible","options":{"baseURL":"http://localhost:11434/v1"}}}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider := cfg["provider"].(map[string]any)
ollama := provider["ollama"].(map[string]any)
if ollama["name"] != "Ollama" {
t.Errorf("provider name not migrated: got %q, want %q", ollama["name"], "Ollama")
}
})
t.Run("preserve custom provider name", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"name":"My Custom Ollama","npm":"@ai-sdk/openai-compatible","options":{"baseURL":"http://localhost:11434/v1"}}}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider := cfg["provider"].(map[string]any)
ollama := provider["ollama"].(map[string]any)
if ollama["name"] != "My Custom Ollama" {
t.Errorf("custom provider name was changed: got %q, want %q", ollama["name"], "My Custom Ollama")
}
})
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
@@ -619,6 +659,54 @@ func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
}
}
func TestOpenCodeEdit_BackfillsCloudModelLimitOnExistingEntry(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"glm-5"}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{
"provider": {
"ollama": {
"models": {
"glm-5:cloud": {
"name": "glm-5:cloud",
"_launch": true
}
}
}
}
}`), 0o644)
if err := o.Edit([]string{"glm-5:cloud"}); err != nil {
t.Fatal(err)
}
entry := readOpenCodeModel(t, configPath, "glm-5:cloud")
limit, ok := entry["limit"].(map[string]any)
if !ok {
t.Fatal("cloud model limit was not added on re-edit")
}
if limit["context"] != float64(202_752) {
t.Errorf("context = %v, want 202752", limit["context"])
}
if limit["output"] != float64(131_072) {
t.Errorf("output = %v, want 131072", limit["output"])
}
}
func TestLookupCloudModelLimit(t *testing.T) {
tests := []struct {
name string
@@ -626,13 +714,19 @@ func TestLookupCloudModelLimit(t *testing.T) {
wantContext int
wantOutput int
}{
{"glm-4.7", true, 202_752, 131_072},
{"glm-4.7", false, 0, 0},
{"glm-4.7:cloud", true, 202_752, 131_072},
{"kimi-k2.5", true, 262_144, 262_144},
{"glm-5:cloud", true, 202_752, 131_072},
{"gpt-oss:120b-cloud", true, 131_072, 131_072},
{"gpt-oss:20b-cloud", true, 131_072, 131_072},
{"kimi-k2.5", false, 0, 0},
{"kimi-k2.5:cloud", true, 262_144, 262_144},
{"deepseek-v3.2", true, 163_840, 65_536},
{"deepseek-v3.2", false, 0, 0},
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
{"qwen3-coder:480b", true, 262_144, 65_536},
{"qwen3.5", false, 0, 0},
{"qwen3.5:cloud", true, 262_144, 32_768},
{"qwen3-coder:480b", false, 0, 0},
{"qwen3-coder:480b:cloud", true, 262_144, 65_536},
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
{"llama3.2", false, 0, 0},
{"unknown-model:cloud", false, 0, 0},

View File

@@ -1,4 +1,4 @@
package config
package launch
import (
"context"
@@ -12,6 +12,7 @@ import (
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
@@ -26,15 +27,6 @@ func (p *Pi) Run(model string, args []string) error {
return fmt.Errorf("pi is not installed, install with: npm install -g @mariozechner/pi-coding-agent")
}
// Call Edit() to ensure config is up-to-date before launch
models := []string{model}
if config, err := loadIntegration("pi"); err == nil && len(config.Models) > 0 {
models = config.Models
}
if err := p.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("pi", args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
@@ -107,7 +99,8 @@ func (p *Pi) Edit(models []string) error {
// Build new models list:
// 1. Keep user-managed models (no _launch marker) - untouched
// 2. Keep ollama-managed models (_launch marker) that are still selected
// 2. Keep ollama-managed models (_launch marker) that are still selected,
// except stale cloud entries that should be rebuilt below
// 3. Add new ollama-managed models
var newModels []any
for _, m := range existingModels {
@@ -117,7 +110,13 @@ func (p *Pi) Edit(models []string) error {
if !isPiOllamaModel(modelObj) {
newModels = append(newModels, m)
} else if selectedSet[id] {
// Ollama-managed and still selected - keep it
// Rebuild stale managed cloud entries so createConfig refreshes
// the whole entry instead of patching it in place.
if !hasContextWindow(modelObj) {
if _, ok := lookupCloudModelLimit(id); ok {
continue
}
}
newModels = append(newModels, m)
selectedSet[id] = false
}
@@ -142,7 +141,7 @@ func (p *Pi) Edit(models []string) error {
if err != nil {
return err
}
if err := writeWithBackup(configPath, configData); err != nil {
if err := fileutil.WriteWithBackup(configPath, configData); err != nil {
return err
}
@@ -160,7 +159,7 @@ func (p *Pi) Edit(models []string) error {
if err != nil {
return err
}
return writeWithBackup(settingsPath, settingsData)
return fileutil.WriteWithBackup(settingsPath, settingsData)
}
func (p *Pi) Models() []string {
@@ -170,7 +169,7 @@ func (p *Pi) Models() []string {
}
configPath := filepath.Join(home, ".pi", "agent", "models.json")
config, err := readJSONFile(configPath)
config, err := fileutil.ReadJSON(configPath)
if err != nil {
return nil
}
@@ -199,15 +198,38 @@ func isPiOllamaModel(cfg map[string]any) bool {
return false
}
func hasContextWindow(cfg map[string]any) bool {
switch v := cfg["contextWindow"].(type) {
case float64:
return v > 0
case int:
return v > 0
case int64:
return v > 0
default:
return false
}
}
// createConfig builds Pi model config with capability detection
func createConfig(ctx context.Context, client *api.Client, modelID string) map[string]any {
cfg := map[string]any{
"id": modelID,
"_launch": true,
}
if l, ok := lookupCloudModelLimit(modelID); ok {
cfg["contextWindow"] = l.Context
}
applyCloudContextFallback := func() {
if l, ok := lookupCloudModelLimit(modelID); ok {
cfg["contextWindow"] = l.Context
}
}
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID})
if err != nil {
applyCloudContextFallback()
return cfg
}
@@ -223,15 +245,21 @@ func createConfig(ctx context.Context, client *api.Client, modelID string) map[s
cfg["reasoning"] = true
}
// Extract context window from ModelInfo
// Extract context window from ModelInfo. For known cloud models, the
// pre-filled shared limit remains unless the server provides a positive value.
hasContextWindow := false
for key, val := range resp.ModelInfo {
if strings.HasSuffix(key, ".context_length") {
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
cfg["contextWindow"] = int(ctxLen)
hasContextWindow = true
}
break
}
}
if !hasContextWindow {
applyCloudContextFallback()
}
return cfg
}

View File

@@ -1,4 +1,4 @@
package config
package launch
import (
"context"
@@ -192,6 +192,48 @@ func TestPiEdit(t *testing.T) {
}
})
t.Run("rebuilds stale existing managed cloud model", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://localhost:11434/v1",
"api": "openai-completions",
"apiKey": "ollama",
"models": [
{"id": "glm-5:cloud", "_launch": true, "legacyField": "stale"}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
if err := pi.Edit([]string{"glm-5:cloud"}); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
modelEntry := modelsArray[0].(map[string]any)
if modelEntry["contextWindow"] != float64(202_752) {
t.Errorf("contextWindow = %v, want 202752", modelEntry["contextWindow"])
}
input, ok := modelEntry["input"].([]any)
if !ok || len(input) != 1 || input[0] != "text" {
t.Errorf("input = %v, want [text]", modelEntry["input"])
}
if _, ok := modelEntry["legacyField"]; ok {
t.Error("legacyField should be removed when stale managed cloud entry is rebuilt")
}
})
t.Run("replaces old models with new ones", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
@@ -798,6 +840,59 @@ func TestCreateConfig(t *testing.T) {
}
})
t.Run("cloud model falls back to hardcoded context when show fails", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"model not found"}`)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg := createConfig(context.Background(), client, "kimi-k2.5:cloud")
if cfg["contextWindow"] != 262_144 {
t.Errorf("contextWindow = %v, want 262144", cfg["contextWindow"])
}
})
t.Run("cloud model falls back to hardcoded context when show omits model info", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":[],"model_info":{}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg := createConfig(context.Background(), client, "glm-5:cloud")
if cfg["contextWindow"] != 202_752 {
t.Errorf("contextWindow = %v, want 202752", cfg["contextWindow"])
}
})
t.Run("cloud model with dash suffix falls back to hardcoded context", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"model not found"}`)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg := createConfig(context.Background(), client, "gpt-oss:120b-cloud")
if cfg["contextWindow"] != 131_072 {
t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
}
})
t.Run("skips zero context length", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {

355
cmd/launch/registry.go Normal file
View File

@@ -0,0 +1,355 @@
package launch
import (
"fmt"
"os"
"os/exec"
"slices"
"strings"
)
// IntegrationInstallSpec describes how launcher should detect and guide installation.
type IntegrationInstallSpec struct {
CheckInstalled func() bool
EnsureInstalled func() error
URL string
Command []string
}
// IntegrationSpec is the canonical registry entry for one integration.
type IntegrationSpec struct {
Name string
Runner Runner
Aliases []string
Hidden bool
Description string
Install IntegrationInstallSpec
}
// IntegrationInfo contains display information about a registered integration.
type IntegrationInfo struct {
Name string
DisplayName string
Description string
}
var launcherIntegrationOrder = []string{"opencode", "droid", "pi", "cline"}
var integrationSpecs = []*IntegrationSpec{
{
Name: "claude",
Runner: &Claude{},
Description: "Anthropic's coding tool with subagents",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := (&Claude{}).findPath()
return err == nil
},
URL: "https://code.claude.com/docs/en/quickstart",
},
},
{
Name: "cline",
Runner: &Cline{},
Description: "Autonomous coding agent with parallel execution",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := exec.LookPath("cline")
return err == nil
},
Command: []string{"npm", "install", "-g", "cline"},
},
},
{
Name: "codex",
Runner: &Codex{},
Description: "OpenAI's open-source coding agent",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := exec.LookPath("codex")
return err == nil
},
URL: "https://developers.openai.com/codex/cli/",
Command: []string{"npm", "install", "-g", "@openai/codex"},
},
},
{
Name: "droid",
Runner: &Droid{},
Description: "Factory's coding agent across terminal and IDEs",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := exec.LookPath("droid")
return err == nil
},
URL: "https://docs.factory.ai/cli/getting-started/quickstart",
},
},
{
Name: "opencode",
Runner: &OpenCode{},
Description: "Anomaly's open-source coding agent",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := exec.LookPath("opencode")
return err == nil
},
URL: "https://opencode.ai",
},
},
{
Name: "openclaw",
Runner: &Openclaw{},
Aliases: []string{"clawdbot", "moltbot"},
Description: "Personal AI with 100+ skills",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
if _, err := exec.LookPath("openclaw"); err == nil {
return true
}
if _, err := exec.LookPath("clawdbot"); err == nil {
return true
}
return false
},
EnsureInstalled: func() error {
_, err := ensureOpenclawInstalled()
return err
},
URL: "https://docs.openclaw.ai",
},
},
{
Name: "pi",
Runner: &Pi{},
Description: "Minimal AI agent toolkit with plugin support",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := exec.LookPath("pi")
return err == nil
},
Command: []string{"npm", "install", "-g", "@mariozechner/pi-coding-agent"},
},
},
}
var integrationSpecsByName map[string]*IntegrationSpec
func init() {
rebuildIntegrationSpecIndexes()
}
func hyperlink(url, text string) string {
return fmt.Sprintf("\033]8;;%s\033\\%s\033]8;;\033\\", url, text)
}
func rebuildIntegrationSpecIndexes() {
integrationSpecsByName = make(map[string]*IntegrationSpec, len(integrationSpecs))
canonical := make(map[string]bool, len(integrationSpecs))
for _, spec := range integrationSpecs {
key := strings.ToLower(spec.Name)
if key == "" {
panic("launch: integration spec missing name")
}
if canonical[key] {
panic(fmt.Sprintf("launch: duplicate integration name %q", key))
}
canonical[key] = true
integrationSpecsByName[key] = spec
}
seenAliases := make(map[string]string)
for _, spec := range integrationSpecs {
for _, alias := range spec.Aliases {
key := strings.ToLower(alias)
if key == "" {
panic(fmt.Sprintf("launch: integration %q has empty alias", spec.Name))
}
if canonical[key] {
panic(fmt.Sprintf("launch: alias %q collides with canonical integration name", key))
}
if owner, exists := seenAliases[key]; exists {
panic(fmt.Sprintf("launch: alias %q collides between %q and %q", key, owner, spec.Name))
}
seenAliases[key] = spec.Name
integrationSpecsByName[key] = spec
}
}
orderSeen := make(map[string]bool, len(launcherIntegrationOrder))
for _, name := range launcherIntegrationOrder {
key := strings.ToLower(name)
if orderSeen[key] {
panic(fmt.Sprintf("launch: duplicate launcher order entry %q", key))
}
orderSeen[key] = true
spec, ok := integrationSpecsByName[key]
if !ok {
panic(fmt.Sprintf("launch: unknown launcher order entry %q", key))
}
if spec.Name != key {
panic(fmt.Sprintf("launch: launcher order entry %q must use canonical name, not alias", key))
}
if spec.Hidden {
panic(fmt.Sprintf("launch: hidden integration %q cannot appear in launcher order", key))
}
}
}
// LookupIntegrationSpec resolves either a canonical integration name or alias to its spec.
func LookupIntegrationSpec(name string) (*IntegrationSpec, error) {
spec, ok := integrationSpecsByName[strings.ToLower(name)]
if !ok {
return nil, fmt.Errorf("unknown integration: %s", name)
}
return spec, nil
}
// LookupIntegration resolves a registry name to the canonical key and runner.
func LookupIntegration(name string) (string, Runner, error) {
spec, err := LookupIntegrationSpec(name)
if err != nil {
return "", nil, err
}
return spec.Name, spec.Runner, nil
}
// ListVisibleIntegrationSpecs returns the canonical integrations that should appear in interactive UIs.
func ListVisibleIntegrationSpecs() []IntegrationSpec {
visible := make([]IntegrationSpec, 0, len(integrationSpecs))
for _, spec := range integrationSpecs {
if spec.Hidden {
continue
}
visible = append(visible, *spec)
}
orderRank := make(map[string]int, len(launcherIntegrationOrder))
for i, name := range launcherIntegrationOrder {
orderRank[name] = i + 1
}
slices.SortFunc(visible, func(a, b IntegrationSpec) int {
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
if aRank > 0 && bRank > 0 {
return aRank - bRank
}
if aRank > 0 {
return 1
}
if bRank > 0 {
return -1
}
return strings.Compare(a.Name, b.Name)
})
return visible
}
// ListIntegrationInfos returns the registered integrations in launcher display order.
func ListIntegrationInfos() []IntegrationInfo {
visible := ListVisibleIntegrationSpecs()
infos := make([]IntegrationInfo, 0, len(visible))
for _, spec := range visible {
infos = append(infos, IntegrationInfo{
Name: spec.Name,
DisplayName: spec.Runner.String(),
Description: spec.Description,
})
}
return infos
}
// IntegrationSelectionItems returns the sorted integration items shown by launcher selection UIs.
func IntegrationSelectionItems() ([]ModelItem, error) {
visible := ListVisibleIntegrationSpecs()
if len(visible) == 0 {
return nil, fmt.Errorf("no integrations available")
}
items := make([]ModelItem, 0, len(visible))
for _, spec := range visible {
description := spec.Runner.String()
if conn, err := loadStoredIntegrationConfig(spec.Name); err == nil && len(conn.Models) > 0 {
description = fmt.Sprintf("%s (%s)", spec.Runner.String(), conn.Models[0])
}
items = append(items, ModelItem{Name: spec.Name, Description: description})
}
return items, nil
}
// IsIntegrationInstalled checks if an integration binary is installed.
func IsIntegrationInstalled(name string) bool {
integration, err := integrationFor(name)
if err != nil {
fmt.Fprintf(os.Stderr, "Ollama couldn't find integration %q, so it'll show up as not installed.\n", name)
return false
}
return integration.installed
}
// integration is resolved registry metadata used by launcher state and install checks.
// It combines immutable registry spec data with computed runtime traits.
type integration struct {
spec *IntegrationSpec
installed bool
autoInstallable bool
editor bool
installHint string
}
// integrationFor resolves an integration name into the canonical spec plus
// derived launcher/install traits used across registry and launch flows.
func integrationFor(name string) (integration, error) {
spec, err := LookupIntegrationSpec(name)
if err != nil {
return integration{}, err
}
installed := true
if spec.Install.CheckInstalled != nil {
installed = spec.Install.CheckInstalled()
}
_, editor := spec.Runner.(Editor)
hint := ""
if spec.Install.URL != "" {
hint = "Install from " + hyperlink(spec.Install.URL, spec.Install.URL)
} else if len(spec.Install.Command) > 0 {
hint = "Install with: " + strings.Join(spec.Install.Command, " ")
}
return integration{
spec: spec,
installed: installed,
autoInstallable: spec.Install.EnsureInstalled != nil,
editor: editor,
installHint: hint,
}, nil
}
// EnsureIntegrationInstalled installs auto-installable integrations when missing.
func EnsureIntegrationInstalled(name string, runner Runner) error {
integration, err := integrationFor(name)
if err != nil {
return fmt.Errorf("%s is not installed", runner)
}
if integration.installed {
return nil
}
if integration.autoInstallable {
return integration.spec.Install.EnsureInstalled()
}
switch {
case integration.spec.Install.URL != "":
return fmt.Errorf("%s is not installed, install from %s", integration.spec.Name, integration.spec.Install.URL)
case len(integration.spec.Install.Command) > 0:
return fmt.Errorf("%s is not installed, install with: %s", integration.spec.Name, strings.Join(integration.spec.Install.Command, " "))
default:
return fmt.Errorf("%s is not installed", runner)
}
}

View File

@@ -0,0 +1,21 @@
package launch
import "strings"
// OverrideIntegration replaces one registry entry's runner for tests and returns a restore function.
func OverrideIntegration(name string, runner Runner) func() {
spec, err := LookupIntegrationSpec(name)
if err != nil {
key := strings.ToLower(name)
integrationSpecsByName[key] = &IntegrationSpec{Name: key, Runner: runner}
return func() {
delete(integrationSpecsByName, key)
}
}
original := spec.Runner
spec.Runner = runner
return func() {
spec.Runner = original
}
}

View File

@@ -0,0 +1,68 @@
package launch
import (
"os"
"path/filepath"
"testing"
)
func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
tests := []struct {
name string
binary string
runner Runner
checkPath func(home string) string
}{
{
name: "droid",
binary: "droid",
runner: &Droid{},
checkPath: func(home string) string {
return filepath.Join(home, ".factory", "settings.json")
},
},
{
name: "opencode",
binary: "opencode",
runner: &OpenCode{},
checkPath: func(home string) string {
return filepath.Join(home, ".config", "opencode", "opencode.json")
},
},
{
name: "cline",
binary: "cline",
runner: &Cline{},
checkPath: func(home string) string {
return filepath.Join(home, ".cline", "data", "globalState.json")
},
},
{
name: "pi",
binary: "pi",
runner: &Pi{},
checkPath: func(home string) string {
return filepath.Join(home, ".pi", "agent", "models.json")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
home := t.TempDir()
setTestHome(t, home)
binDir := t.TempDir()
writeFakeBinary(t, binDir, tt.binary)
t.Setenv("PATH", binDir)
configPath := tt.checkPath(home)
if err := tt.runner.Run("llama3.2", nil); err != nil {
t.Fatalf("Run returned error: %v", err)
}
if _, err := os.Stat(configPath); !os.IsNotExist(err) {
t.Fatalf("expected Run to leave %s untouched, got err=%v", configPath, err)
}
})
}
}

View File

@@ -0,0 +1,103 @@
package launch
import (
"errors"
"fmt"
"os"
"golang.org/x/term"
)
// ANSI escape sequences for terminal formatting.
const (
ansiBold = "\033[1m"
ansiReset = "\033[0m"
ansiGray = "\033[37m"
ansiGreen = "\033[32m"
ansiYellow = "\033[33m"
)
// ErrCancelled is returned when the user cancels a selection.
var ErrCancelled = errors.New("cancelled")
// errCancelled is kept as an internal alias for existing call sites.
var errCancelled = ErrCancelled
// DefaultConfirmPrompt provides a TUI-based confirmation prompt.
// When set, ConfirmPrompt delegates to it instead of using raw terminal I/O.
var DefaultConfirmPrompt func(prompt string) (bool, error)
// SingleSelector is a function type for single item selection.
// current is the name of the previously selected item to highlight; empty means no pre-selection.
type SingleSelector func(title string, items []ModelItem, current string) (string, error)
// MultiSelector is a function type for multi item selection.
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
// DefaultSingleSelector is the default single-select implementation.
var DefaultSingleSelector SingleSelector
// DefaultMultiSelector is the default multi-select implementation.
var DefaultMultiSelector MultiSelector
// DefaultSignIn provides a TUI-based sign-in flow.
// When set, ensureAuth uses it instead of plain text prompts.
// Returns the signed-in username or an error.
var DefaultSignIn func(modelName, signInURL string) (string, error)
type launchConfirmPolicy struct {
yes bool
requireYesMessage bool
}
var currentLaunchConfirmPolicy launchConfirmPolicy
func withLaunchConfirmPolicy(policy launchConfirmPolicy) func() {
old := currentLaunchConfirmPolicy
currentLaunchConfirmPolicy = policy
return func() {
currentLaunchConfirmPolicy = old
}
}
// ConfirmPrompt is the shared confirmation gate for launch flows (integration
// edits, missing-model pulls, sign-in prompts, OpenClaw install/security, etc).
// Behavior is controlled by currentLaunchConfirmPolicy, typically scoped by
// withLaunchConfirmPolicy in LaunchCmd (e.g. auto-approve with --yes).
func ConfirmPrompt(prompt string) (bool, error) {
if currentLaunchConfirmPolicy.yes {
return true, nil
}
if currentLaunchConfirmPolicy.requireYesMessage {
return false, fmt.Errorf("%s requires confirmation; re-run with --yes to continue", prompt)
}
if DefaultConfirmPrompt != nil {
return DefaultConfirmPrompt(prompt)
}
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
return false, err
}
defer term.Restore(fd, oldState)
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
buf := make([]byte, 1)
for {
if _, err := os.Stdin.Read(buf); err != nil {
return false, err
}
switch buf[0] {
case 'Y', 'y', 13:
fmt.Fprintf(os.Stderr, "yes\r\n")
return true, nil
case 'N', 'n', 27, 3:
fmt.Fprintf(os.Stderr, "no\r\n")
return false, nil
}
}
}

View File

@@ -0,0 +1,76 @@
package launch
import (
"strings"
"testing"
)
func TestErrCancelled(t *testing.T) {
t.Run("NotNil", func(t *testing.T) {
if errCancelled == nil {
t.Error("errCancelled should not be nil")
}
})
t.Run("Message", func(t *testing.T) {
if errCancelled.Error() != "cancelled" {
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
}
})
}
func TestWithLaunchConfirmPolicy_ScopesAndRestores(t *testing.T) {
oldPolicy := currentLaunchConfirmPolicy
oldHook := DefaultConfirmPrompt
t.Cleanup(func() {
currentLaunchConfirmPolicy = oldPolicy
DefaultConfirmPrompt = oldHook
})
currentLaunchConfirmPolicy = launchConfirmPolicy{}
var hookCalls int
DefaultConfirmPrompt = func(prompt string) (bool, error) {
hookCalls++
return true, nil
}
restoreOuter := withLaunchConfirmPolicy(launchConfirmPolicy{requireYesMessage: true})
restoreInner := withLaunchConfirmPolicy(launchConfirmPolicy{yes: true})
ok, err := ConfirmPrompt("test prompt")
if err != nil {
t.Fatalf("expected --yes policy to allow prompt, got error: %v", err)
}
if !ok {
t.Fatal("expected --yes policy to auto-accept prompt")
}
if hookCalls != 0 {
t.Fatalf("expected --yes to skip hook, got %d hook calls", hookCalls)
}
restoreInner()
_, err = ConfirmPrompt("test prompt")
if err == nil {
t.Fatal("expected requireYesMessage policy to block prompt")
}
if !strings.Contains(err.Error(), "re-run with --yes") {
t.Fatalf("expected actionable --yes error, got: %v", err)
}
if hookCalls != 0 {
t.Fatalf("expected blocking policy to skip hook, got %d hook calls", hookCalls)
}
restoreOuter()
ok, err = ConfirmPrompt("test prompt")
if err != nil {
t.Fatalf("expected restored default behavior to use hook, got error: %v", err)
}
if !ok {
t.Fatal("expected hook to return true")
}
if hookCalls != 1 {
t.Fatalf("expected one hook call after restore, got %d", hookCalls)
}
}

View File

@@ -0,0 +1,82 @@
package launch
import (
"strings"
"testing"
"github.com/ollama/ollama/cmd/config"
)
var (
integrations map[string]Runner
integrationAliases map[string]bool
integrationOrder = launcherIntegrationOrder
)
func init() {
integrations = buildTestIntegrations()
integrationAliases = buildTestIntegrationAliases()
}
func buildTestIntegrations() map[string]Runner {
result := make(map[string]Runner, len(integrationSpecsByName))
for name, spec := range integrationSpecsByName {
result[strings.ToLower(name)] = spec.Runner
}
return result
}
func buildTestIntegrationAliases() map[string]bool {
result := make(map[string]bool)
for _, spec := range integrationSpecs {
for _, alias := range spec.Aliases {
result[strings.ToLower(alias)] = true
}
}
return result
}
func setTestHome(t *testing.T, dir string) {
t.Helper()
setLaunchTestHome(t, dir)
}
func SaveIntegration(appName string, models []string) error {
return config.SaveIntegration(appName, models)
}
func LoadIntegration(appName string) (*config.IntegrationConfig, error) {
return config.LoadIntegration(appName)
}
func SaveAliases(appName string, aliases map[string]string) error {
return config.SaveAliases(appName, aliases)
}
func LastModel() string {
return config.LastModel()
}
func SetLastModel(model string) error {
return config.SetLastModel(model)
}
func LastSelection() string {
return config.LastSelection()
}
func SetLastSelection(selection string) error {
return config.SetLastSelection(selection)
}
func IntegrationModel(appName string) string {
return config.IntegrationModel(appName)
}
func IntegrationModels(appName string) []string {
return config.IntegrationModels(appName)
}
func integrationOnboarded(appName string) error {
return config.MarkIntegrationOnboarded(appName)
}

View File

@@ -7,7 +7,7 @@ import (
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/cmd/launch"
)
var (
@@ -64,8 +64,8 @@ type SelectItem struct {
Recommended bool
}
// ConvertItems converts config.ModelItem slice to SelectItem slice.
func ConvertItems(items []config.ModelItem) []SelectItem {
// ConvertItems converts launch.ModelItem slice to SelectItem slice.
func ConvertItems(items []launch.ModelItem) []SelectItem {
out := make([]SelectItem, len(items))
for i, item := range items {
out[i] = SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
@@ -101,6 +101,16 @@ type selectorModel struct {
width int
}
func selectorModelWithCurrent(title string, items []SelectItem, current string) selectorModel {
m := selectorModel{
title: title,
items: items,
cursor: cursorForCurrent(items, current),
}
m.updateScroll(m.otherStart())
return m
}
func (m selectorModel) filteredItems() []SelectItem {
if m.filter == "" {
return m.items
@@ -367,13 +377,24 @@ func (m selectorModel) View() string {
// cursorForCurrent returns the item index matching current, or 0 if not found.
func cursorForCurrent(items []SelectItem, current string) int {
if current != "" {
for i, item := range items {
if item.Name == current || strings.HasPrefix(item.Name, current+":") || strings.HasPrefix(current, item.Name+":") {
return i
}
if current == "" {
return 0
}
// Prefer exact name matches before tag-prefix fallback so "qwen3.5" does not
// incorrectly select "qwen3.5:cloud" (and vice versa) based on list order.
for i, item := range items {
if item.Name == current {
return i
}
}
for i, item := range items {
if strings.HasPrefix(item.Name, current+":") || strings.HasPrefix(current, item.Name+":") {
return i
}
}
return 0
}
@@ -382,11 +403,7 @@ func SelectSingle(title string, items []SelectItem, current string) (string, err
return "", fmt.Errorf("no items to select from")
}
m := selectorModel{
title: title,
items: items,
cursor: cursorForCurrent(items, current),
}
m := selectorModelWithCurrent(title, items, current)
p := tea.NewProgram(m)
finalModel, err := p.Run()
@@ -415,6 +432,12 @@ type multiSelectorModel struct {
cancelled bool
confirmed bool
width int
// multi enables full multi-select editing mode. The zero value (false)
// shows a single-select picker where Enter adds the chosen model to
// the existing list. Tab toggles between modes.
multi bool
singleAdd string // model picked in single mode
}
func newMultiSelectorModel(title string, items []SelectItem, preChecked []string) multiSelectorModel {
@@ -429,13 +452,23 @@ func newMultiSelectorModel(title string, items []SelectItem, preChecked []string
m.itemIndex[item.Name] = i
}
for _, name := range preChecked {
if idx, ok := m.itemIndex[name]; ok {
// Reverse order so preChecked[0] (the current default) ends up last
// in checkOrder, matching the "last checked = default" convention.
for i := len(preChecked) - 1; i >= 0; i-- {
if idx, ok := m.itemIndex[preChecked[i]]; ok {
m.checked[idx] = true
m.checkOrder = append(m.checkOrder, idx)
}
}
// Position cursor on the current default model
if len(preChecked) > 0 {
if idx, ok := m.itemIndex[preChecked[0]]; ok {
m.cursor = idx
m.updateScroll(m.otherStart())
}
}
return m
}
@@ -507,6 +540,7 @@ func (m *multiSelectorModel) toggleItem() {
origIdx := m.itemIndex[item.Name]
if m.checked[origIdx] {
wasDefault := len(m.checkOrder) > 0 && m.checkOrder[len(m.checkOrder)-1] == origIdx
delete(m.checked, origIdx)
for i, idx := range m.checkOrder {
if idx == origIdx {
@@ -514,6 +548,34 @@ func (m *multiSelectorModel) toggleItem() {
break
}
}
if wasDefault {
// When removing the default, pick the nearest checked model above it
// (or below if none above) so default fallback follows list order.
newDefault := -1
for i := origIdx - 1; i >= 0; i-- {
if m.checked[i] {
newDefault = i
break
}
}
if newDefault == -1 {
for i := origIdx + 1; i < len(m.items); i++ {
if m.checked[i] {
newDefault = i
break
}
}
}
if newDefault != -1 {
for i, idx := range m.checkOrder {
if idx == newDefault {
m.checkOrder = append(m.checkOrder[:i], m.checkOrder[i+1:]...)
break
}
}
m.checkOrder = append(m.checkOrder, newDefault)
}
}
} else {
m.checked[origIdx] = true
m.checkOrder = append(m.checkOrder, origIdx)
@@ -546,14 +608,25 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.cancelled = true
return m, tea.Quit
case tea.KeyTab:
m.multi = !m.multi
case tea.KeyEnter:
if len(m.checkOrder) > 0 {
if !m.multi {
if len(filtered) > 0 && m.cursor < len(filtered) {
m.singleAdd = filtered[m.cursor].Name
m.confirmed = true
return m, tea.Quit
}
} else if len(m.checkOrder) > 0 {
m.confirmed = true
return m, tea.Quit
}
case tea.KeySpace:
m.toggleItem()
if m.multi {
m.toggleItem()
}
case tea.KeyUp:
if m.cursor > 0 {
@@ -592,7 +665,9 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// On some terminals (e.g. Windows PowerShell), space arrives as
// KeyRunes instead of KeySpace. Intercept it so toggle still works.
if len(msg.Runes) == 1 && msg.Runes[0] == ' ' {
m.toggleItem()
if m.multi {
m.toggleItem()
}
} else {
m.filter += string(msg.Runes)
m.cursor = 0
@@ -604,6 +679,19 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, nil
}
func (m multiSelectorModel) renderSingleItem(s *strings.Builder, item SelectItem, idx int) {
if idx == m.cursor {
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
} else {
s.WriteString(selectorItemStyle.Render(item.Name))
}
s.WriteString("\n")
if item.Description != "" {
s.WriteString(selectorDescLineStyle.Render(item.Description))
s.WriteString("\n")
}
}
func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem, idx int) {
origIdx := m.itemIndex[item.Name]
@@ -615,7 +703,7 @@ func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem,
}
suffix := ""
if len(m.checkOrder) > 0 && m.checkOrder[0] == origIdx {
if len(m.checkOrder) > 0 && m.checkOrder[len(m.checkOrder)-1] == origIdx {
suffix = " " + selectorDefaultTagStyle.Render("(default)")
}
@@ -637,6 +725,11 @@ func (m multiSelectorModel) View() string {
return ""
}
renderItem := m.renderSingleItem
if m.multi {
renderItem = m.renderMultiItem
}
var s strings.Builder
s.WriteString(selectorTitleStyle.Render(m.title))
@@ -661,7 +754,7 @@ func (m multiSelectorModel) View() string {
if idx >= len(filtered) {
break
}
m.renderMultiItem(&s, filtered[idx], idx)
renderItem(&s, filtered[idx], idx)
}
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
@@ -684,7 +777,7 @@ func (m multiSelectorModel) View() string {
s.WriteString(sectionHeaderStyle.Render("Recommended"))
s.WriteString("\n")
for _, idx := range recItems {
m.renderMultiItem(&s, filtered[idx], idx)
renderItem(&s, filtered[idx], idx)
}
}
@@ -704,7 +797,7 @@ func (m multiSelectorModel) View() string {
if idx >= len(otherItems) {
break
}
m.renderMultiItem(&s, filtered[otherItems[idx]], otherItems[idx])
renderItem(&s, filtered[otherItems[idx]], otherItems[idx])
}
if remaining := len(otherItems) - m.scrollOffset - displayCount; remaining > 0 {
@@ -716,15 +809,18 @@ func (m multiSelectorModel) View() string {
s.WriteString("\n")
count := m.selectedCount()
if count == 0 {
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
if !m.multi {
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • tab add multiple • esc cancel"))
} else {
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
count := m.selectedCount()
if count == 0 {
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
} else {
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
}
s.WriteString("\n\n")
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • tab select single • enter confirm • esc cancel"))
}
s.WriteString("\n\n")
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • enter confirm • esc cancel"))
result := s.String()
if m.width > 0 {
@@ -747,18 +843,28 @@ func SelectMultiple(title string, items []SelectItem, preChecked []string) ([]st
}
fm := finalModel.(multiSelectorModel)
if fm.cancelled {
if fm.cancelled || !fm.confirmed {
return nil, ErrCancelled
}
if !fm.confirmed {
return nil, ErrCancelled
// Single-add mode: prepend the picked model, keep existing models deduped
if fm.singleAdd != "" {
result := []string{fm.singleAdd}
for _, name := range preChecked {
if name != fm.singleAdd {
result = append(result, name)
}
}
return result, nil
}
var result []string
// Multi-edit mode: last checked is default (first in result)
last := fm.checkOrder[len(fm.checkOrder)-1]
result := []string{fm.items[last].Name}
for _, idx := range fm.checkOrder {
result = append(result, fm.items[idx].Name)
if idx != last {
result = append(result, fm.items[idx].Name)
}
}
return result, nil
}

View File

@@ -216,6 +216,41 @@ func TestUpdateScroll(t *testing.T) {
}
}
func TestSelectorModelWithCurrent_ScrollsToCurrentInMoreSection(t *testing.T) {
m := selectorModelWithCurrent("Pick:", mixedItems(), "other-10")
if m.cursor != 11 {
t.Fatalf("cursor = %d, want 11", m.cursor)
}
if m.scrollOffset == 0 {
t.Fatal("scrollOffset should move to reveal current item in More section")
}
content := m.renderContent()
if !strings.Contains(content, "▸ other-10") {
t.Fatalf("expected current item to be visible and highlighted\n%s", content)
}
}
func TestSelectorModelWithCurrent_HighlightsExactLocalWhenCloudVariantExists(t *testing.T) {
m := selectorModelWithCurrent("Pick:", []SelectItem{
{Name: "qwen3.5:cloud", Recommended: true},
{Name: "qwen3.5", Recommended: true},
}, "qwen3.5")
if m.cursor != 1 {
t.Fatalf("cursor = %d, want 1", m.cursor)
}
content := m.renderContent()
if !strings.Contains(content, "▸ qwen3.5") {
t.Fatalf("expected local qwen3.5 to be highlighted\n%s", content)
}
if strings.Contains(content, "▸ qwen3.5:cloud") {
t.Fatalf("did not expect cloud qwen3.5:cloud to be highlighted\n%s", content)
}
}
func TestRenderContent_SectionHeaders(t *testing.T) {
m := selectorModel{
title: "Pick:",
@@ -418,6 +453,28 @@ func TestCursorForCurrent(t *testing.T) {
}
}
func TestCursorForCurrent_PrefersExactLocalOverCloudPrefix(t *testing.T) {
testItems := []SelectItem{
{Name: "qwen3.5:cloud", Recommended: true},
{Name: "qwen3.5", Recommended: true},
}
if got := cursorForCurrent(testItems, "qwen3.5"); got != 1 {
t.Errorf("cursorForCurrent(%q) = %d, want %d", "qwen3.5", got, 1)
}
}
func TestCursorForCurrent_PrefersExactCloudOverLocalPrefix(t *testing.T) {
testItems := []SelectItem{
{Name: "qwen3.5", Recommended: true},
{Name: "qwen3.5:cloud", Recommended: true},
}
if got := cursorForCurrent(testItems, "qwen3.5:cloud"); got != 1 {
t.Errorf("cursorForCurrent(%q) = %d, want %d", "qwen3.5:cloud", got, 1)
}
}
// --- ReorderItems ---
func TestReorderItems(t *testing.T) {
@@ -539,6 +596,7 @@ func TestMultiView_CursorIndicator(t *testing.T) {
func TestMultiView_CheckedItemShowsX(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), []string{"a"})
m.multi = true
content := m.View()
if !strings.Contains(content, "[x]") {
@@ -550,11 +608,18 @@ func TestMultiView_CheckedItemShowsX(t *testing.T) {
}
func TestMultiView_DefaultTag(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), []string{"a"})
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "b"})
m.multi = true
content := m.View()
if !strings.Contains(content, "(default)") {
t.Error("first checked item should have (default) tag")
t.Error("should have (default) tag")
}
// preChecked[0] ("a") should be the default (last in checkOrder)
aIdx := strings.Index(content, "a")
defaultIdx := strings.Index(content, "(default)")
if defaultIdx < aIdx {
t.Error("(default) tag should appear after 'a' (the current default)")
}
}
@@ -585,6 +650,7 @@ func TestMultiView_OverflowIndicator(t *testing.T) {
func TestMultiUpdate_SpaceTogglesItem(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
m.multi = true
m.cursor = 1
// Simulate space delivered as tea.KeySpace
@@ -601,6 +667,7 @@ func TestMultiUpdate_SpaceTogglesItem(t *testing.T) {
func TestMultiUpdate_SpaceRuneTogglesItem(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
m.multi = true
m.cursor = 1
// Simulate space delivered as tea.KeyRunes (Windows PowerShell behavior)
@@ -618,6 +685,189 @@ func TestMultiUpdate_SpaceRuneTogglesItem(t *testing.T) {
}
}
// --- Single-add mode ---
func TestMulti_StartsInSingleMode(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
if m.multi {
t.Error("should start in single mode (multi=false)")
}
}
func TestMulti_SingleModeNoCheckboxes(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
content := m.View()
if strings.Contains(content, "[x]") || strings.Contains(content, "[ ]") {
t.Error("single mode should not show checkboxes")
}
if !strings.Contains(content, "▸") {
t.Error("single mode should show cursor indicator")
}
}
func TestMulti_SingleModeEnterPicksItem(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
m.cursor = 1
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
m = updated.(multiSelectorModel)
if m.singleAdd != "b" {
t.Errorf("enter in single mode should pick cursor item, got %q", m.singleAdd)
}
if !m.confirmed {
t.Error("should set confirmed")
}
}
func TestMulti_SingleModeSpaceIsNoop(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
m.cursor = 0
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeySpace})
m = updated.(multiSelectorModel)
if len(m.checked) != 0 {
t.Error("space in single mode should not toggle items")
}
}
func TestMulti_SingleModeSpaceRuneIsNoop(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
m.cursor = 0
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{' '}})
m = updated.(multiSelectorModel)
if len(m.checked) != 0 {
t.Error("space rune in single mode should not toggle items")
}
if m.filter != "" {
t.Error("space rune in single mode should not add to filter")
}
}
func TestMulti_TabTogglesMode(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
if m.multi {
t.Fatal("should start in single mode")
}
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyTab})
m = updated.(multiSelectorModel)
if !m.multi {
t.Error("tab should switch to multi mode")
}
updated, _ = m.Update(tea.KeyMsg{Type: tea.KeyTab})
m = updated.(multiSelectorModel)
if m.multi {
t.Error("tab should switch back to single mode")
}
}
func TestMulti_SingleModeHelpText(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a"), nil)
content := m.View()
if !strings.Contains(content, "tab add multiple") {
t.Error("single mode should show 'tab add multiple' in help")
}
}
func TestMulti_MultiModeHelpText(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a"), nil)
m.multi = true
content := m.View()
if !strings.Contains(content, "tab select single") {
t.Error("multi mode should show 'tab select single' in help")
}
}
// --- preChecked initialization order ---
func TestMulti_PreCheckedDefaultIsLast(t *testing.T) {
// preChecked[0] ("a") is the current default and should end up
// last in checkOrder so it gets the (default) tag.
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "b", "c"})
if len(m.checkOrder) != 3 {
t.Fatalf("expected 3 in checkOrder, got %d", len(m.checkOrder))
}
lastIdx := m.checkOrder[len(m.checkOrder)-1]
if m.items[lastIdx].Name != "a" {
t.Errorf("preChecked[0] should be last in checkOrder, got %q", m.items[lastIdx].Name)
}
}
func TestMulti_CursorOnDefaultModel(t *testing.T) {
// preChecked[0] ("b") is the default; cursor should start on it
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"b", "c"})
if m.cursor != 1 {
t.Errorf("cursor should be on preChecked[0] ('b') at index 1, got %d", m.cursor)
}
}
// --- Multi-mode last-checked is default ---
func TestMulti_LastCheckedIsDefault(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("alpha", "beta", "gamma"), nil)
m.multi = true
// Check "alpha" then "gamma"
m.cursor = 0
m.toggleItem()
m.cursor = 2
m.toggleItem()
// Last checked ("gamma") should be at the end of checkOrder
lastIdx := m.checkOrder[len(m.checkOrder)-1]
if m.items[lastIdx].Name != "gamma" {
t.Errorf("last checked should be 'gamma', got %q", m.items[lastIdx].Name)
}
// The (default) tag renders based on checkOrder[len-1]
content := m.View()
if !strings.Contains(content, "(default)") {
t.Fatal("should show (default) tag")
}
// "alpha" line should NOT have the default tag
for _, line := range strings.Split(content, "\n") {
if strings.Contains(line, "alpha") && strings.Contains(line, "(default)") {
t.Error("'alpha' (first checked) should not have (default) tag")
}
}
}
func TestMulti_UncheckingDefaultFallsBackToNearestCheckedAbove(t *testing.T) {
// Default is "b", and checked models are "a", "b", "c".
// Unticking default should make "a" (the nearest checked item above) default.
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"b", "c", "a"})
m.multi = true
m.cursor = 1 // "b"
m.toggleItem()
lastIdx := m.checkOrder[len(m.checkOrder)-1]
if m.items[lastIdx].Name != "a" {
t.Fatalf("expected default to fall back to 'a', got %q", m.items[lastIdx].Name)
}
}
func TestMulti_UncheckingTopDefaultFallsBackToNearestCheckedBelow(t *testing.T) {
// Default is top item "a". With no checked item above, fallback should pick
// the nearest checked item below ("b").
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "c", "b"})
m.multi = true
m.cursor = 0 // "a"
m.toggleItem()
lastIdx := m.checkOrder[len(m.checkOrder)-1]
if m.items[lastIdx].Name != "b" {
t.Fatalf("expected default to fall back to 'b', got %q", m.items[lastIdx].Name)
}
}
// Key message helpers for testing
type keyType = int

View File

@@ -1,15 +1,24 @@
package tui
import (
"context"
"fmt"
"strings"
"time"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/launch"
)
type signInTickMsg struct{}
type signInCheckMsg struct {
signedIn bool
userName string
}
type signInModel struct {
modelName string
signInURL string
@@ -104,9 +113,21 @@ func renderSignIn(modelName, signInURL string, spinner, width int) string {
return lipgloss.NewStyle().PaddingLeft(2).Render(s.String())
}
func checkSignIn() tea.Msg {
client, err := api.ClientFromEnvironment()
if err != nil {
return signInCheckMsg{signedIn: false}
}
user, err := client.Whoami(context.Background())
if err == nil && user != nil && user.Name != "" {
return signInCheckMsg{signedIn: true, userName: user.Name}
}
return signInCheckMsg{signedIn: false}
}
// RunSignIn shows a bubbletea sign-in dialog and polls until the user signs in or cancels.
func RunSignIn(modelName, signInURL string) (string, error) {
config.OpenBrowser(signInURL)
launch.OpenBrowser(signInURL)
m := signInModel{
modelName: modelName,

View File

@@ -1,16 +1,11 @@
package tui
import (
"context"
"errors"
"fmt"
"strings"
"time"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/cmd/launch"
"github.com/ollama/ollama/version"
)
@@ -45,7 +40,7 @@ var (
type menuItem struct {
title string
description string
integration string // integration name for loading model config, empty if not an integration
integration string
isRunModel bool
isOthers bool
}
@@ -57,18 +52,12 @@ var mainMenuItems = []menuItem{
isRunModel: true,
},
{
title: "Launch Claude Code",
description: "Agentic coding across large codebases",
integration: "claude",
},
{
title: "Launch Codex",
description: "OpenAI's open-source coding agent",
integration: "codex",
},
{
title: "Launch OpenClaw",
description: "Personal AI with 100+ skills",
integration: "openclaw",
},
}
@@ -79,277 +68,106 @@ var othersMenuItem = menuItem{
isOthers: true,
}
// getOtherIntegrations dynamically builds the "Others" list from the integration
// registry, excluding any integrations already present in the pinned mainMenuItems.
func getOtherIntegrations() []menuItem {
pinned := map[string]bool{
"run": true, // not an integration but in the pinned list
type model struct {
state *launch.LauncherState
items []menuItem
cursor int
showOthers bool
width int
quitting bool
selected bool
action TUIAction
}
func newModel(state *launch.LauncherState) model {
m := model{
state: state,
}
m.showOthers = shouldExpandOthers(state)
m.items = buildMenuItems(state, m.showOthers)
m.cursor = initialCursor(state, m.items)
return m
}
func shouldExpandOthers(state *launch.LauncherState) bool {
if state == nil {
return false
}
for _, item := range otherIntegrationItems(state) {
if item.integration == state.LastSelection {
return true
}
}
return false
}
func buildMenuItems(state *launch.LauncherState, showOthers bool) []menuItem {
items := make([]menuItem, 0, len(mainMenuItems)+1)
for _, item := range mainMenuItems {
if item.integration != "" {
pinned[item.integration] = true
if item.integration == "" {
items = append(items, item)
continue
}
if integrationState, ok := state.Integrations[item.integration]; ok {
items = append(items, integrationMenuItem(integrationState))
}
}
var others []menuItem
for _, info := range config.ListIntegrationInfos() {
if showOthers {
items = append(items, otherIntegrationItems(state)...)
} else {
items = append(items, othersMenuItem)
}
return items
}
func integrationMenuItem(state launch.LauncherIntegrationState) menuItem {
description := state.Description
if description == "" {
description = "Open " + state.DisplayName + " integration"
}
return menuItem{
title: "Launch " + state.DisplayName,
description: description,
integration: state.Name,
}
}
func otherIntegrationItems(state *launch.LauncherState) []menuItem {
pinned := map[string]bool{
"claude": true,
"codex": true,
"openclaw": true,
}
var items []menuItem
for _, info := range launch.ListIntegrationInfos() {
if pinned[info.Name] {
continue
}
desc := info.Description
if desc == "" {
desc = "Open " + info.DisplayName + " integration"
}
others = append(others, menuItem{
title: "Launch " + info.DisplayName,
description: desc,
integration: info.Name,
})
}
return others
}
type model struct {
items []menuItem
cursor int
quitting bool
selected bool
changeModel bool
changeModels []string // multi-select result for Editor integrations
showOthers bool
availableModels map[string]bool
err error
showingModal bool
modalSelector selectorModel
modalItems []SelectItem
showingMultiModal bool
multiModalSelector multiSelectorModel
showingSignIn bool
signInURL string
signInModel string
signInSpinner int
signInFromModal bool // true if sign-in was triggered from modal (not main menu)
width int // terminal width from WindowSizeMsg
statusMsg string // temporary status message shown near help text
}
type signInTickMsg struct{}
type signInCheckMsg struct {
signedIn bool
userName string
}
type clearStatusMsg struct{}
func (m *model) modelExists(name string) bool {
if m.availableModels == nil || name == "" {
return false
}
if m.availableModels[name] {
return true
}
// Check for prefix match (e.g., "llama2" matches "llama2:latest")
for modelName := range m.availableModels {
if strings.HasPrefix(modelName, name+":") {
return true
}
}
return false
}
func (m *model) buildModalItems() []SelectItem {
modelItems, _ := config.GetModelItems(context.Background())
return ReorderItems(ConvertItems(modelItems))
}
func (m *model) openModelModal(currentModel string) {
m.modalItems = m.buildModalItems()
cursor := 0
if currentModel != "" {
for i, item := range m.modalItems {
if item.Name == currentModel || strings.HasPrefix(item.Name, currentModel+":") || strings.HasPrefix(currentModel, item.Name+":") {
cursor = i
break
}
}
}
m.modalSelector = selectorModel{
title: "Select model:",
items: m.modalItems,
cursor: cursor,
helpText: "↑/↓ navigate • enter select • ← back",
}
m.modalSelector.updateScroll(m.modalSelector.otherStart())
m.showingModal = true
}
func (m *model) openMultiModelModal(integration string) {
items := m.buildModalItems()
var preChecked []string
if models := config.IntegrationModels(integration); len(models) > 0 {
preChecked = models
}
m.multiModalSelector = newMultiSelectorModel("Select models:", items, preChecked)
// Set cursor to the first pre-checked (last used) model
if len(preChecked) > 0 {
for i, item := range items {
if item.Name == preChecked[0] {
m.multiModalSelector.cursor = i
m.multiModalSelector.updateScroll(m.multiModalSelector.otherStart())
break
}
}
}
m.showingMultiModal = true
}
func isCloudModel(name string) bool {
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
}
func cloudStatusDisabled(client *api.Client) bool {
status, err := client.CloudStatusExperimental(context.Background())
if err != nil {
return false
}
return status.Cloud.Disabled
}
func cloudModelDisabled(name string) bool {
if !isCloudModel(name) {
return false
}
client, err := api.ClientFromEnvironment()
if err != nil {
return false
}
return cloudStatusDisabled(client)
}
// checkCloudSignIn checks if a cloud model needs sign-in.
// Returns a command to start sign-in if needed, or nil if already signed in.
func (m *model) checkCloudSignIn(modelName string, fromModal bool) tea.Cmd {
if modelName == "" || !isCloudModel(modelName) {
return nil
}
client, err := api.ClientFromEnvironment()
if err != nil {
return nil
}
if cloudStatusDisabled(client) {
return nil
}
user, err := client.Whoami(context.Background())
if err == nil && user != nil && user.Name != "" {
return nil
}
var aErr api.AuthorizationError
if errors.As(err, &aErr) && aErr.SigninURL != "" {
return m.startSignIn(modelName, aErr.SigninURL, fromModal)
}
return nil
}
// startSignIn initiates the sign-in flow for a cloud model.
// fromModal indicates if this was triggered from the model picker modal.
func (m *model) startSignIn(modelName, signInURL string, fromModal bool) tea.Cmd {
m.showingModal = false
m.showingSignIn = true
m.signInURL = signInURL
m.signInModel = modelName
m.signInSpinner = 0
m.signInFromModal = fromModal
config.OpenBrowser(signInURL)
return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
return signInTickMsg{}
})
}
func checkSignIn() tea.Msg {
client, err := api.ClientFromEnvironment()
if err != nil {
return signInCheckMsg{signedIn: false}
}
user, err := client.Whoami(context.Background())
if err == nil && user != nil && user.Name != "" {
return signInCheckMsg{signedIn: true, userName: user.Name}
}
return signInCheckMsg{signedIn: false}
}
func (m *model) loadAvailableModels() {
m.availableModels = make(map[string]bool)
client, err := api.ClientFromEnvironment()
if err != nil {
return
}
models, err := client.List(context.Background())
if err != nil {
return
}
cloudDisabled := cloudStatusDisabled(client)
for _, mdl := range models.Models {
if cloudDisabled && mdl.RemoteModel != "" {
integrationState, ok := state.Integrations[info.Name]
if !ok {
continue
}
m.availableModels[mdl.Name] = true
items = append(items, integrationMenuItem(integrationState))
}
return items
}
func (m *model) buildItems() {
others := getOtherIntegrations()
m.items = make([]menuItem, 0, len(mainMenuItems)+1+len(others))
m.items = append(m.items, mainMenuItems...)
if m.showOthers {
m.items = append(m.items, others...)
} else {
m.items = append(m.items, othersMenuItem)
func initialCursor(state *launch.LauncherState, items []menuItem) int {
if state == nil || state.LastSelection == "" {
return 0
}
}
func isOthersIntegration(name string) bool {
for _, item := range getOtherIntegrations() {
if item.integration == name {
return true
for i, item := range items {
if state.LastSelection == "run" && item.isRunModel {
return i
}
if item.integration == state.LastSelection {
return i
}
}
return false
}
func initialModel() model {
m := model{
cursor: 0,
}
m.loadAvailableModels()
lastSelection := config.LastSelection()
if isOthersIntegration(lastSelection) {
m.showOthers = true
}
m.buildItems()
if lastSelection != "" {
for i, item := range m.items {
if lastSelection == "run" && item.isRunModel {
m.cursor = i
break
} else if item.integration == lastSelection {
m.cursor = i
break
}
}
}
return m
return 0
}
func (m model) Init() tea.Cmd {
@@ -357,127 +175,11 @@ func (m model) Init() tea.Cmd {
}
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if wmsg, ok := msg.(tea.WindowSizeMsg); ok {
wasSet := m.width > 0
m.width = wmsg.Width
if wasSet {
return m, tea.EnterAltScreen
}
return m, nil
}
if _, ok := msg.(clearStatusMsg); ok {
m.statusMsg = ""
return m, nil
}
if m.showingSignIn {
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc:
m.showingSignIn = false
if m.signInFromModal {
m.showingModal = true
}
return m, nil
}
case signInTickMsg:
m.signInSpinner++
// Check sign-in status every 5th tick (~1 second)
if m.signInSpinner%5 == 0 {
return m, tea.Batch(
tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
return signInTickMsg{}
}),
checkSignIn,
)
}
return m, tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
return signInTickMsg{}
})
case signInCheckMsg:
if msg.signedIn {
if m.signInFromModal {
m.modalSelector.selected = m.signInModel
m.changeModel = true
} else {
m.selected = true
}
m.quitting = true
return m, tea.Quit
}
}
return m, nil
}
if m.showingMultiModal {
switch msg := msg.(type) {
case tea.KeyMsg:
if msg.Type == tea.KeyLeft {
m.showingMultiModal = false
return m, nil
}
updated, cmd := m.multiModalSelector.Update(msg)
m.multiModalSelector = updated.(multiSelectorModel)
if m.multiModalSelector.cancelled {
m.showingMultiModal = false
return m, nil
}
if m.multiModalSelector.confirmed {
var selected []string
for _, idx := range m.multiModalSelector.checkOrder {
selected = append(selected, m.multiModalSelector.items[idx].Name)
}
if len(selected) > 0 {
m.changeModels = selected
m.changeModel = true
m.quitting = true
return m, tea.Quit
}
m.multiModalSelector.confirmed = false
return m, nil
}
return m, cmd
}
return m, nil
}
if m.showingModal {
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc, tea.KeyLeft:
m.showingModal = false
return m, nil
case tea.KeyEnter:
filtered := m.modalSelector.filteredItems()
if len(filtered) > 0 && m.modalSelector.cursor < len(filtered) {
m.modalSelector.selected = filtered[m.modalSelector.cursor].Name
}
if m.modalSelector.selected != "" {
if cmd := m.checkCloudSignIn(m.modalSelector.selected, true); cmd != nil {
return m, cmd
}
m.changeModel = true
m.quitting = true
return m, tea.Quit
}
return m, nil
default:
// Delegate navigation (up/down/pgup/pgdown/filter/backspace) to selectorModel
m.modalSelector.updateNavigation(msg)
}
}
return m, nil
}
switch msg := msg.(type) {
case tea.WindowSizeMsg:
m.width = msg.Width
return m, nil
case tea.KeyMsg:
switch msg.String() {
case "ctrl+c", "q", "esc":
@@ -488,150 +190,78 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if m.cursor > 0 {
m.cursor--
}
// Auto-collapse "Others" when cursor moves back into pinned items
if m.showOthers && m.cursor < len(mainMenuItems) {
m.showOthers = false
m.buildItems()
m.items = buildMenuItems(m.state, false)
m.cursor = min(m.cursor, len(m.items)-1)
}
return m, nil
case "down", "j":
if m.cursor < len(m.items)-1 {
m.cursor++
}
// Auto-expand "Others..." when cursor lands on it
if m.cursor < len(m.items) && m.items[m.cursor].isOthers && !m.showOthers {
m.showOthers = true
m.buildItems()
// cursor now points at the first "other" integration
m.items = buildMenuItems(m.state, true)
}
return m, nil
case "enter", " ":
item := m.items[m.cursor]
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
return m, nil
if m.selectableItem(m.items[m.cursor]) {
m.selected = true
m.action = actionForMenuItem(m.items[m.cursor], false)
m.quitting = true
return m, tea.Quit
}
var configuredModel string
if item.isRunModel {
configuredModel = config.LastModel()
} else if item.integration != "" {
configuredModel = config.IntegrationModel(item.integration)
}
if cmd := m.checkCloudSignIn(configuredModel, false); cmd != nil {
return m, cmd
}
if configuredModel != "" && isCloudModel(configuredModel) && cloudModelDisabled(configuredModel) {
if item.integration != "" && config.IsEditorIntegration(item.integration) {
m.openMultiModelModal(item.integration)
} else {
m.openModelModal(configuredModel)
}
return m, nil
}
m.selected = true
m.quitting = true
return m, tea.Quit
return m, nil
case "right", "l":
item := m.items[m.cursor]
if item.integration != "" || item.isRunModel {
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
return m, nil
}
if item.integration != "" && config.IsEditorIntegration(item.integration) {
m.openMultiModelModal(item.integration)
} else {
var currentModel string
if item.isRunModel {
currentModel = config.LastModel()
} else if item.integration != "" {
currentModel = config.IntegrationModel(item.integration)
}
m.openModelModal(currentModel)
}
if item.isRunModel || m.changeableItem(item) {
m.selected = true
m.action = actionForMenuItem(item, true)
m.quitting = true
return m, tea.Quit
}
return m, nil
}
}
return m, nil
}
func (m model) selectableItem(item menuItem) bool {
if item.isRunModel {
return true
}
if item.integration == "" || item.isOthers {
return false
}
state, ok := m.state.Integrations[item.integration]
return ok && state.Selectable
}
func (m model) changeableItem(item menuItem) bool {
if item.integration == "" || item.isOthers {
return false
}
state, ok := m.state.Integrations[item.integration]
return ok && state.Changeable
}
func (m model) View() string {
if m.quitting {
return ""
}
if m.showingSignIn {
return m.renderSignInDialog()
}
if m.showingMultiModal {
return m.multiModalSelector.View()
}
if m.showingModal {
return m.renderModal()
}
s := selectorTitleStyle.Render("Ollama "+versionStyle.Render(version.Version)) + "\n\n"
for i, item := range m.items {
cursor := ""
style := menuItemStyle
isInstalled := true
if item.integration != "" {
isInstalled = config.IsIntegrationInstalled(item.integration)
}
if m.cursor == i {
cursor = "▸ "
if isInstalled {
style = menuSelectedItemStyle
} else {
style = greyedSelectedStyle
}
} else if !isInstalled && item.integration != "" {
style = greyedStyle
}
title := item.title
var modelSuffix string
if item.integration != "" {
if !isInstalled {
title += " " + notInstalledStyle.Render("(not installed)")
} else if m.cursor == i {
if mdl := config.IntegrationModel(item.integration); mdl != "" && m.modelExists(mdl) {
modelSuffix = " " + modelStyle.Render("("+mdl+")")
}
}
} else if item.isRunModel && m.cursor == i {
if mdl := config.LastModel(); mdl != "" && m.modelExists(mdl) {
modelSuffix = " " + modelStyle.Render("("+mdl+")")
}
}
s += style.Render(cursor+title) + modelSuffix + "\n"
desc := item.description
if !isInstalled && item.integration != "" && m.cursor == i {
if hint := config.IntegrationInstallHint(item.integration); hint != "" {
desc = hint
} else {
desc = "not installed"
}
}
s += menuDescStyle.Render(desc) + "\n\n"
s += m.renderMenuItem(i, item)
}
if m.statusMsg != "" {
s += "\n" + lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "124", Dark: "210"}).Render(m.statusMsg) + "\n"
}
s += "\n" + selectorHelpStyle.Render("↑/↓ navigate • enter launch • → change model • esc quit")
s += "\n" + selectorHelpStyle.Render("↑/↓ navigate • enter launch • → configure • esc quit")
if m.width > 0 {
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
@@ -639,80 +269,125 @@ func (m model) View() string {
return s
}
func (m model) renderModal() string {
modalStyle := lipgloss.NewStyle().
PaddingBottom(1).
PaddingRight(2)
func (m model) renderMenuItem(index int, item menuItem) string {
cursor := ""
style := menuItemStyle
title := item.title
description := item.description
modelSuffix := ""
s := modalStyle.Render(m.modalSelector.renderContent())
if m.width > 0 {
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
}
return s
}
func (m model) renderSignInDialog() string {
return renderSignIn(m.signInModel, m.signInURL, m.signInSpinner, m.width)
}
type Selection int
const (
SelectionNone Selection = iota
SelectionRunModel
SelectionChangeRunModel
SelectionIntegration // Generic integration selection
SelectionChangeIntegration // Generic change model for integration
)
type Result struct {
Selection Selection
Integration string // integration name if applicable
Model string // model name if selected from single-select modal
Models []string // models selected from multi-select modal (Editor integrations)
}
func Run() (Result, error) {
m := initialModel()
p := tea.NewProgram(m)
finalModel, err := p.Run()
if err != nil {
return Result{Selection: SelectionNone}, fmt.Errorf("error running TUI: %w", err)
}
fm := finalModel.(model)
if fm.err != nil {
return Result{Selection: SelectionNone}, fm.err
}
if !fm.selected && !fm.changeModel {
return Result{Selection: SelectionNone}, nil
}
item := fm.items[fm.cursor]
if fm.changeModel {
if item.isRunModel {
return Result{
Selection: SelectionChangeRunModel,
Model: fm.modalSelector.selected,
}, nil
}
return Result{
Selection: SelectionChangeIntegration,
Integration: item.integration,
Model: fm.modalSelector.selected,
Models: fm.changeModels,
}, nil
if m.cursor == index {
cursor = "▸ "
}
if item.isRunModel {
return Result{Selection: SelectionRunModel}, nil
if m.cursor == index && m.state.RunModel != "" {
modelSuffix = " " + modelStyle.Render("("+m.state.RunModel+")")
}
if m.cursor == index {
style = menuSelectedItemStyle
}
} else if item.isOthers {
if m.cursor == index {
style = menuSelectedItemStyle
}
} else {
integrationState := m.state.Integrations[item.integration]
if !integrationState.Selectable {
if m.cursor == index {
style = greyedSelectedStyle
} else {
style = greyedStyle
}
} else if m.cursor == index {
style = menuSelectedItemStyle
}
if m.cursor == index && integrationState.CurrentModel != "" {
modelSuffix = " " + modelStyle.Render("("+integrationState.CurrentModel+")")
}
if !integrationState.Installed {
if integrationState.AutoInstallable {
title += " " + notInstalledStyle.Render("(install)")
} else {
title += " " + notInstalledStyle.Render("(not installed)")
}
if m.cursor == index {
if integrationState.AutoInstallable {
description = "Press enter to install"
} else if integrationState.InstallHint != "" {
description = integrationState.InstallHint
} else {
description = "not installed"
}
}
}
}
return Result{
Selection: SelectionIntegration,
Integration: item.integration,
}, nil
return style.Render(cursor+title) + modelSuffix + "\n" + menuDescStyle.Render(description) + "\n\n"
}
type TUIActionKind int
const (
TUIActionNone TUIActionKind = iota
TUIActionRunModel
TUIActionLaunchIntegration
)
type TUIAction struct {
Kind TUIActionKind
Integration string
ForceConfigure bool
}
func (a TUIAction) LastSelection() string {
switch a.Kind {
case TUIActionRunModel:
return "run"
case TUIActionLaunchIntegration:
return a.Integration
default:
return ""
}
}
func (a TUIAction) RunModelRequest() launch.RunModelRequest {
return launch.RunModelRequest{ForcePicker: a.ForceConfigure}
}
func (a TUIAction) IntegrationLaunchRequest() launch.IntegrationLaunchRequest {
return launch.IntegrationLaunchRequest{
Name: a.Integration,
ForceConfigure: a.ForceConfigure,
}
}
func actionForMenuItem(item menuItem, forceConfigure bool) TUIAction {
switch {
case item.isRunModel:
return TUIAction{Kind: TUIActionRunModel, ForceConfigure: forceConfigure}
case item.integration != "":
return TUIAction{Kind: TUIActionLaunchIntegration, Integration: item.integration, ForceConfigure: forceConfigure}
default:
return TUIAction{Kind: TUIActionNone}
}
}
func RunMenu(state *launch.LauncherState) (TUIAction, error) {
menu := newModel(state)
program := tea.NewProgram(menu)
finalModel, err := program.Run()
if err != nil {
return TUIAction{Kind: TUIActionNone}, fmt.Errorf("error running TUI: %w", err)
}
finalMenu := finalModel.(model)
if !finalMenu.selected {
return TUIAction{Kind: TUIActionNone}, nil
}
return finalMenu.action, nil
}

178
cmd/tui/tui_test.go Normal file
View File

@@ -0,0 +1,178 @@
package tui
import (
"strings"
"testing"
tea "github.com/charmbracelet/bubbletea"
"github.com/ollama/ollama/cmd/launch"
)
func launcherTestState() *launch.LauncherState {
return &launch.LauncherState{
LastSelection: "run",
RunModel: "qwen3:8b",
Integrations: map[string]launch.LauncherIntegrationState{
"claude": {
Name: "claude",
DisplayName: "Claude Code",
Description: "Anthropic's coding tool with subagents",
Selectable: true,
Changeable: true,
CurrentModel: "glm-5:cloud",
},
"codex": {
Name: "codex",
DisplayName: "Codex",
Description: "OpenAI's open-source coding agent",
Selectable: true,
Changeable: true,
},
"openclaw": {
Name: "openclaw",
DisplayName: "OpenClaw",
Description: "Personal AI with 100+ skills",
Selectable: true,
Changeable: true,
AutoInstallable: true,
},
"droid": {
Name: "droid",
DisplayName: "Droid",
Description: "Factory's coding agent across terminal and IDEs",
Selectable: true,
Changeable: true,
},
"pi": {
Name: "pi",
DisplayName: "Pi",
Description: "Minimal AI agent toolkit with plugin support",
Selectable: true,
Changeable: true,
},
},
}
}
func TestMenuRendersPinnedItemsAndMore(t *testing.T) {
view := newModel(launcherTestState()).View()
for _, want := range []string{"Run a model", "Launch Claude Code", "Launch Codex", "Launch OpenClaw", "More..."} {
if !strings.Contains(view, want) {
t.Fatalf("expected menu view to contain %q\n%s", want, view)
}
}
}
func TestMenuExpandsOthersFromLastSelection(t *testing.T) {
state := launcherTestState()
state.LastSelection = "pi"
menu := newModel(state)
if !menu.showOthers {
t.Fatal("expected others section to expand when last selection is in the overflow list")
}
view := menu.View()
if !strings.Contains(view, "Launch Pi") {
t.Fatalf("expected expanded view to contain overflow integration\n%s", view)
}
if strings.Contains(view, "More...") {
t.Fatalf("expected expanded view to replace More... item\n%s", view)
}
}
func TestMenuEnterOnRunSelectsRun(t *testing.T) {
menu := newModel(launcherTestState())
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
got := updated.(model)
want := TUIAction{Kind: TUIActionRunModel}
if !got.selected || got.action != want {
t.Fatalf("expected enter on run to select run action, got selected=%v action=%v", got.selected, got.action)
}
}
func TestMenuRightOnRunSelectsChangeRun(t *testing.T) {
menu := newModel(launcherTestState())
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
got := updated.(model)
want := TUIAction{Kind: TUIActionRunModel, ForceConfigure: true}
if !got.selected || got.action != want {
t.Fatalf("expected right on run to select change-run action, got selected=%v action=%v", got.selected, got.action)
}
}
func TestMenuEnterOnIntegrationSelectsLaunch(t *testing.T) {
menu := newModel(launcherTestState())
menu.cursor = 1
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
got := updated.(model)
want := TUIAction{Kind: TUIActionLaunchIntegration, Integration: "claude"}
if !got.selected || got.action != want {
t.Fatalf("expected enter on integration to launch, got selected=%v action=%v", got.selected, got.action)
}
}
func TestMenuRightOnIntegrationSelectsConfigure(t *testing.T) {
menu := newModel(launcherTestState())
menu.cursor = 1
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
got := updated.(model)
want := TUIAction{Kind: TUIActionLaunchIntegration, Integration: "claude", ForceConfigure: true}
if !got.selected || got.action != want {
t.Fatalf("expected right on integration to configure, got selected=%v action=%v", got.selected, got.action)
}
}
func TestMenuIgnoresDisabledActions(t *testing.T) {
state := launcherTestState()
claude := state.Integrations["claude"]
claude.Selectable = false
claude.Changeable = false
state.Integrations["claude"] = claude
menu := newModel(state)
menu.cursor = 1
updatedEnter, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
if updatedEnter.(model).selected {
t.Fatal("expected non-selectable integration to ignore enter")
}
updatedRight, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
if updatedRight.(model).selected {
t.Fatal("expected non-changeable integration to ignore right")
}
}
func TestMenuShowsCurrentModelSuffixes(t *testing.T) {
menu := newModel(launcherTestState())
runView := menu.View()
if !strings.Contains(runView, "(qwen3:8b)") {
t.Fatalf("expected run row to show current model suffix\n%s", runView)
}
menu.cursor = 1
integrationView := menu.View()
if !strings.Contains(integrationView, "(glm-5:cloud)") {
t.Fatalf("expected integration row to show current model suffix\n%s", integrationView)
}
}
func TestMenuShowsInstallStatusAndHint(t *testing.T) {
state := launcherTestState()
codex := state.Integrations["codex"]
codex.Installed = false
codex.Selectable = false
codex.Changeable = false
codex.InstallHint = "Install from https://example.com/codex"
state.Integrations["codex"] = codex
menu := newModel(state)
menu.cursor = 2
view := menu.View()
if !strings.Contains(view, "(not installed)") {
t.Fatalf("expected not-installed marker\n%s", view)
}
if !strings.Contains(view, codex.InstallHint) {
t.Fatalf("expected install hint in description\n%s", view)
}
}

View File

@@ -257,10 +257,11 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
if err != nil {
return nil, nil, err
}
bts = sanitizeNonFiniteJSON(bts)
var p ModelParameters
if err := json.Unmarshal(bts, &p); err != nil {
return nil, nil, err
return nil, nil, fmt.Errorf("parse config.json: %w", err)
}
if len(p.Architectures) < 1 {
@@ -315,16 +316,20 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
conv = &glm4MoeLiteModel{}
case "GlmOcrForConditionalGeneration":
conv = &glmOcrModel{}
case "Lfm2ForCausalLM":
case "Lfm2ForCausalLM", "Lfm2MoeForCausalLM":
conv = &lfm2Model{}
case "Qwen3NextForCausalLM":
case "Lfm2VlForConditionalGeneration":
conv = &lfm2VLTextModel{}
case "Qwen3NextForCausalLM", "Qwen3_5ForConditionalGeneration", "Qwen3_5MoeForConditionalGeneration":
conv = &qwen3NextModel{}
case "NemotronHForCausalLM":
conv = &nemotronHModel{}
default:
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}
if err := json.Unmarshal(bts, conv); err != nil {
return nil, nil, err
return nil, nil, fmt.Errorf("parse config.json for %q: %w", p.Architectures[0], err)
}
if t, ok := conv.(moreParser); ok {

View File

@@ -1,6 +1,8 @@
package convert
import (
"cmp"
"fmt"
"slices"
"strings"
@@ -13,42 +15,149 @@ type lfm2Model struct {
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
IntermediateSize uint32 `json:"intermediate_size"`
BlockFFDim uint32 `json:"block_ff_dim"`
BlockMultipleOf uint32 `json:"block_multiple_of"`
BlockAutoAdjustFFDim bool `json:"block_auto_adjust_ff_dim"`
BlockFFNDimMultiplier float32 `json:"block_ffn_dim_multiplier"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
NormEps float32 `json:"norm_eps"`
ConvLCache uint32 `json:"conv_L_cache"`
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
NumExperts uint32 `json:"num_experts"`
NumLocalExperts uint32 `json:"num_local_experts"`
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
NumDenseLayers uint32 `json:"num_dense_layers"`
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
LayerTypes []string `json:"layer_types"`
TieEmbedding bool `json:"tie_embedding"`
RopeParameters struct {
RopeTheta float32 `json:"rope_theta"`
} `json:"rope_parameters"`
}
var _ ModelConverter = (*lfm2Model)(nil)
const (
defaultMaxPositionEmbeddings = uint32(128_000)
fallbackContextLength = uint32(32_768)
)
func (p *lfm2Model) isMoE() bool {
return p.ModelType == "lfm2_moe" || p.expertCount() > 0
}
func (p *lfm2Model) ropeFreqBase() float32 {
if p.RopeTheta != 0 {
return p.RopeTheta
}
return p.RopeParameters.RopeTheta
}
func (p *lfm2Model) expertCount() uint32 {
if p.NumLocalExperts > 0 {
return p.NumLocalExperts
}
return p.NumExperts
}
func (p *lfm2Model) feedForwardLength() uint32 {
ff := p.IntermediateSize
if p.BlockFFDim != 0 {
ff = p.BlockFFDim
}
if !p.BlockAutoAdjustFFDim || p.BlockMultipleOf == 0 {
return ff
}
ff = (2 * ff) / 3
// Keep default multiplier behavior consistent with llama.cpp conversion.
if p.BlockFFNDimMultiplier != 0 {
ff = uint32(float32(ff) * p.BlockFFNDimMultiplier)
}
m := p.BlockMultipleOf
return m * ((ff + m - 1) / m)
}
func (p *lfm2Model) hasKnownContextLengthFallbackSignature() bool {
return p.isMoE() &&
p.VocabSize == 65536 &&
p.HiddenSize == 2048 &&
p.NumHiddenLayers == 40 &&
p.IntermediateSize == 11776 &&
p.NumAttentionHeads == 32 &&
p.NumKeyValueHeads == 8 &&
p.NumDenseLayers == 2 &&
p.expertCount() == 64 &&
p.NumExpertsPerToken == 4 &&
p.MoEIntermediateSize == 1536
}
func (p *lfm2Model) contextLength() uint32 {
if p.MaxPositionEmbeddings == defaultMaxPositionEmbeddings && p.hasKnownContextLengthFallbackSignature() {
return fallbackContextLength
}
return p.MaxPositionEmbeddings
}
func (p *lfm2Model) KV(t *Tokenizer) KV {
architecture := "lfm2"
if p.isMoE() {
architecture = "lfm2moe"
}
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "lfm2"
kv["lfm2.vocab_size"] = p.VocabSize
kv["lfm2.block_count"] = p.NumHiddenLayers
kv["lfm2.embedding_length"] = p.HiddenSize
kv["lfm2.feed_forward_length"] = p.IntermediateSize
kv["lfm2.context_length"] = p.MaxPositionEmbeddings
kv["general.architecture"] = architecture
kv["tokenizer.ggml.pre"] = "lfm2"
kv["vocab_size"] = p.VocabSize
kv["block_count"] = p.NumHiddenLayers
kv["embedding_length"] = p.HiddenSize
kv["feed_forward_length"] = p.feedForwardLength()
kv["context_length"] = p.contextLength()
// Build per-layer KV head count array based on layer_types
// (0 = shortconv layer, non-zero = attention layer with that many KV heads)
// (0 = shortconv layer, non-zero = attention layer with that many KV heads).
//
// Dense LFM2 in HF defaults to all attention layers when layer_types is absent.
// Preserve that behavior to avoid accidentally emitting all-conv metadata.
kvHeadCounts := make([]uint32, p.NumHiddenLayers)
for i := range p.NumHiddenLayers {
if int(i) < len(p.LayerTypes) && p.LayerTypes[i] == "full_attention" {
if len(p.LayerTypes) == 0 {
for i := range p.NumHiddenLayers {
kvHeadCounts[i] = p.NumKeyValueHeads
}
} else {
for i := range p.NumHiddenLayers {
if int(i) < len(p.LayerTypes) && p.LayerTypes[i] == "full_attention" {
kvHeadCounts[i] = p.NumKeyValueHeads
}
}
}
kv["lfm2.attention.head_count"] = p.NumAttentionHeads
kv["lfm2.attention.head_count_kv"] = kvHeadCounts
kv["lfm2.attention.key_length"] = p.HiddenSize / p.NumAttentionHeads
kv["lfm2.attention.value_length"] = p.HiddenSize / p.NumAttentionHeads
kv["lfm2.attention.layer_norm_rms_epsilon"] = p.NormEps
kv["lfm2.rope.freq_base"] = p.RopeTheta
kv["lfm2.shortconv.l_cache"] = p.ConvLCache
kv["attention.head_count"] = p.NumAttentionHeads
kv["attention.head_count_kv"] = kvHeadCounts
kv["attention.key_length"] = p.HiddenSize / p.NumAttentionHeads
kv["attention.value_length"] = p.HiddenSize / p.NumAttentionHeads
kv["attention.layer_norm_rms_epsilon"] = p.NormEps
kv["shortconv.l_cache"] = p.ConvLCache
if ropeFreqBase := p.ropeFreqBase(); ropeFreqBase != 0 {
kv["rope.freq_base"] = ropeFreqBase
}
if p.isMoE() {
kv["expert_count"] = p.expertCount()
kv["expert_used_count"] = p.NumExpertsPerToken
kv["expert_feed_forward_length"] = p.MoEIntermediateSize
kv["leading_dense_block_count"] = p.NumDenseLayers
kv["expert_gating_func"] = uint32(2) // sigmoid
kv["expert_weights_scale"] = cmp.Or(p.RoutedScalingFactor, float32(1.0))
}
return kv
}
@@ -56,6 +165,30 @@ func (p *lfm2Model) KV(t *Tokenizer) KV {
func (p *lfm2Model) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
if p.isMoE() {
merges := make([]merge, 0, p.NumHiddenLayers*3)
for i := range p.NumHiddenLayers {
if i < p.NumDenseLayers {
continue
}
merges = append(merges, merge{
fmt.Sprintf("blk.%d.feed_forward.experts.*.w1.weight", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
}, merge{
fmt.Sprintf("blk.%d.feed_forward.experts.*.w2.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
}, merge{
fmt.Sprintf("blk.%d.feed_forward.experts.*.w3.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
})
}
merged, remaining := mergeTensors(ts, merges...)
out = append(out, merged...)
ts = remaining
}
for _, t := range ts {
shape := t.Shape()
@@ -80,7 +213,7 @@ func (p *lfm2Model) Tensors(ts []Tensor) []*ggml.Tensor {
func (p *lfm2Model) Replacements() []string {
return []string{
"model.embed_tokens", "token_embd",
"model.embedding_norm", "output_norm",
"model.embedding_norm", "token_embd_norm",
"model.layers", "blk",
"operator_norm", "attn_norm",
"self_attn.q_proj", "attn_q",
@@ -92,6 +225,8 @@ func (p *lfm2Model) Replacements() []string {
"conv.conv", "shortconv.conv",
"conv.in_proj", "shortconv.in_proj",
"conv.out_proj", "shortconv.out_proj",
"feed_forward.gate", "ffn_gate_inp",
"feed_forward.expert_bias", "exp_probs_b.bias",
"feed_forward.w1", "ffn_gate",
"feed_forward.w2", "ffn_down",
"feed_forward.w3", "ffn_up",

View File

@@ -0,0 +1,271 @@
package convert
import (
"io"
"slices"
"strings"
"testing"
)
type lfm2StubTensor struct {
tensorBase
}
func newLFM2StubTensor(name string, shape []uint64) *lfm2StubTensor {
return &lfm2StubTensor{
tensorBase: tensorBase{
name: name,
shape: shape,
},
}
}
func (t *lfm2StubTensor) WriteTo(io.Writer) (int64, error) {
return 0, nil
}
func (t *lfm2StubTensor) Clone() Tensor {
return &lfm2StubTensor{
tensorBase: tensorBase{
name: t.name,
shape: slices.Clone(t.shape),
},
}
}
func TestLFM2MoEKV(t *testing.T) {
var p lfm2Model
p.ModelParameters.ModelType = "lfm2_moe"
p.VocabSize = 65536
p.HiddenSize = 2048
p.NumHiddenLayers = 4
p.MaxPositionEmbeddings = 128000
p.IntermediateSize = 11776
p.NumAttentionHeads = 32
p.NumKeyValueHeads = 8
p.LayerTypes = []string{"conv", "full_attention", "conv", "full_attention"}
p.NormEps = 1e-5
p.ConvLCache = 3
p.MoEIntermediateSize = 1536
p.NumExperts = 64
p.NumExpertsPerToken = 4
p.NumDenseLayers = 2
p.RopeParameters.RopeTheta = 1_000_000
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
if got, want := kv["general.architecture"], "lfm2moe"; got != want {
t.Fatalf("general.architecture = %v, want %v", got, want)
}
if got, want := kv["tokenizer.ggml.pre"], "lfm2"; got != want {
t.Fatalf("tokenizer.ggml.pre = %v, want %v", got, want)
}
if got, want := kv["expert_count"], uint32(64); got != want {
t.Fatalf("expert_count = %v, want %v", got, want)
}
if got, want := kv["expert_used_count"], uint32(4); got != want {
t.Fatalf("expert_used_count = %v, want %v", got, want)
}
if got, want := kv["expert_feed_forward_length"], uint32(1536); got != want {
t.Fatalf("expert_feed_forward_length = %v, want %v", got, want)
}
if got, want := kv["leading_dense_block_count"], uint32(2); got != want {
t.Fatalf("leading_dense_block_count = %v, want %v", got, want)
}
if got, want := kv["expert_gating_func"], uint32(2); got != want {
t.Fatalf("expert_gating_func = %v, want %v", got, want)
}
gotHeadCounts, ok := kv["attention.head_count_kv"].([]uint32)
if !ok {
t.Fatalf("attention.head_count_kv has unexpected type %T", kv["attention.head_count_kv"])
}
wantHeadCounts := []uint32{0, 8, 0, 8}
if !slices.Equal(gotHeadCounts, wantHeadCounts) {
t.Fatalf("attention.head_count_kv = %v, want %v", gotHeadCounts, wantHeadCounts)
}
if got, want := kv["rope.freq_base"], float32(1_000_000); got != want {
t.Fatalf("rope.freq_base = %v, want %v", got, want)
}
}
func TestLFM2DenseKV(t *testing.T) {
p := lfm2Model{
ModelParameters: ModelParameters{ModelType: "lfm2", VocabSize: 32000},
HiddenSize: 1024,
NumHiddenLayers: 2,
MaxPositionEmbeddings: 32768,
IntermediateSize: 4096,
NumAttentionHeads: 16,
NumKeyValueHeads: 4,
LayerTypes: []string{"conv", "full_attention"},
NormEps: 1e-5,
ConvLCache: 3,
RopeTheta: 10000,
}
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
if got, want := kv["general.architecture"], "lfm2"; got != want {
t.Fatalf("general.architecture = %v, want %v", got, want)
}
if got, want := kv["tokenizer.ggml.pre"], "lfm2"; got != want {
t.Fatalf("tokenizer.ggml.pre = %v, want %v", got, want)
}
if _, ok := kv["expert_count"]; ok {
t.Fatalf("expert_count should not be set for dense lfm2")
}
}
func TestLFM2MoETensors(t *testing.T) {
p := lfm2Model{
ModelParameters: ModelParameters{ModelType: "lfm2_moe"},
NumHiddenLayers: 4,
NumDenseLayers: 2,
}
in := []Tensor{
newLFM2StubTensor("blk.2.feed_forward.experts.0.w1.weight", []uint64{1536, 2048}),
newLFM2StubTensor("blk.2.feed_forward.experts.1.w1.weight", []uint64{1536, 2048}),
newLFM2StubTensor("blk.2.feed_forward.experts.0.w2.weight", []uint64{2048, 1536}),
newLFM2StubTensor("blk.2.feed_forward.experts.1.w2.weight", []uint64{2048, 1536}),
newLFM2StubTensor("blk.2.feed_forward.experts.0.w3.weight", []uint64{1536, 2048}),
newLFM2StubTensor("blk.2.feed_forward.experts.1.w3.weight", []uint64{1536, 2048}),
newLFM2StubTensor("blk.0.shortconv.conv.weight", []uint64{2048, 1, 3}),
}
out := p.Tensors(in)
byName := make(map[string][]uint64, len(out))
for _, tns := range out {
byName[tns.Name] = tns.Shape
}
if got, ok := byName["blk.2.ffn_gate_exps.weight"]; !ok {
t.Fatalf("missing merged tensor blk.2.ffn_gate_exps.weight")
} else if !slices.Equal(got, []uint64{2, 1536, 2048}) {
t.Fatalf("blk.2.ffn_gate_exps.weight shape = %v, want [2 1536 2048]", got)
}
if got, ok := byName["blk.2.ffn_down_exps.weight"]; !ok {
t.Fatalf("missing merged tensor blk.2.ffn_down_exps.weight")
} else if !slices.Equal(got, []uint64{2, 2048, 1536}) {
t.Fatalf("blk.2.ffn_down_exps.weight shape = %v, want [2 2048 1536]", got)
}
if got, ok := byName["blk.2.ffn_up_exps.weight"]; !ok {
t.Fatalf("missing merged tensor blk.2.ffn_up_exps.weight")
} else if !slices.Equal(got, []uint64{2, 1536, 2048}) {
t.Fatalf("blk.2.ffn_up_exps.weight shape = %v, want [2 1536 2048]", got)
}
if got, ok := byName["blk.0.shortconv.conv.weight"]; !ok {
t.Fatalf("missing shortconv tensor")
} else if !slices.Equal(got, []uint64{2048, 3}) {
t.Fatalf("blk.0.shortconv.conv.weight shape = %v, want [2048 3]", got)
}
if _, ok := byName["blk.2.feed_forward.experts.0.w1.weight"]; ok {
t.Fatalf("unmerged expert tensor should not be present")
}
}
func TestLFM2MoEReplacements(t *testing.T) {
p := lfm2Model{}
replacer := strings.NewReplacer(p.Replacements()...)
if got, want := replacer.Replace("model.layers.2.feed_forward.expert_bias"), "blk.2.exp_probs_b.bias"; got != want {
t.Fatalf("expert bias replacement = %q, want %q", got, want)
}
if got, want := replacer.Replace("model.layers.2.feed_forward.gate.weight"), "blk.2.ffn_gate_inp.weight"; got != want {
t.Fatalf("gate replacement = %q, want %q", got, want)
}
}
func TestLFM2KVContextLengthEdgeCaseFallbackOverride(t *testing.T) {
p := lfm2Model{
ModelParameters: ModelParameters{ModelType: "lfm2_moe", VocabSize: 65536},
HiddenSize: 2048,
NumHiddenLayers: 40,
MaxPositionEmbeddings: 128000,
IntermediateSize: 11776,
NumAttentionHeads: 32,
NumKeyValueHeads: 8,
LayerTypes: make([]string, 40),
NormEps: 1e-5,
ConvLCache: 3,
MoEIntermediateSize: 1536,
NumExperts: 64,
NumExpertsPerToken: 4,
NumDenseLayers: 2,
}
for i := 0; i < len(p.LayerTypes); i++ {
p.LayerTypes[i] = "conv"
}
p.LayerTypes[2] = "full_attention"
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
if got, want := kv["context_length"], uint32(32768); got != want {
t.Fatalf("context_length = %v, want %v", got, want)
}
}
func TestLFM2KVContextLengthNoOverride(t *testing.T) {
p := lfm2Model{
ModelParameters: ModelParameters{ModelType: "lfm2_moe", VocabSize: 65536},
HiddenSize: 2048,
NumHiddenLayers: 39, // mismatch: should not trigger edge case
MaxPositionEmbeddings: 128000,
IntermediateSize: 11776,
NumAttentionHeads: 32,
NumKeyValueHeads: 8,
LayerTypes: []string{"conv", "full_attention"},
NormEps: 1e-5,
ConvLCache: 3,
MoEIntermediateSize: 1536,
NumExperts: 64,
NumExpertsPerToken: 4,
NumDenseLayers: 2,
}
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
if got, want := kv["context_length"], uint32(128000); got != want {
t.Fatalf("context_length = %v, want %v", got, want)
}
}
func TestLFM2KVFeedForwardLengthAutoAdjust(t *testing.T) {
p := lfm2Model{
ModelParameters: ModelParameters{ModelType: "lfm2", VocabSize: 65536},
HiddenSize: 2048,
NumHiddenLayers: 16,
MaxPositionEmbeddings: 128000,
IntermediateSize: 12288, // should be ignored when block_ff_dim is set
BlockFFDim: 12288,
BlockAutoAdjustFFDim: true,
BlockMultipleOf: 256,
BlockFFNDimMultiplier: 1.0,
NumAttentionHeads: 32,
NumKeyValueHeads: 8,
LayerTypes: []string{"conv", "full_attention"},
NormEps: 1e-5,
ConvLCache: 3,
}
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
if got, want := kv["feed_forward_length"], uint32(8192); got != want {
t.Fatalf("feed_forward_length = %v, want %v", got, want)
}
}

417
convert/convert_lfm2_vl.go Normal file
View File

@@ -0,0 +1,417 @@
package convert
import (
"cmp"
"encoding/json"
"errors"
"fmt"
"io/fs"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
// lfm2VLTextModel converts the language model component of LFM2 VL checkpoints.
type lfm2VLTextModel struct {
TextConfig lfm2Model `json:"text_config"`
DoImageSplitting *bool `json:"do_image_splitting"`
DownsampleFactor uint32 `json:"downsample_factor"`
EncoderPatchSize uint32 `json:"encoder_patch_size"`
ImageTokenID uint32 `json:"image_token_id"`
MaxImageTokens uint32 `json:"max_image_tokens"`
MinImageTokens uint32 `json:"min_image_tokens"`
MaxTiles uint32 `json:"max_tiles"`
MinTiles uint32 `json:"min_tiles"`
TileSize uint32 `json:"tile_size"`
MaxPixelsTolerance float32 `json:"max_pixels_tolerance"`
ProjectorUseLayernorm bool `json:"projector_use_layernorm"`
ProjectorHiddenSize uint32 `json:"projector_hidden_size"`
ProjectorHiddenAct string `json:"projector_hidden_act"`
UseImageSpecialTokens *bool `json:"use_image_special_tokens"`
UseThumbnail *bool `json:"use_thumbnail"`
VisionConfig struct {
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumChannels uint32 `json:"num_channels"`
PatchSize uint32 `json:"patch_size"`
LayerNormEpsilon float32 `json:"layer_norm_eps"`
} `json:"vision_config"`
Processor struct {
ImageProcessor struct {
DoImageSplitting *bool `json:"do_image_splitting"`
DownsampleFactor uint32 `json:"downsample_factor"`
MaxImageTokens uint32 `json:"max_image_tokens"`
MinImageTokens uint32 `json:"min_image_tokens"`
MaxTiles uint32 `json:"max_tiles"`
MinTiles uint32 `json:"min_tiles"`
MaxPixelsTol float32 `json:"max_pixels_tolerance"`
TileSize uint32 `json:"tile_size"`
UseThumbnail *bool `json:"use_thumbnail"`
ImageMean []float32 `json:"image_mean"`
ImageStd []float32 `json:"image_std"`
Size struct {
Height uint32 `json:"height"`
Width uint32 `json:"width"`
} `json:"size"`
} `json:"image_processor"`
}
}
func (p *lfm2VLTextModel) textModel() *lfm2Model {
return &p.TextConfig
}
func (p *lfm2VLTextModel) specialTokenTypes() []string {
return p.textModel().specialTokenTypes()
}
func (p *lfm2VLTextModel) parseMore(fsys fs.FS) error {
bts, err := fs.ReadFile(fsys, "processor_config.json")
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil
}
return err
}
return json.Unmarshal(bts, &p.Processor)
}
func (p *lfm2VLTextModel) visionImageSize() uint32 {
// LFM2-VL image processor operates on 512 tiles and downsamples by factor 2
// before projection. Keep a fixed square image size compatible with position
// embeddings and the simplified runtime image pipeline.
tile := cmp.Or(
p.Processor.ImageProcessor.TileSize,
p.Processor.ImageProcessor.Size.Height,
p.Processor.ImageProcessor.Size.Width,
uint32(512),
)
downsample := cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2))
if downsample == 0 {
return tile
}
return max(uint32(1), tile/downsample)
}
func (p *lfm2VLTextModel) KV(t *Tokenizer) KV {
kv := p.textModel().KV(t)
boolOr := func(defaultValue bool, values ...*bool) bool {
for _, v := range values {
if v != nil {
return *v
}
}
return defaultValue
}
kv["vision.block_count"] = cmp.Or(p.VisionConfig.NumHiddenLayers, uint32(27))
kv["vision.embedding_length"] = cmp.Or(p.VisionConfig.HiddenSize, uint32(1152))
kv["vision.feed_forward_length"] = cmp.Or(p.VisionConfig.IntermediateSize, uint32(4304))
kv["vision.attention.head_count"] = cmp.Or(p.VisionConfig.NumAttentionHeads, uint32(16))
kv["vision.attention.layer_norm_epsilon"] = cmp.Or(p.VisionConfig.LayerNormEpsilon, float32(1e-6))
kv["vision.patch_size"] = cmp.Or(p.VisionConfig.PatchSize, p.EncoderPatchSize, uint32(16))
kv["vision.num_channels"] = cmp.Or(p.VisionConfig.NumChannels, uint32(3))
kv["vision.image_size"] = p.visionImageSize()
kv["vision.projector.scale_factor"] = cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2))
kv["vision.projector.use_layernorm"] = p.ProjectorUseLayernorm
kv["vision.do_image_splitting"] = boolOr(true, p.DoImageSplitting, p.Processor.ImageProcessor.DoImageSplitting)
kv["vision.min_tiles"] = cmp.Or(p.MinTiles, p.Processor.ImageProcessor.MinTiles, uint32(2))
kv["vision.max_tiles"] = cmp.Or(p.MaxTiles, p.Processor.ImageProcessor.MaxTiles, uint32(10))
kv["vision.tile_size"] = cmp.Or(p.TileSize, p.Processor.ImageProcessor.TileSize, uint32(512))
kv["vision.min_image_tokens"] = cmp.Or(p.MinImageTokens, p.Processor.ImageProcessor.MinImageTokens, uint32(64))
kv["vision.max_image_tokens"] = cmp.Or(p.MaxImageTokens, p.Processor.ImageProcessor.MaxImageTokens, uint32(256))
kv["vision.max_pixels_tolerance"] = cmp.Or(p.MaxPixelsTolerance, p.Processor.ImageProcessor.MaxPixelsTol, float32(2.0))
kv["vision.use_thumbnail"] = boolOr(true, p.UseThumbnail, p.Processor.ImageProcessor.UseThumbnail)
kv["vision.use_image_special_tokens"] = boolOr(true, p.UseImageSpecialTokens)
kv["vision.image_mean"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageMean, []float32{0.5, 0.5, 0.5}))
kv["vision.image_std"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageStd, []float32{0.5, 0.5, 0.5}))
kv["vision.image_token_id"] = cmp.Or(p.ImageTokenID, uint32(396))
setVisionTokenID := func(k, token string) {
if t == nil || t.Vocabulary == nil {
return
}
for i, v := range t.Vocabulary.Tokens {
if v == token {
kv[k] = uint32(i)
return
}
}
}
setVisionTokenID("vision.image_start_token_id", "<|image_start|>")
setVisionTokenID("vision.image_end_token_id", "<|image_end|>")
setVisionTokenID("vision.image_thumbnail_token_id", "<|img_thumbnail|>")
return kv
}
func (p *lfm2VLTextModel) Tensors(ts []Tensor) []*ggml.Tensor {
patchSize := int(cmp.Or(p.VisionConfig.PatchSize, p.EncoderPatchSize, uint32(16)))
numChannels := int(cmp.Or(p.VisionConfig.NumChannels, uint32(3)))
for _, t := range ts {
if t.Name() == "v.patch_embd.weight" {
shape := t.Shape()
if len(shape) == 2 {
inputDim := uint64(numChannels * patchSize * patchSize)
if shape[1] == inputDim {
channels := numChannels
patch := patchSize
t.SetRepacker(func(_ string, data []float32, srcShape []uint64) ([]float32, error) {
return repackPatchEmbeddingWeight(data, srcShape, channels, patch)
})
}
}
}
}
out := p.textModel().Tensors(ts)
for _, t := range out {
if t.Name == "v.patch_embd.weight" && len(t.Shape) == 2 {
t.Shape = []uint64{t.Shape[0], uint64(numChannels), uint64(patchSize), uint64(patchSize)}
}
}
return out
}
func (p *lfm2VLTextModel) Replacements() []string {
out := make([]string, 0, 96)
addText := func(from, to string) {
out = append(out, from, to)
if strings.HasPrefix(from, "model.") {
suffix := strings.TrimPrefix(from, "model.")
out = append(out,
"model.language_model."+suffix, to,
"model.language_model.model."+suffix, to,
)
}
}
base := p.textModel().Replacements()
for i := 0; i+1 < len(base); i += 2 {
addText(base[i], base[i+1])
}
// Vision tower + multimodal projector tensors (single-file conversion).
out = append(out,
"model.vision_tower.vision_model.embeddings.patch_embedding", "v.patch_embd",
"model.vision_tower.vision_model.embeddings.position_embedding", "v.position_embd",
"model.vision_tower.vision_model.encoder.layers", "v.blk",
"model.vision_tower.vision_model.post_layernorm", "v.post_ln",
"model.multi_modal_projector.layer_norm", "mm.layer_norm",
"model.multi_modal_projector.linear_1", "mm.1",
"model.multi_modal_projector.linear_2", "mm.2",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.out_proj", "attn_out",
"layer_norm1", "ln1",
"layer_norm2", "ln2",
"mlp.fc1", "ffn_up",
"mlp.fc2", "ffn_down",
)
return out
}
// lfm2VLProjectorModel converts the vision encoder + projector component of LFM2 VL checkpoints.
type lfm2VLProjectorModel struct {
ModelParameters
DownsampleFactor uint32 `json:"downsample_factor"`
ProjectorHiddenDim uint32 `json:"projector_hidden_size"`
VisionModel struct {
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumChannels uint32 `json:"num_channels"`
PatchSize uint32 `json:"patch_size"`
LayerNormEpsilon float32 `json:"layer_norm_eps"`
ImageSize uint32 `json:"image_size"`
} `json:"vision_config"`
Processor struct {
ImageProcessor struct {
DownsampleFactor uint32 `json:"downsample_factor"`
TileSize uint32 `json:"tile_size"`
ImageMean []float32 `json:"image_mean"`
ImageStd []float32 `json:"image_std"`
Size struct {
Height uint32 `json:"height"`
Width uint32 `json:"width"`
} `json:"size"`
} `json:"image_processor"`
}
}
var (
_ ModelConverter = (*lfm2VLTextModel)(nil)
_ ModelConverter = (*lfm2VLProjectorModel)(nil)
_ moreParser = (*lfm2VLTextModel)(nil)
_ moreParser = (*lfm2VLProjectorModel)(nil)
)
func (p *lfm2VLProjectorModel) parseMore(fsys fs.FS) error {
bts, err := fs.ReadFile(fsys, "processor_config.json")
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil
}
return err
}
return json.Unmarshal(bts, &p.Processor)
}
func (p *lfm2VLProjectorModel) imageSize() uint32 {
if p.VisionModel.ImageSize > 0 {
return p.VisionModel.ImageSize
}
downsample := cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2))
baseSize := cmp.Or(
p.Processor.ImageProcessor.TileSize,
p.Processor.ImageProcessor.Size.Height,
p.Processor.ImageProcessor.Size.Width,
uint32(256),
)
if downsample == 0 {
return baseSize
}
return max(uint32(1), baseSize/downsample)
}
func (p *lfm2VLProjectorModel) KV(_ *Tokenizer) KV {
kv := KV{
"general.architecture": "clip",
"general.type": "mmproj",
"general.file_type": uint32(1),
"general.quantization_version": uint32(2),
"clip.has_vision_encoder": true,
"clip.projector_type": "lfm2",
"clip.use_gelu": true,
}
kv["clip.vision.block_count"] = cmp.Or(p.VisionModel.NumHiddenLayers, uint32(27))
kv["clip.vision.embedding_length"] = cmp.Or(p.VisionModel.HiddenSize, uint32(1152))
kv["clip.vision.feed_forward_length"] = cmp.Or(p.VisionModel.IntermediateSize, uint32(4304))
kv["clip.vision.attention.head_count"] = cmp.Or(p.VisionModel.NumAttentionHeads, uint32(16))
kv["clip.vision.attention.layer_norm_epsilon"] = cmp.Or(p.VisionModel.LayerNormEpsilon, float32(1e-6))
kv["clip.vision.patch_size"] = cmp.Or(p.VisionModel.PatchSize, uint32(16))
kv["clip.vision.image_size"] = p.imageSize()
kv["clip.vision.projection_dim"] = cmp.Or(p.ProjectorHiddenDim, uint32(2048))
kv["clip.vision.projector.scale_factor"] = cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2))
kv["clip.vision.image_mean"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageMean, []float32{0.5, 0.5, 0.5}))
kv["clip.vision.image_std"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageStd, []float32{0.5, 0.5, 0.5}))
return kv
}
func defaultFloat32Slice(v, fallback []float32) []float32 {
if len(v) > 0 {
return v
}
return fallback
}
func (p *lfm2VLProjectorModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
numChannels := cmp.Or(p.VisionModel.NumChannels, uint32(3))
patchSize := cmp.Or(p.VisionModel.PatchSize, uint32(16))
for _, t := range ts {
name := t.Name()
if !(strings.HasPrefix(name, "v.") || strings.HasPrefix(name, "mm.")) {
continue
}
shape := t.Shape()
if name == "v.patch_embd.weight" && len(shape) == 2 {
inputDim := uint64(numChannels * patchSize * patchSize)
if shape[1] == inputDim {
shape = []uint64{shape[0], uint64(numChannels), uint64(patchSize), uint64(patchSize)}
channels := int(numChannels)
patch := int(patchSize)
t.SetRepacker(func(_ string, data []float32, srcShape []uint64) ([]float32, error) {
return repackPatchEmbeddingWeight(data, srcShape, channels, patch)
})
}
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
}
return out
}
func (p *lfm2VLProjectorModel) Replacements() []string {
return []string{
"model.multi_modal_projector.linear_1", "mm.1",
"model.multi_modal_projector.linear_2", "mm.2",
"model.vision_tower.vision_model.embeddings.patch_embedding", "v.patch_embd",
"model.vision_tower.vision_model.embeddings.position_embedding", "v.position_embd",
"model.vision_tower.vision_model.encoder.layers", "v.blk",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.out_proj", "attn_out",
"layer_norm1", "ln1",
"layer_norm2", "ln2",
"mlp.fc1", "ffn_up",
"mlp.fc2", "ffn_down",
"model.vision_tower.vision_model.post_layernorm", "v.post_ln",
}
}
func repackPatchEmbeddingWeight(data []float32, srcShape []uint64, channels, patch int) ([]float32, error) {
if len(srcShape) != 2 {
return nil, fmt.Errorf("invalid patch embedding shape rank: %d", len(srcShape))
}
outDim := int(srcShape[0])
flatInputDim := int(srcShape[1])
expectedInputDim := channels * patch * patch
if flatInputDim != expectedInputDim {
return nil, fmt.Errorf("invalid patch embedding input dim: got %d, want %d", flatInputDim, expectedInputDim)
}
expectedSize := outDim * flatInputDim
if len(data) != expectedSize {
return nil, fmt.Errorf("invalid patch embedding data size: got %d, want %d", len(data), expectedSize)
}
repacked := make([]float32, len(data))
perChannel := patch * patch
for o := range outDim {
inBase := o * flatInputDim
outBase := o * flatInputDim
for y := range patch {
for x := range patch {
inPixelBase := inBase + (y*patch+x)*channels
for c := range channels {
src := inPixelBase + c
dst := outBase + c*perChannel + y*patch + x
repacked[dst] = data[src]
}
}
}
}
return repacked, nil
}

View File

@@ -0,0 +1,249 @@
package convert
import (
"slices"
"strings"
"testing"
)
func TestLFM2VLTextModelKVUsesTextConfig(t *testing.T) {
p := lfm2VLTextModel{
TextConfig: lfm2Model{
ModelParameters: ModelParameters{ModelType: "lfm2", VocabSize: 65536},
HiddenSize: 2048,
NumHiddenLayers: 16,
MaxPositionEmbeddings: 128000,
IntermediateSize: 12288,
BlockFFDim: 12288,
BlockAutoAdjustFFDim: true,
BlockMultipleOf: 256,
BlockFFNDimMultiplier: 1.0,
NumAttentionHeads: 32,
NumKeyValueHeads: 8,
LayerTypes: []string{"conv", "full_attention"},
NormEps: 1e-5,
ConvLCache: 3,
},
DownsampleFactor: 2,
VisionConfig: struct {
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumChannels uint32 `json:"num_channels"`
PatchSize uint32 `json:"patch_size"`
LayerNormEpsilon float32 `json:"layer_norm_eps"`
}{
HiddenSize: 1152,
IntermediateSize: 4304,
NumAttentionHeads: 16,
NumHiddenLayers: 27,
NumChannels: 3,
PatchSize: 16,
LayerNormEpsilon: 1e-6,
},
}
p.Processor.ImageProcessor.TileSize = 512
p.Processor.ImageProcessor.ImageMean = []float32{0.5, 0.5, 0.5}
p.Processor.ImageProcessor.ImageStd = []float32{0.5, 0.5, 0.5}
kv := p.KV(&Tokenizer{
Vocabulary: &Vocabulary{
Model: "gpt2",
Tokens: []string{"<|pad|>", "<image>", "<|image_start|>", "<|image_end|>", "<|img_thumbnail|>"},
},
})
if got, want := kv["general.architecture"], "lfm2"; got != want {
t.Fatalf("general.architecture = %v, want %v", got, want)
}
if got, want := kv["feed_forward_length"], uint32(8192); got != want {
t.Fatalf("feed_forward_length = %v, want %v", got, want)
}
if got, want := kv["vision.block_count"], uint32(27); got != want {
t.Fatalf("vision.block_count = %v, want %v", got, want)
}
if got, want := kv["vision.image_size"], uint32(256); got != want {
t.Fatalf("vision.image_size = %v, want %v", got, want)
}
if got, want := kv["vision.image_token_id"], uint32(396); got != want {
t.Fatalf("vision.image_token_id = %v, want %v", got, want)
}
if got, want := kv["vision.image_start_token_id"], uint32(2); got != want {
t.Fatalf("vision.image_start_token_id = %v, want %v", got, want)
}
if got, want := kv["vision.do_image_splitting"], true; got != want {
t.Fatalf("vision.do_image_splitting = %v, want %v", got, want)
}
if got, want := kv["vision.min_tiles"], uint32(2); got != want {
t.Fatalf("vision.min_tiles = %v, want %v", got, want)
}
if got, want := kv["vision.max_tiles"], uint32(10); got != want {
t.Fatalf("vision.max_tiles = %v, want %v", got, want)
}
if got, want := kv["vision.tile_size"], uint32(512); got != want {
t.Fatalf("vision.tile_size = %v, want %v", got, want)
}
if got, want := kv["vision.use_thumbnail"], true; got != want {
t.Fatalf("vision.use_thumbnail = %v, want %v", got, want)
}
if got, want := kv["vision.use_image_special_tokens"], true; got != want {
t.Fatalf("vision.use_image_special_tokens = %v, want %v", got, want)
}
}
func TestLFM2VLTextModelTensorsIncludeVision(t *testing.T) {
p := lfm2VLTextModel{}
p.VisionConfig.PatchSize = 16
p.VisionConfig.NumChannels = 3
input := []Tensor{
newLFM2StubTensor("model.embed_tokens.weight", []uint64{65536, 2048}),
newLFM2StubTensor("model.layers.0.ffn_norm.weight", []uint64{2048}),
newLFM2StubTensor("v.patch_embd.weight", []uint64{1152, 768}),
newLFM2StubTensor("v.blk.0.attn_q.weight", []uint64{1152, 1152}),
newLFM2StubTensor("mm.1.weight", []uint64{2048, 4608}),
}
out := p.Tensors(input)
if len(out) == 0 {
t.Fatal("expected non-empty tensor list")
}
foundPatch := false
foundVision := false
for _, tns := range out {
if tns.Name == "v.patch_embd.weight" {
foundPatch = true
if !slices.Equal(tns.Shape, []uint64{1152, 3, 16, 16}) {
t.Fatalf("v.patch_embd.weight shape = %v, want [1152 3 16 16]", tns.Shape)
}
}
if strings.HasPrefix(tns.Name, "v.") || strings.HasPrefix(tns.Name, "mm.") {
foundVision = true
}
}
if !foundPatch {
t.Fatal("expected v.patch_embd.weight in output tensors")
}
if !foundVision {
t.Fatal("expected at least one vision/projector tensor in output")
}
}
func TestLFM2VLTextModelReplacements(t *testing.T) {
p := lfm2VLTextModel{}
r := strings.NewReplacer(p.Replacements()...)
tests := []struct {
name string
in string
want string
}{
{
name: "language_model_embed_tokens",
in: "model.language_model.embed_tokens.weight",
want: "token_embd.weight",
},
{
name: "language_model_layers",
in: "model.language_model.layers.2.self_attn.q_proj.weight",
want: "blk.2.attn_q.weight",
},
{
name: "nested_language_model_prefix",
in: "model.language_model.model.embedding_norm.weight",
want: "token_embd_norm.weight",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := r.Replace(tt.in); got != tt.want {
t.Fatalf("replacement(%q) = %q, want %q", tt.in, got, tt.want)
}
})
}
}
func TestLFM2VLProjectorKV(t *testing.T) {
p := lfm2VLProjectorModel{
DownsampleFactor: 2,
ProjectorHiddenDim: 2048,
}
p.VisionModel.NumHiddenLayers = 27
p.VisionModel.HiddenSize = 1152
p.VisionModel.IntermediateSize = 4304
p.VisionModel.NumAttentionHeads = 16
p.VisionModel.PatchSize = 16
p.VisionModel.LayerNormEpsilon = 1e-6
p.Processor.ImageProcessor.TileSize = 512
p.Processor.ImageProcessor.ImageMean = []float32{0.5, 0.5, 0.5}
p.Processor.ImageProcessor.ImageStd = []float32{0.5, 0.5, 0.5}
kv := p.KV(nil)
if got, want := kv["general.architecture"], "clip"; got != want {
t.Fatalf("general.architecture = %v, want %v", got, want)
}
if got, want := kv["clip.projector_type"], "lfm2"; got != want {
t.Fatalf("clip.projector_type = %v, want %v", got, want)
}
if got, want := kv["clip.vision.image_size"], uint32(256); got != want {
t.Fatalf("clip.vision.image_size = %v, want %v", got, want)
}
}
func TestLFM2VLProjectorTensorsPatchReshape(t *testing.T) {
p := lfm2VLProjectorModel{}
p.VisionModel.NumChannels = 3
p.VisionModel.PatchSize = 16
input := []Tensor{
newLFM2StubTensor("v.patch_embd.weight", []uint64{1152, 768}),
newLFM2StubTensor("mm.1.weight", []uint64{2048, 4608}),
newLFM2StubTensor("model.embed_tokens.weight", []uint64{65536, 2048}),
}
out := p.Tensors(input)
if len(out) != 2 {
t.Fatalf("expected 2 tensors, got %d", len(out))
}
var patchShape []uint64
for _, tns := range out {
if tns.Name == "v.patch_embd.weight" {
patchShape = tns.Shape
break
}
}
if !slices.Equal(patchShape, []uint64{1152, 3, 16, 16}) {
t.Fatalf("v.patch_embd.weight shape = %v, want [1152 3 16 16]", patchShape)
}
}
func TestRepackPatchEmbeddingWeight(t *testing.T) {
data := []float32{
0, 1, // y=0,x=0
2, 3, // y=0,x=1
4, 5, // y=1,x=0
6, 7, // y=1,x=1
}
got, err := repackPatchEmbeddingWeight(data, []uint64{1, 8}, 2, 2)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
want := []float32{0, 2, 4, 6, 1, 3, 5, 7}
if !slices.Equal(got, want) {
t.Fatalf("repacked data = %v, want %v", got, want)
}
}

View File

@@ -0,0 +1,385 @@
package convert
import (
"cmp"
"encoding/json"
"fmt"
"io/fs"
"math"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type hybridPattern string
func (p *hybridPattern) UnmarshalJSON(data []byte) error {
if string(data) == "null" {
*p = ""
return nil
}
var single string
if err := json.Unmarshal(data, &single); err == nil {
*p = hybridPattern(strings.TrimSpace(single))
return nil
}
var parts []string
if err := json.Unmarshal(data, &parts); err == nil {
*p = hybridPattern(strings.Join(parts, ""))
return nil
}
return fmt.Errorf("hybrid_override_pattern must be a string or string array")
}
type nemotronHModel struct {
ModelParameters
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
HeadDim uint32 `json:"head_dim"`
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
NormEpsilon float32 `json:"norm_eps"`
RopeTheta float32 `json:"rope_theta"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
ConvKernel uint32 `json:"conv_kernel"`
SSMStateSize uint32 `json:"ssm_state_size"`
MambaNumHeads uint32 `json:"mamba_num_heads"`
MambaHeadDim uint32 `json:"mamba_head_dim"`
NGroups uint32 `json:"n_groups"`
IntermediateSize uint32 `json:"intermediate_size"`
HybridOverridePattern hybridPattern `json:"hybrid_override_pattern"`
// MoE
NumExperts uint32 `json:"num_experts"`
NumSharedExperts uint32 `json:"num_shared_experts"`
NRoutedExperts uint32 `json:"n_routed_experts"`
NSharedExperts uint32 `json:"n_shared_experts"`
NumExpertsPerTok uint32 `json:"num_experts_per_tok"`
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
MoESharedExpertIntermediate uint32 `json:"moe_shared_expert_intermediate_size"`
NormTopKProb bool `json:"norm_topk_prob"`
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
ExpertGroupCount uint32 `json:"n_group"`
ExpertGroupUsedCount uint32 `json:"topk_group"`
}
var _ ModelConverter = (*nemotronHModel)(nil)
func (n *nemotronHModel) parseMore(_ fs.FS) error {
if n.NumHiddenLayers == 0 {
return fmt.Errorf("nemotron_h: num_hidden_layers must be set")
}
if n.HiddenSize == 0 {
return fmt.Errorf("nemotron_h: hidden_size must be set")
}
if n.NumAttentionHeads == 0 {
return fmt.Errorf("nemotron_h: num_attention_heads must be set")
}
if n.HeadDim == 0 {
if n.HiddenSize%n.NumAttentionHeads != 0 {
return fmt.Errorf("nemotron_h: hidden_size (%d) must be divisible by num_attention_heads (%d)", n.HiddenSize, n.NumAttentionHeads)
}
n.HeadDim = n.HiddenSize / n.NumAttentionHeads
}
if n.NumKeyValueHeads == 0 {
n.NumKeyValueHeads = n.NumAttentionHeads
}
if n.ConvKernel == 0 {
return fmt.Errorf("nemotron_h: conv_kernel must be set")
}
if n.SSMStateSize == 0 {
return fmt.Errorf("nemotron_h: ssm_state_size must be set")
}
if n.ssmHeadCount() == 0 {
return fmt.Errorf("nemotron_h: mamba_num_heads must be set")
}
if n.MambaHeadDim == 0 {
return fmt.Errorf("nemotron_h: mamba_head_dim must be set")
}
if n.NGroups == 0 {
n.NGroups = 1
}
if _, _, err := n.layerArrays(); err != nil {
return err
}
if n.isMoE() {
if n.routedExpertCount() == 0 {
return fmt.Errorf("nemotron_h: routed expert count must be set for MoE models")
}
if n.NumExpertsPerTok == 0 {
return fmt.Errorf("nemotron_h: num_experts_per_tok must be set for MoE models")
}
if n.NumExpertsPerTok > n.routedExpertCount() {
return fmt.Errorf("nemotron_h: num_experts_per_tok (%d) cannot exceed expert_count (%d)", n.NumExpertsPerTok, n.routedExpertCount())
}
if n.moeIntermediateSize() == 0 {
return fmt.Errorf("nemotron_h: moe_intermediate_size must be set for MoE models")
}
}
return nil
}
func (n *nemotronHModel) isMoE() bool {
return cmp.Or(n.routedExpertCount(), n.NumExpertsPerTok, n.MoEIntermediateSize) > 0
}
func (n *nemotronHModel) routedExpertCount() uint32 {
return cmp.Or(n.NRoutedExperts, n.NumExperts)
}
func (n *nemotronHModel) sharedExpertCount() uint32 {
return cmp.Or(n.NSharedExperts, n.NumSharedExperts)
}
func (n *nemotronHModel) ssmHeadCount() uint32 {
return n.MambaNumHeads
}
func (n *nemotronHModel) ssmInnerSize() uint32 {
return n.MambaHeadDim * n.ssmHeadCount()
}
func (n *nemotronHModel) epsilon() float32 {
return cmp.Or(n.NormEpsilon, n.LayerNormEpsilon, float32(1e-5))
}
func (n *nemotronHModel) moeIntermediateSize() uint32 {
return cmp.Or(n.MoEIntermediateSize, n.IntermediateSize)
}
func (n *nemotronHModel) denseIntermediateSize() uint32 {
return cmp.Or(n.IntermediateSize, n.MoEIntermediateSize)
}
func (n *nemotronHModel) layerArrays() (headCountKV []uint32, ffnLengths []uint32, err error) {
pattern := strings.TrimSpace(string(n.HybridOverridePattern))
if pattern == "" {
return nil, nil, fmt.Errorf("nemotron_h: hybrid_override_pattern must be set")
}
runes := []rune(pattern)
if len(runes) != int(n.NumHiddenLayers) {
return nil, nil, fmt.Errorf("nemotron_h: hybrid_override_pattern length (%d) must match num_hidden_layers (%d)", len(runes), n.NumHiddenLayers)
}
headCountKV = make([]uint32, n.NumHiddenLayers)
ffnLengths = make([]uint32, n.NumHiddenLayers)
attnKVHeads := cmp.Or(n.NumKeyValueHeads, n.NumAttentionHeads)
moeFFN := n.moeIntermediateSize()
denseFFN := n.denseIntermediateSize()
for i, layerType := range runes {
switch layerType {
case 'M':
// Recurrent layer: no KV heads and no FFN.
case '*', 'A':
// Attention-only layer.
headCountKV[i] = attnKVHeads
case 'E':
// MoE layer.
if moeFFN == 0 {
return nil, nil, fmt.Errorf("nemotron_h: moe layer at index %d but moe_intermediate_size is zero", i)
}
ffnLengths[i] = moeFFN
case '-':
// Dense FFN layer.
if denseFFN == 0 {
return nil, nil, fmt.Errorf("nemotron_h: dense FFN layer at index %d but intermediate_size is zero", i)
}
ffnLengths[i] = denseFFN
default:
return nil, nil, fmt.Errorf("nemotron_h: unsupported layer type %q in hybrid_override_pattern at index %d", layerType, i)
}
}
return headCountKV, ffnLengths, nil
}
func (n *nemotronHModel) KV(t *Tokenizer) KV {
kv := n.ModelParameters.KV(t)
arch := "nemotron_h"
if n.isMoE() {
arch = "nemotron_h_moe"
}
kv["general.architecture"] = arch
kv["block_count"] = n.NumHiddenLayers
kv["context_length"] = n.MaxPositionEmbeddings
kv["embedding_length"] = n.HiddenSize
kv["attention.head_count"] = n.NumAttentionHeads
kv["attention.key_length"] = n.HeadDim
kv["attention.value_length"] = n.HeadDim
kv["attention.layer_norm_epsilon"] = n.epsilon()
kv["attention.layer_norm_rms_epsilon"] = n.epsilon()
kv["rope.freq_base"] = cmp.Or(n.RopeTheta, float32(10000))
if n.PartialRotaryFactor > 0 && n.PartialRotaryFactor <= 1 {
kv["rope.dimension_count"] = uint32(float32(n.HeadDim) * n.PartialRotaryFactor)
}
if headCountKV, ffnLengths, err := n.layerArrays(); err == nil {
kv["attention.head_count_kv"] = headCountKV
kv["feed_forward_length"] = ffnLengths
}
kv["ssm.conv_kernel"] = n.ConvKernel
kv["ssm.inner_size"] = n.ssmInnerSize()
kv["ssm.state_size"] = n.SSMStateSize
kv["ssm.group_count"] = n.NGroups
kv["ssm.time_step_rank"] = n.ssmHeadCount()
if n.isMoE() {
kv["expert_count"] = n.routedExpertCount()
kv["expert_used_count"] = n.NumExpertsPerTok
kv["expert_feed_forward_length"] = n.moeIntermediateSize()
if n.sharedExpertCount() > 0 {
kv["expert_shared_count"] = n.sharedExpertCount()
}
if n.MoESharedExpertIntermediate > 0 {
kv["expert_shared_feed_forward_length"] = n.MoESharedExpertIntermediate
}
kv["expert_weights_norm"] = n.NormTopKProb
kv["expert_weights_scale"] = n.RoutedScalingFactor
if n.ExpertGroupCount > 0 {
kv["expert_group_count"] = n.ExpertGroupCount
}
if n.ExpertGroupUsedCount > 0 {
kv["expert_group_used_count"] = n.ExpertGroupUsedCount
}
}
return kv
}
func normalizeVectorShapeToColumn(shape []uint64) []uint64 {
switch len(shape) {
case 1:
return []uint64{shape[0], 1}
case 2:
if shape[0] == 1 && shape[1] > 1 {
return []uint64{shape[1], 1}
}
if shape[1] == 1 && shape[0] > 1 {
return []uint64{shape[0], 1}
}
}
return slices.Clone(shape)
}
func (n *nemotronHModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
remaining := ts
if n.isMoE() {
merges := make([]merge, 0, n.NumHiddenLayers*2)
for i := range n.NumHiddenLayers {
merges = append(merges, merge{
fmt.Sprintf("blk.%d.mixer.experts.*.up_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}, merge{
fmt.Sprintf("blk.%d.mixer.experts.*.down_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
})
}
merged, rest := mergeTensors(ts, merges...)
out = append(out, merged...)
remaining = rest
}
nGroups := uint64(cmp.Or(n.NGroups, uint32(1)))
for _, t := range remaining {
name := t.Name()
shape := slices.Clone(t.Shape())
switch {
case strings.HasSuffix(name, ".ssm_a"):
shape = normalizeVectorShapeToColumn(shape)
t.SetRepacker(func(_ string, data []float32, _ []uint64) ([]float32, error) {
out := make([]float32, len(data))
for i, v := range data {
out[i] = -float32(math.Exp(float64(v)))
}
return out, nil
})
case strings.HasSuffix(name, ".ssm_d"):
shape = normalizeVectorShapeToColumn(shape)
case strings.HasSuffix(name, ".ssm_norm.weight"):
switch len(shape) {
case 1:
if nGroups > 0 && shape[0]%nGroups == 0 {
shape = []uint64{nGroups, shape[0] / nGroups}
}
case 2:
if shape[0] == 1 && nGroups > 0 && shape[1]%nGroups == 0 {
shape = []uint64{nGroups, shape[1] / nGroups}
}
}
case strings.HasSuffix(name, ".ssm_conv1d.weight"):
if len(shape) == 3 {
if shape[0] == 1 {
shape = []uint64{shape[1], shape[2]}
} else if shape[1] == 1 {
shape = []uint64{shape[0], shape[2]}
}
}
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: shape,
WriterTo: t,
})
}
return out
}
func (n *nemotronHModel) Replacements() []string {
return []string{
// Embedding and output
"lm_head", "output",
"backbone.embeddings", "token_embd",
"backbone.norm_f", "output_norm",
"backbone.layers", "blk",
// Recurrent (Mamba2) tensors
"mixer.in_proj", "ssm_in",
"mixer.out_proj", "ssm_out",
"mixer.dt_bias", "ssm_dt.bias",
"mixer.A_log", "ssm_a",
"mixer.D", "ssm_d",
"mixer.conv1d", "ssm_conv1d",
"mixer.norm.weight", "ssm_norm.weight",
// Attention tensors
"mixer.q_proj", "attn_q",
"mixer.k_proj", "attn_k",
"mixer.v_proj", "attn_v",
"mixer.o_proj", "attn_output",
// FFN / MoE tensors
"mixer.gate.e_score_correction_bias", "exp_probs_b.bias",
"mixer.gate", "ffn_gate_inp",
"mixer.fc1_latent_proj", "ffn_latent_in",
"mixer.fc2_latent_proj", "ffn_latent_out",
"mixer.shared_experts.up_proj", "ffn_up_shexp",
"mixer.shared_experts.down_proj", "ffn_down_shexp",
"mixer.up_proj", "ffn_up",
"mixer.down_proj", "ffn_down",
// Per-layer pre-norm
".norm.weight", ".attn_norm.weight",
}
}

View File

@@ -0,0 +1,230 @@
package convert
import (
"bytes"
"encoding/binary"
"encoding/json"
"io"
"os"
"path/filepath"
"slices"
"strings"
"testing"
)
func TestHybridPatternUnmarshal(t *testing.T) {
t.Run("string", func(t *testing.T) {
var p hybridPattern
if err := json.Unmarshal([]byte(`"MEM*"`), &p); err != nil {
t.Fatal(err)
}
if got, want := string(p), "MEM*"; got != want {
t.Fatalf("unexpected pattern: got %q want %q", got, want)
}
})
t.Run("array", func(t *testing.T) {
var p hybridPattern
if err := json.Unmarshal([]byte(`["M","E","M","*"]`), &p); err != nil {
t.Fatal(err)
}
if got, want := string(p), "MEM*"; got != want {
t.Fatalf("unexpected pattern: got %q want %q", got, want)
}
})
}
func TestNemotronHLayerArrays(t *testing.T) {
m := &nemotronHModel{
NumHiddenLayers: 5,
NumAttentionHeads: 32,
NumKeyValueHeads: 8,
HybridOverridePattern: "MEM*E",
NRoutedExperts: 128,
NumExpertsPerTok: 6,
MoEIntermediateSize: 1856,
}
headsKV, ffn, err := m.layerArrays()
if err != nil {
t.Fatal(err)
}
if got, want := headsKV, []uint32{0, 0, 0, 8, 0}; !slices.Equal(got, want) {
t.Fatalf("unexpected head_count_kv: got %v want %v", got, want)
}
if got, want := ffn, []uint32{0, 1856, 0, 0, 1856}; !slices.Equal(got, want) {
t.Fatalf("unexpected feed_forward_length: got %v want %v", got, want)
}
}
func TestNemotronHKV(t *testing.T) {
m := &nemotronHModel{
MaxPositionEmbeddings: 1048576,
HiddenSize: 2688,
NumHiddenLayers: 5,
NumAttentionHeads: 32,
NumKeyValueHeads: 2,
HeadDim: 128,
LayerNormEpsilon: 1e-5,
RopeTheta: 10000,
PartialRotaryFactor: 0.5,
ConvKernel: 4,
SSMStateSize: 128,
MambaNumHeads: 64,
MambaHeadDim: 64,
NGroups: 8,
HybridOverridePattern: "MEM*E",
NRoutedExperts: 128,
NSharedExperts: 1,
NumExpertsPerTok: 6,
MoEIntermediateSize: 1856,
MoESharedExpertIntermediate: 3712,
NormTopKProb: true,
RoutedScalingFactor: 2.5,
}
if err := m.parseMore(nil); err != nil {
t.Fatal(err)
}
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
if got, want := kv["general.architecture"], "nemotron_h_moe"; got != want {
t.Fatalf("unexpected architecture: got %v want %v", got, want)
}
headCountKV, ok := kv["attention.head_count_kv"].([]uint32)
if !ok {
t.Fatalf("attention.head_count_kv has unexpected type: %T", kv["attention.head_count_kv"])
}
if got, want := headCountKV, []uint32{0, 0, 0, 2, 0}; !slices.Equal(got, want) {
t.Fatalf("unexpected attention.head_count_kv: got %v want %v", got, want)
}
ffnLength, ok := kv["feed_forward_length"].([]uint32)
if !ok {
t.Fatalf("feed_forward_length has unexpected type: %T", kv["feed_forward_length"])
}
if got, want := ffnLength, []uint32{0, 1856, 0, 0, 1856}; !slices.Equal(got, want) {
t.Fatalf("unexpected feed_forward_length: got %v want %v", got, want)
}
}
func TestNemotronHTensorsTransforms(t *testing.T) {
m := &nemotronHModel{NGroups: 8}
in := []Tensor{
&fakeTensor{
name: "blk.0.ssm_a",
shape: []uint64{4},
data: []float32{0, 1, 2, 3},
},
&fakeTensor{
name: "blk.0.ssm_d",
shape: []uint64{4},
data: []float32{0, 1, 2, 3},
},
&fakeTensor{
name: "blk.0.ssm_norm.weight",
shape: []uint64{16},
data: make([]float32, 16),
},
&fakeTensor{
name: "blk.0.ssm_conv1d.weight",
shape: []uint64{10, 1, 4},
data: make([]float32, 40),
},
}
out := m.Tensors(in)
if len(out) != len(in) {
t.Fatalf("unexpected output tensor count: got %d want %d", len(out), len(in))
}
got := map[string]struct {
shape []uint64
writer io.WriterTo
}{}
for _, t := range out {
got[t.Name] = struct {
shape []uint64
writer io.WriterTo
}{shape: t.Shape, writer: t.WriterTo}
}
if shape := got["blk.0.ssm_a"].shape; !slices.Equal(shape, []uint64{4, 1}) {
t.Fatalf("unexpected ssm_a shape: %v", shape)
}
if shape := got["blk.0.ssm_d"].shape; !slices.Equal(shape, []uint64{4, 1}) {
t.Fatalf("unexpected ssm_d shape: %v", shape)
}
if shape := got["blk.0.ssm_norm.weight"].shape; !slices.Equal(shape, []uint64{8, 2}) {
t.Fatalf("unexpected ssm_norm shape: %v", shape)
}
if shape := got["blk.0.ssm_conv1d.weight"].shape; !slices.Equal(shape, []uint64{10, 4}) {
t.Fatalf("unexpected ssm_conv1d shape: %v", shape)
}
var b bytes.Buffer
if _, err := got["blk.0.ssm_a"].writer.WriteTo(&b); err != nil {
t.Fatal(err)
}
values := make([]float32, 4)
if err := binary.Read(&b, binary.LittleEndian, &values); err != nil {
t.Fatal(err)
}
// 0 -> -exp(0) == -1
if values[0] != -1 {
t.Fatalf("unexpected transformed ssm_a[0]: got %v want -1", values[0])
}
}
func TestNemotronHLoadModelMetadata(t *testing.T) {
tempDir := t.TempDir()
config := `{
"architectures": ["NemotronHForCausalLM"],
"model_type": "nemotron_h",
"num_hidden_layers": 4,
"hidden_size": 512,
"max_position_embeddings": 32768,
"num_attention_heads": 8,
"num_key_value_heads": 2,
"head_dim": 64,
"layer_norm_epsilon": 1e-5,
"conv_kernel": 4,
"ssm_state_size": 128,
"mamba_num_heads": 16,
"mamba_head_dim": 32,
"n_groups": 8,
"hybrid_override_pattern": "ME*M",
"n_routed_experts": 16,
"num_experts_per_tok": 4,
"moe_intermediate_size": 256
}`
if err := os.WriteFile(filepath.Join(tempDir, "config.json"), []byte(config), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tempDir, "tokenizer.json"), []byte(`{}`), 0o644); err != nil {
t.Fatal(err)
}
kv, _, err := LoadModelMetadata(os.DirFS(tempDir))
if err != nil {
t.Fatal(err)
}
if _, ok := kv.(*nemotronHModel); !ok {
t.Fatalf("unexpected converter type: %T", kv)
}
}
func TestNemotronHReplacementsLatentProjections(t *testing.T) {
m := &nemotronHModel{}
r := strings.NewReplacer(m.Replacements()...)
if got, want := r.Replace("backbone.layers.1.mixer.fc1_latent_proj.weight"), "blk.1.ffn_latent_in.weight"; got != want {
t.Fatalf("unexpected fc1 replacement: got %q want %q", got, want)
}
if got, want := r.Replace("backbone.layers.1.mixer.fc2_latent_proj.weight"), "blk.1.ffn_latent_out.weight"; got != want {
t.Fatalf("unexpected fc2 replacement: got %q want %q", got, want)
}
}

View File

@@ -1,6 +1,7 @@
package convert
import (
"encoding/json"
"fmt"
"io/fs"
"math"
@@ -13,8 +14,21 @@ import (
"github.com/ollama/ollama/fs/ggml"
)
type qwen3NextModel struct {
ModelParameters
type qwen3NextRopeScaling struct {
Type string `json:"type"`
Factor ropeFactor `json:"factor"`
MropeSection []int32 `json:"mrope_section"`
}
type qwen3NextRopeParams struct {
MRopeInterleaved bool `json:"mrope_interleaved"`
MropeSection []int32 `json:"mrope_section"`
RopeType string `json:"rope_type"`
RopeTheta float32 `json:"rope_theta"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
}
type qwen3NextTextConfig struct {
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
@@ -28,12 +42,13 @@ type qwen3NextModel struct {
// MoE config
NumExperts uint32 `json:"num_experts"`
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
NormTopkProb bool `json:"norm_topk_prob"`
NormTopkProb *bool `json:"norm_topk_prob"`
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
SharedExpertIntermSize uint32 `json:"shared_expert_intermediate_size"`
// Hybrid attention config
FullAttentionInterval uint32 `json:"full_attention_interval"`
FullAttentionInterval uint32 `json:"full_attention_interval"`
LayerTypes []string `json:"layer_types"`
// Linear attention (Gated Delta Net) config
LinearConvKernelDim uint32 `json:"linear_conv_kernel_dim"`
@@ -43,16 +58,102 @@ type qwen3NextModel struct {
LinearValueHeadDim uint32 `json:"linear_value_head_dim"`
// RoPE config
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
RopeScaling struct {
Type string `json:"type"`
Factor ropeFactor `json:"factor"`
} `json:"rope_scaling"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
RopeScaling qwen3NextRopeScaling `json:"rope_scaling"`
RopeParameters qwen3NextRopeParams `json:"rope_parameters"`
}
type qwen3NextVisionConfig struct {
Depth uint32 `json:"depth"`
HiddenSize uint32 `json:"hidden_size"`
NumHeads uint32 `json:"num_heads"`
InChannels uint32 `json:"in_channels"`
PatchSize uint32 `json:"patch_size"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
RMSNormEps float32 `json:"layer_norm_epsilon"`
RopeTheta float32 `json:"rope_theta"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
DeepstackVisualIndexes []int32 `json:"deepstack_visual_indexes"`
Size struct {
ShortestEdge uint32 `json:"shortest_edge"`
LongestEdge uint32 `json:"longest_edge"`
} `json:"size"`
ImageMean []float32 `json:"image_mean"`
ImageStd []float32 `json:"image_std"`
}
type qwen3NextModel struct {
ModelParameters
qwen3NextTextConfig
TextConfig *qwen3NextTextConfig `json:"text_config"`
VisionModel qwen3NextVisionConfig `json:"vision_config"`
ImageTokenID uint32 `json:"image_token_id"`
VisionStartTokenID uint32 `json:"vision_start_token_id"`
VisionEndTokenID uint32 `json:"vision_end_token_id"`
}
var _ ModelConverter = (*qwen3NextModel)(nil)
func (q *qwen3NextModel) parseMore(_ fs.FS) error {
func (q *qwen3NextModel) parseMore(fsys fs.FS) error {
if q.TextConfig != nil {
q.qwen3NextTextConfig = *q.TextConfig
}
if q.RopeTheta == 0 {
q.RopeTheta = q.RopeParameters.RopeTheta
}
if q.PartialRotaryFactor == 0 {
q.PartialRotaryFactor = q.RopeParameters.PartialRotaryFactor
}
if q.RopeScaling.Type == "" && q.RopeParameters.RopeType != "" {
q.RopeScaling.Type = q.RopeParameters.RopeType
}
// Pull vision preprocessing fields when present.
if q.VisionModel.Depth > 0 {
if bts, err := fs.ReadFile(fsys, "preprocessor_config.json"); err == nil {
var pre struct {
Size struct {
ShortestEdge uint32 `json:"shortest_edge"`
LongestEdge uint32 `json:"longest_edge"`
} `json:"size"`
PatchSize uint32 `json:"patch_size"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
MergeSize uint32 `json:"merge_size"`
ImageMean []float32 `json:"image_mean"`
ImageStd []float32 `json:"image_std"`
}
if json.Unmarshal(bts, &pre) == nil {
if q.VisionModel.PatchSize == 0 {
q.VisionModel.PatchSize = pre.PatchSize
}
if q.VisionModel.TemporalPatchSize == 0 {
q.VisionModel.TemporalPatchSize = pre.TemporalPatchSize
}
if q.VisionModel.SpatialMergeSize == 0 {
q.VisionModel.SpatialMergeSize = pre.MergeSize
}
if q.VisionModel.Size.ShortestEdge == 0 {
q.VisionModel.Size.ShortestEdge = pre.Size.ShortestEdge
}
if q.VisionModel.Size.LongestEdge == 0 {
q.VisionModel.Size.LongestEdge = pre.Size.LongestEdge
}
if len(q.VisionModel.ImageMean) == 0 {
q.VisionModel.ImageMean = pre.ImageMean
}
if len(q.VisionModel.ImageStd) == 0 {
q.VisionModel.ImageStd = pre.ImageStd
}
}
}
}
if q.NumHiddenLayers == 0 {
return fmt.Errorf("qwen3next: num_hidden_layers must be set")
}
@@ -74,36 +175,96 @@ func (q *qwen3NextModel) parseMore(_ fs.FS) error {
if q.LinearNumKeyHeads == 0 || q.LinearNumValueHeads == 0 || q.LinearKeyHeadDim == 0 || q.LinearValueHeadDim == 0 {
return fmt.Errorf("qwen3next: linear attention config must be set (linear_num_key_heads, linear_num_value_heads, linear_key_head_dim, linear_value_head_dim)")
}
if q.FullAttentionInterval == 0 {
return fmt.Errorf("qwen3next: full_attention_interval must be set")
}
if q.FullAttentionInterval > q.NumHiddenLayers {
return fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds num_hidden_layers (%d)", q.FullAttentionInterval, q.NumHiddenLayers)
}
hasFull := false
for i := range q.NumHiddenLayers {
if (i+1)%q.FullAttentionInterval == 0 {
hasFull = true
break
}
}
if !hasFull {
return fmt.Errorf("qwen3next: head_count_kv would be all zeros (full_attention_interval=%d, num_hidden_layers=%d)", q.FullAttentionInterval, q.NumHiddenLayers)
if _, err := q.kvHeadCounts(); err != nil {
return err
}
return nil
}
func (q *qwen3NextModel) kvHeadCounts() ([]uint32, error) {
if len(q.LayerTypes) > 0 {
kv := make([]uint32, q.NumHiddenLayers)
hasFull := false
hasRecurrent := false
for i := range q.NumHiddenLayers {
layerType := ""
if i < uint32(len(q.LayerTypes)) {
layerType = q.LayerTypes[i]
}
if layerType == "full_attention" {
kv[i] = q.NumKeyValueHeads
hasFull = true
} else {
hasRecurrent = true
}
}
if !hasFull || !hasRecurrent {
return nil, fmt.Errorf("qwen3next: layer_types must include both full_attention and linear_attention")
}
return kv, nil
}
if q.FullAttentionInterval == 0 {
return nil, fmt.Errorf("qwen3next: full_attention_interval must be set")
}
if q.FullAttentionInterval > q.NumHiddenLayers {
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds num_hidden_layers (%d)", q.FullAttentionInterval, q.NumHiddenLayers)
}
kv := make([]uint32, q.NumHiddenLayers)
hasFull := false
for i := range q.NumHiddenLayers {
if (i+1)%q.FullAttentionInterval == 0 {
kv[i] = q.NumKeyValueHeads
hasFull = true
}
}
if !hasFull {
return nil, fmt.Errorf("qwen3next: head_count_kv would be all zeros (full_attention_interval=%d, num_hidden_layers=%d)", q.FullAttentionInterval, q.NumHiddenLayers)
}
return kv, nil
}
func (q *qwen3NextModel) ropeSections() []int32 {
if len(q.RopeParameters.MropeSection) > 0 {
return q.RopeParameters.MropeSection
}
return q.RopeScaling.MropeSection
}
func (q *qwen3NextModel) shouldReorderVHeads() bool {
modelType := strings.ToLower(q.ModelType)
if strings.Contains(modelType, "qwen3_next") || strings.Contains(modelType, "qwen3next") {
return false
}
for _, arch := range q.Architectures {
arch = strings.ToLower(arch)
if strings.Contains(arch, "qwen3next") || strings.Contains(arch, "qwen3_next") {
return false
}
}
// Default to qwen3.5 layout for all other qwen3next-family imports.
return true
}
func (q *qwen3NextModel) KV(t *Tokenizer) KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen3next"
kv["tokenizer.ggml.pre"] = "qwen2"
arch := "qwen35"
if q.NumExperts > 0 {
arch = "qwen35moe"
}
kv["general.architecture"] = arch
kv["tokenizer.ggml.pre"] = "qwen35"
kv["block_count"] = q.NumHiddenLayers
kv["context_length"] = q.MaxPositionEmbeddings
kv["embedding_length"] = q.HiddenSize
kv["feed_forward_length"] = q.IntermediateSize
kv["attention.head_count"] = q.NumAttentionHeads
headDim := q.HeadDim
if headDim == 0 && q.NumAttentionHeads > 0 {
headDim = q.HiddenSize / q.NumAttentionHeads
@@ -113,18 +274,31 @@ func (q *qwen3NextModel) KV(t *Tokenizer) KV {
kv["attention.layer_norm_rms_epsilon"] = q.RMSNormEPS
kv["rope.freq_base"] = q.RopeTheta
// RoPE dimension count (partial rotary)
// partial_rotary_factor = 0.25 means only 25% of head_dim uses RoPE
partialRotary := q.PartialRotaryFactor
if partialRotary > 0 && partialRotary <= 1 {
kv["rope.dimension_count"] = uint32(float32(headDim) * partialRotary)
}
// MoE config
if sections := q.ropeSections(); len(sections) > 0 {
kv["mrope_sections"] = sections
kv["rope.mrope_section"] = sections
kv["rope.dimension_sections"] = sections
}
if q.RopeParameters.MRopeInterleaved {
kv["rope.mrope_interleaved"] = true
}
if q.RopeScaling.Type != "" && q.RopeScaling.Type != "default" {
kv["rope.scaling.type"] = q.RopeScaling.Type
kv["rope.scaling.factor"] = q.RopeScaling.Factor
}
if q.NumExperts > 0 {
kv["expert_count"] = q.NumExperts
kv["expert_used_count"] = q.NumExpertsPerToken
kv["norm_top_k_prob"] = q.NormTopkProb
if q.NormTopkProb != nil {
kv["norm_top_k_prob"] = *q.NormTopkProb
}
if q.MoEIntermediateSize > 0 {
kv["expert_feed_forward_length"] = q.MoEIntermediateSize
}
@@ -133,33 +307,66 @@ func (q *qwen3NextModel) KV(t *Tokenizer) KV {
}
}
// SSM/Linear attention config
// d_inner = linear_value_head_dim * linear_num_value_heads
dInner := q.LinearValueHeadDim * q.LinearNumValueHeads
kv["ssm.inner_size"] = dInner
kv["ssm.state_size"] = q.LinearKeyHeadDim // head_k_dim
kv["ssm.group_count"] = q.LinearNumKeyHeads // num_k_heads
kv["ssm.time_step_rank"] = q.LinearNumValueHeads // num_v_heads
kv["ssm.state_size"] = q.LinearKeyHeadDim
kv["ssm.group_count"] = q.LinearNumKeyHeads
kv["ssm.time_step_rank"] = q.LinearNumValueHeads
kv["ssm.conv_kernel"] = q.LinearConvKernelDim
interval := q.FullAttentionInterval
kv["full_attention_interval"] = interval
// Build per-layer KV head count array to identify layer types
// 0 = recurrent (linear attention), non-zero = full attention
kvHeadCounts := make([]uint32, q.NumHiddenLayers)
for i := range q.NumHiddenLayers {
// Full attention every full_attention_interval layers (starting at interval-1)
if interval > 0 && (i+1)%interval == 0 {
kvHeadCounts[i] = q.NumKeyValueHeads
}
// else stays 0 (recurrent layer)
if q.shouldReorderVHeads() {
kv["ssm.v_head_reordered"] = true
}
if q.FullAttentionInterval > 0 {
kv["full_attention_interval"] = q.FullAttentionInterval
}
kv["attention.head_count_kv"] = kvHeadCounts
// RoPE scaling
if q.RopeScaling.Type != "" {
kv["rope.scaling.type"] = q.RopeScaling.Type
kv["rope.scaling.factor"] = q.RopeScaling.Factor
if headCounts, err := q.kvHeadCounts(); err == nil {
kv["attention.head_count_kv"] = headCounts
}
if q.VisionModel.Depth > 0 {
kv["vision.block_count"] = q.VisionModel.Depth
kv["vision.embedding_length"] = q.VisionModel.HiddenSize
kv["vision.attention.head_count"] = q.VisionModel.NumHeads
kv["vision.num_channels"] = q.VisionModel.InChannels
if q.VisionModel.PatchSize > 0 {
kv["vision.patch_size"] = q.VisionModel.PatchSize
}
if q.VisionModel.SpatialMergeSize > 0 {
kv["vision.spatial_merge_size"] = q.VisionModel.SpatialMergeSize
}
if q.VisionModel.RMSNormEps > 0 {
kv["vision.attention.layer_norm_epsilon"] = q.VisionModel.RMSNormEps
}
if q.VisionModel.RopeTheta > 0 {
kv["vision.rope.freq_base"] = q.VisionModel.RopeTheta
}
if q.VisionModel.TemporalPatchSize > 0 {
kv["vision.temporal_patch_size"] = q.VisionModel.TemporalPatchSize
}
kv["vision.deepstack_visual_indexes"] = q.VisionModel.DeepstackVisualIndexes
if q.VisionModel.Size.ShortestEdge > 0 {
kv["vision.shortest_edge"] = q.VisionModel.Size.ShortestEdge
}
if q.VisionModel.Size.LongestEdge > 0 {
kv["vision.longest_edge"] = q.VisionModel.Size.LongestEdge
}
if len(q.VisionModel.ImageMean) > 0 {
kv["vision.image_mean"] = q.VisionModel.ImageMean
}
if len(q.VisionModel.ImageStd) > 0 {
kv["vision.image_std"] = q.VisionModel.ImageStd
}
}
if q.ImageTokenID > 0 {
kv["image_token_id"] = q.ImageTokenID
}
if q.VisionStartTokenID > 0 {
kv["vision_start_token_id"] = q.VisionStartTokenID
}
if q.VisionEndTokenID > 0 {
kv["vision_end_token_id"] = q.VisionEndTokenID
}
return kv
@@ -168,7 +375,6 @@ func (q *qwen3NextModel) KV(t *Tokenizer) KV {
func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
// Create merges for expert tensors - stack individual experts into batched tensors
merges := make([]merge, q.NumHiddenLayers*3)
for i := range q.NumHiddenLayers {
merges[i*3+0] = merge{
@@ -185,16 +391,13 @@ func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
}
}
// Merge expert tensors
merged, remaining := mergeTensors(ts, merges...)
out = append(out, merged...)
// Process remaining tensors
for _, t := range remaining {
name := t.Name()
shape := t.Shape()
// Split linear_attn.in_proj_qkvz (ssm_in) into attn_qkv + attn_gate when possible
if strings.HasSuffix(name, ".ssm_in.weight") {
if qkv, gate, ok := q.splitQKVZTensor(t); ok {
out = append(out, qkv, gate)
@@ -204,84 +407,299 @@ func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
}
switch {
// Add 1 to norm weights (except ssm_norm which is linear_attn.norm)
// This matches the Python converter behavior for qwen3next
case strings.Contains(name, ".mlp.experts.gate_up_proj"):
out = append(out, slices.Collect(splitDim(t, 1,
split{Replacer: strings.NewReplacer(".mlp.experts.gate_up_proj", ".ffn_gate_exps.weight")},
split{Replacer: strings.NewReplacer(".mlp.experts.gate_up_proj", ".ffn_up_exps.weight")},
))...)
case strings.Contains(name, ".mlp.experts.down_proj"):
out = append(out, &ggml.Tensor{
Name: strings.NewReplacer(".mlp.experts.down_proj", ".ffn_down_exps.weight").Replace(name),
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
case strings.HasPrefix(name, "v.blk.") && strings.Contains(name, ".attn_qkv"):
out = append(out, slices.Collect(splitDim(t, 0,
split{Replacer: strings.NewReplacer("attn_qkv", "attn_q")},
split{Replacer: strings.NewReplacer("attn_qkv", "attn_k")},
split{Replacer: strings.NewReplacer("attn_qkv", "attn_v")},
))...)
case strings.Contains(name, "patch_embed") && strings.HasSuffix(name, "weight"):
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: append([]uint64{shape[0] * shape[1]}, shape[2:]...),
WriterTo: t,
})
case strings.HasSuffix(name, "_norm.weight") && !strings.HasSuffix(name, ".ssm_norm.weight"):
t.SetRepacker(q.addOne)
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
// Handle linear attention A_log -> ssm_a (negate and exp)
// Note: name has already been transformed by Replacements at this point
case strings.HasSuffix(name, ".ssm_a"):
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
// Compute -exp(A_log)
result := make([]float32, len(data))
for i, v := range data {
// -exp(v)
result[i] = -float32(math.Exp(float64(v)))
}
return result, nil
})
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
t.SetRepacker(q.repackSSMA())
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
case strings.HasSuffix(name, ".attn_qkv.weight"):
if q.shouldReorderVHeads() {
t.SetRepacker(q.repackAttnQKV())
}
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
case strings.HasSuffix(name, ".attn_gate.weight"):
if q.shouldReorderVHeads() {
// HF tensor layout is [out_features, in_features]; reorder rows.
t.SetRepacker(q.repackReorderDim(0, int(q.LinearValueHeadDim)))
}
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
case strings.HasSuffix(name, ".ssm_beta.weight"), strings.HasSuffix(name, ".ssm_alpha.weight"):
if q.shouldReorderVHeads() {
// HF tensor layout is [out_features, in_features]; reorder rows.
t.SetRepacker(q.repackReorderDim(0, 1))
}
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
case strings.HasSuffix(name, ".ssm_dt"):
if q.shouldReorderVHeads() {
t.SetRepacker(q.repackReorderDim(0, 1))
}
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
case strings.HasSuffix(name, ".ssm_out.weight"):
if q.shouldReorderVHeads() {
// HF out_proj layout is [out_features, in_features]; reorder columns.
t.SetRepacker(q.repackReorderDim(1, int(q.LinearValueHeadDim)))
}
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
// Squeeze conv1d weights: [1, D, K] or [D, 1, K] -> [D, K]
case strings.HasSuffix(name, ".ssm_conv1d.weight"):
newShape := slices.Clone(shape)
if len(shape) == 3 {
if shape[0] == 1 {
// [1, D, K] -> [D, K]
newShape = []uint64{shape[1], shape[2]}
} else if shape[1] == 1 {
// [D, 1, K] -> [D, K]
newShape = []uint64{shape[0], shape[2]}
}
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: newShape,
WriterTo: t,
})
// Squeeze shared expert gate: [D, 1] or [1, D] -> [D]
case strings.HasSuffix(name, ".ffn_gate_inp_shexp.weight"):
newShape := slices.Clone(shape)
if len(shape) == 2 {
if shape[0] == 1 && shape[1] > 1 {
newShape = []uint64{shape[1]}
} else if shape[1] == 1 && shape[0] > 1 {
newShape = []uint64{shape[0]}
}
if q.shouldReorderVHeads() {
t.SetRepacker(q.repackConv1D())
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: newShape,
WriterTo: t,
})
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: newShape, WriterTo: t})
default:
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
}
}
return out
}
func (q *qwen3NextModel) repackReorderDim(dim, headDim int) Repacker {
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
if !q.shouldReorderVHeads() {
return data, nil
}
numK := int(q.LinearNumKeyHeads)
numVPerK := int(q.LinearNumValueHeads / q.LinearNumKeyHeads)
return reorderHeadLayout(data, shape, dim, numK, numVPerK, headDim)
}
}
func (q *qwen3NextModel) repackAttnQKV() Repacker {
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
if !q.shouldReorderVHeads() || len(shape) != 2 {
return data, nil
}
rows := int(shape[0])
cols := int(shape[1])
numK := int(q.LinearNumKeyHeads)
numV := int(q.LinearNumValueHeads)
headK := int(q.LinearKeyHeadDim)
headV := int(q.LinearValueHeadDim)
qDim := headK * numK
kDim := headK * numK
vDim := headV * numV
qkvDim := qDim + kDim + vDim
switch {
case rows == qkvDim:
// HF layout: [out_features, in_features]. Keep Q/K rows unchanged and
// reorder only V rows from grouped -> tiled head layout.
out := make([]float32, len(data))
qkRows := qDim + kDim
qkSize := qkRows * cols
copy(out[:qkSize], data[:qkSize])
vStart := qkSize
vEnd := vStart + vDim*cols
reorderedV, err := reorderHeadLayout(data[vStart:vEnd], []uint64{uint64(vDim), uint64(cols)}, 0, numK, numV/numK, headV)
if err != nil {
return nil, err
}
copy(out[vStart:vEnd], reorderedV)
copy(out[vEnd:], data[vEnd:])
return out, nil
case cols == qkvDim:
// Fallback for already-transposed [in_features, out_features] tensors.
out := make([]float32, len(data))
copy(out, data)
for r := range rows {
base := r * cols
vStart := base + qDim + kDim
vEnd := vStart + vDim
reorderedV, err := reorderHeadLayout(out[vStart:vEnd], []uint64{uint64(vDim)}, 0, numK, numV/numK, headV)
if err != nil {
return nil, err
}
copy(out[vStart:vEnd], reorderedV)
}
return out, nil
default:
return data, nil
}
}
}
func (q *qwen3NextModel) repackConv1D() Repacker {
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
if !q.shouldReorderVHeads() {
return data, nil
}
normShape := slices.Clone(shape)
if len(shape) == 3 {
if shape[0] == 1 {
normShape = []uint64{shape[1], shape[2]}
} else if shape[1] == 1 {
normShape = []uint64{shape[0], shape[2]}
}
}
if len(normShape) != 2 {
return data, nil
}
rows := int(normShape[0])
cols := int(normShape[1])
numK := int(q.LinearNumKeyHeads)
numV := int(q.LinearNumValueHeads)
headK := int(q.LinearKeyHeadDim)
headV := int(q.LinearValueHeadDim)
qkChannels := 2 * headK * numK
totalChannels := qkChannels + headV*numV
if qkChannels <= 0 {
return data, nil
}
switch {
case rows == totalChannels:
// HF layout after squeeze: [channels, kernel]
out := make([]float32, len(data))
prefix := qkChannels * cols
copy(out[:prefix], data[:prefix])
reorderedV, err := reorderHeadLayout(data[prefix:], []uint64{uint64(totalChannels - qkChannels), uint64(cols)}, 0, numK, numV/numK, headV)
if err != nil {
return nil, err
}
copy(out[prefix:], reorderedV)
return out, nil
case cols == totalChannels:
// Fallback for transposed [kernel, channels]
out := make([]float32, len(data))
copy(out, data)
vChannels := totalChannels - qkChannels
for r := range rows {
base := r * cols
vStart := base + qkChannels
vEnd := vStart + vChannels
reorderedV, err := reorderHeadLayout(out[vStart:vEnd], []uint64{uint64(vChannels)}, 0, numK, numV/numK, headV)
if err != nil {
return nil, err
}
copy(out[vStart:vEnd], reorderedV)
}
return out, nil
default:
return data, nil
}
}
}
func (q *qwen3NextModel) repackSSMA() Repacker {
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
result := make([]float32, len(data))
for i, v := range data {
result[i] = -float32(math.Exp(float64(v)))
}
if !q.shouldReorderVHeads() {
return result, nil
}
numK := int(q.LinearNumKeyHeads)
numVPerK := int(q.LinearNumValueHeads / q.LinearNumKeyHeads)
return reorderHeadLayout(result, shape, 0, numK, numVPerK, 1)
}
}
func reorderHeadLayout(data []float32, shape []uint64, dim int, numKHeads, numVPerK, headDim int) ([]float32, error) {
if len(shape) == 0 || numKHeads <= 0 || numVPerK <= 0 || headDim <= 0 {
return data, nil
}
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
if dim < 0 {
dim += len(dims)
}
if dim < 0 || dim >= len(dims) {
return data, nil
}
expected := numKHeads * numVPerK * headDim
if dims[dim] != expected {
return data, nil
}
newShape := make([]int, 0, len(dims)+2)
newShape = append(newShape, dims[:dim]...)
newShape = append(newShape, numKHeads, numVPerK, headDim)
newShape = append(newShape, dims[dim+1:]...)
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if err := tt.Reshape(newShape...); err != nil {
return nil, err
}
perm := make([]int, len(newShape))
for i := range perm {
perm[i] = i
}
perm[dim], perm[dim+1] = perm[dim+1], perm[dim]
tt, err := tensor.Transpose(tt, perm...)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
total := 1
for _, d := range dims {
total *= d
}
if err := tt.Reshape(total); err != nil {
return nil, err
}
return native.VectorF32(tt.(*tensor.Dense))
}
type qkvzSplitSpec struct {
hidden int
headKDim int
@@ -369,7 +787,6 @@ func (q *qwen3NextModel) repackQKVZ(spec qkvzSplitSpec, extractGate bool) Repack
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
var err error
// Convert to [hidden, out_features] layout for slicing
tt, err = tensor.Transpose(tt, 1, 0)
if err != nil {
return nil, err
@@ -444,7 +861,6 @@ func (q *qwen3NextModel) repackQKVZ(spec qkvzSplitSpec, extractGate bool) Repack
}
}
// addOne adds 1.0 to all elements in the tensor (for norm weights)
func (*qwen3NextModel) addOne(_ string, data []float32, shape []uint64) ([]float32, error) {
n := tensor.New(tensor.WithShape(int(shape[0])), tensor.WithBacking(data))
ones := tensor.Ones(tensor.Float32, int(shape[0]))
@@ -471,10 +887,21 @@ func (q *qwen3NextModel) Replacements() []string {
return []string{
// Embeddings and output
"lm_head", "output",
"model.language_model.embed_tokens", "token_embd",
"model.language_model.norm", "output_norm",
"model.language_model.layers", "blk",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"model.layers", "blk",
// Vision
"model.visual", "v",
"patch_embed.proj", "patch_embed",
"blocks", "blk",
"attn.qkv", "attn_qkv",
"attn.proj", "attn_out",
"deepstack_merger_list", "deepstack_merger",
// Layer norms
"input_layernorm", "attn_norm",
"post_attention_layernorm", "post_attention_norm",
@@ -487,9 +914,16 @@ func (q *qwen3NextModel) Replacements() []string {
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
// Linear attention (Gated Delta Net)
// Linear attention (legacy qwen3next)
"linear_attn.in_proj_qkvz", "ssm_in",
"linear_attn.in_proj_ba", "ssm_ba",
// Linear attention (qwen35)
"linear_attn.in_proj_qkv", "attn_qkv",
"linear_attn.in_proj_z", "attn_gate",
"linear_attn.in_proj_a", "ssm_alpha",
"linear_attn.in_proj_b", "ssm_beta",
"linear_attn.conv1d", "ssm_conv1d",
"linear_attn.dt_bias", "ssm_dt",
"linear_attn.dt_proj", "ssm_dt",
@@ -497,14 +931,14 @@ func (q *qwen3NextModel) Replacements() []string {
"linear_attn.norm", "ssm_norm",
"linear_attn.out_proj", "ssm_out",
// MoE (experts are stacked via mergeTensors, not replaced here)
// MoE
"mlp.gate.weight", "ffn_gate_inp.weight",
"mlp.shared_expert.down_proj", "ffn_down_shexp",
"mlp.shared_expert.gate_proj", "ffn_gate_shexp",
"mlp.shared_expert.up_proj", "ffn_up_shexp",
"mlp.shared_expert_gate", "ffn_gate_inp_shexp",
// Dense FFN (if any layers use it)
// Dense FFN
"mlp.down_proj", "ffn_down",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",

View File

@@ -0,0 +1,563 @@
package convert
import (
"bytes"
"encoding/binary"
"os"
"slices"
"strings"
"testing"
"github.com/ollama/ollama/fs/ggml"
)
func boolPtr(v bool) *bool {
return &v
}
func readTensorData(t *testing.T, tensor *ggml.Tensor) []float32 {
t.Helper()
var b bytes.Buffer
if _, err := tensor.WriteTo(&b); err != nil {
t.Fatal(err)
}
numel := 1
for _, d := range tensor.Shape {
numel *= int(d)
}
values := make([]float32, numel)
if err := binary.Read(&b, binary.LittleEndian, &values); err != nil {
t.Fatal(err)
}
return values
}
func TestQwen3NextLegacyModelTypeDisablesReorder(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_next",
},
}
if m.shouldReorderVHeads() {
t.Fatalf("legacy qwen3_next model_type should not reorder v-head layout")
}
}
func TestQwen3NextLegacyArchitectureDisablesReorder(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
Architectures: []string{"Qwen3NextForCausalLM"},
},
}
if m.shouldReorderVHeads() {
t.Fatalf("legacy Qwen3Next architecture should not reorder v-head layout")
}
}
func TestQwen3NextKVLegacyConfig(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_next",
},
qwen3NextTextConfig: qwen3NextTextConfig{
MaxPositionEmbeddings: 8192,
HiddenSize: 512,
NumHiddenLayers: 4,
IntermediateSize: 2048,
NumAttentionHeads: 8,
NumKeyValueHeads: 2,
HeadDim: 64,
RopeTheta: 1_000_000,
RMSNormEPS: 1e-6,
NumExperts: 8,
NumExpertsPerToken: 2,
NormTopkProb: boolPtr(true),
MoEIntermediateSize: 256,
SharedExpertIntermSize: 512,
FullAttentionInterval: 2,
LinearConvKernelDim: 4,
LinearKeyHeadDim: 64,
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
LinearValueHeadDim: 64,
PartialRotaryFactor: 0.25,
},
}
if err := m.parseMore(os.DirFS(t.TempDir())); err != nil {
t.Fatal(err)
}
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
if got, want := kv["general.architecture"], "qwen35moe"; got != want {
t.Fatalf("unexpected architecture: got %v want %v", got, want)
}
if got, want := kv["tokenizer.ggml.pre"], "qwen35"; got != want {
t.Fatalf("unexpected tokenizer pre: got %v want %v", got, want)
}
headCountKV, ok := kv["attention.head_count_kv"].([]uint32)
if !ok {
t.Fatalf("attention.head_count_kv has unexpected type: %T", kv["attention.head_count_kv"])
}
if got, want := headCountKV, []uint32{0, 2, 0, 2}; !slices.Equal(got, want) {
t.Fatalf("unexpected attention.head_count_kv: got %v want %v", got, want)
}
if _, ok := kv["ssm.v_head_reordered"]; ok {
t.Fatalf("legacy qwen3next should not enable ssm.v_head_reordered")
}
if got, want := kv["norm_top_k_prob"], true; got != want {
t.Fatalf("unexpected norm_top_k_prob: got %v want %v", got, want)
}
}
func TestQwen35MoeOmitsNormTopKProbWhenUnset(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_5",
},
qwen3NextTextConfig: qwen3NextTextConfig{
MaxPositionEmbeddings: 4096,
HiddenSize: 512,
NumHiddenLayers: 4,
IntermediateSize: 2048,
NumAttentionHeads: 8,
NumKeyValueHeads: 2,
HeadDim: 64,
RopeTheta: 1_000_000,
RMSNormEPS: 1e-6,
NumExperts: 8,
NumExpertsPerToken: 2,
FullAttentionInterval: 2,
LinearConvKernelDim: 4,
LinearKeyHeadDim: 64,
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
LinearValueHeadDim: 64,
PartialRotaryFactor: 0.25,
},
}
if err := m.parseMore(os.DirFS(t.TempDir())); err != nil {
t.Fatal(err)
}
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
if _, ok := kv["norm_top_k_prob"]; ok {
t.Fatalf("expected norm_top_k_prob to be omitted when not set in config")
}
}
func TestQwen35KVFromTextConfig(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_5",
},
TextConfig: &qwen3NextTextConfig{
MaxPositionEmbeddings: 16384,
HiddenSize: 1024,
NumHiddenLayers: 4,
IntermediateSize: 4096,
NumAttentionHeads: 8,
NumKeyValueHeads: 4,
HeadDim: 128,
RMSNormEPS: 1e-6,
LayerTypes: []string{
"linear_attention",
"full_attention",
"linear_attention",
"full_attention",
},
LinearConvKernelDim: 4,
LinearKeyHeadDim: 128,
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
LinearValueHeadDim: 128,
RopeParameters: qwen3NextRopeParams{
MRopeInterleaved: true,
MropeSection: []int32{11, 11, 10},
RopeType: "default",
RopeTheta: 10_000_000,
PartialRotaryFactor: 0.25,
},
},
VisionModel: qwen3NextVisionConfig{
Depth: 2,
HiddenSize: 128,
NumHeads: 4,
InChannels: 3,
PatchSize: 16,
SpatialMergeSize: 2,
RMSNormEps: 1e-6,
RopeTheta: 10_000,
TemporalPatchSize: 2,
DeepstackVisualIndexes: []int32{1},
},
ImageTokenID: 1001,
VisionStartTokenID: 1002,
VisionEndTokenID: 1003,
}
m.VisionModel.Size.ShortestEdge = 224
m.VisionModel.Size.LongestEdge = 4096
m.VisionModel.ImageMean = []float32{0.5, 0.5, 0.5}
m.VisionModel.ImageStd = []float32{0.2, 0.2, 0.2}
if err := m.parseMore(os.DirFS(t.TempDir())); err != nil {
t.Fatal(err)
}
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
if got, want := kv["general.architecture"], "qwen35"; got != want {
t.Fatalf("unexpected architecture: got %v want %v", got, want)
}
headCountKV, ok := kv["attention.head_count_kv"].([]uint32)
if !ok {
t.Fatalf("attention.head_count_kv has unexpected type: %T", kv["attention.head_count_kv"])
}
if got, want := headCountKV, []uint32{0, 4, 0, 4}; !slices.Equal(got, want) {
t.Fatalf("unexpected attention.head_count_kv: got %v want %v", got, want)
}
if got, ok := kv["ssm.v_head_reordered"].(bool); !ok || !got {
t.Fatalf("expected ssm.v_head_reordered=true, got %v (%T)", kv["ssm.v_head_reordered"], kv["ssm.v_head_reordered"])
}
mrope, ok := kv["mrope_sections"].([]int32)
if !ok {
t.Fatalf("mrope_sections has unexpected type: %T", kv["mrope_sections"])
}
if got, want := mrope, []int32{11, 11, 10}; !slices.Equal(got, want) {
t.Fatalf("unexpected mrope_sections: got %v want %v", got, want)
}
ropeSections, ok := kv["rope.dimension_sections"].([]int32)
if !ok {
t.Fatalf("rope.dimension_sections has unexpected type: %T", kv["rope.dimension_sections"])
}
if got, want := ropeSections, []int32{11, 11, 10}; !slices.Equal(got, want) {
t.Fatalf("unexpected rope.dimension_sections: got %v want %v", got, want)
}
if got, ok := kv["rope.mrope_interleaved"].(bool); !ok || !got {
t.Fatalf("expected rope.mrope_interleaved=true, got %v (%T)", kv["rope.mrope_interleaved"], kv["rope.mrope_interleaved"])
}
if got, want := kv["vision.block_count"], uint32(2); got != want {
t.Fatalf("unexpected vision.block_count: got %v want %v", got, want)
}
}
func TestQwen3NextReplacements(t *testing.T) {
r := strings.NewReplacer((&qwen3NextModel{}).Replacements()...)
if got, want := r.Replace("model.language_model.layers.1.linear_attn.in_proj_qkv.weight"), "blk.1.attn_qkv.weight"; got != want {
t.Fatalf("unexpected language-model replacement: got %q want %q", got, want)
}
if got, want := r.Replace("model.visual.blocks.0.attn.qkv.weight"), "v.blk.0.attn_qkv.weight"; got != want {
t.Fatalf("unexpected vision replacement: got %q want %q", got, want)
}
if got, want := r.Replace("model.layers.1.linear_attn.in_proj_qkvz.weight"), "blk.1.ssm_in.weight"; got != want {
t.Fatalf("unexpected legacy replacement: got %q want %q", got, want)
}
}
func TestQwen35ReordersVHeads(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_5",
},
qwen3NextTextConfig: qwen3NextTextConfig{
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
LinearValueHeadDim: 1,
},
}
out := m.Tensors([]Tensor{
&fakeTensor{
name: "blk.0.attn_gate.weight",
shape: []uint64{4, 2},
data: []float32{0, 1, 2, 3, 4, 5, 6, 7},
},
})
if len(out) != 1 {
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
}
if got, want := readTensorData(t, out[0]), []float32{0, 1, 4, 5, 2, 3, 6, 7}; !slices.Equal(got, want) {
t.Fatalf("unexpected data: got %v want %v", got, want)
}
}
func TestQwen35ReordersAttnQKVOutputDim(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_5",
},
qwen3NextTextConfig: qwen3NextTextConfig{
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
LinearKeyHeadDim: 1,
LinearValueHeadDim: 1,
},
}
out := m.Tensors([]Tensor{
&fakeTensor{
name: "blk.0.attn_qkv.weight",
shape: []uint64{8, 2}, // [out_features, in_features] (HF layout)
data: []float32{
0, 1, // q0
2, 3, // q1
4, 5, // k0
6, 7, // k1
10, 11, // v(k0,v0)
12, 13, // v(k0,v1)
20, 21, // v(k1,v0)
22, 23, // v(k1,v1)
},
},
})
if len(out) != 1 {
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
}
if got, want := readTensorData(t, out[0]), []float32{
0, 1, 2, 3, 4, 5, 6, 7,
10, 11, 20, 21, 12, 13, 22, 23,
}; !slices.Equal(got, want) {
t.Fatalf("unexpected qkv data: got %v want %v", got, want)
}
}
func TestQwen35ReordersSsmOutInputDim(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_5",
},
qwen3NextTextConfig: qwen3NextTextConfig{
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
LinearValueHeadDim: 1,
},
}
out := m.Tensors([]Tensor{
&fakeTensor{
name: "blk.0.ssm_out.weight",
shape: []uint64{2, 4},
data: []float32{0, 1, 2, 3, 4, 5, 6, 7},
},
})
if len(out) != 1 {
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
}
if got, want := readTensorData(t, out[0]), []float32{0, 2, 1, 3, 4, 6, 5, 7}; !slices.Equal(got, want) {
t.Fatalf("unexpected ssm_out data: got %v want %v", got, want)
}
}
func TestQwen35ReordersSsmBetaRows(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_5",
},
qwen3NextTextConfig: qwen3NextTextConfig{
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
},
}
out := m.Tensors([]Tensor{
&fakeTensor{
name: "blk.0.ssm_beta.weight",
shape: []uint64{4, 2},
data: []float32{0, 1, 2, 3, 4, 5, 6, 7},
},
})
if len(out) != 1 {
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
}
if got, want := readTensorData(t, out[0]), []float32{0, 1, 4, 5, 2, 3, 6, 7}; !slices.Equal(got, want) {
t.Fatalf("unexpected ssm_beta data: got %v want %v", got, want)
}
}
func TestQwen35ReordersConv1DChannelDim(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_5",
},
qwen3NextTextConfig: qwen3NextTextConfig{
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
LinearKeyHeadDim: 1,
LinearValueHeadDim: 1,
},
}
out := m.Tensors([]Tensor{
&fakeTensor{
name: "blk.0.ssm_conv1d.weight",
shape: []uint64{8, 2}, // [channels, kernel] after squeeze
data: []float32{
0, 1, // q0
2, 3, // q1
4, 5, // k0
6, 7, // k1
10, 11, // v(k0,v0)
12, 13, // v(k0,v1)
20, 21, // v(k1,v0)
22, 23, // v(k1,v1)
},
},
})
if len(out) != 1 {
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
}
if got, want := readTensorData(t, out[0]), []float32{
0, 1, 2, 3, 4, 5, 6, 7,
10, 11, 20, 21, 12, 13, 22, 23,
}; !slices.Equal(got, want) {
t.Fatalf("unexpected conv1d data: got %v want %v", got, want)
}
}
func TestLegacyQwen3NextDoesNotReorderVHeads(t *testing.T) {
m := &qwen3NextModel{
ModelParameters: ModelParameters{
ModelType: "qwen3_next",
},
qwen3NextTextConfig: qwen3NextTextConfig{
LinearNumKeyHeads: 2,
LinearNumValueHeads: 4,
LinearValueHeadDim: 1,
},
}
out := m.Tensors([]Tensor{
&fakeTensor{
name: "blk.0.attn_gate.weight",
shape: []uint64{4, 1},
data: []float32{0, 1, 2, 3},
},
})
if len(out) != 1 {
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
}
if got, want := readTensorData(t, out[0]), []float32{0, 1, 2, 3}; !slices.Equal(got, want) {
t.Fatalf("unexpected data for legacy qwen3next: got %v want %v", got, want)
}
}
func TestQwen35MoePackedExperts(t *testing.T) {
m := &qwen3NextModel{
qwen3NextTextConfig: qwen3NextTextConfig{
NumHiddenLayers: 1,
},
}
out := m.Tensors([]Tensor{
&fakeTensor{
name: "blk.0.mlp.experts.gate_up_proj",
shape: []uint64{2, 4, 3},
data: []float32{
0, 1, 2,
3, 4, 5,
6, 7, 8,
9, 10, 11,
12, 13, 14,
15, 16, 17,
18, 19, 20,
21, 22, 23,
},
},
&fakeTensor{
name: "blk.0.mlp.experts.down_proj",
shape: []uint64{2, 5, 3},
data: make([]float32, 2*5*3),
},
})
get := func(name string) *ggml.Tensor {
for _, tensor := range out {
if tensor.Name == name {
return tensor
}
}
return nil
}
gate := get("blk.0.ffn_gate_exps.weight")
if gate == nil {
t.Fatalf("missing tensor %q", "blk.0.ffn_gate_exps.weight")
}
if got, want := gate.Shape, []uint64{2, 2, 3}; !slices.Equal(got, want) {
t.Fatalf("unexpected gate shape: got %v want %v", got, want)
}
if got, want := readTensorData(t, gate), []float32{
0, 1, 2, 3, 4, 5,
12, 13, 14, 15, 16, 17,
}; !slices.Equal(got, want) {
t.Fatalf("unexpected gate values: got %v want %v", got, want)
}
up := get("blk.0.ffn_up_exps.weight")
if up == nil {
t.Fatalf("missing tensor %q", "blk.0.ffn_up_exps.weight")
}
if got, want := up.Shape, []uint64{2, 2, 3}; !slices.Equal(got, want) {
t.Fatalf("unexpected up shape: got %v want %v", got, want)
}
if got, want := readTensorData(t, up), []float32{
6, 7, 8, 9, 10, 11,
18, 19, 20, 21, 22, 23,
}; !slices.Equal(got, want) {
t.Fatalf("unexpected up values: got %v want %v", got, want)
}
down := get("blk.0.ffn_down_exps.weight")
if down == nil {
t.Fatalf("missing tensor %q", "blk.0.ffn_down_exps.weight")
}
if got, want := down.Shape, []uint64{2, 5, 3}; !slices.Equal(got, want) {
t.Fatalf("unexpected down shape: got %v want %v", got, want)
}
}
func TestQwen35SharedExpertGateKeepsMatrixShape(t *testing.T) {
m := &qwen3NextModel{}
out := m.Tensors([]Tensor{
&fakeTensor{
name: "blk.0.ffn_gate_inp_shexp.weight",
shape: []uint64{1, 4},
data: []float32{0, 1, 2, 3},
},
})
if len(out) != 1 {
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
}
if got, want := out[0].Shape, []uint64{1, 4}; !slices.Equal(got, want) {
t.Fatalf("unexpected shared gate shape: got %v want %v", got, want)
}
}

97
convert/json_compat.go Normal file
View File

@@ -0,0 +1,97 @@
package convert
// sanitizeNonFiniteJSON rewrites non-standard JSON numeric tokens that some
// HF configs emit (Infinity, -Infinity, NaN) into standard JSON numbers.
//
// This is intentionally conservative:
// - only runs outside quoted strings
// - only rewrites full tokens
//
// We map these values to 0 because encoding/json rejects non-finite values,
// and these fields are typically model-side metadata not consumed by the
// converter.
func sanitizeNonFiniteJSON(in []byte) []byte {
if len(in) == 0 {
return in
}
out := make([]byte, 0, len(in))
inString := false
escape := false
for i := 0; i < len(in); {
c := in[i]
if inString {
out = append(out, c)
if escape {
escape = false
} else if c == '\\' {
escape = true
} else if c == '"' {
inString = false
}
i++
continue
}
if c == '"' {
inString = true
out = append(out, c)
i++
continue
}
if hasToken(in, i, "-Infinity") {
out = append(out, '0')
i += len("-Infinity")
continue
}
if hasToken(in, i, "Infinity") {
out = append(out, '0')
i += len("Infinity")
continue
}
if hasToken(in, i, "NaN") {
out = append(out, '0')
i += len("NaN")
continue
}
out = append(out, c)
i++
}
return out
}
func hasToken(in []byte, at int, tok string) bool {
end := at + len(tok)
if at < 0 || end > len(in) {
return false
}
if string(in[at:end]) != tok {
return false
}
if at > 0 && !isJSONValuePrefixBoundary(in[at-1]) {
return false
}
if end < len(in) && !isJSONValueSuffixBoundary(in[end]) {
return false
}
return true
}
func isJSONWhitespace(b byte) bool {
return b == ' ' || b == '\t' || b == '\n' || b == '\r'
}
func isJSONValuePrefixBoundary(b byte) bool {
return isJSONWhitespace(b) || b == ':' || b == ',' || b == '['
}
func isJSONValueSuffixBoundary(b byte) bool {
return isJSONWhitespace(b) || b == ',' || b == ']' || b == '}'
}

View File

@@ -0,0 +1,46 @@
package convert
import "testing"
func TestSanitizeNonFiniteJSON(t *testing.T) {
tests := []struct {
name string
in string
want string
}{
{
name: "infinity token",
in: `{"a":[0,Infinity,1]}`,
want: `{"a":[0,0,1]}`,
},
{
name: "negative infinity token",
in: `{"a":-Infinity}`,
want: `{"a":0}`,
},
{
name: "nan token",
in: `{"a":NaN}`,
want: `{"a":0}`,
},
{
name: "tokens inside strings untouched",
in: `{"a":"Infinity -Infinity NaN","b":Infinity}`,
want: `{"a":"Infinity -Infinity NaN","b":0}`,
},
{
name: "identifier-like token untouched",
in: `{"a":InfinityValue}`,
want: `{"a":InfinityValue}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := string(sanitizeNonFiniteJSON([]byte(tt.in)))
if got != tt.want {
t.Fatalf("sanitizeNonFiniteJSON() = %q, want %q", got, tt.want)
}
})
}
}

View File

@@ -101,6 +101,8 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
t.Pre = "deepseek-coder"
case "1ff7f41064896984db5d1bb6ff64fa4bc29007d08c1b439e505b7392777a319e":
t.Pre = "qwen2"
case "00431aed57e696b747435f734d1e3b9b1bfd931a121fb5cac7129e97c181e9ba":
t.Pre = "qwen35"
case "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855":
// noop, empty pretokenizer
default:
@@ -212,8 +214,13 @@ type tokenizer struct {
PreTokenizer struct {
PreTokenizers []struct {
Type string `json:"type"`
Pattern struct {
Type string `json:"type"`
Behavior string `json:"behavior"`
Invert bool `json:"invert"`
AddPrefixSpace bool `json:"add_prefix_space"`
TrimOffsets bool `json:"trim_offsets"`
UseRegex bool `json:"use_regex"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
} `json:"pretokenizers"`

View File

@@ -191,6 +191,84 @@ func TestParseTokenizer(t *testing.T) {
Pre: "default",
},
},
{
name: "llama-bpe pretokenizer and control tokens",
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
"tokenizer.json": strings.NewReader(`{
"added_tokens": [
{"id": 1, "content": "<|startoftext|>", "special": true},
{"id": 6, "content": "<|im_start|>", "special": true},
{"id": 7, "content": "<|im_end|>", "special": true},
{"id": 8, "content": "<|tool_list_start|>", "special": true},
{"id": 9, "content": "<|tool_list_end|>", "special": true},
{"id": 10, "content": "<|tool_call_start|>", "special": true},
{"id": 11, "content": "<|tool_call_end|>", "special": true},
{"id": 12, "content": "<|tool_response_start|>", "special": true},
{"id": 13, "content": "<|tool_response_end|>", "special": true},
{"id": 396, "content": "<image>", "special": true},
{"id": 64400, "content": "<think>", "special": true},
{"id": 64401, "content": "</think>", "special": true}
],
"model": {
"vocab": {
"<|startoftext|>": 1,
"<|im_start|>": 6,
"<|im_end|>": 7,
"<|tool_list_start|>": 8,
"<|tool_list_end|>": 9,
"<|tool_call_start|>": 10,
"<|tool_call_end|>": 11,
"<|tool_response_start|>": 12,
"<|tool_response_end|>": 13,
"<image>": 396,
"<think>": 64400,
"</think>": 64401
}
},
"pre_tokenizer": {
"type": "Sequence",
"pretokenizers": [
{
"type": "Split",
"pattern": {
"Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
},
"behavior": "Isolated",
"invert": false
},
{
"type": "ByteLevel",
"add_prefix_space": false,
"trim_offsets": true,
"use_regex": false
}
]
}
}`),
}),
want: &Tokenizer{
Vocabulary: &Vocabulary{
Model: "gpt2",
Tokens: []string{
"<|startoftext|>",
"<|im_start|>",
"<|im_end|>",
"<|tool_list_start|>",
"<|tool_list_end|>",
"<|tool_call_start|>",
"<|tool_call_end|>",
"<|tool_response_start|>",
"<|tool_response_end|>",
"<image>",
"<think>",
"</think>",
},
Scores: []float32{1, 6, 7, 8, 9, 10, 11, 12, 13, 396, 64400, 64401},
Types: []int32{3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3},
},
Pre: "llama-bpe",
},
},
{
name: "list string merges",
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
@@ -308,6 +386,28 @@ func TestParseTokenizer(t *testing.T) {
Pre: "default",
},
},
{
name: "qwen35 pretokenizer",
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
"tokenizer.json": strings.NewReader(`{
"pre_tokenizer": {
"type": "Sequence",
"pretokenizers": [
{
"type": "Split",
"pattern": {
"Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
}
}
]
}
}`),
}),
want: &Tokenizer{
Vocabulary: &Vocabulary{Model: "gpt2"},
Pre: "qwen35",
},
},
}
for _, tt := range cases {

View File

@@ -12,7 +12,6 @@ To use Ollama with tools that expect the Anthropic API (like Claude Code), set t
```shell
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
export ANTHROPIC_API_KEY="" # required but ignored
export ANTHROPIC_BASE_URL=http://localhost:11434
```
@@ -269,7 +268,7 @@ ollama launch claude --config
Set the environment variables and run Claude Code:
```shell
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 claude --model qwen3-coder
```
Or set the environment variables in your shell profile:
@@ -277,7 +276,6 @@ Or set the environment variables in your shell profile:
```shell
export ANTHROPIC_AUTH_TOKEN=ollama
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=""
```
Then run Claude Code with any Ollama model:

View File

@@ -6,7 +6,7 @@ Ollama provides compatibility with parts of the [OpenAI API](https://platform.op
## Usage
### Simple `v1/chat/completions` example
### Simple `/v1/chat/completions` example
<CodeGroup dropdown>
@@ -57,7 +57,7 @@ curl -X POST http://localhost:11434/v1/chat/completions \
</CodeGroup>
### Simple `v1/responses` example
### Simple `/v1/responses` example
<CodeGroup dropdown>
@@ -103,7 +103,7 @@ curl -X POST http://localhost:11434/v1/responses \
</CodeGroup>
### v1/chat/completions with vision example
### `/v1/chat/completions` with vision example
<CodeGroup dropdown>
@@ -184,6 +184,7 @@ curl -X POST http://localhost:11434/v1/chat/completions \
- [x] Reproducible outputs
- [x] Vision
- [x] Tools
- [x] Reasoning/thinking control (for thinking models)
- [ ] Logprobs
#### Supported request fields
@@ -207,6 +208,9 @@ curl -X POST http://localhost:11434/v1/chat/completions \
- [x] `top_p`
- [x] `max_tokens`
- [x] `tools`
- [x] `reasoning_effort` (`"high"`, `"medium"`, `"low"`, `"none"`)
- [x] `reasoning`
- [x] `effort` (`"high"`, `"medium"`, `"low"`, `"none"`)
- [ ] `tool_choice`
- [ ] `logit_bias`
- [ ] `user`

View File

@@ -40,7 +40,7 @@ ollama launch claude
Launch with a specific model:
```
ollama launch claude --model qwen3-coder
ollama launch claude --model qwen3.5
```
Configure without launching:

View File

@@ -51,6 +51,9 @@ Install prerequisites:
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
- (Optional) VULKAN GPU support
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
- (Optional) MLX engine support
- [CUDA 13+ SDK](https://developer.nvidia.com/cuda-downloads)
- [cuDNN 9+](https://developer.nvidia.com/cudnn)
Then, configure and build the project:
@@ -101,6 +104,10 @@ Install prerequisites:
- (Optional) VULKAN GPU support
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
- Or install via package manager: `sudo apt install vulkan-sdk` (Ubuntu/Debian) or `sudo dnf install vulkan-sdk` (Fedora/CentOS)
- (Optional) MLX engine support
- [CUDA 13+ SDK](https://developer.nvidia.com/cuda-downloads)
- [cuDNN 9+](https://developer.nvidia.com/cudnn)
- OpenBLAS/LAPACK: `sudo apt install libopenblas-dev liblapack-dev liblapacke-dev` (Ubuntu/Debian)
> [!IMPORTANT]
> Ensure prerequisites are in `PATH` before running CMake.
@@ -118,6 +125,67 @@ Lastly, run Ollama:
go run . serve
```
## MLX Engine (Optional)
The MLX engine enables running safetensor based models. It requires building the [MLX](https://github.com/ml-explore/mlx) and [MLX-C](https://github.com/ml-explore/mlx-c) shared libraries separately via CMake. On MacOS, MLX leverages the Metal library to run on the GPU, and on Windows and Linux, runs on NVIDIA GPUs via CUDA v13.
### macOS (Apple Silicon)
Requires the Metal toolchain. Install [Xcode](https://developer.apple.com/xcode/) first, then:
```shell
xcodebuild -downloadComponent MetalToolchain
```
Verify it's installed correctly (should print "no input files"):
```shell
xcrun metal
```
Then build:
```shell
cmake -B build --preset MLX
cmake --build build --preset MLX --parallel
cmake --install build --component MLX
```
> [!NOTE]
> Without the Metal toolchain, cmake will silently complete with Metal disabled. Check the cmake output for `Setting MLX_BUILD_METAL=OFF` which indicates the toolchain is missing.
### Windows / Linux (CUDA)
Requires CUDA 13+ and [cuDNN](https://developer.nvidia.com/cudnn) 9+.
```shell
cmake -B build --preset "MLX CUDA 13"
cmake --build build --target mlx --target mlxc --config Release --parallel
cmake --install build --component MLX --strip
```
### Local MLX source overrides
To build against a local checkout of MLX and/or MLX-C (useful for development), set environment variables before running CMake:
```shell
export OLLAMA_MLX_SOURCE=/path/to/mlx
export OLLAMA_MLX_C_SOURCE=/path/to/mlx-c
```
For example, using the helper scripts with local mlx and mlx-c repos:
```shell
OLLAMA_MLX_SOURCE=../mlx OLLAMA_MLX_C_SOURCE=../mlx-c ./scripts/build_linux.sh
OLLAMA_MLX_SOURCE=../mlx OLLAMA_MLX_C_SOURCE=../mlx-c ./scripts/build_darwin.sh
```
```powershell
$env:OLLAMA_MLX_SOURCE="../mlx"
$env:OLLAMA_MLX_C_SOURCE="../mlx-c"
./scripts/build_darwin.ps1
```
## Docker
```shell

View File

@@ -61,11 +61,17 @@ Ollama supports the following AMD GPUs via the ROCm library:
### Linux Support
| Family | Cards and accelerators |
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `SSG` |
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` |
Ollama requires the AMD ROCm v7 driver on Linux. You can install or upgrade
using the `amdgpu-install` utility from
[AMD's ROCm documentation](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/).
| Family | Cards and accelerators |
| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| AMD Radeon RX | `9070 XT` `9070 GRE` `9070` `9060 XT` `9060 XT LP` `9060` `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7700` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `5700 XT` `5700` `5600 XT` `5500 XT` |
| AMD Radeon AI PRO | `R9700` `R9600D` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` |
| AMD Ryzen AI | `Ryzen AI Max+ 395` `Ryzen AI Max 390` `Ryzen AI Max 385` `Ryzen AI 9 HX 475` `Ryzen AI 9 HX 470` `Ryzen AI 9 465` `Ryzen AI 9 HX 375` `Ryzen AI 9 HX 370` `Ryzen AI 9 365` |
| AMD Instinct | `MI350X` `MI300X` `MI300A` `MI250X` `MI250` `MI210` `MI100` |
### Windows Support
@@ -97,17 +103,20 @@ This table shows some example GPUs that map to these LLVM targets:
| **LLVM Target** | **An Example GPU** |
|-----------------|---------------------|
| gfx908 | Radeon Instinct MI100 |
| gfx90a | Radeon Instinct MI210 |
| gfx940 | Radeon Instinct MI300 |
| gfx941 | |
| gfx942 | |
| gfx90a | Radeon Instinct MI210/MI250 |
| gfx942 | Radeon Instinct MI300X/MI300A |
| gfx950 | Radeon Instinct MI350X |
| gfx1010 | Radeon RX 5700 XT |
| gfx1012 | Radeon RX 5500 XT |
| gfx1030 | Radeon PRO V620 |
| gfx1100 | Radeon PRO W7900 |
| gfx1101 | Radeon PRO W7700 |
| gfx1102 | Radeon RX 7600 |
AMD is working on enhancing ROCm v6 to broaden support for families of GPUs in a
future release which should increase support for more GPUs.
| gfx1103 | Radeon 780M |
| gfx1150 | Ryzen AI 9 HX 375 |
| gfx1151 | Ryzen AI Max+ 395 |
| gfx1200 | Radeon RX 9070 |
| gfx1201 | Radeon RX 9070 XT |
Reach out on [Discord](https://discord.gg/ollama) or file an
[issue](https://github.com/ollama/ollama/issues) for additional help.

View File

@@ -4,7 +4,7 @@ title: Claude Code
Claude Code is Anthropic's agentic coding tool that can read, modify, and execute code in your working directory.
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `glm-4.7`, `qwen3-coder`, `gpt-oss`.
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `qwen3.5`, `glm-5:cloud`, `kimi-k2.5:cloud`.
![Claude Code with Ollama](https://files.ollama.com/claude-code.png)
@@ -32,13 +32,57 @@ irm https://claude.ai/install.ps1 | iex
ollama launch claude
```
To configure without launching:
### Run directly with a model
```shell
ollama launch claude --config
ollama launch claude --model kimi-k2.5:cloud
```
### Manual setup
## Recommended Models
- `kimi-k2.5:cloud`
- `glm-5:cloud`
- `minimax-m2.5:cloud`
- `qwen3.5:cloud`
- `glm-4.7-flash`
- `qwen3.5`
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
## Scheduled Tasks with `/loop`
The `/loop` command runs a prompt or slash command on a recurring schedule inside Claude Code. This is useful for automating repetitive tasks like checking PRs, running research, or setting reminders.
```
/loop <interval> <prompt or /command>
```
### Examples
**Check in on your PRs**
```
/loop 30m Check my open PRs and summarize their status
```
**Automate research tasks**
```
/loop 1h Research the latest AI news and summarize key developments
```
**Automate bug reporting and triaging**
```
/loop 15m Check for new GitHub issues and triage by priority
```
**Set reminders**
```
/loop 1h Remind me to review the deploy status
```
## Manual setup
Claude Code connects to Ollama using the Anthropic-compatible API.
@@ -53,23 +97,14 @@ export ANTHROPIC_BASE_URL=http://localhost:11434
2. Run Claude Code with an Ollama model:
```shell
claude --model gpt-oss:20b
claude --model qwen3.5
```
Or run with environment variables inline:
```shell
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model glm-5:cloud
```
**Note:** Claude Code requires a large context window. We recommend at least 64k tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
## Recommended Models
- `qwen3-coder`
- `glm-4.7`
- `gpt-oss:20b`
- `gpt-oss:120b`
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).

View File

@@ -4,47 +4,65 @@ title: OpenClaw
OpenClaw is a personal AI assistant that runs on your own devices. It bridges messaging services (WhatsApp, Telegram, Slack, Discord, iMessage, and more) to AI coding agents through a centralized gateway.
## Install
Install [OpenClaw](https://openclaw.ai/)
```bash
npm install -g openclaw@latest
```
Then run the onboarding wizard:
```bash
openclaw onboard --install-daemon
```
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
## Usage with Ollama
### Quick setup
## Quick start
```bash
ollama launch openclaw
```
Ollama handles everything automatically:
1. **Install** — If OpenClaw isn't installed, Ollama prompts to install it via npm
2. **Security** — On the first launch, a security notice explains the risks of tool access
3. **Model** — Pick a model from the selector (local or cloud)
4. **Onboarding** — Ollama configures the provider, installs the gateway daemon, and sets your model as the primary
5. **Gateway** — Starts in the background and opens the OpenClaw TUI
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens if using local models. See [Context length](/context-length) for more information.</Note>
<Note>Previously known as Clawdbot. `ollama launch clawdbot` still works as an alias.</Note>
This configures OpenClaw to use Ollama and starts the gateway.
If the gateway is already running, no changes need to be made as the gateway will auto-reload the changes.
## Configure without launching
To change the model without starting the gateway and TUI:
To configure without launching:
```shell
```bash
ollama launch openclaw --config
```
## Recommended Models
To use a specific model directly:
- `qwen3-coder`
- `glm-4.7`
- `gpt-oss:20b`
- `gpt-oss:120b`
```bash
ollama launch openclaw --model kimi-k2.5:cloud
```
If the gateway is already running, it restarts automatically to pick up the new model.
## Recommended models
**Cloud models**:
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
- `minimax-m2.5:cloud` — Fast, efficient coding and real-world productivity
- `glm-5:cloud` — Reasoning and code generation
**Local models:**
- `glm-4.7-flash` — Reasoning and code generation locally (~25 GB VRAM)
More models at [ollama.com/search](https://ollama.com/search?c=cloud).
## Connect messaging apps
```bash
openclaw configure --section channels
```
Link WhatsApp, Telegram, Slack, Discord, or iMessage to chat with your local models from anywhere.
## Stopping the gateway
```bash
openclaw gateway stop
```
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).

Some files were not shown because too many files have changed in this diff Show More