Compare commits

...

43 Commits

Author SHA1 Message Date
jmorganca
4fdeb59325 convert: handle layers_block_type config field 2026-03-10 20:51:52 -07:00
Parth Sareen
61086083eb server: add experimental web search and web fetch routes (#14753) 2026-03-09 21:52:12 -07:00
Daniel Hiltgen
62d1f01ab4 ci: Fix windows build (#14754)
Instead of relying on sh for wildcard, do it in Go for better windows
compatibility.
2026-03-09 19:27:59 -07:00
Daniel Hiltgen
10e51c5177 MLX: add header vendoring and remove go build tag (#14642)
* prefer rocm v6 on windows

Avoid building with v7 - more changes are needed

* MLX: add header vendoring and remove go build tag

This switches to using a vendoring approach for the mlx-c headers so that Go
can build without requiring a cmake first.  This enables building the new MLX
based code by default.  Every time cmake runs, the headers are refreshed, so we
can easily keep them in sync when we bump mlx versions.  Basic Windows
and Linux support are verified.

* ci: harden for flaky choco repo servers

CI sometimes fails due to choco not actually installing cache.  Since it just speeds up the build, we can proceed without.

* review comments
2026-03-09 17:24:45 -07:00
Patrick Devine
3e06bde643 mlx: get parameters from modelfile during model creation (#14747) 2026-03-09 15:33:24 -07:00
Eva H
6be2de8214 app: auto update should be enabled when reset to defaults (#14741) 2026-03-09 15:02:36 -04:00
Daniel Hiltgen
ebb1b9ec14 rocm: update linux to v7.2 (#14391)
* rocm: update linux to v7.2

* review comments
2026-03-09 08:26:55 -07:00
Patrick Devine
d126467d5d x/mlxrunner: replace sampler interface chain with single stateful Sampler (#14652)
- Collapse MLX sampling state into a single sample.Sampler struct (options + history).
- Replace interface-based sampler chain (TopP, TopK, penalty, etc.) with function-based transforms.
- Update request/pipeline wiring to use *sample.Sampler, seed history from prompt tokens, and append generated tokens each step.
- Implement top_p, min_p, repeat_penalty, and frequency_penalty
2026-03-07 17:50:57 -08:00
Devon Rifkin
afb4c62fbf cloud_proxy: handle stream disconnects gracefully (#14685)
Previously we were printing out bad errors for expected cases like
clients disconnecting. Now we only debug log when that happens (which
still might help in cases where we're figuring out why an integration
isn't working). For other errors, we print out a proper warning now
2026-03-06 19:18:52 -08:00
Patrick Devine
e790dc435b mlx: int4 groupsize 64 (#14682)
Change affine 4bit integers to use groupsize 64
2026-03-06 16:39:47 -08:00
Daniel Hiltgen
288077c3a3 build: smarter docker parallelism (#14653)
Our Dockerfile leverages parallel stages for more efficient builds.  However,
our old parallel settings were naive and lead to under/over utilization
depending on the capabilities of your build system.

This change switches to using Ninja for all our docker cmake builds to leverage
its smarter parallel logic.  We tell Ninja to target a load of nproc so each of
the build stages will share the load on the system aiming for full CPU use
without oversaturation.

The GPU parallelism settings are also adjusted to 4 to avoid a long-tail for
the last few GPU targets as they work through the long list of GPU
architectures.

This also fixes the Dockerfile to move Vulkan install to just the stage that
needs it instead of blocking most other GPU installs.  This should speed up CI
which always has a clean build cache.
2026-03-06 16:36:22 -08:00
Daniel Hiltgen
4425c54eda create: fix localhost handling (#14681) 2026-03-06 16:35:58 -08:00
Michael Yang
778899a5d2 docs: format compat docs (#14678) 2026-03-06 14:53:17 -08:00
Jeffrey Morgan
4eab60c1e2 Reapply "don't require pulling stubs for cloud models" again (#14608)
* Revert "Revert "Reapply "don't require pulling stubs for cloud models"" (#14606)"

This reverts commit 39982a954e.

* fix test + do cloud lookup only when seeing cloud models

---------

Co-authored-by: ParthSareen <parth.sareen@ollama.com>
2026-03-06 14:27:47 -08:00
Bruce MacDonald
1af850e6e3 parsers: repair unclosed arg_value tags in GLM tool calls (#14656)
GLM models sometimes omits </arg_value> closing tags in tool call XML, causing xml.Unmarshal to fail with "element <arg_value> closed by </tool_call>".

This is a known issue across the GLM family.

Sanitize the input to fix closing arg_key values so encoding/xml can handle it.
2026-03-06 14:08:34 -08:00
Parth Sareen
9b0c7cc7b9 cmd: override stale entries for context window pi (#14655) 2026-03-05 16:30:24 -08:00
Daniel Hiltgen
6928630601 mlx: prevent remote creation mismatch (#14651)
If the user is pointing at a remote OLLAMA_HOST, fail experimental safetensor
based create operations as we only support local creation currently.
2026-03-05 14:59:00 -08:00
Parth Sareen
9896e3627f cmd/config: fix cloud model limit lookups in integrations (#14650) 2026-03-05 13:57:28 -08:00
Bruce MacDonald
15732f0ea7 cmd: use native Ollama API endpoint for OpenClaw (#14649)
Remove the /v1 suffix from the OpenClaw provider baseUrl so it uses
the native Ollama API instead of the OpenAI-compatible endpoint. The
/v1 endpoint my break tool calling in OpenClaw.
2026-03-05 13:29:17 -08:00
Parth Sareen
562c76d7cc cmd: add qwen3.5 context length for launch (#14626) 2026-03-04 14:10:52 -08:00
Parth Sareen
122c68c151 server: loosen thinking level constraint (#14625) 2026-03-04 13:42:18 -08:00
Jeffrey Morgan
82848a7806 model: fix renderer and parser for qwen3.5 (#14605) 2026-03-03 20:58:29 -08:00
Jeffrey Morgan
39982a954e Revert "Reapply "don't require pulling stubs for cloud models"" (#14606)
This reverts commit 799e51d419.
2026-03-03 20:56:10 -08:00
Patrick Devine
e9f6ea232f Add qwen3.5-next-moe support to MLX runner and models (#14417)
This change adds support for qwen3.5-next-moe models (qwen3-next/qwen3.5-next/qwen3-coder) to the MLX runner. It also:

* introduces recurrent cache support and related MLX ops
* updates pipeline/runner integration and adds tests
* properly quantizes stacked expert tensors
* a Gated Delta Metal kernel for fast SSM inference
* adds new MLX calls for Conv1d, DepthwideConv1d, Contiguous, Exp, Log, SoftmaxAxis
2026-03-03 16:39:22 -08:00
Patrick Devine
110eff01a9 chore: remove old imagegen LLMs models (#14597)
These models are implemented in the x/mlxrunner instead.
2026-03-03 13:23:40 -08:00
Jeffrey Morgan
799e51d419 Reapply "don't require pulling stubs for cloud models"
This reverts commit 97d2f05a6d.
2026-03-03 13:17:10 -08:00
Victor-Quqi
e8fcb29586 model/renderers: fix glm-ocr image tags in renderer prompts (#14584) 2026-03-03 12:51:34 -08:00
Jeffrey Morgan
97d2f05a6d Revert "don't require pulling stubs for cloud models (#14574)" (#14596)
This reverts commit 8207e55ec7.
2026-03-03 12:51:23 -08:00
Devon Rifkin
8207e55ec7 don't require pulling stubs for cloud models (#14574)
* don't require pulling stubs for cloud models

This is a first in a series of PRs that will better integrate Ollama's
cloud into the API and CLI. Previously we used to have a layer of
indirection where you'd first have to pull a "stub" model that contains
a reference to a cloud model. With this change, you don't have to pull
first, you can just use a cloud model in various routes like `/api/chat`
and `/api/show`. This change respects
<https://github.com/ollama/ollama/pull/14221>, so if cloud is disabled,
these models won't be accessible.

There's also a new, simpler pass-through proxy that doesn't convert the
requests ahead of hitting the cloud models, which they themselves
already support various formats (e.g., `v1/chat/completions` or Open
Responses, etc.). This will help prevent issues caused by double
converting (e.g., `v1/chat/completions` converted to `api/chat` on the
client, then calling cloud and converting back to a
`v1/chat/completions` response instead of the cloud model handling the
original `v1/chat/completions` request first).

There's now a notion of "source tags", which can be mixed with existing
tags. So instead of having different formats like`gpt-oss:20b-cloud` vs.
`kimi-k2.5:cloud` (`-cloud` suffix vs. `:cloud`), you can now specify
cloud by simply appending `:cloud`. This PR doesn't change model
resolution yet, but sets us up to allow for things like omitting the
non-source tag, which would make something like `ollama run
gpt-oss:cloud` work the same way that `ollama run gpt-oss` already works
today.

More detailed changes:

- Added a shared model selector parser in `types/modelselector`:
  - supports `:cloud` and `:local`
  - accepts source tags in any position
  - supports legacy `:<tag>-cloud`
  - rejects conflicting source tags
- Integrated selector handling across server inference/show routes:
  - `GenerateHandler`, `ChatHandler`, `EmbedHandler`,
    `EmbeddingsHandler`, `ShowHandler`
- Added explicit-cloud passthrough proxy for ollama.com:
  - same-endpoint forwarding for `/api/*`, `/v1/*`, and `/v1/messages`
  - normalizes `model` (and `name` for `/api/show`) before forwarding
  - forwards request headers except hop-by-hop/proxy-managed headers
  - uses bounded response-header timeout
  - handles auth failures in a friendly way
- Preserved cloud-disable behavior (`OLLAMA_NO_CLOUD`)
- Updated create flow to support `FROM ...:cloud` model sources (though
  this flow uses the legacy proxy still, supporting Modelfile overrides
  is more complicated with the direct proxy approach)
- Updated CLI/TUI/config cloud detection to use shared selector logic
- Updated CLI preflight behavior so explicit cloud requests do not
  auto-pull local stubs

What's next?

- Cloud discovery/listing and cache-backed `ollama ls` / `/api/tags`
- Modelfile overlay support for virtual cloud models on OpenAI/Anthropic
  request families
- Recommender/default-selection behavior for ambiguous model families
- Fully remove the legacy flow

Fixes: https://github.com/ollama/ollama/issues/13801

* consolidate pull logic into confirmAndPull helper

pullIfNeeded and ShowOrPull shared identical confirm-and-pull logic.
Extract confirmAndPull to eliminate the duplication.

* skip local existence checks for cloud models

ModelExists and the TUI's modelExists both check the local model list,
which causes cloud models to appear missing. Return true early for
explicit cloud models so the TUI displays them beside the integration
name and skips re-prompting the model picker on relaunch.

* support optionally pulling stubs for newly-style names

We now normalize names like `<family>:<size>:cloud` into legacy-style
names like `<family>:<size>-cloud` for pulling and deleting (this also
supports stripping `:local`). Support for pulling cloud models is
temporary, once we integrate properly into `/api/tags` we won't need
this anymore.

* Fix server alias syncing

* Update cmd/cmd.go

Co-authored-by: Parth Sareen <parth.sareen@ollama.com>

* address comments

* improve some naming

---------

Co-authored-by: ParthSareen <parth.sareen@ollama.com>
2026-03-03 10:46:33 -08:00
Jesse Gross
ad16bffc7d mlx: Remove peak memory from the API
This is still in flux so it is better to just log it for now.
2026-03-02 15:56:18 -08:00
Jesse Gross
c1e3ef4bcc mlxrunner: Refcount pinned tensors
Otherwise, it is error prone to manage multiple components working
with the same tensor.
2026-03-02 15:56:06 -08:00
Parth Sareen
a3093cd5e5 cmd/opencode: rename provider from "Ollama (local)" to "Ollama" (#14566)
The "(local)" qualifier is unnecessary since there's only one Ollama
provider. Existing configs with the old name are migrated automatically;
custom names are left unchanged.
2026-03-02 14:17:18 -08:00
Bruce MacDonald
23d4cad1a2 server: verify digest is not empty on create (#14555)
An empty digest is not a valid digest for an incoming create request. Reject empty digests at the api level.
2026-03-02 13:43:35 -08:00
Jeffrey Morgan
86513cb697 runner: add token history sampling parameters to ollama runner (#14537) 2026-03-01 19:16:07 -08:00
Jeffrey Morgan
3490e9590b model/qwen3next: avoid crash in in DeltaNet when offloading (#14541)
Co-authored-by: Yossi Ovadia <jabadia@gmail.com>
2026-03-01 18:44:04 -08:00
Jeffrey Morgan
8da09b1e7e qwen3next: add compatibility with imported GGUF models (#14517) 2026-02-28 14:21:42 -08:00
Jesse Gross
a60b9adcce mlxrunner: Fix prompt eval timing and count metrics
Only the last token's processing time is included in prompt processing,
giving an artificially high rate. In addition, the number of tokens
only included the tokens that miss the cache, instead of our historic
total tokens.
2026-02-27 17:29:47 -08:00
Jesse Gross
a16f96658b mlxrunner: Enforce model context limit
Currently, context length is unbounded - the cache will keep
growing forever independent of the model's trained context
length. This caps it and enforces semantics similar to most
cloud services:
 - Long prompts will result in an error, not truncation.
 - Generation that exceeds the context will be stopped
2026-02-27 17:29:47 -08:00
Jesse Gross
18ab09b431 mlxrunner: Propagate pipeline errors to client via api.StatusError
Errors that occur during pipeline processing are currently only
logged but not sent back to the client. Rather than using HTTP
status codes as we have historically done, this serializes errors
as messages to allow sending them at any time during the stream.
2026-02-27 17:29:47 -08:00
Jesse Gross
638faeac54 mlxrunner: Report actual memory usage from runner
The MLX runner previously reported a static VRAM estimate that was
computed at load time and consisted only of the weights. This is
strictly less than the actual memory usage, as it does not include
the KV cache or compute graph.
2026-02-27 17:29:47 -08:00
Jesse Gross
dd5eb6337d mlxrunner: Fix panic on full KV cache hit
When the entire prompt was already cached (e.g. repeated prompt),
findRemaining returned an empty slice, causing FromValues to panic
on an index-out-of-range accessing a zero-length byte slice.

Fix by always keeping at least one token to re-evaluate so the
pipeline can seed token generation. Also reject empty prompts
early rather than panicking.
2026-02-27 11:07:03 -08:00
Patrick Devine
79917cf80b show peak memory usage (#14485) 2026-02-26 18:38:27 -08:00
Parth Sareen
cc90a035a0 model/parsers: add stable tool call indexing for glm47 and qwen3 parsers (#14484) 2026-02-26 18:14:29 -08:00
229 changed files with 14398 additions and 5565 deletions

View File

@@ -117,6 +117,25 @@ jobs:
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
flags: ''
runner_dir: 'vulkan'
- os: windows
arch: amd64
preset: 'MLX CUDA 13'
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip
cuda-components:
- '"cudart"'
- '"nvcc"'
- '"cublas"'
- '"cublas_dev"'
- '"cufft"'
- '"cufft_dev"'
- '"nvrtc"'
- '"nvrtc_dev"'
- '"crt"'
- '"nvvm"'
- '"nvptxcompiler"'
cuda-version: '13.0'
flags: ''
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
env:
@@ -125,8 +144,10 @@ jobs:
- name: Install system dependencies
run: |
choco install -y --no-progress ccache ninja
ccache -o cache_dir=${{ github.workspace }}\.ccache
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan')
if (Get-Command ccache -ErrorAction SilentlyContinue) {
ccache -o cache_dir=${{ github.workspace }}\.ccache
}
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan') || startsWith(matrix.preset, 'MLX ')
id: cache-install
uses: actions/cache/restore@v4
with:
@@ -134,8 +155,9 @@ jobs:
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
key: ${{ matrix.install }}
- if: startsWith(matrix.preset, 'CUDA ')
C:\Program Files\NVIDIA\CUDNN
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'MLX ')
name: Install CUDA ${{ matrix.cuda-version }}
run: |
$ErrorActionPreference = "Stop"
@@ -179,6 +201,23 @@ jobs:
run: |
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: startsWith(matrix.preset, 'MLX ')
name: Install cuDNN for MLX
run: |
$ErrorActionPreference = "Stop"
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip"
Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted
$cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName
New-Item -ItemType Directory -Force -Path $cudnnRoot
Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse
}
echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
uses: actions/cache/save@v4
with:
@@ -186,7 +225,8 @@ jobs:
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
key: ${{ matrix.install }}
C:\Program Files\NVIDIA\CUDNN
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
- uses: actions/checkout@v4
- uses: actions/cache@v4
with:
@@ -198,7 +238,7 @@ jobs:
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}"
cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}"
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
cmake --install build --component "${{ startsWith(matrix.preset, 'MLX ') && 'MLX' || startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
env:
CMAKE_GENERATOR: Ninja

View File

@@ -37,7 +37,7 @@ jobs:
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
}
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*' '.github/**/*') | tee -a $GITHUB_OUTPUT
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
linux:
@@ -51,7 +51,7 @@ jobs:
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
- preset: ROCm
container: rocm/dev-ubuntu-22.04:6.1.2
container: rocm/dev-ubuntu-22.04:7.2
extra-packages: rocm-libs
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm'
- preset: Vulkan
@@ -60,6 +60,10 @@ jobs:
mesa-vulkan-drivers vulkan-tools
libvulkan1 libvulkan-dev
vulkan-sdk cmake ccache g++ make
- preset: 'MLX CUDA 13'
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
extra-packages: libcudnn9-dev-cuda-13 libopenblas-dev liblapack-dev liblapacke-dev git curl
flags: '-DCMAKE_CUDA_ARCHITECTURES=87 -DBLAS_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu -DLAPACK_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu'
runs-on: linux
container: ${{ matrix.container }}
steps:
@@ -76,6 +80,10 @@ jobs:
$sudo apt-get update
fi
$sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }}
# MLX requires CMake 3.25+, install from official releases
if [ "${{ matrix.preset }}" = "MLX CUDA 13" ]; then
curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.31.2/cmake-3.31.2-linux-$(uname -m).tar.gz | $sudo tar xz -C /usr/local --strip-components 1
fi
# Export VULKAN_SDK if provided by LunarG package (defensive)
if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then
echo "VULKAN_SDK=/usr" >> $GITHUB_ENV
@@ -87,8 +95,8 @@ jobs:
path: /github/home/.cache/ccache
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
- run: |
cmake --preset ${{ matrix.preset }} ${{ matrix.flags }}
cmake --build --preset ${{ matrix.preset }} --parallel
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
cmake --build --preset "${{ matrix.preset }}" --parallel
windows:
needs: [changes]
@@ -114,12 +122,31 @@ jobs:
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
- preset: Vulkan
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
- preset: 'MLX CUDA 13'
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
cuda-components:
- '"cudart"'
- '"nvcc"'
- '"cublas"'
- '"cublas_dev"'
- '"cufft"'
- '"cufft_dev"'
- '"nvrtc"'
- '"nvrtc_dev"'
- '"crt"'
- '"nvvm"'
- '"nvptxcompiler"'
cuda-version: '13.0'
runs-on: windows
steps:
- run: |
choco install -y --no-progress ccache ninja
ccache -o cache_dir=${{ github.workspace }}\.ccache
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan'
if (Get-Command ccache -ErrorAction SilentlyContinue) {
ccache -o cache_dir=${{ github.workspace }}\.ccache
}
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan' || matrix.preset == 'MLX CUDA 13'
id: cache-install
uses: actions/cache/restore@v4
with:
@@ -127,8 +154,9 @@ jobs:
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
key: ${{ matrix.install }}
- if: matrix.preset == 'CUDA'
C:\Program Files\NVIDIA\CUDNN
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
- if: matrix.preset == 'CUDA' || matrix.preset == 'MLX CUDA 13'
name: Install CUDA ${{ matrix.cuda-version }}
run: |
$ErrorActionPreference = "Stop"
@@ -168,6 +196,23 @@ jobs:
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
- if: matrix.preset == 'MLX CUDA 13'
name: Install cuDNN for MLX
run: |
$ErrorActionPreference = "Stop"
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip"
Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted
$cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName
New-Item -ItemType Directory -Force -Path $cudnnRoot
Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse
}
echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
uses: actions/cache/save@v4
with:
@@ -175,7 +220,8 @@ jobs:
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
key: ${{ matrix.install }}
C:\Program Files\NVIDIA\CUDNN
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
- uses: actions/checkout@v4
- uses: actions/cache@v4
with:

View File

@@ -64,10 +64,15 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR})
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${OLLAMA_BUILD_DIR})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
# Store ggml include paths for use with target_include_directories later.
# We avoid global include_directories() to prevent polluting the include path
# for other projects like MLX (whose openblas dependency has its own common.h).
set(GGML_INCLUDE_DIRS
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx
)
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
@@ -87,6 +92,14 @@ if(NOT CPU_VARIANTS)
set(CPU_VARIANTS "ggml-cpu")
endif()
# Apply ggml include directories to ggml targets only (not globally)
target_include_directories(ggml-base PRIVATE ${GGML_INCLUDE_DIRS})
foreach(variant ${CPU_VARIANTS})
if(TARGET ${variant})
target_include_directories(${variant} PRIVATE ${GGML_INCLUDE_DIRS})
endif()
endforeach()
install(TARGETS ggml-base ${CPU_VARIANTS}
RUNTIME_DEPENDENCIES
PRE_EXCLUDE_REGEXES ".*"
@@ -103,6 +116,7 @@ if(CMAKE_CUDA_COMPILER)
find_package(CUDAToolkit)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
target_include_directories(ggml-cuda PRIVATE ${GGML_INCLUDE_DIRS})
install(TARGETS ggml-cuda
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
@@ -134,6 +148,7 @@ if(CMAKE_HIP_COMPILER)
if(AMDGPU_TARGETS)
find_package(hip REQUIRED)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
target_include_directories(ggml-hip PRIVATE ${GGML_INCLUDE_DIRS})
if (WIN32)
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
@@ -148,7 +163,7 @@ if(CMAKE_HIP_COMPILER)
)
install(RUNTIME_DEPENDENCY_SET rocm
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register roctx64 rocroller drm drm_amdgpu numa elf
PRE_EXCLUDE_REGEXES ".*"
POST_EXCLUDE_REGEXES "system32"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
@@ -168,6 +183,7 @@ if(NOT APPLE)
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
target_include_directories(ggml-vulkan PRIVATE ${GGML_INCLUDE_DIRS})
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
PRE_INCLUDE_REGEXES vulkan
@@ -179,7 +195,6 @@ if(NOT APPLE)
endif()
option(MLX_ENGINE "Enable MLX backend" OFF)
if(MLX_ENGINE)
message(STATUS "Setting up MLX (this takes a while...)")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx)
@@ -187,10 +202,36 @@ if(MLX_ENGINE)
# Find CUDA toolkit if MLX is built with CUDA support
find_package(CUDAToolkit)
# Build list of directories for runtime dependency resolution
set(MLX_RUNTIME_DIRS ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR})
# Add cuDNN bin paths for DLLs (Windows MLX CUDA builds)
# CUDNN_ROOT_DIR is the standard CMake variable for cuDNN location
if(DEFINED ENV{CUDNN_ROOT_DIR})
# cuDNN 9.x has versioned subdirectories under bin/ (e.g., bin/13.0/)
file(GLOB CUDNN_BIN_SUBDIRS "$ENV{CUDNN_ROOT_DIR}/bin/*")
list(APPEND MLX_RUNTIME_DIRS ${CUDNN_BIN_SUBDIRS})
endif()
# Add build output directory and MLX dependency build directories
list(APPEND MLX_RUNTIME_DIRS ${OLLAMA_BUILD_DIR})
# OpenBLAS DLL location (pre-built zip extracts into openblas-src/bin/)
list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/openblas-src/bin)
# NCCL: on Linux, if real NCCL is found, cmake bundles libnccl.so via the
# regex below. If NCCL is not found, MLX links a static stub (OBJECT lib)
# so there is no runtime dependency. This path covers the stub build dir
# for windows so we include the DLL in our dependencies.
list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/distributed/nccl/nccl_stub-prefix/src/nccl_stub-build/Release)
# Base regexes for runtime dependencies (cross-platform)
set(MLX_INCLUDE_REGEXES cublas cublasLt cudart cufft nvrtc nvrtc-builtins cudnn nccl openblas gfortran)
# On Windows, also include dl.dll (dlfcn-win32 POSIX emulation layer)
if(WIN32)
list(APPEND MLX_INCLUDE_REGEXES "^dl\\.dll$")
endif()
install(TARGETS mlx mlxc
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
DIRECTORIES ${MLX_RUNTIME_DIRS}
PRE_INCLUDE_REGEXES ${MLX_INCLUDE_REGEXES}
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
@@ -205,13 +246,54 @@ if(MLX_ENGINE)
COMPONENT MLX)
endif()
# Manually install cudart and cublas since they might not be picked up as direct dependencies
# Install CCCL headers for NVRTC JIT compilation at runtime.
# MLX's own install rules use the default component so they get skipped by
# --component MLX. Headers are installed alongside libmlx in OLLAMA_INSTALL_DIR.
# On Linux, MLX's jit_module.cpp resolves CCCL via
# current_binary_dir().parent_path() / "include" / "cccl", so we create a
# symlink from lib/ollama/include -> ${OLLAMA_RUNNER_DIR}/include
# This will need refinement if we add multiple CUDA versions for MLX in the future.
if(EXISTS ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda)
install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda
DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl
COMPONENT MLX)
install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/nv
DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl
COMPONENT MLX)
if(NOT WIN32 AND NOT APPLE)
install(CODE "
set(_link \"${CMAKE_INSTALL_PREFIX}/lib/ollama/include\")
set(_target \"${OLLAMA_RUNNER_DIR}/include\")
if(NOT EXISTS \${_link})
execute_process(COMMAND \${CMAKE_COMMAND} -E create_symlink \${_target} \${_link})
endif()
" COMPONENT MLX)
endif()
endif()
# On Windows, explicitly install dl.dll (dlfcn-win32 POSIX dlopen emulation)
# RUNTIME_DEPENDENCIES auto-excludes it via POST_EXCLUDE_FILES_STRICT because
# dlfcn-win32 is a known CMake target with its own install rules (which install
# to the wrong destination). We must install it explicitly here.
if(WIN32)
install(FILES ${OLLAMA_BUILD_DIR}/dl.dll
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()
# Manually install CUDA runtime libraries that MLX loads via dlopen
# (not detected by RUNTIME_DEPENDENCIES since they aren't link-time deps)
if(CUDAToolkit_FOUND)
file(GLOB CUDART_LIBS
file(GLOB MLX_CUDA_LIBS
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
if(CUDART_LIBS)
install(FILES ${CUDART_LIBS}
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*"
"${CUDAToolkit_LIBRARY_DIR}/libcublasLt.so*"
"${CUDAToolkit_LIBRARY_DIR}/libnvrtc.so*"
"${CUDAToolkit_LIBRARY_DIR}/libnvrtc-builtins.so*"
"${CUDAToolkit_LIBRARY_DIR}/libcufft.so*"
"${CUDAToolkit_LIBRARY_DIR}/libcudnn.so*")
if(MLX_CUDA_LIBS)
install(FILES ${MLX_CUDA_LIBS}
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()

View File

@@ -77,6 +77,15 @@
"OLLAMA_RUNNER_DIR": "rocm"
}
},
{
"name": "ROCm 7",
"inherits": [ "ROCm" ],
"cacheVariables": {
"CMAKE_HIP_FLAGS": "-parallel-jobs=4",
"AMDGPU_TARGETS": "gfx942;gfx950;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1151;gfx1200;gfx1201;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-",
"OLLAMA_RUNNER_DIR": "rocm"
}
},
{
"name": "Vulkan",
"inherits": [ "Default" ],
@@ -103,6 +112,7 @@
"name": "MLX CUDA 13",
"inherits": [ "MLX", "CUDA 13" ],
"cacheVariables": {
"MLX_CUDA_ARCHITECTURES": "86;89;90;90a;100;103;75-virtual;80-virtual;110-virtual;120-virtual;121-virtual",
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
}
}
@@ -158,6 +168,11 @@
"inherits": [ "ROCm" ],
"configurePreset": "ROCm 6"
},
{
"name": "ROCm 7",
"inherits": [ "ROCm" ],
"configurePreset": "ROCm 7"
},
{
"name": "Vulkan",
"targets": [ "ggml-vulkan" ],

View File

@@ -1,28 +1,23 @@
# vim: filetype=dockerfile
ARG FLAVOR=${TARGETARCH}
ARG PARALLEL=8
ARG ROCMVERSION=6.3.3
ARG ROCMVERSION=7.2
ARG JETPACK5VERSION=r35.4.1
ARG JETPACK6VERSION=r36.4.0
ARG CMAKEVERSION=3.31.2
ARG NINJAVERSION=1.12.1
ARG VULKANVERSION=1.4.321.1
# Default empty stages for local MLX source overrides.
# Override with: docker build --build-context local-mlx=../mlx --build-context local-mlx-c=../mlx-c
FROM scratch AS local-mlx
FROM scratch AS local-mlx-c
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
RUN dnf install -y yum-utils ccache gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ gcc-toolset-11-binutils \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
ARG VULKANVERSION
RUN wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
&& tar xvf /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
&& dnf -y install ninja-build \
&& ln -s /usr/bin/python3 /usr/bin/python \
&& /${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \
&& /${VULKANVERSION}/vulkansdk -j 8 shaderc
RUN cp -r /${VULKANVERSION}/x86_64/include/* /usr/local/include/ \
&& cp -r /${VULKANVERSION}/x86_64/lib/* /usr/local/lib
ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
# install epel-release for ccache
@@ -33,100 +28,119 @@ ENV CC=clang CXX=clang++
FROM base-${TARGETARCH} AS base
ARG CMAKEVERSION
ARG NINJAVERSION
RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
RUN dnf install -y unzip \
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux$([ "$(uname -m)" = "aarch64" ] && echo "-aarch64").zip \
&& unzip /tmp/ninja.zip -d /usr/local/bin \
&& rm /tmp/ninja.zip
ENV CMAKE_GENERATOR=Ninja
ENV LDFLAGS=-s
FROM base AS cpu
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CPU' \
&& cmake --build --parallel ${PARALLEL} --preset 'CPU' \
&& cmake --install build --component CPU --strip --parallel ${PARALLEL}
&& cmake --build --preset 'CPU' -- -l $(nproc) \
&& cmake --install build --component CPU --strip
FROM base AS cuda-11
ARG CUDA11VERSION=11.8
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
ENV PATH=/usr/local/cuda-11/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 11' \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
&& cmake --build --preset 'CUDA 11' -- -l $(nproc) \
&& cmake --install build --component CUDA --strip
FROM base AS cuda-12
ARG CUDA12VERSION=12.8
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
ENV PATH=/usr/local/cuda-12/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 12' \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
&& cmake --build --preset 'CUDA 12' -- -l $(nproc) \
&& cmake --install build --component CUDA --strip
FROM base AS cuda-13
ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
ENV PATH=/usr/local/cuda-13/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 13' \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
&& cmake --build --preset 'CUDA 13' -- -l $(nproc) \
&& cmake --install build --component CUDA --strip
FROM base AS rocm-6
FROM base AS rocm-7
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'ROCm 6' \
&& cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
&& cmake --install build --component HIP --strip --parallel ${PARALLEL}
cmake --preset 'ROCm 7' \
&& cmake --build --preset 'ROCm 7' -- -l $(nproc) \
&& cmake --install build --component HIP --strip
RUN rm -f dist/lib/ollama/rocm/rocblas/library/*gfx90[06]*
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
ARG CMAKEVERSION
RUN apt-get update && apt-get install -y curl ccache \
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
ARG NINJAVERSION
RUN apt-get update && apt-get install -y curl ccache unzip \
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 \
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux-aarch64.zip \
&& unzip /tmp/ninja.zip -d /usr/local/bin \
&& rm /tmp/ninja.zip
ENV CMAKE_GENERATOR=Ninja
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'JetPack 5' \
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 5' \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
&& cmake --build --preset 'JetPack 5' -- -l $(nproc) \
&& cmake --install build --component CUDA --strip
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
ARG CMAKEVERSION
RUN apt-get update && apt-get install -y curl ccache \
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
ARG NINJAVERSION
RUN apt-get update && apt-get install -y curl ccache unzip \
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 \
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux-aarch64.zip \
&& unzip /tmp/ninja.zip -d /usr/local/bin \
&& rm /tmp/ninja.zip
ENV CMAKE_GENERATOR=Ninja
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'JetPack 6' \
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
&& cmake --build --preset 'JetPack 6' -- -l $(nproc) \
&& cmake --install build --component CUDA --strip
FROM base AS vulkan
ARG VULKANVERSION
RUN ln -s /usr/bin/python3 /usr/bin/python \
&& wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk.tar.xz \
&& tar xvf /tmp/vulkansdk.tar.xz -C /tmp \
&& /tmp/${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \
&& /tmp/${VULKANVERSION}/vulkansdk -j 8 shaderc \
&& cp -r /tmp/${VULKANVERSION}/x86_64/include/* /usr/local/include/ \
&& cp -r /tmp/${VULKANVERSION}/x86_64/lib/* /usr/local/lib \
&& cp -r /tmp/${VULKANVERSION}/x86_64/bin/* /usr/local/bin/ \
&& rm -rf /tmp/${VULKANVERSION} /tmp/vulkansdk.tar.xz
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'Vulkan' \
&& cmake --build --parallel --preset 'Vulkan' \
&& cmake --install build --component Vulkan --strip --parallel 8
&& cmake --build --preset 'Vulkan' -- -l $(nproc) \
&& cmake --install build --component Vulkan --strip
FROM base AS mlx
ARG CUDA13VERSION=13.0
@@ -138,20 +152,27 @@ ENV PATH=/usr/local/cuda-13/bin:$PATH
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
ARG PARALLEL
WORKDIR /go/src/github.com/ollama/ollama
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/imagegen/mlx x/imagegen/mlx
COPY go.mod go.sum .
COPY MLX_VERSION .
COPY MLX_VERSION MLX_CORE_VERSION .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
--mount=type=bind,from=local-mlx,target=/tmp/local-mlx \
--mount=type=bind,from=local-mlx-c,target=/tmp/local-mlx-c \
if [ -f /tmp/local-mlx/CMakeLists.txt ]; then \
export OLLAMA_MLX_SOURCE=/tmp/local-mlx; \
fi \
&& if [ -f /tmp/local-mlx-c/CMakeLists.txt ]; then \
export OLLAMA_MLX_C_SOURCE=/tmp/local-mlx-c; \
fi \
&& cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
&& cmake --build --preset 'MLX CUDA 13' -- -l $(nproc) \
&& cmake --install build --component MLX --strip
FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama
@@ -160,16 +181,14 @@ RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-
ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download
COPY . .
# Clone mlx-c headers for CGO (version from MLX_VERSION file)
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
ARG CGO_CFLAGS
ARG CGO_CXXFLAGS
ENV CGO_CFLAGS="${CGO_CFLAGS} -I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
ENV CGO_CFLAGS="${CGO_CFLAGS}"
ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}"
RUN --mount=type=cache,target=/root/.cache/go-build \
go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .
go build -trimpath -buildmode=pie -o /bin/ollama .
FROM --platform=linux/amd64 scratch AS amd64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
@@ -186,10 +205,9 @@ COPY --from=jetpack-5 dist/lib/ollama/ /lib/ollama/
COPY --from=jetpack-6 dist/lib/ollama/ /lib/ollama/
FROM scratch AS rocm
COPY --from=rocm-6 dist/lib/ollama /lib/ollama
COPY --from=rocm-7 dist/lib/ollama /lib/ollama
FROM ${FLAVOR} AS archive
ARG VULKANVERSION
COPY --from=cpu dist/lib/ollama /lib/ollama
COPY --from=build /bin/ollama /bin/ollama

1
MLX_CORE_VERSION Normal file
View File

@@ -0,0 +1 @@
v0.30.6

View File

@@ -1063,7 +1063,7 @@ func DefaultOptions() Options {
TopP: 0.9,
TypicalP: 1.0,
RepeatLastN: 64,
RepeatPenalty: 1.1,
RepeatPenalty: 1.0,
PresencePenalty: 0.0,
FrequencyPenalty: 0.0,
Seed: -1,

View File

@@ -214,6 +214,7 @@ export default function Settings() {
Agent: false,
Tools: false,
ContextLength: 0,
AutoUpdateEnabled: true,
});
updateSettingsMutation.mutate(defaultSettings);
}

View File

@@ -41,6 +41,7 @@ import (
"github.com/ollama/ollama/cmd/tui"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
@@ -131,6 +132,17 @@ func getModelfileName(cmd *cobra.Command) (string, error) {
return absName, nil
}
// isLocalhost returns true if the configured Ollama host is a loopback or unspecified address.
func isLocalhost() bool {
host := envconfig.Host()
h, _, _ := net.SplitHostPort(host.Host)
if h == "localhost" {
return true
}
ip := net.ParseIP(h)
return ip != nil && (ip.IsLoopback() || ip.IsUnspecified())
}
func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()
@@ -145,6 +157,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
// Check for --experimental flag for safetensors model creation
experimental, _ := cmd.Flags().GetBool("experimental")
if experimental {
if !isLocalhost() {
return errors.New("remote safetensor model creation not yet supported")
}
// Get Modelfile content - either from -f flag or default to "FROM ."
var reader io.Reader
filename, err := getModelfileName(cmd)
@@ -168,29 +183,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to parse Modelfile: %w", err)
}
// Extract FROM path and configuration
var modelDir string
mfConfig := &xcreateclient.ModelfileConfig{}
for _, cmd := range modelfile.Commands {
switch cmd.Name {
case "model":
modelDir = cmd.Args
case "template":
mfConfig.Template = cmd.Args
case "system":
mfConfig.System = cmd.Args
case "license":
mfConfig.License = cmd.Args
case "parser":
mfConfig.Parser = cmd.Args
case "renderer":
mfConfig.Renderer = cmd.Args
}
}
if modelDir == "" {
modelDir = "."
modelDir, mfConfig, err := xcreateclient.ConfigFromModelfile(modelfile)
if err != nil {
return err
}
// Resolve relative paths based on Modelfile location
@@ -214,6 +209,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
if filename == "" {
// No Modelfile found - check if current directory is an image gen model
if create.IsTensorModelDir(".") {
if !isLocalhost() {
return errors.New("remote safetensor model creation not yet supported")
}
quantize, _ := cmd.Flags().GetString("quantize")
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
ModelName: modelName,
@@ -406,12 +404,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
return err
}
requestedCloud := modelref.HasExplicitCloudSource(opts.Model)
if info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model}); err != nil {
return err
} else if info.RemoteHost != "" {
} else if info.RemoteHost != "" || requestedCloud {
// Cloud model, no need to load/unload
isCloud := strings.HasPrefix(info.RemoteHost, "https://ollama.com")
isCloud := requestedCloud || strings.HasPrefix(info.RemoteHost, "https://ollama.com")
// Check if user is signed in for ollama.com cloud models
if isCloud {
@@ -422,10 +422,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
if opts.ShowConnect {
p.StopAndClear()
remoteModel := info.RemoteModel
if remoteModel == "" {
remoteModel = opts.Model
}
if isCloud {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", remoteModel)
} else {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", remoteModel, info.RemoteHost)
}
}
@@ -497,6 +501,20 @@ func generateEmbedding(cmd *cobra.Command, modelName, input string, keepAlive *a
return nil
}
// TODO(parthsareen): consolidate with TUI signin flow
func handleCloudAuthorizationError(err error) bool {
var authErr api.AuthorizationError
if errors.As(err, &authErr) && authErr.StatusCode == http.StatusUnauthorized {
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
if authErr.SigninURL != "" {
fmt.Printf(ConnectInstructions, authErr.SigninURL)
}
return true
}
return false
}
func RunHandler(cmd *cobra.Command, args []string) error {
interactive := true
@@ -585,17 +603,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
opts.WordWrap = !nowrap
useImagegen := false
if cmd.Flags().Lookup("imagegen") != nil {
useImagegen, err = cmd.Flags().GetBool("imagegen")
if err != nil {
return err
}
}
if useImagegen {
opts.Options["use_imagegen_runner"] = true
}
// Fill out the rest of the options based on information about the
// model.
client, err := api.ClientFromEnvironment()
@@ -604,12 +611,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
name := args[0]
requestedCloud := modelref.HasExplicitCloudSource(name)
info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: name}
info, err := client.Show(cmd.Context(), showReq)
var se api.StatusError
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
if requestedCloud {
return nil, err
}
if err := PullHandler(cmd, []string{name}); err != nil {
return nil, err
}
@@ -618,6 +629,9 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return info, err
}()
if err != nil {
if handleCloudAuthorizationError(err) {
return nil
}
return err
}
@@ -712,7 +726,13 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generateInteractive(cmd, opts)
}
return generate(cmd, opts)
if err := generate(cmd, opts); err != nil {
if handleCloudAuthorizationError(err) {
return nil
}
return err
}
return nil
}
func SigninHandler(cmd *cobra.Command, args []string) error {

View File

@@ -18,6 +18,7 @@ import (
"github.com/spf13/cobra"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/types/model"
)
@@ -705,6 +706,139 @@ func TestRunEmbeddingModelNoInput(t *testing.T) {
}
}
func TestRunHandler_CloudAuthErrorOnShow_PrintsSigninMessage(t *testing.T) {
var generateCalled bool
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusUnauthorized)
if err := json.NewEncoder(w).Encode(map[string]string{
"error": "unauthorized",
"signin_url": "https://ollama.com/signin",
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
generateCalled = true
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
default:
http.NotFound(w, r)
}
}))
t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close)
cmd := &cobra.Command{}
cmd.SetContext(t.Context())
cmd.Flags().String("keepalive", "", "")
cmd.Flags().Bool("truncate", false, "")
cmd.Flags().Int("dimensions", 0, "")
cmd.Flags().Bool("verbose", false, "")
cmd.Flags().Bool("insecure", false, "")
cmd.Flags().Bool("nowordwrap", false, "")
cmd.Flags().String("format", "", "")
cmd.Flags().String("think", "", "")
cmd.Flags().Bool("hidethinking", false, "")
oldStdout := os.Stdout
readOut, writeOut, _ := os.Pipe()
os.Stdout = writeOut
t.Cleanup(func() { os.Stdout = oldStdout })
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
_ = writeOut.Close()
var out bytes.Buffer
_, _ = io.Copy(&out, readOut)
if err != nil {
t.Fatalf("RunHandler returned error: %v", err)
}
if generateCalled {
t.Fatal("expected run to stop before /api/generate after unauthorized /api/show")
}
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
t.Fatalf("expected sign-in guidance message, got %q", out.String())
}
if !strings.Contains(out.String(), "https://ollama.com/signin") {
t.Fatalf("expected signin_url in output, got %q", out.String())
}
}
func TestRunHandler_CloudAuthErrorOnGenerate_PrintsSigninMessage(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ShowResponse{
Capabilities: []model.Capability{model.CapabilityCompletion},
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusUnauthorized)
if err := json.NewEncoder(w).Encode(map[string]string{
"error": "unauthorized",
"signin_url": "https://ollama.com/signin",
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
default:
http.NotFound(w, r)
}
}))
t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close)
cmd := &cobra.Command{}
cmd.SetContext(t.Context())
cmd.Flags().String("keepalive", "", "")
cmd.Flags().Bool("truncate", false, "")
cmd.Flags().Int("dimensions", 0, "")
cmd.Flags().Bool("verbose", false, "")
cmd.Flags().Bool("insecure", false, "")
cmd.Flags().Bool("nowordwrap", false, "")
cmd.Flags().String("format", "", "")
cmd.Flags().String("think", "", "")
cmd.Flags().Bool("hidethinking", false, "")
oldStdout := os.Stdout
readOut, writeOut, _ := os.Pipe()
os.Stdout = writeOut
t.Cleanup(func() { os.Stdout = oldStdout })
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
_ = writeOut.Close()
var out bytes.Buffer
_, _ = io.Copy(&out, readOut)
if err != nil {
t.Fatalf("RunHandler returned error: %v", err)
}
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
t.Fatalf("expected sign-in guidance message, got %q", out.String())
}
if !strings.Contains(out.String(), "https://ollama.com/signin") {
t.Fatalf("expected signin_url in output, got %q", out.String())
}
}
func TestGetModelfileName(t *testing.T) {
tests := []struct {
name string
@@ -1664,20 +1798,26 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
tests := []struct {
name string
model string
remoteHost string
remoteModel string
whoamiStatus int
whoamiResp any
expectedError string
}{
{
name: "ollama.com cloud model - user signed in",
model: "test-cloud-model",
remoteHost: "https://ollama.com",
remoteModel: "test-model",
whoamiStatus: http.StatusOK,
whoamiResp: api.UserResponse{Name: "testuser"},
},
{
name: "ollama.com cloud model - user not signed in",
model: "test-cloud-model",
remoteHost: "https://ollama.com",
remoteModel: "test-model",
whoamiStatus: http.StatusUnauthorized,
whoamiResp: map[string]string{
"error": "unauthorized",
@@ -1687,7 +1827,33 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
},
{
name: "non-ollama.com remote - no auth check",
model: "test-cloud-model",
remoteHost: "https://other-remote.com",
remoteModel: "test-model",
whoamiStatus: http.StatusUnauthorized, // should not be called
whoamiResp: nil,
},
{
name: "explicit :cloud model - auth check without remote metadata",
model: "kimi-k2.5:cloud",
remoteHost: "",
remoteModel: "",
whoamiStatus: http.StatusOK,
whoamiResp: api.UserResponse{Name: "testuser"},
},
{
name: "explicit -cloud model - auth check without remote metadata",
model: "kimi-k2.5:latest-cloud",
remoteHost: "",
remoteModel: "",
whoamiStatus: http.StatusOK,
whoamiResp: api.UserResponse{Name: "testuser"},
},
{
name: "dash cloud-like name without explicit source does not require auth",
model: "test-cloud-model",
remoteHost: "",
remoteModel: "",
whoamiStatus: http.StatusUnauthorized, // should not be called
whoamiResp: nil,
},
@@ -1702,7 +1868,7 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(api.ShowResponse{
RemoteHost: tt.remoteHost,
RemoteModel: "test-model",
RemoteModel: tt.remoteModel,
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
@@ -1715,6 +1881,8 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
case "/api/generate":
w.WriteHeader(http.StatusOK)
default:
http.NotFound(w, r)
}
@@ -1727,13 +1895,13 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
cmd.SetContext(t.Context())
opts := &runOptions{
Model: "test-cloud-model",
Model: tt.model,
ShowConnect: false,
}
err := loadOrUnloadModel(cmd, opts)
if strings.HasPrefix(tt.remoteHost, "https://ollama.com") {
if strings.HasPrefix(tt.remoteHost, "https://ollama.com") || modelref.HasExplicitCloudSource(tt.model) {
if !whoamiCalled {
t.Error("expected whoami to be called for ollama.com cloud model")
}
@@ -1760,3 +1928,38 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
})
}
}
func TestIsLocalhost(t *testing.T) {
tests := []struct {
name string
host string
expected bool
}{
{"default empty", "", true},
{"localhost no port", "localhost", true},
{"localhost with port", "localhost:11435", true},
{"127.0.0.1 no port", "127.0.0.1", true},
{"127.0.0.1 with port", "127.0.0.1:11434", true},
{"0.0.0.0 no port", "0.0.0.0", true},
{"0.0.0.0 with port", "0.0.0.0:11434", true},
{"::1 no port", "::1", true},
{"[::1] with port", "[::1]:11434", true},
{"loopback with scheme", "http://localhost:11434", true},
{"remote hostname", "example.com", false},
{"remote hostname with port", "example.com:11434", false},
{"remote IP", "192.168.1.1", false},
{"remote IP with port", "192.168.1.1:11434", false},
{"remote with scheme", "http://example.com:11434", false},
{"https remote", "https://example.com:443", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", tt.host)
got := isLocalhost()
if got != tt.expected {
t.Errorf("isLocalhost() with OLLAMA_HOST=%q = %v, want %v", tt.host, got, tt.expected)
}
})
}
}

View File

@@ -107,15 +107,12 @@ func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAli
}
if !force && aliases["primary"] != "" {
client, _ := api.ClientFromEnvironment()
if isCloudModel(ctx, client, aliases["primary"]) {
if isCloudModel(ctx, client, aliases["fast"]) {
return aliases, false, nil
}
} else {
delete(aliases, "fast")
if isCloudModelName(aliases["primary"]) {
aliases["fast"] = aliases["primary"]
return aliases, false, nil
}
delete(aliases, "fast")
return aliases, false, nil
}
items, existingModels, cloudModels, client, err := listModels(ctx)
@@ -139,10 +136,8 @@ func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAli
aliases["primary"] = primary
}
if isCloudModel(ctx, client, aliases["primary"]) {
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
aliases["fast"] = aliases["primary"]
}
if isCloudModelName(aliases["primary"]) {
aliases["fast"] = aliases["primary"]
} else {
delete(aliases, "fast")
}

View File

@@ -233,6 +233,9 @@ func ModelExists(ctx context.Context, name string) bool {
if name == "" {
return false
}
if isCloudModelName(name) {
return true
}
client, err := api.ClientFromEnvironment()
if err != nil {
return false

View File

@@ -10,7 +10,6 @@ import (
"path/filepath"
"slices"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
@@ -125,13 +124,12 @@ func (d *Droid) Edit(models []string) error {
}
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
client, _ := api.ClientFromEnvironment()
var newModels []any
var defaultModelID string
for i, model := range models {
maxOutput := 64000
if isCloudModel(context.Background(), client, model) {
if isCloudModelName(model) {
if l, ok := lookupCloudModelLimit(model); ok {
maxOutput = l.Output
}

View File

@@ -1276,25 +1276,17 @@ func TestDroidEdit_LocalModelDefaultMaxOutput(t *testing.T) {
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
// Verify that every cloud model in cloudModelLimits has a valid output
// value that would be used for maxOutputTokens when isCloudModel returns true.
// :cloud suffix stripping must also work since that's how users specify them.
// value that would be used for maxOutputTokens when the selected model uses
// the explicit :cloud source tag.
for name, expected := range cloudModelLimits {
t.Run(name, func(t *testing.T) {
l, ok := lookupCloudModelLimit(name)
if !ok {
t.Fatalf("lookupCloudModelLimit(%q) returned false", name)
}
if l.Output != expected.Output {
t.Errorf("output = %d, want %d", l.Output, expected.Output)
}
// Also verify :cloud suffix lookup
cloudName := name + ":cloud"
l2, ok := lookupCloudModelLimit(cloudName)
l, ok := lookupCloudModelLimit(cloudName)
if !ok {
t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName)
}
if l2.Output != expected.Output {
t.Errorf(":cloud output = %d, want %d", l2.Output, expected.Output)
if l.Output != expected.Output {
t.Errorf("output = %d, want %d", l.Output, expected.Output)
}
})
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/ollama/ollama/api"
internalcloud "github.com/ollama/ollama/internal/cloud"
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/progress"
"github.com/spf13/cobra"
)
@@ -81,6 +82,7 @@ var cloudModelLimits = map[string]cloudModelLimit{
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
"glm-4.6": {Context: 202_752, Output: 131_072},
"glm-4.7": {Context: 202_752, Output: 131_072},
"glm-5": {Context: 202_752, Output: 131_072},
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
@@ -90,6 +92,7 @@ var cloudModelLimits = map[string]cloudModelLimit{
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
"qwen3.5": {Context: 262_144, Output: 32_768},
}
// recommendedVRAM maps local recommended models to their approximate VRAM requirement.
@@ -324,12 +327,7 @@ func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (stri
// If the selected model isn't installed, pull it first
if !existingModels[selected] {
if cloudModels[selected] {
// Cloud models only pull a small manifest; no confirmation needed
if err := pullModel(ctx, client, selected); err != nil {
return "", fmt.Errorf("failed to pull %s: %w", selected, err)
}
} else {
if !isCloudModelName(selected) {
msg := fmt.Sprintf("Download %s?", selected)
if ok, err := confirmPrompt(msg); err != nil {
return "", err
@@ -524,7 +522,7 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
var toPull []string
for _, m := range selected {
if !existingModels[m] {
if !existingModels[m] && !isCloudModelName(m) {
toPull = append(toPull, m)
}
}
@@ -550,12 +548,28 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
return selected, nil
}
// TODO(parthsareen): consolidate pull logic from call sites
func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error {
if existingModels[model] {
if isCloudModelName(model) || existingModels[model] {
return nil
}
msg := fmt.Sprintf("Download %s?", model)
if ok, err := confirmPrompt(msg); err != nil {
return confirmAndPull(ctx, client, model)
}
// TODO(parthsareen): pull this out to tui package
// ShowOrPull checks if a model exists via client.Show and offers to pull it if not found.
func ShowOrPull(ctx context.Context, client *api.Client, model string) error {
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
return nil
}
if isCloudModelName(model) {
return nil
}
return confirmAndPull(ctx, client, model)
}
func confirmAndPull(ctx context.Context, client *api.Client, model string) error {
if ok, err := confirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
return err
} else if !ok {
return errCancelled
@@ -567,26 +581,6 @@ func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[st
return nil
}
// TODO(parthsareen): pull this out to tui package
// ShowOrPull checks if a model exists via client.Show and offers to pull it if not found.
func ShowOrPull(ctx context.Context, client *api.Client, model string) error {
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
return nil
}
// Cloud models only pull a small manifest; skip the download confirmation
// TODO(parthsareen): consolidate with cloud config changes
if strings.HasSuffix(model, "cloud") {
return pullModel(ctx, client, model)
}
if ok, err := confirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
return err
} else if !ok {
return errCancelled
}
fmt.Fprintf(os.Stderr, "\n")
return pullModel(ctx, client, model)
}
func listModels(ctx context.Context) ([]ModelItem, map[string]bool, map[string]bool, *api.Client, error) {
client, err := api.ClientFromEnvironment()
if err != nil {
@@ -731,10 +725,8 @@ func syncAliases(ctx context.Context, client *api.Client, ac AliasConfigurer, na
}
aliases["primary"] = model
if isCloudModel(ctx, client, model) {
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
aliases["fast"] = model
}
if isCloudModelName(model) {
aliases["fast"] = model
} else {
delete(aliases, "fast")
}
@@ -1020,7 +1012,7 @@ Examples:
existingAliases = aliases
// Ensure cloud models are authenticated
if isCloudModel(cmd.Context(), client, model) {
if isCloudModelName(model) {
if err := ensureAuth(cmd.Context(), client, map[string]bool{model: true}, []string{model}); err != nil {
return err
}
@@ -1209,7 +1201,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
// When user has no models, preserve recommended order.
notInstalled := make(map[string]bool)
for i := range items {
if !existingModels[items[i].Name] {
if !existingModels[items[i].Name] && !cloudModels[items[i].Name] {
notInstalled[items[i].Name] = true
var parts []string
if items[i].Description != "" {
@@ -1303,7 +1295,8 @@ func IsCloudModelDisabled(ctx context.Context, name string) bool {
}
func isCloudModelName(name string) bool {
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
// TODO(drifkin): Replace this wrapper with inlining once things stabilize a bit
return modelref.HasExplicitCloudSource(name)
}
func filterCloudModels(existing []modelInfo) []modelInfo {

View File

@@ -426,8 +426,14 @@ func TestBuildModelList_NoExistingModels(t *testing.T) {
}
for _, item := range items {
if !strings.HasSuffix(item.Description, "(not downloaded)") {
t.Errorf("item %q should have description ending with '(not downloaded)', got %q", item.Name, item.Description)
if strings.HasSuffix(item.Name, ":cloud") {
if strings.HasSuffix(item.Description, "(not downloaded)") {
t.Errorf("cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
}
} else {
if !strings.HasSuffix(item.Description, "(not downloaded)") {
t.Errorf("item %q should have description ending with '(not downloaded)', got %q", item.Name, item.Description)
}
}
}
}
@@ -492,10 +498,14 @@ func TestBuildModelList_ExistingRecommendedMarked(t *testing.T) {
if strings.HasSuffix(item.Description, "(not downloaded)") {
t.Errorf("installed recommended %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
}
case "minimax-m2.5:cloud", "kimi-k2.5:cloud", "qwen3:8b":
case "qwen3:8b":
if !strings.HasSuffix(item.Description, "(not downloaded)") {
t.Errorf("non-installed recommended %q should have '(not downloaded)' suffix, got %q", item.Name, item.Description)
}
case "minimax-m2.5:cloud", "kimi-k2.5:cloud":
if strings.HasSuffix(item.Description, "(not downloaded)") {
t.Errorf("cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
}
}
}
}
@@ -536,7 +546,13 @@ func TestBuildModelList_HasRecommendedCloudModel_OnlyNonInstalledAtBottom(t *tes
}
for _, item := range items {
if !slices.Contains([]string{"kimi-k2.5:cloud", "llama3.2"}, item.Name) {
isCloud := strings.HasSuffix(item.Name, ":cloud")
isInstalled := slices.Contains([]string{"kimi-k2.5:cloud", "llama3.2"}, item.Name)
if isInstalled || isCloud {
if strings.HasSuffix(item.Description, "(not downloaded)") {
t.Errorf("installed or cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
}
} else {
if !strings.HasSuffix(item.Description, "(not downloaded)") {
t.Errorf("non-installed %q should have '(not downloaded)' suffix, got %q", item.Name, item.Description)
}
@@ -1000,8 +1016,8 @@ func TestShowOrPull_ModelNotFound_ConfirmNo_Cancelled(t *testing.T) {
}
}
func TestShowOrPull_CloudModel_SkipsConfirmation(t *testing.T) {
// Confirm prompt should NOT be called for cloud models
func TestShowOrPull_CloudModel_DoesNotPull(t *testing.T) {
// Confirm prompt should NOT be called for explicit cloud models
oldHook := DefaultConfirmPrompt
DefaultConfirmPrompt = func(prompt string) (bool, error) {
t.Error("confirm prompt should not be called for cloud models")
@@ -1032,8 +1048,115 @@ func TestShowOrPull_CloudModel_SkipsConfirmation(t *testing.T) {
if err != nil {
t.Errorf("ShowOrPull should succeed for cloud model, got: %v", err)
}
if !pullCalled {
t.Error("expected pull to be called for cloud model without confirmation")
if pullCalled {
t.Error("expected pull not to be called for cloud model")
}
}
func TestShowOrPull_CloudLegacySuffix_DoesNotPull(t *testing.T) {
// Confirm prompt should NOT be called for explicit cloud models
oldHook := DefaultConfirmPrompt
DefaultConfirmPrompt = func(prompt string) (bool, error) {
t.Error("confirm prompt should not be called for cloud models")
return false, nil
}
defer func() { DefaultConfirmPrompt = oldHook }()
var pullCalled bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"model not found"}`)
case "/api/pull":
pullCalled = true
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"status":"success"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
err := ShowOrPull(context.Background(), client, "gpt-oss:20b-cloud")
if err != nil {
t.Errorf("ShowOrPull should succeed for cloud model, got: %v", err)
}
if pullCalled {
t.Error("expected pull not to be called for cloud model")
}
}
func TestPullIfNeeded_CloudModel_DoesNotPull(t *testing.T) {
oldHook := DefaultConfirmPrompt
DefaultConfirmPrompt = func(prompt string) (bool, error) {
t.Error("confirm prompt should not be called for cloud models")
return false, nil
}
defer func() { DefaultConfirmPrompt = oldHook }()
err := pullIfNeeded(context.Background(), nil, map[string]bool{}, "glm-5:cloud")
if err != nil {
t.Fatalf("expected no error for cloud model, got %v", err)
}
err = pullIfNeeded(context.Background(), nil, map[string]bool{}, "gpt-oss:20b-cloud")
if err != nil {
t.Fatalf("expected no error for cloud model with legacy suffix, got %v", err)
}
}
func TestSelectModelsWithSelectors_CloudSelection_DoesNotPull(t *testing.T) {
var pullCalled bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/status":
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"not found"}`)
case "/api/tags":
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"models":[]}`)
case "/api/me":
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"name":"test-user"}`)
case "/api/pull":
pullCalled = true
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"status":"success"}`)
default:
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"not found"}`)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
single := func(title string, items []ModelItem, current string) (string, error) {
for _, item := range items {
if item.Name == "glm-5:cloud" {
return item.Name, nil
}
}
t.Fatalf("expected glm-5:cloud in selector items, got %v", items)
return "", nil
}
multi := func(title string, items []ModelItem, preChecked []string) ([]string, error) {
return nil, fmt.Errorf("multi selector should not be called")
}
selected, err := selectModelsWithSelectors(context.Background(), "codex", "", single, multi)
if err != nil {
t.Fatalf("selectModelsWithSelectors returned error: %v", err)
}
if !slices.Equal(selected, []string{"glm-5:cloud"}) {
t.Fatalf("unexpected selected models: %v", selected)
}
if pullCalled {
t.Fatal("expected cloud selection to skip pull")
}
}

View File

@@ -502,7 +502,7 @@ func (c *Openclaw) Edit(models []string) error {
ollama = make(map[string]any)
}
ollama["baseUrl"] = envconfig.Host().String() + "/v1"
ollama["baseUrl"] = envconfig.Host().String()
// needed to register provider
ollama["apiKey"] = "ollama-local"
ollama["api"] = "ollama"

View File

@@ -589,7 +589,7 @@ const testOpenclawFixture = `{
"providers": {
"anthropic": {"apiKey": "xxx"},
"ollama": {
"baseUrl": "http://127.0.0.1:11434/v1",
"baseUrl": "http://127.0.0.1:11434",
"models": [{"id": "old-model", "customField": "preserved"}]
}
}

View File

@@ -12,8 +12,8 @@ import (
"slices"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/internal/modelref"
)
// OpenCode implements Runner and Editor for OpenCode integration
@@ -26,13 +26,10 @@ type cloudModelLimit struct {
}
// lookupCloudModelLimit returns the token limits for a cloud model.
// It tries the exact name first, then strips the ":cloud" suffix.
// It normalizes explicit cloud source suffixes before checking the shared limit map.
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
if l, ok := cloudModelLimits[name]; ok {
return l, true
}
base := strings.TrimSuffix(name, ":cloud")
if base != name {
base, stripped := modelref.StripCloudSourceTag(name)
if stripped {
if l, ok := cloudModelLimits[base]; ok {
return l, true
}
@@ -122,13 +119,18 @@ func (o *OpenCode) Edit(modelList []string) error {
if !ok {
ollama = map[string]any{
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama (local)",
"name": "Ollama",
"options": map[string]any{
"baseURL": envconfig.Host().String() + "/v1",
},
}
}
// Migrate legacy provider name
if name, _ := ollama["name"].(string); name == "Ollama (local)" {
ollama["name"] = "Ollama"
}
models, ok := ollama["models"].(map[string]any)
if !ok {
models = make(map[string]any)
@@ -147,8 +149,6 @@ func (o *OpenCode) Edit(modelList []string) error {
}
}
client, _ := api.ClientFromEnvironment()
for _, model := range modelList {
if existing, ok := models[model].(map[string]any); ok {
// migrate existing models without _launch marker
@@ -158,7 +158,7 @@ func (o *OpenCode) Edit(modelList []string) error {
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
}
}
if isCloudModel(context.Background(), client, model) {
if isCloudModelName(model) {
if l, ok := lookupCloudModelLimit(model); ok {
existing["limit"] = map[string]any{
"context": l.Context,
@@ -172,7 +172,7 @@ func (o *OpenCode) Edit(modelList []string) error {
"name": model,
"_launch": true,
}
if isCloudModel(context.Background(), client, model) {
if isCloudModelName(model) {
if l, ok := lookupCloudModelLimit(model); ok {
entry["limit"] = map[string]any{
"context": l.Context,

View File

@@ -3,6 +3,8 @@ package config
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
@@ -232,6 +234,44 @@ func TestOpenCodeEdit(t *testing.T) {
}
})
t.Run("migrate Ollama (local) provider name", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"name":"Ollama (local)","npm":"@ai-sdk/openai-compatible","options":{"baseURL":"http://localhost:11434/v1"}}}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider := cfg["provider"].(map[string]any)
ollama := provider["ollama"].(map[string]any)
if ollama["name"] != "Ollama" {
t.Errorf("provider name not migrated: got %q, want %q", ollama["name"], "Ollama")
}
})
t.Run("preserve custom provider name", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"name":"My Custom Ollama","npm":"@ai-sdk/openai-compatible","options":{"baseURL":"http://localhost:11434/v1"}}}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider := cfg["provider"].(map[string]any)
ollama := provider["ollama"].(map[string]any)
if ollama["name"] != "My Custom Ollama" {
t.Errorf("custom provider name was changed: got %q, want %q", ollama["name"], "My Custom Ollama")
}
})
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
@@ -619,6 +659,54 @@ func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
}
}
func TestOpenCodeEdit_BackfillsCloudModelLimitOnExistingEntry(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"glm-5"}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{
"provider": {
"ollama": {
"models": {
"glm-5:cloud": {
"name": "glm-5:cloud",
"_launch": true
}
}
}
}
}`), 0o644)
if err := o.Edit([]string{"glm-5:cloud"}); err != nil {
t.Fatal(err)
}
entry := readOpenCodeModel(t, configPath, "glm-5:cloud")
limit, ok := entry["limit"].(map[string]any)
if !ok {
t.Fatal("cloud model limit was not added on re-edit")
}
if limit["context"] != float64(202_752) {
t.Errorf("context = %v, want 202752", limit["context"])
}
if limit["output"] != float64(131_072) {
t.Errorf("output = %v, want 131072", limit["output"])
}
}
func TestLookupCloudModelLimit(t *testing.T) {
tests := []struct {
name string
@@ -626,13 +714,17 @@ func TestLookupCloudModelLimit(t *testing.T) {
wantContext int
wantOutput int
}{
{"glm-4.7", true, 202_752, 131_072},
{"glm-4.7", false, 0, 0},
{"glm-4.7:cloud", true, 202_752, 131_072},
{"kimi-k2.5", true, 262_144, 262_144},
{"glm-5:cloud", true, 202_752, 131_072},
{"gpt-oss:120b-cloud", true, 131_072, 131_072},
{"gpt-oss:20b-cloud", true, 131_072, 131_072},
{"kimi-k2.5", false, 0, 0},
{"kimi-k2.5:cloud", true, 262_144, 262_144},
{"deepseek-v3.2", true, 163_840, 65_536},
{"deepseek-v3.2", false, 0, 0},
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
{"qwen3-coder:480b", true, 262_144, 65_536},
{"qwen3-coder:480b", false, 0, 0},
{"qwen3-coder:480b:cloud", true, 262_144, 65_536},
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
{"llama3.2", false, 0, 0},
{"unknown-model:cloud", false, 0, 0},

View File

@@ -107,7 +107,8 @@ func (p *Pi) Edit(models []string) error {
// Build new models list:
// 1. Keep user-managed models (no _launch marker) - untouched
// 2. Keep ollama-managed models (_launch marker) that are still selected
// 2. Keep ollama-managed models (_launch marker) that are still selected,
// except stale cloud entries that should be rebuilt below
// 3. Add new ollama-managed models
var newModels []any
for _, m := range existingModels {
@@ -117,7 +118,13 @@ func (p *Pi) Edit(models []string) error {
if !isPiOllamaModel(modelObj) {
newModels = append(newModels, m)
} else if selectedSet[id] {
// Ollama-managed and still selected - keep it
// Rebuild stale managed cloud entries so createConfig refreshes
// the whole entry instead of patching it in place.
if !hasContextWindow(modelObj) {
if _, ok := lookupCloudModelLimit(id); ok {
continue
}
}
newModels = append(newModels, m)
selectedSet[id] = false
}
@@ -199,12 +206,28 @@ func isPiOllamaModel(cfg map[string]any) bool {
return false
}
func hasContextWindow(cfg map[string]any) bool {
switch v := cfg["contextWindow"].(type) {
case float64:
return v > 0
case int:
return v > 0
case int64:
return v > 0
default:
return false
}
}
// createConfig builds Pi model config with capability detection
func createConfig(ctx context.Context, client *api.Client, modelID string) map[string]any {
cfg := map[string]any{
"id": modelID,
"_launch": true,
}
if l, ok := lookupCloudModelLimit(modelID); ok {
cfg["contextWindow"] = l.Context
}
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID})
if err != nil {
@@ -223,7 +246,8 @@ func createConfig(ctx context.Context, client *api.Client, modelID string) map[s
cfg["reasoning"] = true
}
// Extract context window from ModelInfo
// Extract context window from ModelInfo. For known cloud models, the
// pre-filled shared limit remains unless the server provides a positive value.
for key, val := range resp.ModelInfo {
if strings.HasSuffix(key, ".context_length") {
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {

View File

@@ -192,6 +192,48 @@ func TestPiEdit(t *testing.T) {
}
})
t.Run("rebuilds stale existing managed cloud model", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://localhost:11434/v1",
"api": "openai-completions",
"apiKey": "ollama",
"models": [
{"id": "glm-5:cloud", "_launch": true, "legacyField": "stale"}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
if err := pi.Edit([]string{"glm-5:cloud"}); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
modelEntry := modelsArray[0].(map[string]any)
if modelEntry["contextWindow"] != float64(202_752) {
t.Errorf("contextWindow = %v, want 202752", modelEntry["contextWindow"])
}
input, ok := modelEntry["input"].([]any)
if !ok || len(input) != 1 || input[0] != "text" {
t.Errorf("input = %v, want [text]", modelEntry["input"])
}
if _, ok := modelEntry["legacyField"]; ok {
t.Error("legacyField should be removed when stale managed cloud entry is rebuilt")
}
})
t.Run("replaces old models with new ones", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
@@ -798,6 +840,60 @@ func TestCreateConfig(t *testing.T) {
}
})
t.Run("falls back to cloud context when show fails", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"model not found"}`)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg := createConfig(context.Background(), client, "kimi-k2.5:cloud")
if cfg["contextWindow"] != 262_144 {
t.Errorf("contextWindow = %v, want 262144", cfg["contextWindow"])
}
})
t.Run("falls back to cloud context when model info is empty", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":[],"model_info":{}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg := createConfig(context.Background(), client, "glm-5:cloud")
if cfg["contextWindow"] != 202_752 {
t.Errorf("contextWindow = %v, want 202752", cfg["contextWindow"])
}
})
t.Run("falls back to cloud context for dash cloud suffix", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"model not found"}`)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg := createConfig(context.Background(), client, "gpt-oss:120b-cloud")
if cfg["contextWindow"] != 131_072 {
t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
}
})
t.Run("skips zero context length", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {

View File

@@ -11,6 +11,7 @@ import (
"github.com/charmbracelet/lipgloss"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/version"
)
@@ -147,7 +148,13 @@ type signInCheckMsg struct {
type clearStatusMsg struct{}
func (m *model) modelExists(name string) bool {
if m.availableModels == nil || name == "" {
if name == "" {
return false
}
if modelref.HasExplicitCloudSource(name) {
return true
}
if m.availableModels == nil {
return false
}
if m.availableModels[name] {
@@ -209,7 +216,7 @@ func (m *model) openMultiModelModal(integration string) {
}
func isCloudModel(name string) bool {
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
return modelref.HasExplicitCloudSource(name)
}
func cloudStatusDisabled(client *api.Client) bool {

View File

@@ -54,6 +54,7 @@ type nemotronHModel struct {
NGroups uint32 `json:"n_groups"`
IntermediateSize uint32 `json:"intermediate_size"`
HybridOverridePattern hybridPattern `json:"hybrid_override_pattern"`
LayersBlockType []string `json:"layers_block_type"`
// MoE
NumExperts uint32 `json:"num_experts"`
@@ -162,8 +163,27 @@ func (n *nemotronHModel) denseIntermediateSize() uint32 {
func (n *nemotronHModel) layerArrays() (headCountKV []uint32, ffnLengths []uint32, err error) {
pattern := strings.TrimSpace(string(n.HybridOverridePattern))
// Convert layers_block_type array to pattern string if hybrid_override_pattern is not set
if pattern == "" && len(n.LayersBlockType) > 0 {
var sb strings.Builder
for _, blockType := range n.LayersBlockType {
switch strings.ToLower(blockType) {
case "mamba":
sb.WriteRune('M')
case "moe":
sb.WriteRune('E')
case "attention":
sb.WriteRune('A')
default:
return nil, nil, fmt.Errorf("nemotron_h: unsupported block type %q in layers_block_type", blockType)
}
}
pattern = sb.String()
}
if pattern == "" {
return nil, nil, fmt.Errorf("nemotron_h: hybrid_override_pattern must be set")
return nil, nil, fmt.Errorf("nemotron_h: hybrid_override_pattern or layers_block_type must be set")
}
runes := []rune(pattern)

View File

@@ -12,7 +12,6 @@ To use Ollama with tools that expect the Anthropic API (like Claude Code), set t
```shell
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
export ANTHROPIC_API_KEY="" # required but ignored
export ANTHROPIC_BASE_URL=http://localhost:11434
```
@@ -269,7 +268,7 @@ ollama launch claude --config
Set the environment variables and run Claude Code:
```shell
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 claude --model qwen3-coder
```
Or set the environment variables in your shell profile:
@@ -277,7 +276,6 @@ Or set the environment variables in your shell profile:
```shell
export ANTHROPIC_AUTH_TOKEN=ollama
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=""
```
Then run Claude Code with any Ollama model:

View File

@@ -6,7 +6,7 @@ Ollama provides compatibility with parts of the [OpenAI API](https://platform.op
## Usage
### Simple `v1/chat/completions` example
### Simple `/v1/chat/completions` example
<CodeGroup dropdown>
@@ -57,7 +57,7 @@ curl -X POST http://localhost:11434/v1/chat/completions \
</CodeGroup>
### Simple `v1/responses` example
### Simple `/v1/responses` example
<CodeGroup dropdown>
@@ -103,7 +103,7 @@ curl -X POST http://localhost:11434/v1/responses \
</CodeGroup>
### v1/chat/completions with vision example
### `/v1/chat/completions` with vision example
<CodeGroup dropdown>

View File

@@ -51,6 +51,9 @@ Install prerequisites:
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
- (Optional) VULKAN GPU support
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
- (Optional) MLX engine support
- [CUDA 13+ SDK](https://developer.nvidia.com/cuda-downloads)
- [cuDNN 9+](https://developer.nvidia.com/cudnn)
Then, configure and build the project:
@@ -101,6 +104,10 @@ Install prerequisites:
- (Optional) VULKAN GPU support
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
- Or install via package manager: `sudo apt install vulkan-sdk` (Ubuntu/Debian) or `sudo dnf install vulkan-sdk` (Fedora/CentOS)
- (Optional) MLX engine support
- [CUDA 13+ SDK](https://developer.nvidia.com/cuda-downloads)
- [cuDNN 9+](https://developer.nvidia.com/cudnn)
- OpenBLAS/LAPACK: `sudo apt install libopenblas-dev liblapack-dev liblapacke-dev` (Ubuntu/Debian)
> [!IMPORTANT]
> Ensure prerequisites are in `PATH` before running CMake.
@@ -118,6 +125,67 @@ Lastly, run Ollama:
go run . serve
```
## MLX Engine (Optional)
The MLX engine enables running safetensor based models. It requires building the [MLX](https://github.com/ml-explore/mlx) and [MLX-C](https://github.com/ml-explore/mlx-c) shared libraries separately via CMake. On MacOS, MLX leverages the Metal library to run on the GPU, and on Windows and Linux, runs on NVIDIA GPUs via CUDA v13.
### macOS (Apple Silicon)
Requires the Metal toolchain. Install [Xcode](https://developer.apple.com/xcode/) first, then:
```shell
xcodebuild -downloadComponent MetalToolchain
```
Verify it's installed correctly (should print "no input files"):
```shell
xcrun metal
```
Then build:
```shell
cmake -B build --preset MLX
cmake --build build --preset MLX --parallel
cmake --install build --component MLX
```
> [!NOTE]
> Without the Metal toolchain, cmake will silently complete with Metal disabled. Check the cmake output for `Setting MLX_BUILD_METAL=OFF` which indicates the toolchain is missing.
### Windows / Linux (CUDA)
Requires CUDA 13+ and [cuDNN](https://developer.nvidia.com/cudnn) 9+.
```shell
cmake -B build --preset "MLX CUDA 13"
cmake --build build --target mlx --target mlxc --config Release --parallel
cmake --install build --component MLX --strip
```
### Local MLX source overrides
To build against a local checkout of MLX and/or MLX-C (useful for development), set environment variables before running CMake:
```shell
export OLLAMA_MLX_SOURCE=/path/to/mlx
export OLLAMA_MLX_C_SOURCE=/path/to/mlx-c
```
For example, using the helper scripts with local mlx and mlx-c repos:
```shell
OLLAMA_MLX_SOURCE=../mlx OLLAMA_MLX_C_SOURCE=../mlx-c ./scripts/build_linux.sh
OLLAMA_MLX_SOURCE=../mlx OLLAMA_MLX_C_SOURCE=../mlx-c ./scripts/build_darwin.sh
```
```powershell
$env:OLLAMA_MLX_SOURCE="../mlx"
$env:OLLAMA_MLX_C_SOURCE="../mlx-c"
./scripts/build_darwin.ps1
```
## Docker
```shell

View File

@@ -61,11 +61,13 @@ Ollama supports the following AMD GPUs via the ROCm library:
### Linux Support
| Family | Cards and accelerators |
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `SSG` |
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` |
| Family | Cards and accelerators |
| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| AMD Radeon RX | `9070 XT` `9070 GRE` `9070` `9060 XT` `9060 XT LP` `9060` `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7700` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `5700 XT` `5700` `5600 XT` `5500 XT` |
| AMD Radeon AI PRO | `R9700` `R9600D` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` |
| AMD Ryzen AI | `Ryzen AI Max+ 395` `Ryzen AI Max 390` `Ryzen AI Max 385` `Ryzen AI 9 HX 475` `Ryzen AI 9 HX 470` `Ryzen AI 9 465` `Ryzen AI 9 HX 375` `Ryzen AI 9 HX 370` `Ryzen AI 9 365` |
| AMD Instinct | `MI350X` `MI300X` `MI300A` `MI250X` `MI250` `MI210` `MI100` |
### Windows Support
@@ -97,17 +99,20 @@ This table shows some example GPUs that map to these LLVM targets:
| **LLVM Target** | **An Example GPU** |
|-----------------|---------------------|
| gfx908 | Radeon Instinct MI100 |
| gfx90a | Radeon Instinct MI210 |
| gfx940 | Radeon Instinct MI300 |
| gfx941 | |
| gfx942 | |
| gfx90a | Radeon Instinct MI210/MI250 |
| gfx942 | Radeon Instinct MI300X/MI300A |
| gfx950 | Radeon Instinct MI350X |
| gfx1010 | Radeon RX 5700 XT |
| gfx1012 | Radeon RX 5500 XT |
| gfx1030 | Radeon PRO V620 |
| gfx1100 | Radeon PRO W7900 |
| gfx1101 | Radeon PRO W7700 |
| gfx1102 | Radeon RX 7600 |
AMD is working on enhancing ROCm v6 to broaden support for families of GPUs in a
future release which should increase support for more GPUs.
| gfx1103 | Radeon 780M |
| gfx1150 | Ryzen AI 9 HX 375 |
| gfx1151 | Ryzen AI Max+ 395 |
| gfx1200 | Radeon RX 9070 |
| gfx1201 | Radeon RX 9070 XT |
Reach out on [Discord](https://discord.gg/ollama) or file an
[issue](https://github.com/ollama/ollama/issues) for additional help.

View File

@@ -101,7 +101,7 @@ nvidia-smi
### Install AMD ROCm drivers (optional)
[Download and Install](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/tutorial/quick-start.html) ROCm v6.
[Download and Install](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/tutorial/quick-start.html) ROCm v7.
### Start Ollama

View File

@@ -152,7 +152,9 @@ PARAMETER <parameter> <parametervalue>
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.0) | float | repeat_penalty 1.0 |
| presence_penalty | Penalizes tokens that have already appeared in the generated text to reduce repetition. (Default: 0.0) | float | presence_penalty 1.5 |
| frequency_penalty | Penalizes tokens based on how often they have appeared in the generated text. (Default: 0.0) | float | frequency_penalty 1.0 |
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
| seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | int | seed 42 |
| stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate `stop` parameters in a modelfile. | string | stop "AI assistant:" |

View File

@@ -0,0 +1,115 @@
package modelref
import (
"errors"
"fmt"
"strings"
)
type ModelSource uint8
const (
ModelSourceUnspecified ModelSource = iota
ModelSourceLocal
ModelSourceCloud
)
var (
ErrConflictingSourceSuffix = errors.New("use either :local or :cloud, not both")
ErrModelRequired = errors.New("model is required")
)
type ParsedRef struct {
Original string
Base string
Source ModelSource
}
func ParseRef(raw string) (ParsedRef, error) {
var zero ParsedRef
raw = strings.TrimSpace(raw)
if raw == "" {
return zero, ErrModelRequired
}
base, source, explicit := parseSourceSuffix(raw)
if explicit {
if _, _, nested := parseSourceSuffix(base); nested {
return zero, fmt.Errorf("%w: %q", ErrConflictingSourceSuffix, raw)
}
}
return ParsedRef{
Original: raw,
Base: base,
Source: source,
}, nil
}
func HasExplicitCloudSource(raw string) bool {
parsedRef, err := ParseRef(raw)
return err == nil && parsedRef.Source == ModelSourceCloud
}
func HasExplicitLocalSource(raw string) bool {
parsedRef, err := ParseRef(raw)
return err == nil && parsedRef.Source == ModelSourceLocal
}
func StripCloudSourceTag(raw string) (string, bool) {
parsedRef, err := ParseRef(raw)
if err != nil || parsedRef.Source != ModelSourceCloud {
return strings.TrimSpace(raw), false
}
return parsedRef.Base, true
}
func NormalizePullName(raw string) (string, bool, error) {
parsedRef, err := ParseRef(raw)
if err != nil {
return "", false, err
}
if parsedRef.Source != ModelSourceCloud {
return parsedRef.Base, false, nil
}
return toLegacyCloudPullName(parsedRef.Base), true, nil
}
func toLegacyCloudPullName(base string) string {
if hasExplicitTag(base) {
return base + "-cloud"
}
return base + ":cloud"
}
func hasExplicitTag(name string) bool {
lastSlash := strings.LastIndex(name, "/")
lastColon := strings.LastIndex(name, ":")
return lastColon > lastSlash
}
func parseSourceSuffix(raw string) (string, ModelSource, bool) {
idx := strings.LastIndex(raw, ":")
if idx >= 0 {
suffixRaw := strings.TrimSpace(raw[idx+1:])
suffix := strings.ToLower(suffixRaw)
switch suffix {
case "cloud":
return raw[:idx], ModelSourceCloud, true
case "local":
return raw[:idx], ModelSourceLocal, true
}
if !strings.Contains(suffixRaw, "/") && strings.HasSuffix(suffix, "-cloud") {
return raw[:idx+1] + suffixRaw[:len(suffixRaw)-len("-cloud")], ModelSourceCloud, true
}
}
return raw, ModelSourceUnspecified, false
}

View File

@@ -0,0 +1,268 @@
package modelref
import (
"errors"
"testing"
)
func TestParseRef(t *testing.T) {
tests := []struct {
name string
input string
wantBase string
wantSource ModelSource
wantErr error
wantCloud bool
wantLocal bool
wantStripped string
wantStripOK bool
}{
{
name: "cloud suffix",
input: "gpt-oss:20b:cloud",
wantBase: "gpt-oss:20b",
wantSource: ModelSourceCloud,
wantCloud: true,
wantStripped: "gpt-oss:20b",
wantStripOK: true,
},
{
name: "legacy cloud suffix",
input: "gpt-oss:20b-cloud",
wantBase: "gpt-oss:20b",
wantSource: ModelSourceCloud,
wantCloud: true,
wantStripped: "gpt-oss:20b",
wantStripOK: true,
},
{
name: "local suffix",
input: "qwen3:8b:local",
wantBase: "qwen3:8b",
wantSource: ModelSourceLocal,
wantLocal: true,
wantStripped: "qwen3:8b:local",
},
{
name: "no source suffix",
input: "llama3.2",
wantBase: "llama3.2",
wantSource: ModelSourceUnspecified,
wantStripped: "llama3.2",
},
{
name: "bare cloud name is not explicit cloud",
input: "my-cloud-model",
wantBase: "my-cloud-model",
wantSource: ModelSourceUnspecified,
wantStripped: "my-cloud-model",
},
{
name: "slash in suffix blocks legacy cloud parsing",
input: "foo:bar-cloud/baz",
wantBase: "foo:bar-cloud/baz",
wantSource: ModelSourceUnspecified,
wantStripped: "foo:bar-cloud/baz",
},
{
name: "conflicting source suffixes",
input: "foo:cloud:local",
wantErr: ErrConflictingSourceSuffix,
wantSource: ModelSourceUnspecified,
},
{
name: "empty input",
input: " ",
wantErr: ErrModelRequired,
wantSource: ModelSourceUnspecified,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseRef(tt.input)
if tt.wantErr != nil {
if !errors.Is(err, tt.wantErr) {
t.Fatalf("ParseRef(%q) error = %v, want %v", tt.input, err, tt.wantErr)
}
return
}
if err != nil {
t.Fatalf("ParseRef(%q) returned error: %v", tt.input, err)
}
if got.Base != tt.wantBase {
t.Fatalf("base = %q, want %q", got.Base, tt.wantBase)
}
if got.Source != tt.wantSource {
t.Fatalf("source = %v, want %v", got.Source, tt.wantSource)
}
if HasExplicitCloudSource(tt.input) != tt.wantCloud {
t.Fatalf("HasExplicitCloudSource(%q) = %v, want %v", tt.input, HasExplicitCloudSource(tt.input), tt.wantCloud)
}
if HasExplicitLocalSource(tt.input) != tt.wantLocal {
t.Fatalf("HasExplicitLocalSource(%q) = %v, want %v", tt.input, HasExplicitLocalSource(tt.input), tt.wantLocal)
}
stripped, ok := StripCloudSourceTag(tt.input)
if ok != tt.wantStripOK {
t.Fatalf("StripCloudSourceTag(%q) ok = %v, want %v", tt.input, ok, tt.wantStripOK)
}
if stripped != tt.wantStripped {
t.Fatalf("StripCloudSourceTag(%q) base = %q, want %q", tt.input, stripped, tt.wantStripped)
}
})
}
}
func TestNormalizePullName(t *testing.T) {
tests := []struct {
name string
input string
wantName string
wantCloud bool
wantErr error
}{
{
name: "explicit local strips source",
input: "gpt-oss:20b:local",
wantName: "gpt-oss:20b",
},
{
name: "explicit cloud with size maps to legacy dash cloud tag",
input: "gpt-oss:20b:cloud",
wantName: "gpt-oss:20b-cloud",
wantCloud: true,
},
{
name: "legacy cloud with size remains stable",
input: "gpt-oss:20b-cloud",
wantName: "gpt-oss:20b-cloud",
wantCloud: true,
},
{
name: "explicit cloud without tag maps to cloud tag",
input: "qwen3:cloud",
wantName: "qwen3:cloud",
wantCloud: true,
},
{
name: "host port without tag keeps host port and appends cloud tag",
input: "localhost:11434/library/foo:cloud",
wantName: "localhost:11434/library/foo:cloud",
wantCloud: true,
},
{
name: "conflicting source suffixes fail",
input: "foo:cloud:local",
wantErr: ErrConflictingSourceSuffix,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotName, gotCloud, err := NormalizePullName(tt.input)
if tt.wantErr != nil {
if !errors.Is(err, tt.wantErr) {
t.Fatalf("NormalizePullName(%q) error = %v, want %v", tt.input, err, tt.wantErr)
}
return
}
if err != nil {
t.Fatalf("NormalizePullName(%q) returned error: %v", tt.input, err)
}
if gotName != tt.wantName {
t.Fatalf("normalized name = %q, want %q", gotName, tt.wantName)
}
if gotCloud != tt.wantCloud {
t.Fatalf("cloud = %v, want %v", gotCloud, tt.wantCloud)
}
})
}
}
func TestParseSourceSuffix(t *testing.T) {
tests := []struct {
name string
input string
wantBase string
wantSource ModelSource
wantExplicit bool
}{
{
name: "explicit cloud suffix",
input: "gpt-oss:20b:cloud",
wantBase: "gpt-oss:20b",
wantSource: ModelSourceCloud,
wantExplicit: true,
},
{
name: "explicit local suffix",
input: "qwen3:8b:local",
wantBase: "qwen3:8b",
wantSource: ModelSourceLocal,
wantExplicit: true,
},
{
name: "legacy cloud suffix on tag",
input: "gpt-oss:20b-cloud",
wantBase: "gpt-oss:20b",
wantSource: ModelSourceCloud,
wantExplicit: true,
},
{
name: "legacy cloud suffix does not match model segment",
input: "my-cloud-model",
wantBase: "my-cloud-model",
wantSource: ModelSourceUnspecified,
wantExplicit: false,
},
{
name: "legacy cloud suffix blocked when suffix includes slash",
input: "foo:bar-cloud/baz",
wantBase: "foo:bar-cloud/baz",
wantSource: ModelSourceUnspecified,
wantExplicit: false,
},
{
name: "unknown suffix is not explicit source",
input: "gpt-oss:clod",
wantBase: "gpt-oss:clod",
wantSource: ModelSourceUnspecified,
wantExplicit: false,
},
{
name: "uppercase suffix is accepted",
input: "gpt-oss:20b:CLOUD",
wantBase: "gpt-oss:20b",
wantSource: ModelSourceCloud,
wantExplicit: true,
},
{
name: "no suffix",
input: "llama3.2",
wantBase: "llama3.2",
wantSource: ModelSourceUnspecified,
wantExplicit: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotBase, gotSource, gotExplicit := parseSourceSuffix(tt.input)
if gotBase != tt.wantBase {
t.Fatalf("base = %q, want %q", gotBase, tt.wantBase)
}
if gotSource != tt.wantSource {
t.Fatalf("source = %v, want %v", gotSource, tt.wantSource)
}
if gotExplicit != tt.wantExplicit {
t.Fatalf("explicit = %v, want %v", gotExplicit, tt.wantExplicit)
}
})
}
}

View File

@@ -74,8 +74,7 @@ type LlamaServer interface {
Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error
VRAMSize() uint64 // Total VRAM across all GPUs
TotalSize() uint64
MemorySize() (total, vram uint64)
VRAMByGPU(id ml.DeviceID) uint64
Pid() int
GetPort() int
@@ -685,8 +684,9 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
// Windows CUDA should not use mmap for best performance
// Linux with a model larger than free space, mmap leads to thrashing
// For CPU loads we want the memory to be allocated, not FS cache
totalSize, _ := s.MemorySize()
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < s.TotalSize() && s.options.UseMMap == nil) ||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < totalSize && s.options.UseMMap == nil) ||
(len(gpus) == 0 && s.options.UseMMap == nil) ||
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
(s.options.UseMMap != nil && !*s.options.UseMMap) {
@@ -1848,17 +1848,17 @@ func (s *llamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
return nil
}
func (s *llmServer) VRAMSize() uint64 {
func (s *llmServer) MemorySize() (total, vram uint64) {
if s.mem == nil {
return 0
return 0, 0
}
var mem uint64
for _, g := range s.mem.GPUs {
mem += g.Size()
vram += g.Size()
}
total = s.mem.InputWeights + s.mem.CPU.Size() + vram
// Some elements are always on CPU. However, if we have allocated all layers
// on the GPU then include the CPU components as well, to represent complete offloading.
noCPULayers := true
@@ -1869,25 +1869,11 @@ func (s *llmServer) VRAMSize() uint64 {
}
}
if noCPULayers {
mem += s.mem.InputWeights
mem += s.mem.CPU.Graph
vram += s.mem.InputWeights
vram += s.mem.CPU.Graph
}
return mem
}
func (s *llmServer) TotalSize() uint64 {
if s.mem == nil {
return 0
}
mem := s.mem.InputWeights
mem += s.mem.CPU.Size()
for _, g := range s.mem.GPUs {
mem += g.Size()
}
return mem
return total, vram
}
func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {

View File

@@ -17,6 +17,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
internalcloud "github.com/ollama/ollama/internal/cloud"
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/logutil"
)
@@ -919,7 +920,7 @@ func hasWebSearchTool(tools []anthropic.Tool) bool {
}
func isCloudModelName(name string) bool {
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
return modelref.HasExplicitCloudSource(name)
}
// extractQueryFromToolCall extracts the search query from a web_search tool call

View File

@@ -41,8 +41,8 @@ type GatedDeltaNet struct {
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
SSMDT ml.Tensor `gguf:"ssm_dt,alt:ssm_dt.bias"` // alpha bias
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
SSMOut *nn.Linear `gguf:"ssm_out"`
@@ -135,6 +135,18 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
default:
return nil, errors.New("qwen3next: missing linear attention beta/alpha projections")
}
if gdn.SSMDT == nil {
return nil, errors.New("qwen3next: missing linear attention ssm_dt tensor")
}
if gdn.SSMA == nil {
return nil, errors.New("qwen3next: missing linear attention ssm_a tensor")
}
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
return nil, errors.New("qwen3next: missing linear attention ssm_conv1d tensor")
}
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
return nil, errors.New("qwen3next: missing linear attention ssm_norm/ssm_out projections")
}
// Compute gate: softplus(alpha + dt_bias) * -A
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
@@ -442,6 +454,10 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
vT := v.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, headVDim, nChunks, numVHeads*nSeqs)
stateT := state.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
// Collect chunk outputs and concatenate at the end.
// Avoids SET on buffer-less intermediates under partial offload.
chunks := make([]ml.Tensor, nChunks)
for chunk := range nChunks {
qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1)
vTChunk := vT.Slice(ctx, 2, chunk, chunk+1, 1)
@@ -463,14 +479,7 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
vAttn := vTNewChunk.Mulmat(ctx, attnChunk)
coreAttnOutChunk := attnInter.Add(ctx, vAttn)
v = v.SetInplace(
ctx,
coreAttnOutChunk,
v.Stride(1),
v.Stride(2),
v.Stride(3),
chunk*v.Stride(2),
)
chunks[chunk] = coreAttnOutChunk
// Update state for next chunk
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
@@ -483,6 +492,20 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
stateT = stateT.Add(ctx, kgdMulVNew)
}
// Use a balanced concat tree so concat work does not balloon on long prompts.
for len(chunks) > 1 {
merged := make([]ml.Tensor, 0, (len(chunks)+1)/2)
for i := 0; i < len(chunks); i += 2 {
if i+1 < len(chunks) {
merged = append(merged, chunks[i].Concat(ctx, chunks[i+1], 2))
} else {
merged = append(merged, chunks[i])
}
}
chunks = merged
}
v = chunks[0]
// Final reshape
coreAttnOut := v.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)

View File

@@ -437,6 +437,46 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
return m.Output.Forward(ctx, hiddenStates), nil
}
func (m *Model) Validate() error {
if m.Options == nil {
return fmt.Errorf("qwen3next: missing model options")
}
if len(m.Layers) != len(m.Options.isRecurrent) {
return fmt.Errorf("qwen3next: layer config mismatch: have %d layers, %d recurrent flags", len(m.Layers), len(m.Options.isRecurrent))
}
for i, layer := range m.Layers {
if !m.Options.isRecurrent[i] {
continue
}
gdn, ok := layer.Operator.(*GatedDeltaNet)
if !ok || gdn == nil {
return fmt.Errorf("qwen3next: layer %d expected recurrent operator", i)
}
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
return fmt.Errorf("qwen3next: layer %d missing attn_qkv/attn_gate projections", i)
}
if gdn.SSMBetaAlpha == nil && (gdn.SSMBeta == nil || gdn.SSMAlpha == nil) {
return fmt.Errorf("qwen3next: layer %d missing linear attention beta/alpha projections", i)
}
if gdn.SSMDT == nil {
return fmt.Errorf("qwen3next: layer %d missing ssm_dt tensor", i)
}
if gdn.SSMA == nil {
return fmt.Errorf("qwen3next: layer %d missing ssm_a tensor", i)
}
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
return fmt.Errorf("qwen3next: layer %d missing ssm_conv1d tensor", i)
}
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
return fmt.Errorf("qwen3next: layer %d missing ssm_norm/ssm_out projections", i)
}
}
return nil
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
m.positionCache = nil
if len(m.mropeSections) > 0 {
@@ -450,6 +490,64 @@ var (
_ model.MultimodalProcessor = (*Model)(nil)
)
func defaultVHeadReordered(arch string) bool {
return arch == "qwen35" || arch == "qwen35moe"
}
func inferRecurrentLayers(headCountKV []uint64, numLayers int, fullAttentionInterval uint32) ([]bool, error) {
isRecurrent := make([]bool, numLayers)
hasZero := false
hasFull := false
for i := range numLayers {
if i >= len(headCountKV) {
continue
}
if headCountKV[i] == 0 {
isRecurrent[i] = true
hasZero = true
} else {
hasFull = true
}
}
if hasZero && hasFull {
return isRecurrent, nil
}
if !hasFull {
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
}
// Compatibility path: older imports store a scalar KV head count and omit
// per-layer recurrent flags. Derive the hybrid layout from the interval.
interval := int(fullAttentionInterval)
if interval == 0 {
interval = min(4, numLayers)
}
if interval <= 0 {
return nil, fmt.Errorf("qwen3next: invalid block_count (%d)", numLayers)
}
if interval > numLayers {
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds block_count (%d)", interval, numLayers)
}
hasZero = false
hasFull = false
for i := range numLayers {
isRecurrent[i] = (i+1)%interval != 0
if isRecurrent[i] {
hasZero = true
} else {
hasFull = true
}
}
if !hasZero || !hasFull {
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) does not produce a mixed recurrent/full layout", interval)
}
return isRecurrent, nil
}
func New(c fs.Config) (model.Model, error) {
numLayers := int(c.Uint("block_count"))
layers := make([]Layer, numLayers)
@@ -460,26 +558,14 @@ func New(c fs.Config) (model.Model, error) {
HeadCountKV() []uint64
}
var isRecurrent []bool
var headCountKV []uint64
if hc, ok := c.(headCounts); ok {
headCountKV = hc.HeadCountKV()
}
isRecurrent = make([]bool, numLayers)
hasZero := false
hasFull := false
for i := range numLayers {
// If KV head count is 0, it's a recurrent layer
if i < len(headCountKV) && headCountKV[i] == 0 {
isRecurrent[i] = true
hasZero = true
} else if i < len(headCountKV) && headCountKV[i] > 0 {
hasFull = true
}
}
if !hasZero || !hasFull {
return nil, fmt.Errorf("qwen3next: invalid attention.head_count_kv array; expected mix of zero and non-zero values")
isRecurrent, err := inferRecurrentLayers(headCountKV, numLayers, c.Uint("full_attention_interval"))
if err != nil {
return nil, err
}
// Determine if MoE
@@ -543,7 +629,7 @@ func New(c fs.Config) (model.Model, error) {
ssmNGroup: int(c.Uint("ssm.group_count")),
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
convKernelSize: int(c.Uint("ssm.conv_kernel")),
vHeadReordered: c.Bool("ssm.v_head_reordered", false),
vHeadReordered: c.Bool("ssm.v_head_reordered", defaultVHeadReordered(c.Architecture())),
isRecurrent: isRecurrent,
mropeSections: slices.Collect(func(yield func(int) bool) {
for _, section := range mropeSections {
@@ -555,7 +641,7 @@ func New(c fs.Config) (model.Model, error) {
mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)),
}
if opts.numKVHeads == 0 {
return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value")
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
}
// Calculate cache dimensions

View File

@@ -0,0 +1,65 @@
package qwen3next
import (
"slices"
"strings"
"testing"
)
func TestInferRecurrentLayersMixedKVArray(t *testing.T) {
got, err := inferRecurrentLayers([]uint64{0, 2, 0, 2}, 4, 0)
if err != nil {
t.Fatalf("inferRecurrentLayers() error = %v", err)
}
want := []bool{true, false, true, false}
if !slices.Equal(got, want) {
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
}
}
func TestInferRecurrentLayersScalarKVDefaultInterval(t *testing.T) {
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2, 2, 2}, 8, 0)
if err != nil {
t.Fatalf("inferRecurrentLayers() error = %v", err)
}
want := []bool{true, true, true, false, true, true, true, false}
if !slices.Equal(got, want) {
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
}
}
func TestInferRecurrentLayersScalarKVConfiguredInterval(t *testing.T) {
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2}, 6, 3)
if err != nil {
t.Fatalf("inferRecurrentLayers() error = %v", err)
}
want := []bool{true, true, false, true, true, false}
if !slices.Equal(got, want) {
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
}
}
func TestInferRecurrentLayersAllZeroRejects(t *testing.T) {
_, err := inferRecurrentLayers([]uint64{0, 0, 0, 0}, 4, 0)
if err == nil {
t.Fatal("inferRecurrentLayers() expected error, got nil")
}
if !strings.Contains(err.Error(), "must include at least one non-zero value") {
t.Fatalf("unexpected error = %v", err)
}
}
func TestDefaultVHeadReordered(t *testing.T) {
if !defaultVHeadReordered("qwen35") {
t.Fatal("defaultVHeadReordered(qwen35) = false, want true")
}
if !defaultVHeadReordered("qwen35moe") {
t.Fatal("defaultVHeadReordered(qwen35moe) = false, want true")
}
if defaultVHeadReordered("qwen3next") {
t.Fatal("defaultVHeadReordered(qwen3next) = true, want false")
}
}

View File

@@ -0,0 +1,45 @@
package qwen3next
import (
"strings"
"testing"
"github.com/ollama/ollama/ml/nn"
)
func TestValidateRecurrentLayerRequiresSSMDT(t *testing.T) {
m := &Model{
Layers: []Layer{{
Operator: &GatedDeltaNet{
SSMQKV: &nn.Linear{},
SSMQKVGate: &nn.Linear{},
SSMBeta: &nn.Linear{},
SSMAlpha: &nn.Linear{},
},
}},
Options: &Options{
isRecurrent: []bool{true},
},
}
err := m.Validate()
if err == nil {
t.Fatal("Validate() expected error, got nil")
}
if !strings.Contains(err.Error(), "missing ssm_dt") {
t.Fatalf("unexpected error = %v", err)
}
}
func TestValidateNonRecurrentSkipsLinearChecks(t *testing.T) {
m := &Model{
Layers: []Layer{{Operator: &FullAttention{}}},
Options: &Options{
isRecurrent: []bool{false},
},
}
if err := m.Validate(); err != nil {
t.Fatalf("Validate() error = %v", err)
}
}

View File

@@ -32,9 +32,10 @@ const (
)
type GLM46Parser struct {
state glm46ParserState
buffer strings.Builder
tools []api.Tool
state glm46ParserState
buffer strings.Builder
tools []api.Tool
callIndex int
}
func (p *GLM46Parser) HasToolSupport() bool {
@@ -48,6 +49,7 @@ func (p *GLM46Parser) HasThinkingSupport() bool {
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
p.callIndex = 0
return tools
}
@@ -89,6 +91,8 @@ func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string,
slog.Warn("glm-4.6 tool call parsing failed", "error", err)
return "", "", nil, err
}
toolCall.Function.Index = p.callIndex
p.callIndex++
toolCalls = append(toolCalls, toolCall)
case glm46EventThinkingContent:
thinkingSb.WriteString(event.content)
@@ -341,6 +345,47 @@ func escapeGLM46Content(s string) string {
return result.String()
}
// repairUnclosedArgValues inserts missing </arg_value> closing tags.
// GLM models sometimes omit the closing tag, producing XML like:
//
// <arg_value>value</tool_call>
//
// instead of:
//
// <arg_value>value</arg_value></tool_call>
func repairUnclosedArgValues(s string) string {
var result strings.Builder
for {
openIdx := strings.Index(s, "<arg_value>")
if openIdx == -1 {
result.WriteString(s)
break
}
afterOpen := openIdx + len("<arg_value>")
closeIdx := strings.Index(s[afterOpen:], "</arg_value>")
nextKeyIdx := strings.Index(s[afterOpen:], "<arg_key>")
// Check if properly closed before the next <arg_key> (or no next key)
if closeIdx != -1 && (nextKeyIdx == -1 || closeIdx < nextKeyIdx) {
end := afterOpen + closeIdx + len("</arg_value>")
result.WriteString(s[:end])
s = s[end:]
continue
}
// Unclosed — insert </arg_value> before the next <arg_key> or at end
if nextKeyIdx != -1 {
insertAt := afterOpen + nextKeyIdx
result.WriteString(s[:insertAt])
result.WriteString("</arg_value>")
s = s[insertAt:]
} else {
result.WriteString(s)
result.WriteString("</arg_value>")
break
}
}
return result.String()
}
func parseGLM46ToolCall(raw glm46EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
// Escape any unescaped entities in text content
// We need to escape text between tags, but not the tags themselves
@@ -349,10 +394,14 @@ func parseGLM46ToolCall(raw glm46EventRawToolCall, tools []api.Tool) (api.ToolCa
// Wrap the content in a root element to make it valid XML
xmlString := "<tool_call>" + escaped + "</tool_call>"
// Parse XML into struct
// Parse XML into struct, retrying once with repaired XML if it fails
var parsed GLMToolCallXML
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
parsed = GLMToolCallXML{}
repaired := "<tool_call>" + repairUnclosedArgValues(escaped) + "</tool_call>"
if err2 := xml.Unmarshal([]byte(repaired), &parsed); err2 != nil {
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
}
}
// Extract and trim function name

View File

@@ -846,6 +846,47 @@ line3</arg_value>`,
},
},
},
{
name: "unclosed arg_value at end",
tools: []api.Tool{},
rawToolCall: `get-weather
<arg_key>city</arg_key>
<arg_value>Paris`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get-weather",
Arguments: args(`{"city": "Paris"}`),
},
},
},
{
name: "unclosed arg_value before next arg_key",
tools: []api.Tool{},
rawToolCall: `get-weather
<arg_key>city</arg_key>
<arg_value>Paris<arg_key>unit</arg_key>
<arg_value>celsius</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get-weather",
Arguments: args(`{"city": "Paris", "unit": "celsius"}`),
},
},
},
{
name: "multiple unclosed arg_values",
tools: []api.Tool{},
rawToolCall: `get-weather
<arg_key>city</arg_key>
<arg_value>Paris<arg_key>unit</arg_key>
<arg_value>celsius`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get-weather",
Arguments: args(`{"city": "Paris", "unit": "celsius"}`),
},
},
},
}
for i, tc := range cases {
@@ -860,3 +901,45 @@ line3</arg_value>`,
})
}
}
func TestRepairUnclosedArgValues(t *testing.T) {
cases := []struct {
name string
input string
want string
}{
{
name: "already valid",
input: `<arg_key>k</arg_key><arg_value>v</arg_value>`,
want: `<arg_key>k</arg_key><arg_value>v</arg_value>`,
},
{
name: "unclosed at end",
input: `<arg_key>k</arg_key><arg_value>v`,
want: `<arg_key>k</arg_key><arg_value>v</arg_value>`,
},
{
name: "unclosed before next arg_key",
input: `<arg_key>a</arg_key><arg_value>1<arg_key>b</arg_key><arg_value>2</arg_value>`,
want: `<arg_key>a</arg_key><arg_value>1</arg_value><arg_key>b</arg_key><arg_value>2</arg_value>`,
},
{
name: "no arg_value tags",
input: `just plain text`,
want: `just plain text`,
},
{
name: "multiple unclosed",
input: `<arg_key>a</arg_key><arg_value>1<arg_key>b</arg_key><arg_value>2`,
want: `<arg_key>a</arg_key><arg_value>1</arg_value><arg_key>b</arg_key><arg_value>2</arg_value>`,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := repairUnclosedArgValues(tc.input)
if got != tc.want {
t.Errorf("got %q, want %q", got, tc.want)
}
})
}
}

View File

@@ -11,6 +11,7 @@ type GLM47Parser struct {
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
p.callIndex = 0
// When thinking is enabled (nil or true), the prompt ends with <think>,
// so model output starts directly with thinking content (no opening tag).
if thinkValue == nil || thinkValue.Bool() {

View File

@@ -97,3 +97,91 @@ func TestGLM47ParserToolCallEscaping(t *testing.T) {
t.Fatalf("expected %#v, got %#v", expected, toolCall)
}
}
func TestGLM47ParserToolCallIndexing(t *testing.T) {
parser := GLM47Parser{}
parser.Init(nil, nil, nil)
input := `plan</think>
<tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>
<tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>
<tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>`
_, _, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
}
if len(calls) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
}
for i := range want {
if !toolCallEqual(calls[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
}
}
}
func TestGLM47ParserToolCallIndexingStreaming(t *testing.T) {
parser := GLM47Parser{}
parser.Init(nil, nil, nil)
var all []api.ToolCall
_, _, calls, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call><tool_call>second<arg_key>b</arg_key>", false)
if err != nil {
t.Fatalf("step 1 parse failed: %v", err)
}
all = append(all, calls...)
_, _, calls, err = parser.Add("<arg_value>2</arg_value></tool_call><tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>", true)
if err != nil {
t.Fatalf("step 2 parse failed: %v", err)
}
all = append(all, calls...)
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
}
if len(all) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(all))
}
for i := range want {
if !toolCallEqual(all[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
}
}
}
func TestGLM47ParserToolCallIndexResetOnInit(t *testing.T) {
parser := GLM47Parser{}
parser.Init(nil, nil, nil)
_, _, _, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>", true)
if err != nil {
t.Fatalf("first parse failed: %v", err)
}
parser.Init(nil, nil, nil)
_, _, calls, err := parser.Add("plan</think><tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>", true)
if err != nil {
t.Fatalf("second parse failed: %v", err)
}
want := api.ToolCall{
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
}
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if !toolCallEqual(calls[0], want) {
t.Fatalf("got %#v, want %#v", calls[0], want)
}
}

View File

@@ -50,7 +50,7 @@ func ParserForName(name string) Parser {
case "qwen3-thinking":
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
case "qwen3.5":
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
p = &Qwen35Parser{}
case "qwen3-coder":
p = &Qwen3CoderParser{}
case "qwen3-vl-instruct":

View File

@@ -38,6 +38,7 @@ type Qwen3Parser struct {
state qwen3ParserState
buffer strings.Builder
tools []api.Tool
callIndex int
hasThinkingSupport bool
defaultThinking bool
maybeThinkingOpenAtBOL bool
@@ -54,6 +55,7 @@ func (p *Qwen3Parser) HasThinkingSupport() bool {
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
p.buffer.Reset()
p.callIndex = 0
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
if thinkValue == nil {
@@ -106,6 +108,8 @@ func (p *Qwen3Parser) Add(s string, done bool) (content string, thinking string,
slog.Warn("qwen3 tool call parsing failed", "error", err)
return "", "", nil, err
}
toolCall.Function.Index = p.callIndex
p.callIndex++
calls = append(calls, toolCall)
case qwen3EventThinkingContent:
thinkingSb.WriteString(event.content)

238
model/parsers/qwen35.go Normal file
View File

@@ -0,0 +1,238 @@
package parsers
import (
"context"
"log/slog"
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
)
type qwen35ParserState int
const (
qwen35ParserStateCollectingThinking qwen35ParserState = iota
qwen35ParserStateThinkingDoneEatingWhitespace
qwen35ParserStateCollectingContent
)
const (
qwen35ThinkingOpenTag = "<think>"
qwen35ThinkingCloseTag = "</think>"
)
// Qwen35Parser handles qwen3.5 reasoning extraction and delegates post-thinking
// content (including XML tool calls) to Qwen3CoderParser.
type Qwen35Parser struct {
toolParser Qwen3CoderParser
state qwen35ParserState
buffer strings.Builder
// Some checkpoints may emit an explicit leading <think> even when the
// prompt already opened thinking. Strip at most one such tag.
allowLeadingThinkOpenTag bool
}
func (p *Qwen35Parser) HasToolSupport() bool {
return true
}
func (p *Qwen35Parser) HasThinkingSupport() bool {
return true
}
func (p *Qwen35Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.buffer.Reset()
p.toolParser = Qwen3CoderParser{}
p.toolParser.Init(tools, nil, nil)
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
if thinkValue == nil {
thinkingEnabled = true
}
assistantPrefill := lastMessage != nil && lastMessage.Role == "assistant" && lastMessage.Content != ""
if thinkingEnabled && !assistantPrefill {
p.state = qwen35ParserStateCollectingThinking
p.allowLeadingThinkOpenTag = true
} else {
p.state = qwen35ParserStateCollectingContent
p.allowLeadingThinkOpenTag = false
}
return tools
}
type qwen35Event interface {
isQwen35Event()
}
type qwen35EventContent struct {
content string
}
func (qwen35EventContent) isQwen35Event() {}
type qwen35EventThinkingContent struct {
content string
}
func (qwen35EventThinkingContent) isQwen35Event() {}
func (p *Qwen35Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case qwen35EventContent:
parsedContent, _, parsedCalls, err := p.toolParser.Add(event.content, done)
if err != nil {
slog.Warn("qwen3.5 tool call parsing failed", "error", err)
return "", "", nil, err
}
contentSb.WriteString(parsedContent)
calls = append(calls, parsedCalls...)
case qwen35EventThinkingContent:
thinkingSb.WriteString(event.content)
}
}
return contentSb.String(), thinkingSb.String(), calls, nil
}
func (p *Qwen35Parser) parseEvents() []qwen35Event {
var all []qwen35Event
keepLooping := true
for keepLooping {
var events []qwen35Event
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
if len(all) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "qwen3.5 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
}
return all
}
func (p *Qwen35Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
return splitAtTag(&p.buffer, tag, trimAfter)
}
func (p *Qwen35Parser) eatLeadingWhitespaceAndTransitionTo(nextState qwen35ParserState) ([]qwen35Event, bool) {
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
p.buffer.Reset()
if trimmed == "" {
return nil, false
}
p.state = nextState
p.buffer.WriteString(trimmed)
return nil, true
}
// maybeConsumeLeadingThinkOpenTag handles a single optional leading <think> tag.
// Returns (handled, shouldContinueParsingNow).
func (p *Qwen35Parser) maybeConsumeLeadingThinkOpenTag(acc string) (bool, bool) {
if !p.allowLeadingThinkOpenTag {
return false, false
}
trimmed := strings.TrimLeftFunc(acc, unicode.IsSpace)
if strings.HasPrefix(trimmed, qwen35ThinkingOpenTag) {
after := strings.TrimPrefix(trimmed, qwen35ThinkingOpenTag)
after = strings.TrimLeftFunc(after, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(after)
if after == "" {
return true, false
}
p.allowLeadingThinkOpenTag = false
return true, true
}
if strings.HasPrefix(qwen35ThinkingOpenTag, trimmed) {
return true, false
}
p.allowLeadingThinkOpenTag = false
return false, false
}
func (p *Qwen35Parser) eat() ([]qwen35Event, bool) {
var events []qwen35Event
switch p.state {
case qwen35ParserStateCollectingThinking:
acc := p.buffer.String()
if handled, continueNow := p.maybeConsumeLeadingThinkOpenTag(acc); handled {
return events, continueNow
}
if strings.Contains(acc, qwen35ThinkingCloseTag) {
thinking, remaining := p.splitAtTag(qwen35ThinkingCloseTag, true)
if len(thinking) > 0 {
events = append(events, qwen35EventThinkingContent{content: thinking})
}
if remaining == "" {
p.state = qwen35ParserStateThinkingDoneEatingWhitespace
} else {
p.state = qwen35ParserStateCollectingContent
}
return events, true
} else if overlapLen := overlap(acc, qwen35ThinkingCloseTag); overlapLen > 0 {
beforePartialTag := acc[:len(acc)-overlapLen]
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWsLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, qwen35EventThinkingContent{content: unambiguous})
}
return events, false
}
whitespaceLen := trailingWhitespaceLen(acc)
ambiguousStart := len(acc) - whitespaceLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, qwen35EventThinkingContent{content: unambiguous})
}
return events, false
case qwen35ParserStateThinkingDoneEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(qwen35ParserStateCollectingContent)
case qwen35ParserStateCollectingContent:
if p.buffer.Len() == 0 {
return events, false
}
content := p.buffer.String()
p.buffer.Reset()
if len(content) > 0 {
events = append(events, qwen35EventContent{content: content})
}
return events, false
default:
slog.Warn("qwen3.5 parser entered unknown state; resetting to content mode", "state", p.state)
p.state = qwen35ParserStateCollectingContent
return events, false
}
}

View File

@@ -0,0 +1,382 @@
package parsers
import (
"testing"
"github.com/ollama/ollama/api"
)
func TestQwen35ParserXMLToolCall(t *testing.T) {
parser := ParserForName("qwen3.5")
if parser == nil {
t.Fatal("expected qwen3.5 parser")
}
tools := []api.Tool{
{
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Properties: func() *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
props.Set("days", api.ToolProperty{Type: api.PropertyType{"integer"}})
return props
}(),
},
},
},
}
parser.Init(tools, nil, &api.ThinkValue{Value: false})
input := "<tool_call><function=get_weather><parameter=location>\nSan Francisco\n</parameter><parameter=days>\n3\n</parameter></function></tool_call>"
content, thinking, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if content != "" {
t.Fatalf("expected empty content, got %q", content)
}
if thinking != "" {
t.Fatalf("expected empty thinking, got %q", thinking)
}
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
if calls[0].Function.Name != "get_weather" {
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
}
location, ok := calls[0].Function.Arguments.Get("location")
if !ok || location != "San Francisco" {
t.Fatalf("expected location %q, got %v", "San Francisco", location)
}
days, ok := calls[0].Function.Arguments.Get("days")
if !ok || days != 3 {
t.Fatalf("expected days %d, got %v", 3, days)
}
}
func TestQwen35ParserThinkingWithExplicitOpeningTag(t *testing.T) {
parser := ParserForName("qwen3.5")
if parser == nil {
t.Fatal("expected qwen3.5 parser")
}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
content, thinking, calls, err := parser.Add("<think>\nLet me think...</think>Answer.", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "Let me think..." {
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
}
if content != "Answer." {
t.Fatalf("expected content %q, got %q", "Answer.", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen35ParserAssistantPrefillStartsInContent(t *testing.T) {
parser := ParserForName("qwen3.5")
if parser == nil {
t.Fatal("expected qwen3.5 parser")
}
last := &api.Message{Role: "assistant", Content: "Prefilled response start"}
parser.Init(nil, last, nil)
content, thinking, calls, err := parser.Add(" and continued", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "" {
t.Fatalf("expected no thinking for assistant prefill continuation, got %q", thinking)
}
if content != " and continued" {
t.Fatalf("expected content %q, got %q", " and continued", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen35ParserToolCallEmittedInThinkingIsNotParsed(t *testing.T) {
parser := ParserForName("qwen3.5")
if parser == nil {
t.Fatal("expected qwen3.5 parser")
}
tools := []api.Tool{
{
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Properties: func() *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
return props
}(),
},
},
},
}
parser.Init(tools, nil, &api.ThinkValue{Value: true})
input := `Need weather lookup<tool_call><function=get_weather><parameter=location>
SF
</parameter></function></tool_call>`
content, thinking, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if content != "" {
t.Fatalf("expected empty content, got %q", content)
}
expectedThinking := `Need weather lookup<tool_call><function=get_weather><parameter=location>
SF
</parameter></function></tool_call>`
if thinking != expectedThinking {
t.Fatalf("expected thinking %q, got %q", expectedThinking, thinking)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls before </think>, got %d", len(calls))
}
}
func TestQwen35ParserToolCallAfterThinkingCloseIsParsed(t *testing.T) {
parser := ParserForName("qwen3.5")
if parser == nil {
t.Fatal("expected qwen3.5 parser")
}
tools := []api.Tool{
{
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Properties: func() *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
return props
}(),
},
},
},
}
parser.Init(tools, nil, &api.ThinkValue{Value: true})
input := `Need weather lookup</think><tool_call><function=get_weather><parameter=location>
SF
</parameter></function></tool_call>`
content, thinking, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if content != "" {
t.Fatalf("expected empty content, got %q", content)
}
if thinking != "Need weather lookup" {
t.Fatalf("expected thinking %q, got %q", "Need weather lookup", thinking)
}
if len(calls) != 1 {
t.Fatalf("expected 1 tool call after </think>, got %d", len(calls))
}
if calls[0].Function.Name != "get_weather" {
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
}
location, ok := calls[0].Function.Arguments.Get("location")
if !ok || location != "SF" {
t.Fatalf("expected location %q, got %v", "SF", location)
}
}
func TestQwen35ParserThinkingDisabledPassesContentThrough(t *testing.T) {
parser := ParserForName("qwen3.5")
if parser == nil {
t.Fatal("expected qwen3.5 parser")
}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
content, thinking, calls, err := parser.Add("Plain answer without think close tag.", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "" {
t.Fatalf("expected empty thinking, got %q", thinking)
}
if content != "Plain answer without think close tag." {
t.Fatalf("expected content %q, got %q", "Plain answer without think close tag.", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen35ParserThinkingDisabledWithCloseTagTreatsAsContent(t *testing.T) {
parser := ParserForName("qwen3.5")
if parser == nil {
t.Fatal("expected qwen3.5 parser")
}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
content, thinking, calls, err := parser.Add("</think>Some content after spurious tag.", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "" {
t.Fatalf("expected empty thinking, got %q", thinking)
}
if content != "</think>Some content after spurious tag." {
t.Fatalf("expected content %q, got %q", "</think>Some content after spurious tag.", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen35ParserLeadingThinkCloseProducesContent(t *testing.T) {
parser := ParserForName("qwen3.5")
if parser == nil {
t.Fatal("expected qwen3.5 parser")
}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
content, thinking, calls, err := parser.Add("</think>The final answer.", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "" {
t.Fatalf("expected empty thinking, got %q", thinking)
}
if content != "The final answer." {
t.Fatalf("expected content %q, got %q", "The final answer.", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen35ParserStreamingSplitThinkCloseTag(t *testing.T) {
parser := ParserForName("qwen3.5")
if parser == nil {
t.Fatal("expected qwen3.5 parser")
}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
content, thinking, calls, err := parser.Add("Reasoning text</thi", false)
if err != nil {
t.Fatalf("parse failed on first chunk: %v", err)
}
if thinking != "Reasoning text" {
t.Fatalf("expected thinking %q, got %q", "Reasoning text", thinking)
}
if content != "" {
t.Fatalf("expected empty content, got %q", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
content, thinking, calls, err = parser.Add("nk>The final answer.", true)
if err != nil {
t.Fatalf("parse failed on second chunk: %v", err)
}
if thinking != "" {
t.Fatalf("expected no additional thinking on second chunk, got %q", thinking)
}
if content != "The final answer." {
t.Fatalf("expected content %q, got %q", "The final answer.", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen35ParserStreamingEatsWhitespaceAfterThinkClose(t *testing.T) {
parser := ParserForName("qwen3.5")
if parser == nil {
t.Fatal("expected qwen3.5 parser")
}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
content, thinking, calls, err := parser.Add("Reasoning</think>", false)
if err != nil {
t.Fatalf("parse failed on first chunk: %v", err)
}
if thinking != "Reasoning" {
t.Fatalf("expected thinking %q, got %q", "Reasoning", thinking)
}
if content != "" {
t.Fatalf("expected empty content, got %q", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
content, thinking, calls, err = parser.Add("\n \t", false)
if err != nil {
t.Fatalf("parse failed on whitespace chunk: %v", err)
}
if thinking != "" {
t.Fatalf("expected no thinking on whitespace chunk, got %q", thinking)
}
if content != "" {
t.Fatalf("expected whitespace after </think> to be eaten, got content %q", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
content, thinking, calls, err = parser.Add("The final answer.", true)
if err != nil {
t.Fatalf("parse failed on content chunk: %v", err)
}
if thinking != "" {
t.Fatalf("expected no additional thinking, got %q", thinking)
}
if content != "The final answer." {
t.Fatalf("expected content %q, got %q", "The final answer.", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen35ParserThinkingTruncatedWithoutCloseTag(t *testing.T) {
parser := ParserForName("qwen3.5")
if parser == nil {
t.Fatal("expected qwen3.5 parser")
}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
content, thinking, calls, err := parser.Add("Reasoning that never closes", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "Reasoning that never closes" {
t.Fatalf("expected thinking %q, got %q", "Reasoning that never closes", thinking)
}
if content != "" {
t.Fatalf("expected empty content, got %q", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}

View File

@@ -230,3 +230,89 @@ func TestQwen35ParserRespectsNoThink(t *testing.T) {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen3ParserToolCallIndexing(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
input := `<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>
<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>
<tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`
_, _, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
}
if len(calls) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
}
for i := range want {
if !toolCallEqual(calls[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
}
}
}
func TestQwen3ParserToolCallIndexingStreaming(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
var all []api.ToolCall
_, _, calls, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call><tool_call>{"name":"second","arguments":{"b":"2"}`, false)
if err != nil {
t.Fatalf("step 1 parse failed: %v", err)
}
all = append(all, calls...)
_, _, calls, err = parser.Add(`}</tool_call><tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`, true)
if err != nil {
t.Fatalf("step 2 parse failed: %v", err)
}
all = append(all, calls...)
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
}
if len(all) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(all))
}
for i := range want {
if !toolCallEqual(all[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
}
}
}
func TestQwen3ParserToolCallIndexResetOnInit(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
_, _, _, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>`, true)
if err != nil {
t.Fatalf("first parse failed: %v", err)
}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
_, _, calls, err := parser.Add(`<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>`, true)
if err != nil {
t.Fatalf("second parse failed: %v", err)
}
want := api.ToolCall{
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
}
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if !toolCallEqual(calls[0], want) {
t.Fatalf("got %#v, want %#v", calls[0], want)
}
}

View File

@@ -29,9 +29,10 @@ const (
)
type Qwen3CoderParser struct {
state qwenParserState
acc strings.Builder
tools []api.Tool
state qwenParserState
acc strings.Builder
tools []api.Tool
callIndex int
}
func (p *Qwen3CoderParser) HasToolSupport() bool {
@@ -44,6 +45,7 @@ func (p *Qwen3CoderParser) HasThinkingSupport() bool {
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
p.callIndex = 0
return tools // Qwen doesn't modify tools
}
@@ -62,6 +64,8 @@ func (p *Qwen3CoderParser) Add(s string, done bool) (content string, thinking st
slog.Warn("qwen tool call parsing failed", "error", err)
return "", "", nil, err
}
toolCall.Function.Index = p.callIndex
p.callIndex++
toolCalls = append(toolCalls, toolCall)
case qwenEventContent:
// TODO(drifkin): if the same turn contains multiple interleaved content

View File

@@ -1035,6 +1035,92 @@ func TestQwenToolCallValueParsing(t *testing.T) {
}
}
func TestQwen3CoderParserToolCallIndexing(t *testing.T) {
parser := Qwen3CoderParser{}
parser.Init(nil, nil, nil)
input := `<tool_call><function=first><parameter=a>1</parameter></function></tool_call>
<tool_call><function=second><parameter=b>2</parameter></function></tool_call>
<tool_call><function=third><parameter=c>3</parameter></function></tool_call>`
_, _, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
}
if len(calls) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
}
for i := range want {
if !toolCallEqual(calls[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
}
}
}
func TestQwen3CoderParserToolCallIndexingStreaming(t *testing.T) {
parser := Qwen3CoderParser{}
parser.Init(nil, nil, nil)
var all []api.ToolCall
_, _, calls, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call><tool_call><function=second>", false)
if err != nil {
t.Fatalf("step 1 parse failed: %v", err)
}
all = append(all, calls...)
_, _, calls, err = parser.Add("<parameter=b>2</parameter></function></tool_call><tool_call><function=third><parameter=c>3</parameter></function></tool_call>", true)
if err != nil {
t.Fatalf("step 2 parse failed: %v", err)
}
all = append(all, calls...)
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
}
if len(all) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(all))
}
for i := range want {
if !toolCallEqual(all[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
}
}
}
func TestQwen3CoderParserToolCallIndexResetOnInit(t *testing.T) {
parser := Qwen3CoderParser{}
parser.Init(nil, nil, nil)
_, _, _, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call>", true)
if err != nil {
t.Fatalf("first parse failed: %v", err)
}
parser.Init(nil, nil, nil)
_, _, calls, err := parser.Add("<tool_call><function=second><parameter=b>2</parameter></function></tool_call>", true)
if err != nil {
t.Fatalf("second parse failed: %v", err)
}
want := api.ToolCall{
Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 0},
}
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if !toolCallEqual(calls[0], want) {
t.Fatalf("got %#v, want %#v", calls[0], want)
}
}
func TestQwenXMLTransform(t *testing.T) {
cases := []struct {
desc string

View File

@@ -8,7 +8,21 @@ import (
"github.com/ollama/ollama/api"
)
type GlmOcrRenderer struct{}
type GlmOcrRenderer struct {
useImgTags bool
}
func (r *GlmOcrRenderer) renderContent(message api.Message, imageOffset int) (string, int) {
var sb strings.Builder
for range message.Images {
if r.useImgTags {
sb.WriteString(fmt.Sprintf("[img-%d]", imageOffset))
imageOffset++
}
}
sb.WriteString(message.Content)
return sb.String(), imageOffset
}
func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
@@ -38,11 +52,14 @@ func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkV
thinkingExplicitlySet = true
}
imageOffset := 0
for i, message := range messages {
switch message.Role {
case "user":
sb.WriteString("<|user|>\n")
sb.WriteString(message.Content)
content, nextOffset := r.renderContent(message, imageOffset)
imageOffset = nextOffset
sb.WriteString(content)
if thinkingExplicitlySet && !enableThinking && !strings.HasSuffix(message.Content, "/nothink") {
sb.WriteString("/nothink")
}

View File

@@ -0,0 +1,99 @@
package renderers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestGlmOcrRenderer_Images(t *testing.T) {
tests := []struct {
name string
renderer *GlmOcrRenderer
messages []api.Message
expected string
}{
{
name: "use_img_tags_single_image",
renderer: &GlmOcrRenderer{useImgTags: true},
messages: []api.Message{
{
Role: "user",
Content: "Describe this image.",
Images: []api.ImageData{api.ImageData("img1")},
},
},
expected: "[gMASK]<sop><|user|>\n[img-0]Describe this image.<|assistant|>\n",
},
{
name: "use_img_tags_multiple_images",
renderer: &GlmOcrRenderer{useImgTags: true},
messages: []api.Message{
{
Role: "user",
Content: "Describe these images.",
Images: []api.ImageData{api.ImageData("img1"), api.ImageData("img2")},
},
},
expected: "[gMASK]<sop><|user|>\n[img-0][img-1]Describe these images.<|assistant|>\n",
},
{
name: "multi_turn_increments_image_offset",
renderer: &GlmOcrRenderer{useImgTags: true},
messages: []api.Message{
{
Role: "user",
Content: "First image",
Images: []api.ImageData{api.ImageData("img1")},
},
{
Role: "assistant",
Content: "Processed.",
},
{
Role: "user",
Content: "Second image",
Images: []api.ImageData{api.ImageData("img2")},
},
},
expected: "[gMASK]<sop><|user|>\n[img-0]First image<|assistant|>\n<think></think>\nProcessed.\n<|user|>\n[img-1]Second image<|assistant|>\n",
},
{
name: "default_no_img_tags",
renderer: &GlmOcrRenderer{},
messages: []api.Message{
{
Role: "user",
Content: "No image tags expected.",
Images: []api.ImageData{api.ImageData("img1")},
},
},
expected: "[gMASK]<sop><|user|>\nNo image tags expected.<|assistant|>\n",
},
{
name: "no_images_content_unchanged",
renderer: &GlmOcrRenderer{useImgTags: true},
messages: []api.Message{
{
Role: "user",
Content: "Text only message.",
},
},
expected: "[gMASK]<sop><|user|>\nText only message.<|assistant|>\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.renderer.Render(tt.messages, nil, nil)
if err != nil {
t.Fatalf("Render() error = %v", err)
}
if diff := cmp.Diff(tt.expected, got); diff != "" {
t.Fatalf("Render() mismatch (-want +got):\n%s", diff)
}
})
}
}

194
model/renderers/qwen35.go Normal file
View File

@@ -0,0 +1,194 @@
package renderers
import (
"fmt"
"strings"
"github.com/ollama/ollama/api"
)
const (
qwen35ThinkOpenTag = "<think>"
qwen35ThinkCloseTag = "</think>"
qwen35ToolPostamble = `
</tools>
If you choose to call a function ONLY reply in the following format with NO suffix:
<tool_call>
<function=example_function_name>
<parameter=example_parameter_1>
value_1
</parameter>
<parameter=example_parameter_2>
This is the value for the second parameter
that can span
multiple lines
</parameter>
</function>
</tool_call>
<IMPORTANT>
Reminder:
- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags
- Required parameters MUST be specified
- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
</IMPORTANT>`
)
type Qwen35Renderer struct {
isThinking bool
emitEmptyThinkOnNoThink bool
useImgTags bool
}
func (r *Qwen35Renderer) renderContent(content api.Message, imageOffset int) (string, int) {
// This assumes all images are at the front of the message - same assumption as ollama/ollama/runner.go
var subSb strings.Builder
for range content.Images {
if r.useImgTags {
subSb.WriteString(fmt.Sprintf("[img-%d]", imageOffset))
imageOffset++
} else {
subSb.WriteString("<|vision_start|><|image_pad|><|vision_end|>")
}
}
// TODO: support videos
subSb.WriteString(content.Content)
return subSb.String(), imageOffset
}
func splitQwen35ReasoningContent(content, messageThinking string, isThinking bool) (reasoning string, remaining string) {
if isThinking && messageThinking != "" {
return strings.TrimSpace(messageThinking), content
}
if idx := strings.Index(content, qwen35ThinkCloseTag); idx != -1 {
before := content[:idx]
if open := strings.LastIndex(before, qwen35ThinkOpenTag); open != -1 {
reasoning = before[open+len(qwen35ThinkOpenTag):]
} else {
reasoning = before
}
content = strings.TrimLeft(content[idx+len(qwen35ThinkCloseTag):], "\n")
}
return strings.TrimSpace(reasoning), content
}
func (r *Qwen35Renderer) Render(messages []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
var sb strings.Builder
isThinking := r.isThinking
if think != nil {
isThinking = think.Bool()
}
if len(tools) > 0 {
sb.WriteString(imStartTag + "system\n")
sb.WriteString("# Tools\n\nYou have access to the following functions:\n\n<tools>")
for _, tool := range tools {
sb.WriteString("\n")
if b, err := marshalWithSpaces(tool); err == nil {
sb.Write(b)
}
}
sb.WriteString(qwen35ToolPostamble)
if len(messages) > 0 && messages[0].Role == "system" {
systemContent, _ := r.renderContent(messages[0], 0)
systemContent = strings.TrimSpace(systemContent)
if systemContent != "" {
sb.WriteString("\n\n")
sb.WriteString(systemContent)
}
}
sb.WriteString(imEndTag + "\n")
} else if len(messages) > 0 && messages[0].Role == "system" {
systemContent, _ := r.renderContent(messages[0], 0)
sb.WriteString(imStartTag + "system\n" + strings.TrimSpace(systemContent) + imEndTag + "\n")
}
multiStepTool := true
lastQueryIndex := len(messages) - 1 // so this is the last user message
for i := len(messages) - 1; i >= 0; i-- {
message := messages[i]
if multiStepTool && message.Role == "user" {
content, _ := r.renderContent(message, 0)
content = strings.TrimSpace(content)
if !(strings.HasPrefix(content, "<tool_response>") && strings.HasSuffix(content, "</tool_response>")) {
multiStepTool = false
lastQueryIndex = i
}
}
}
imageOffset := 0
for i, message := range messages {
content, nextImageOffset := r.renderContent(message, imageOffset)
imageOffset = nextImageOffset
content = strings.TrimSpace(content)
lastMessage := i == len(messages)-1
prefill := lastMessage && message.Role == "assistant"
if message.Role == "user" || (message.Role == "system" && i != 0) {
sb.WriteString(imStartTag + message.Role + "\n" + content + imEndTag + "\n")
} else if message.Role == "assistant" {
contentReasoning, content := splitQwen35ReasoningContent(content, message.Thinking, isThinking)
if isThinking && i > lastQueryIndex {
sb.WriteString(imStartTag + message.Role + "\n<think>\n" + contentReasoning + "\n</think>\n\n" + content)
} else {
sb.WriteString(imStartTag + message.Role + "\n" + content)
}
if len(message.ToolCalls) > 0 {
for j, toolCall := range message.ToolCalls {
if j == 0 {
if strings.TrimSpace(content) != "" {
sb.WriteString("\n\n")
}
} else {
sb.WriteString("\n")
}
sb.WriteString("<tool_call>\n<function=" + toolCall.Function.Name + ">\n")
for name, value := range toolCall.Function.Arguments.All() {
sb.WriteString("<parameter=" + name + ">\n")
sb.WriteString(formatToolCallArgument(value))
sb.WriteString("\n</parameter>\n")
}
sb.WriteString("</function>\n</tool_call>")
}
}
if !prefill {
sb.WriteString(imEndTag + "\n")
}
} else if message.Role == "tool" {
if i == 0 || messages[i-1].Role != "tool" {
sb.WriteString(imStartTag + "user")
}
sb.WriteString("\n<tool_response>\n" + content + "\n</tool_response>")
if i == len(messages)-1 || messages[i+1].Role != "tool" {
sb.WriteString(imEndTag + "\n")
}
}
// prefill at the end
if lastMessage && !prefill {
sb.WriteString(imStartTag + "assistant\n")
if isThinking {
sb.WriteString("<think>\n")
} else if r.emitEmptyThinkOnNoThink {
sb.WriteString("<think>\n\n</think>\n\n")
}
}
}
return sb.String(), nil
}

View File

@@ -0,0 +1,389 @@
package renderers
import (
"strings"
"testing"
"github.com/ollama/ollama/api"
)
func TestQwen35RendererUsesXMLToolCallingFormat(t *testing.T) {
renderer := &Qwen35Renderer{isThinking: true}
msgs := []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "What's the weather in Paris?"},
{
Role: "assistant",
Content: "I'll check.",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgsOrdered([]orderedArg{
{Key: "location", Value: "Paris"},
}),
},
},
},
},
{Role: "tool", Content: "22C"},
{Role: "user", Content: "Thanks"},
}
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsOrdered([]orderedProp{
{
Key: "location",
Value: api.ToolProperty{
Type: api.PropertyType{"string"},
},
},
}),
Required: []string{"location"},
},
},
},
}
got, err := renderer.Render(msgs, tools, nil)
if err != nil {
t.Fatalf("render failed: %v", err)
}
if !strings.Contains(got, "<tools>") {
t.Fatalf("expected tools section in prompt, got:\n%s", got)
}
if !strings.Contains(got, "<function=example_function_name>") {
t.Fatalf("expected xml-style tool call instructions, got:\n%s", got)
}
wantToolCall := "<tool_call>\n<function=get_weather>\n<parameter=location>\nParis\n</parameter>\n</function>\n</tool_call>"
if !strings.Contains(got, wantToolCall) {
t.Fatalf("expected xml tool call payload, got:\n%s", got)
}
toolsIdx := strings.Index(got, "# Tools")
systemIdx := strings.Index(got, "You are a helpful assistant.")
if toolsIdx == -1 || systemIdx == -1 || systemIdx < toolsIdx {
t.Fatalf("expected system prompt appended after tool instructions, got:\n%s", got)
}
}
func TestQwen35RendererNoThinkPrefill(t *testing.T) {
renderer := &Qwen35Renderer{isThinking: true, emitEmptyThinkOnNoThink: true}
msgs := []api.Message{
{Role: "user", Content: "hello"},
}
got, err := renderer.Render(msgs, nil, &api.ThinkValue{Value: false})
if err != nil {
t.Fatalf("render failed: %v", err)
}
if !strings.HasSuffix(got, "<|im_start|>assistant\n<think>\n\n</think>\n\n") {
t.Fatalf("expected explicit no-think prefill, got:\n%s", got)
}
}
func TestQwen35RendererBackToBackToolCallsAndResponses(t *testing.T) {
renderer := &Qwen35Renderer{isThinking: true}
msgs := []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Run add and multiply."},
{
Role: "assistant",
Content: "I'll run both now.",
Thinking: "Need to call add and multiply.",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "add",
Arguments: testArgsOrdered([]orderedArg{
{Key: "a", Value: 2},
{Key: "b", Value: 3},
}),
},
},
{
Function: api.ToolCallFunction{
Name: "multiply",
Arguments: testArgsOrdered([]orderedArg{
{Key: "x", Value: 4},
{Key: "y", Value: 5},
}),
},
},
},
},
{Role: "tool", Content: "5"},
{Role: "tool", Content: "20"},
{Role: "user", Content: "Summarize the results."},
}
got, err := renderer.Render(msgs, qwen35MathTools(), nil)
if err != nil {
t.Fatalf("render failed: %v", err)
}
if strings.Contains(got, "Need to call add and multiply.") {
t.Fatalf("did not expect historical reasoning block in this sequence, got:\n%s", got)
}
wantToolCalls := `<tool_call>
<function=add>
<parameter=a>
2
</parameter>
<parameter=b>
3
</parameter>
</function>
</tool_call>
<tool_call>
<function=multiply>
<parameter=x>
4
</parameter>
<parameter=y>
5
</parameter>
</function>
</tool_call>`
if !strings.Contains(got, wantToolCalls) {
t.Fatalf("expected back-to-back tool calls, got:\n%s", got)
}
wantToolResponses := `<|im_start|>user
<tool_response>
5
</tool_response>
<tool_response>
20
</tool_response><|im_end|>`
if !strings.Contains(got, wantToolResponses) {
t.Fatalf("expected grouped back-to-back tool responses, got:\n%s", got)
}
if !strings.HasSuffix(got, "<|im_start|>assistant\n<think>\n") {
t.Fatalf("expected assistant thinking prefill at end, got:\n%s", got)
}
}
func TestQwen35RendererInterleavedThinkingAndTools(t *testing.T) {
renderer := &Qwen35Renderer{isThinking: true}
msgs := []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Plan a picnic in Paris."},
{
Role: "assistant",
Content: "Checking weather first.",
Thinking: "Need weather before giving advice.",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgsOrdered([]orderedArg{
{Key: "location", Value: "Paris"},
}),
},
},
},
},
{Role: "tool", Content: "22C"},
{
Role: "assistant",
Content: "Checking UV too.",
Thinking: "Need UV index for sunscreen advice.",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_uv",
Arguments: testArgsOrdered([]orderedArg{
{Key: "location", Value: "Paris"},
}),
},
},
},
},
{Role: "tool", Content: "5"},
}
got, err := renderer.Render(msgs, qwen35WeatherUVTools(), nil)
if err != nil {
t.Fatalf("render failed: %v", err)
}
wantFirstTurn := `<|im_start|>assistant
<think>
Need weather before giving advice.
</think>
Checking weather first.
<tool_call>
<function=get_weather>
<parameter=location>
Paris
</parameter>
</function>
</tool_call><|im_end|>`
if !strings.Contains(got, wantFirstTurn) {
t.Fatalf("expected first assistant thinking/tool sequence, got:\n%s", got)
}
wantSecondTurn := `<|im_start|>assistant
<think>
Need UV index for sunscreen advice.
</think>
Checking UV too.
<tool_call>
<function=get_uv>
<parameter=location>
Paris
</parameter>
</function>
</tool_call><|im_end|>`
if !strings.Contains(got, wantSecondTurn) {
t.Fatalf("expected second assistant thinking/tool sequence, got:\n%s", got)
}
if !strings.HasSuffix(got, "<|im_start|>assistant\n<think>\n") {
t.Fatalf("expected assistant thinking prefill at end, got:\n%s", got)
}
}
func TestQwen35RendererAssistantPrefillWithThinking(t *testing.T) {
renderer := &Qwen35Renderer{isThinking: true}
msgs := []api.Message{
{Role: "user", Content: "Write two words."},
{
Role: "assistant",
Thinking: "Keep it short.",
Content: "Hello world",
},
}
got, err := renderer.Render(msgs, nil, nil)
if err != nil {
t.Fatalf("render failed: %v", err)
}
want := `<|im_start|>user
Write two words.<|im_end|>
<|im_start|>assistant
<think>
Keep it short.
</think>
Hello world`
if got != want {
t.Fatalf("unexpected prefill output\n--- got ---\n%s\n--- want ---\n%s", got, want)
}
}
func qwen35MathTools() []api.Tool {
return []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "add",
Description: "Add two numbers",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsOrdered([]orderedProp{
{
Key: "a",
Value: api.ToolProperty{
Type: api.PropertyType{"integer"},
},
},
{
Key: "b",
Value: api.ToolProperty{
Type: api.PropertyType{"integer"},
},
},
}),
Required: []string{"a", "b"},
},
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "multiply",
Description: "Multiply two numbers",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsOrdered([]orderedProp{
{
Key: "x",
Value: api.ToolProperty{
Type: api.PropertyType{"integer"},
},
},
{
Key: "y",
Value: api.ToolProperty{
Type: api.PropertyType{"integer"},
},
},
}),
Required: []string{"x", "y"},
},
},
},
}
}
func qwen35WeatherUVTools() []api.Tool {
return []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather for a location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsOrdered([]orderedProp{
{
Key: "location",
Value: api.ToolProperty{
Type: api.PropertyType{"string"},
},
},
}),
Required: []string{"location"},
},
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "get_uv",
Description: "Get UV index for a location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsOrdered([]orderedProp{
{
Key: "location",
Value: api.ToolProperty{
Type: api.PropertyType{"string"},
},
},
}),
Required: []string{"location"},
},
},
},
}
}

View File

@@ -57,7 +57,7 @@ func rendererForName(name string) Renderer {
renderer := &Qwen3VLRenderer{isThinking: true, useImgTags: RenderImgTags}
return renderer
case "qwen3.5":
renderer := &Qwen3VLRenderer{isThinking: true, emitEmptyThinkOnNoThink: true, useImgTags: RenderImgTags}
renderer := &Qwen35Renderer{isThinking: true, emitEmptyThinkOnNoThink: true, useImgTags: RenderImgTags}
return renderer
case "cogito":
renderer := &CogitoRenderer{isThinking: true}
@@ -86,7 +86,7 @@ func rendererForName(name string) Renderer {
case "glm-4.7":
return &GLM47Renderer{}
case "glm-ocr":
return &GlmOcrRenderer{}
return &GlmOcrRenderer{useImgTags: RenderImgTags}
case "lfm2":
return &LFM2Renderer{IsThinking: false, useImgTags: RenderImgTags}
case "lfm2-thinking":

View File

@@ -181,6 +181,9 @@ func fileDigestMap(path string) (map[string]string, error) {
}
if !filepath.IsLocal(rel) {
if strings.Contains(rel, ".cache") {
return nil, fmt.Errorf("insecure path: %s\n\nUse --local-dir <dir> when downloading model to disable caching", rel)
}
return nil, fmt.Errorf("insecure path: %s", rel)
}

View File

@@ -562,6 +562,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
if errors.As(err, &reprocess) {
// Prepend these inputs to the sequence's inputs queue for reprocessing
seq.inputs = append(reprocess.Inputs, seq.inputs...)
seq.sampler.Reset()
// Skip this sequence but continue processing the rest
nextBatch.seqs[seqIdx] = nil // clear this sequence for this batch
err = nil
@@ -692,6 +693,12 @@ func (s *Server) computeBatch(activeBatch batchState) {
// (unless we take down the whole runner).
if len(seq.pendingInputs) > 0 {
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
for _, inp := range seq.pendingInputs {
if len(inp.Multimodal) != 0 {
continue
}
seq.sampler.Accept(inp.Token)
}
seq.pendingInputs = []*input.Input{}
}
@@ -892,6 +899,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
req.Options.TopK,
req.Options.TopP,
req.Options.MinP,
req.Options.RepeatPenalty,
req.Options.PresencePenalty,
req.Options.FrequencyPenalty,
req.Options.Seed,
grammar,
)
@@ -938,6 +948,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
seq.sampler.Reset()
for _, inp := range seq.cache.Inputs {
if len(inp.Multimodal) != 0 {
continue
}
seq.sampler.Accept(inp.Token)
}
s.seqs[i] = seq
s.cond.Signal()
found = true

View File

@@ -16,24 +16,49 @@ type token struct {
value float32 // The raw logit or probability from the model
}
const DefaultPenaltyLookback = 64
type Sampler struct {
rng *rand.Rand
topK int
topP float32
minP float32
temperature float32
repeat float32
presence float32
frequency float32
history []int32
grammar *GrammarSampler
}
func (s *Sampler) Reset() {
s.history = s.history[:0]
}
func (s *Sampler) Accept(token int32) {
s.history = append(s.history, token)
if len(s.history) > DefaultPenaltyLookback {
copy(s.history, s.history[len(s.history)-DefaultPenaltyLookback:])
s.history = s.history[:DefaultPenaltyLookback]
}
}
func (s *Sampler) Sample(logits []float32) (int32, error) {
if len(logits) == 0 {
return -1, errors.New("sample: no logits provided to sample")
}
counts := tokenCounts(s.history, len(logits))
tokens := make([]token, len(logits))
for i := range logits {
value := logits[i]
if count := counts[int32(i)]; count > 0 {
value = applyPenalty(value, count, s.repeat, s.presence, s.frequency)
}
tokens[i].id = int32(i)
tokens[i].value = logits[i]
tokens[i].value = value
}
t, err := s.sample(tokens)
@@ -55,8 +80,12 @@ func (s *Sampler) Sample(logits []float32) (int32, error) {
// we need to reset them before applying the grammar and
// sampling again
for i := range logits {
value := logits[i]
if count := counts[int32(i)]; count > 0 {
value = applyPenalty(value, count, s.repeat, s.presence, s.frequency)
}
tokens[i].id = int32(i)
tokens[i].value = logits[i]
tokens[i].value = value
}
s.grammar.Apply(tokens)
t, err = s.sample(tokens)
@@ -127,7 +156,7 @@ func (s *Sampler) sample(tokens []token) (token, error) {
}
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *GrammarSampler) Sampler {
func NewSampler(temperature float32, topK int, topP float32, minP float32, repeatPenalty float32, presencePenalty float32, frequencyPenalty float32, seed int, grammar *GrammarSampler) Sampler {
var rng *rand.Rand
if seed != -1 {
// PCG requires two parameters: sequence and stream
@@ -154,12 +183,19 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
minP = 1.0
}
if repeatPenalty <= 0 {
repeatPenalty = 1.0
}
return Sampler{
rng: rng,
topK: topK,
topP: topP,
minP: minP,
temperature: temperature,
repeat: repeatPenalty,
presence: presencePenalty,
frequency: frequencyPenalty,
grammar: grammar,
}
}

View File

@@ -16,7 +16,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
logits[i] = float32(rand.Float64()*10 - 5)
}
sampler := NewSampler(0.8, 0, 0, 0, 42, nil)
sampler := NewSampler(0.8, 0, 0, 0, 1, 0, 0, 42, nil)
b.ResetTimer()
for b.Loop() {
sampler.Sample(logits)
@@ -49,7 +49,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
for _, tc := range configs {
b.Run("Config"+tc.name, func(b *testing.B) {
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil)
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, 1, 0, 0, tc.seed, nil)
sampler.Sample(logits)
b.ResetTimer()
@@ -62,7 +62,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
// Test with combined transforms separately - topK influences performance greatly
b.Run("TransformCombined", func(b *testing.B) {
sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil)
sampler := NewSampler(0.8, 50, 0.9, 0.05, 1, 0, 0, 42, nil)
b.ResetTimer()
for b.Loop() {
@@ -81,7 +81,7 @@ func BenchmarkGreedySampler(b *testing.B) {
logits[i] = float32(rand.Float64()*10 - 5)
}
sampler := NewSampler(0, -1, 0, 0, -1, nil)
sampler := NewSampler(0, -1, 0, 0, 1, 0, 0, -1, nil)
b.ResetTimer()
for b.Loop() {

View File

@@ -13,7 +13,7 @@ import (
func TestWeighted(t *testing.T) {
logits := []float32{-10, 3, -10, -10}
sampler := NewSampler(0, 0, 0, 0, 0, nil)
sampler := NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil)
got, err := sampler.Sample(logits)
if err != nil {
t.Error(err)
@@ -25,7 +25,7 @@ func TestWeighted(t *testing.T) {
}
logits = []float32{-100, -10, 0, 10}
sampler = NewSampler(0, 0, 0, 0, 0, nil)
sampler = NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil)
got, err = sampler.Sample(logits)
if err != nil {
t.Error(err)
@@ -39,7 +39,7 @@ func TestWeighted(t *testing.T) {
// Test very high p
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
// Use extremely small topP to filter out all tokens
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
sampler = NewSampler(1.0, 0, 1e-10, 0, 1, 0, 0, 0, nil)
got, err = sampler.Sample(logits)
if err != nil {
t.Error(err)
@@ -52,7 +52,7 @@ func TestWeighted(t *testing.T) {
}
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
sampler = NewSampler(1, 0, 0.95, 0.05, 1, 0, 0, 0, nil)
got, err = sampler.Sample(logits)
if err == nil {
t.Errorf("expected error, got %d", got)
@@ -151,8 +151,8 @@ func TestGrammar(t *testing.T) {
func BenchmarkSample(b *testing.B) {
samplers := map[string]Sampler{
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
"Greedy": NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, 1, 0, 0, -1, nil),
}
// Generate random logits for benchmarking

View File

@@ -25,6 +25,48 @@ func (h *tokenHeap) Pop() any {
return x
}
func tokenCounts(history []int32, vocabSize int) map[int32]int {
if len(history) == 0 {
return nil
}
start := 0
if len(history) > DefaultPenaltyLookback {
start = len(history) - DefaultPenaltyLookback
}
counts := make(map[int32]int, len(history)-start)
for _, token := range history[start:] {
if token < 0 || int(token) >= vocabSize {
continue
}
counts[token]++
}
return counts
}
func applyPenalty(logit float32, count int, repeatPenalty float32, presencePenalty float32, frequencyPenalty float32) float32 {
if repeatPenalty != 1.0 {
// Preserve ordering for negative logits when applying repeat penalty.
if logit < 0 {
logit *= repeatPenalty
} else {
logit /= repeatPenalty
}
}
if frequencyPenalty != 0 {
logit -= float32(count) * frequencyPenalty
}
if presencePenalty != 0 {
logit -= presencePenalty
}
return logit
}
// temperature applies scaling to the logits
func temperature(ts []token, temp float32) {
// Ensure temperature clipping near 0 to avoid numerical instability

View File

@@ -295,6 +295,86 @@ func TestMinP(t *testing.T) {
}
}
func TestTokenCounts(t *testing.T) {
history := make([]int32, 70)
history[0] = 7
history[69] = 7
counts := tokenCounts(history, 8)
if got := counts[7]; got != 1 {
t.Fatalf("lookback mismatch: got %d want %d", got, 1)
}
}
func TestApplyPenalty(t *testing.T) {
logit := applyPenalty(5.0, 3, 1.0, 1.5, 0.5)
if math.Abs(float64(logit-2.0)) > 1e-6 {
t.Fatalf("unexpected penalty result: got %f want %f", logit, 2.0)
}
logit = applyPenalty(4.0, 1, 2.0, 0, 0)
if math.Abs(float64(logit-2.0)) > 1e-6 {
t.Fatalf("unexpected repeat penalty result for positive logits: got %f want %f", logit, 2.0)
}
logit = applyPenalty(-4.0, 1, 2.0, 0, 0)
if math.Abs(float64(logit-(-8.0))) > 1e-6 {
t.Fatalf("unexpected repeat penalty result for negative logits: got %f want %f", logit, -8.0)
}
}
func TestSamplerPresencePenalty(t *testing.T) {
logits := []float32{0.0, 5.0, 0.0}
baseline := NewSampler(0, 0, 1, 0, 1, 0, 0, -1, nil)
baseline.Accept(1)
got, err := baseline.Sample(logits)
if err != nil {
t.Fatal(err)
}
if got != 1 {
t.Fatalf("unexpected baseline token: got %d want %d", got, 1)
}
presence := NewSampler(0, 0, 1, 0, 1, 6, 0, -1, nil)
presence.Accept(1)
got, err = presence.Sample(logits)
if err != nil {
t.Fatal(err)
}
if got == 1 {
t.Fatalf("presence penalty did not change repeated token selection")
}
}
func TestSamplerFrequencyPenalty(t *testing.T) {
logits := []float32{0.0, 5.0, 4.0}
baseline := NewSampler(0, 0, 1, 0, 1, 0, 0, -1, nil)
baseline.Accept(1)
baseline.Accept(1)
baseline.Accept(1)
got, err := baseline.Sample(logits)
if err != nil {
t.Fatal(err)
}
if got != 1 {
t.Fatalf("unexpected baseline token: got %d want %d", got, 1)
}
frequency := NewSampler(0, 0, 1, 0, 1, 0, 1.0, -1, nil)
frequency.Accept(1)
frequency.Accept(1)
frequency.Accept(1)
got, err = frequency.Sample(logits)
if err != nil {
t.Fatal(err)
}
if got != 2 {
t.Fatalf("frequency penalty did not demote repeated token as expected: got %d want %d", got, 2)
}
}
func BenchmarkTransforms(b *testing.B) {
// Generate random logits
tokens := make([]token, 1<<16)

View File

@@ -59,7 +59,7 @@ _build_darwin() {
cmake --install $BUILD_DIR --component CPU
cmake --install $BUILD_DIR --component MLX
# Override CGO flags to point to the amd64 build directory
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
MLX_CGO_CFLAGS="-O3 -mmacosx-version-min=14.0"
MLX_CGO_LDFLAGS="-ldl -lc++ -framework Accelerate -mmacosx-version-min=14.0"
else
BUILD_DIR=build
@@ -70,10 +70,10 @@ _build_darwin() {
cmake --build --preset MLX --parallel
cmake --install $BUILD_DIR --component MLX
# Use default CGO flags from mlx.go for arm64
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
MLX_CGO_CFLAGS="-O3 -mmacosx-version-min=14.0"
MLX_CGO_LDFLAGS="-lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
fi
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX .
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -o $INSTALL_PREFIX .
# Copy MLX libraries to same directory as executable for dlopen
cp $INSTALL_PREFIX/lib/ollama/libmlxc.dylib $INSTALL_PREFIX/
cp $INSTALL_PREFIX/lib/ollama/libmlx.dylib $INSTALL_PREFIX/

View File

@@ -4,7 +4,10 @@
#
# gcloud auth application-default login
$ErrorActionPreference = "Stop"
# Use "Continue" so that stderr output from native commands (e.g. CGo warnings)
# is not promoted to a terminating exception by the try/catch block.
# All native commands already check $LASTEXITCODE explicitly.
$ErrorActionPreference = "Continue"
mkdir -Force -path .\dist | Out-Null
@@ -16,13 +19,13 @@ function checkEnv {
if ($null -ne $arch) {
$script:ARCH = ($arch.ToString().ToLower()).Replace("x64", "amd64")
} else {
write-host "WARNING: old powershell detected, assuming amd64 architecture - set `$env:ARCH to override"
Write-Output "WARNING: old powershell detected, assuming amd64 architecture - set `$env:ARCH to override"
$script:ARCH="amd64"
}
}
$script:TARGET_ARCH=$script:ARCH
Write-host "Building for ${script:TARGET_ARCH}"
write-host "Locating required tools and paths"
Write-Output "Locating required tools and paths"
$script:SRC_DIR=$PWD
# Locate CUDA versions
@@ -37,16 +40,17 @@ function checkEnv {
$script:CUDA_DIRS=($cudaList | sort-object -Descending)
}
if ($script:CUDA_DIRS.length -gt 0) {
write-host "Available CUDA Versions: $script:CUDA_DIRS"
Write-Output "Available CUDA Versions: $script:CUDA_DIRS"
} else {
write-host "No CUDA versions detected"
Write-Output "No CUDA versions detected"
}
# Locate ROCm version
if ($null -ne $env:HIP_PATH) {
# Locate ROCm v6
$rocmDir=(get-item "C:\Program Files\AMD\ROCm\6.*" -ea 'silentlycontinue' | sort-object -Descending | select-object -First 1)
if ($null -ne $rocmDir) {
$script:HIP_PATH=$rocmDir.FullName
} elseif ($null -ne $env:HIP_PATH -and $env:HIP_PATH -match '[/\\]6\.') {
$script:HIP_PATH=$env:HIP_PATH
} else {
$script:HIP_PATH=(get-item "C:\Program Files\AMD\ROCm\*\bin\" -ea 'silentlycontinue' | sort-object -Descending)
}
$inoSetup=(get-item "C:\Program Files*\Inno Setup*\")
@@ -78,7 +82,7 @@ function checkEnv {
} else {
$script:PKG_VERSION="0.0.0"
}
write-host "Building Ollama $script:VERSION with package version $script:PKG_VERSION"
Write-Output "Building Ollama $script:VERSION with package version $script:PKG_VERSION"
# Note: Windows Kits 10 signtool crashes with GCP's plugin
if ($null -eq $env:SIGN_TOOL) {
@@ -87,12 +91,32 @@ function checkEnv {
${script:SignTool}=${env:SIGN_TOOL}
}
if ("${env:KEY_CONTAINER}") {
${script:OLLAMA_CERT}=$(resolve-path "${script:SRC_DIR}\ollama_inc.crt")
Write-host "Code signing enabled"
if (Test-Path "${script:SRC_DIR}\ollama_inc.crt") {
${script:OLLAMA_CERT}=$(resolve-path "${script:SRC_DIR}\ollama_inc.crt")
Write-host "Code signing enabled"
} else {
Write-Output "WARNING: KEY_CONTAINER is set but ollama_inc.crt not found at ${script:SRC_DIR}\ollama_inc.crt - code signing disabled"
}
} else {
write-host "Code signing disabled - please set KEY_CONTAINERS to sign and copy ollama_inc.crt to the top of the source tree"
Write-Output "Code signing disabled - please set KEY_CONTAINERS to sign and copy ollama_inc.crt to the top of the source tree"
}
$script:JOBS=([Environment]::ProcessorCount)
if ($env:OLLAMA_BUILD_PARALLEL) {
$script:JOBS=[int]$env:OLLAMA_BUILD_PARALLEL
} else {
# Use physical core count rather than logical processors (hyperthreads)
# to avoid saturating the system during builds
try {
$cores = (Get-CimInstance Win32_Processor | Measure-Object -Property NumberOfCores -Sum).Sum
} catch {
$cores = 0
}
if ($cores -gt 0) {
$script:JOBS = $cores
} else {
$script:JOBS = [Environment]::ProcessorCount
}
}
Write-Output "Build parallelism: $script:JOBS (set OLLAMA_BUILD_PARALLEL to override)"
}
@@ -127,7 +151,7 @@ function cuda11 {
}
}
}
write-host "Building CUDA v$cudaMajorVer backend libraries $cuda"
Write-Output "Building CUDA v$cudaMajorVer backend libraries $cuda"
$env:CUDAToolkit_ROOT=$cuda
& cmake -B build\cuda_v$cudaMajorVer --preset "CUDA $cudaMajorVer" -T cuda="$cuda" -DCMAKE_CUDA_COMPILER="$cuda\bin\nvcc.exe" -G "Visual Studio 16 2019" --install-prefix "$script:DIST_DIR"
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
@@ -136,12 +160,12 @@ function cuda11 {
& cmake --install build\cuda_v$cudaMajorVer --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
} else {
write-host "CUDA v$cudaMajorVer not detected, skipping"
Write-Output "CUDA v$cudaMajorVer not detected, skipping"
}
} else {
write-host "not arch we wanted"
Write-Output "not arch we wanted"
}
write-host "done"
Write-Output "done"
}
function cudaCommon {
@@ -159,7 +183,7 @@ function cudaCommon {
}
}
}
write-host "Building CUDA v$cudaMajorVer backend libraries $cuda"
Write-Output "Building CUDA v$cudaMajorVer backend libraries $cuda"
$env:CUDAToolkit_ROOT=$cuda
& cmake -B build\cuda_v$cudaMajorVer --preset "CUDA $cudaMajorVer" -T cuda="$cuda" --install-prefix "$script:DIST_DIR"
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
@@ -168,7 +192,7 @@ function cudaCommon {
& cmake --install build\cuda_v$cudaMajorVer --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
} else {
write-host "CUDA v$cudaMajorVer not detected, skipping"
Write-Output "CUDA v$cudaMajorVer not detected, skipping"
}
}
}
@@ -181,11 +205,11 @@ function cuda13 {
cudaCommon("13")
}
function rocm {
function rocm6 {
mkdir -Force -path "${script:DIST_DIR}\" | Out-Null
if ($script:ARCH -ne "arm64") {
if ($script:HIP_PATH) {
write-host "Building ROCm backend libraries $script:HIP_PATH"
Write-Output "Building ROCm backend libraries $script:HIP_PATH"
if (-Not (get-command -ErrorAction silent ninja)) {
$NINJA_DIR=(gci -path (Get-CimInstance MSFT_VSInstance -Namespace root/cimv2/vs)[0].InstallLocation -r -fi ninja.exe).Directory.FullName
$env:PATH="$NINJA_DIR;$env:PATH"
@@ -193,9 +217,11 @@ function rocm {
$env:HIPCXX="${script:HIP_PATH}\bin\clang++.exe"
$env:HIP_PLATFORM="amd"
$env:CMAKE_PREFIX_PATH="${script:HIP_PATH}"
# Set CC/CXX via environment instead of -D flags to avoid triggering
# spurious compiler-change reconfigures that reset CMAKE_INSTALL_PREFIX
$env:CC="${script:HIP_PATH}\bin\clang.exe"
$env:CXX="${script:HIP_PATH}\bin\clang++.exe"
& cmake -B build\rocm --preset "ROCm 6" -G Ninja `
-DCMAKE_C_COMPILER=clang `
-DCMAKE_CXX_COMPILER=clang++ `
-DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" `
-DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" `
--install-prefix $script:DIST_DIR
@@ -203,20 +229,22 @@ function rocm {
$env:HIPCXX=""
$env:HIP_PLATFORM=""
$env:CMAKE_PREFIX_PATH=""
$env:CC=""
$env:CXX=""
& cmake --build build\rocm --target ggml-hip --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build\rocm --component "HIP" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
Remove-Item -Path $script:DIST_DIR\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
} else {
write-host "ROCm not detected, skipping"
Write-Output "ROCm not detected, skipping"
}
}
}
function vulkan {
if ($env:VULKAN_SDK) {
write-host "Building Vulkan backend libraries"
Write-Output "Building Vulkan backend libraries"
& cmake -B build\vulkan --preset Vulkan --install-prefix $script:DIST_DIR
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build build\vulkan --target ggml-vulkan --config Release --parallel $script:JOBS
@@ -224,33 +252,91 @@ function vulkan {
& cmake --install build\vulkan --component Vulkan --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
} else {
write-host "Vulkan not detected, skipping"
Write-Output "Vulkan not detected, skipping"
}
}
function mlxCuda13 {
mkdir -Force -path "${script:DIST_DIR}\" | Out-Null
$cudaMajorVer="13"
if ($script:ARCH -ne "arm64") {
if ("$script:CUDA_DIRS".Contains("v$cudaMajorVer")) {
foreach ($d in $Script:CUDA_DIRS){
if ($d.FullName.Contains("v$cudaMajorVer")) {
if (test-path -literalpath (join-path -path $d -childpath "nvcc.exe" ) ) {
$cuda=($d.FullName|split-path -parent)
break
}
}
}
# Check for cuDNN - required for MLX CUDA backend
# Supports two layouts:
# 1. CI/zip extract: CUDNN\include\cudnn.h, lib\x64\, bin\x64\
# 2. Official installer: CUDNN\v*\include\{cuda-ver}\cudnn.h, lib\{cuda-ver}\x64\, bin\{cuda-ver}\
if ($env:CUDNN_INCLUDE_PATH -and $env:CUDNN_LIBRARY_PATH) {
Write-Output "Using cuDNN from environment: $env:CUDNN_INCLUDE_PATH"
} elseif (Test-Path "C:\Program Files\NVIDIA\CUDNN\include\cudnn.h") {
# CI/zip layout (flat)
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
$env:CUDNN_ROOT_DIR = $cudnnRoot
$env:CUDNN_INCLUDE_PATH = "$cudnnRoot\include"
$env:CUDNN_LIBRARY_PATH = "$cudnnRoot\lib\x64"
Write-Output "Found cuDNN at $cudnnRoot (flat layout)"
} else {
# Official installer layout (versioned)
$cudnnRoot = $null
$resolved = Resolve-Path -Path "C:\Program Files\NVIDIA\CUDNN\v*" -ErrorAction SilentlyContinue | Sort-Object -Descending | Select-Object -First 1
if ($resolved -and (Test-Path "$($resolved.Path)\include\$cudaMajorVer.0\cudnn.h")) {
$cudnnRoot = $resolved.Path
$env:CUDNN_ROOT_DIR = $cudnnRoot
$env:CUDNN_INCLUDE_PATH = "$cudnnRoot\include\$cudaMajorVer.0"
$env:CUDNN_LIBRARY_PATH = "$cudnnRoot\lib\$cudaMajorVer.0\x64"
Write-Output "Found cuDNN at $cudnnRoot (official installer, CUDA $cudaMajorVer.0)"
} else {
Write-Output "cuDNN not found - set CUDNN_INCLUDE_PATH and CUDNN_LIBRARY_PATH environment variables"
Write-Output "Skipping MLX build"
return
}
}
Write-Output "Building MLX CUDA v$cudaMajorVer backend libraries $cuda"
$env:CUDAToolkit_ROOT=$cuda
& cmake -B build\mlx_cuda_v$cudaMajorVer --preset "MLX CUDA $cudaMajorVer" -T cuda="$cuda" --install-prefix "$script:DIST_DIR"
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build build\mlx_cuda_v$cudaMajorVer --target mlx --target mlxc --config Release --parallel $script:JOBS -- /nodeReuse:false
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build\mlx_cuda_v$cudaMajorVer --component "MLX" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
} else {
Write-Output "CUDA v$cudaMajorVer not detected, skipping MLX build"
}
}
}
function ollama {
mkdir -Force -path "${script:DIST_DIR}\" | Out-Null
write-host "Building ollama CLI"
Write-Output "Building ollama CLI"
& go build -trimpath -ldflags "-s -w -X=github.com/ollama/ollama/version.Version=$script:VERSION -X=github.com/ollama/ollama/server.mode=release" .
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
cp .\ollama.exe "${script:DIST_DIR}\"
}
function app {
write-host "Building Ollama App $script:VERSION with package version $script:PKG_VERSION"
Write-Output "Building Ollama App $script:VERSION with package version $script:PKG_VERSION"
if (!(Get-Command npm -ErrorAction SilentlyContinue)) {
write-host "npm is not installed. Please install Node.js and npm first:"
write-host " Visit: https://nodejs.org/"
Write-Output "npm is not installed. Please install Node.js and npm first:"
Write-Output " Visit: https://nodejs.org/"
exit 1
}
if (!(Get-Command tsc -ErrorAction SilentlyContinue)) {
write-host "Installing TypeScript compiler..."
Write-Output "Installing TypeScript compiler..."
npm install -g typescript
}
if (!(Get-Command tscriptify -ErrorAction SilentlyContinue)) {
write-host "Installing tscriptify..."
Write-Output "Installing tscriptify..."
go install github.com/tkrajina/typescriptify-golang-structs/tscriptify@latest
}
if (!(Get-Command tscriptify -ErrorAction SilentlyContinue)) {
@@ -260,32 +346,32 @@ function app {
Push-Location app/ui/app
npm install
if ($LASTEXITCODE -ne 0) {
write-host "ERROR: npm install failed with exit code $LASTEXITCODE"
Write-Output "ERROR: npm install failed with exit code $LASTEXITCODE"
exit $LASTEXITCODE
}
write-host "Building React application..."
Write-Output "Building React application..."
npm run build
if ($LASTEXITCODE -ne 0) {
write-host "ERROR: npm run build failed with exit code $LASTEXITCODE"
Write-Output "ERROR: npm run build failed with exit code $LASTEXITCODE"
exit $LASTEXITCODE
}
# Check if dist directory exists and has content
if (!(Test-Path "dist")) {
write-host "ERROR: dist directory was not created by npm run build"
Write-Output "ERROR: dist directory was not created by npm run build"
exit 1
}
$distFiles = Get-ChildItem "dist" -Recurse
if ($distFiles.Count -eq 0) {
write-host "ERROR: dist directory is empty after npm run build"
Write-Output "ERROR: dist directory is empty after npm run build"
exit 1
}
Pop-Location
write-host "Running go generate"
Write-Output "Running go generate"
& go generate ./...
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& go build -trimpath -ldflags "-s -w -H windowsgui -X=github.com/ollama/ollama/app/version.Version=$script:VERSION" -o .\dist\windows-ollama-app-${script:ARCH}.exe ./app/cmd/app/
@@ -293,42 +379,42 @@ function app {
}
function deps {
write-host "Download MSVC Redistributables"
Write-Output "Download MSVC Redistributables"
mkdir -Force -path "${script:SRC_DIR}\dist\\windows-arm64" | Out-Null
mkdir -Force -path "${script:SRC_DIR}\dist\\windows-amd64" | Out-Null
invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.arm64.exe" -OutFile "${script:SRC_DIR}\dist\windows-arm64\vc_redist.arm64.exe"
invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.x64.exe" -OutFile "${script:SRC_DIR}\dist\windows-amd64\vc_redist.x64.exe"
write-host "Done."
invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.arm64.exe" -OutFile "${script:SRC_DIR}\dist\windows-arm64\vc_redist.arm64.exe" -ErrorAction Stop
invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.x64.exe" -OutFile "${script:SRC_DIR}\dist\windows-amd64\vc_redist.x64.exe" -ErrorAction Stop
Write-Output "Done."
}
function sign {
# Copy install.ps1 to dist for release packaging
write-host "Copying install.ps1 to dist"
Copy-Item -Path "${script:SRC_DIR}\scripts\install.ps1" -Destination "${script:SRC_DIR}\dist\install.ps1"
Write-Output "Copying install.ps1 to dist"
Copy-Item -Path "${script:SRC_DIR}\scripts\install.ps1" -Destination "${script:SRC_DIR}\dist\install.ps1" -ErrorAction Stop
if ("${env:KEY_CONTAINER}") {
write-host "Signing Ollama executables, scripts and libraries"
Write-Output "Signing Ollama executables, scripts and libraries"
& "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} `
$(get-childitem -path "${script:SRC_DIR}\dist\windows-*" -r -include @('*.exe', '*.dll'))
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
write-host "Signing install.ps1"
Write-Output "Signing install.ps1"
& "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} `
"${script:SRC_DIR}\dist\install.ps1"
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
} else {
write-host "Signing not enabled"
Write-Output "Signing not enabled"
}
}
function installer {
if ($null -eq ${script:INNO_SETUP_DIR}) {
write-host "ERROR: missing Inno Setup installation directory - install from https://jrsoftware.org/isdl.php"
Write-Output "ERROR: missing Inno Setup installation directory - install from https://jrsoftware.org/isdl.php"
exit 1
}
write-host "Building Ollama Installer"
Write-Output "Building Ollama Installer"
cd "${script:SRC_DIR}\app"
$env:PKG_VERSION=$script:PKG_VERSION
if ("${env:KEY_CONTAINER}") {
@@ -342,24 +428,24 @@ function installer {
function zip {
if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64") {
if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm") {
write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64-rocm.zip"
Write-Output "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64-rocm.zip"
# Temporarily adjust paths so we can retain the same directory structure
Remove-Item -ea 0 -r "${script:SRC_DIR}\dist\windows-amd64-rocm"
mkdir -Force -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama"
Write-Output "Extract this ROCm zip file to the same location where you extracted ollama-windows-amd64.zip" > "${script:SRC_DIR}\dist\windows-amd64-rocm\README.txt"
Move-Item -path "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -destination "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama"
Move-Item -path "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -destination "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama" -ErrorAction Stop
Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-amd64-rocm\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64-rocm.zip" -Force
}
write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64.zip"
Write-Output "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64.zip"
Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-amd64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64.zip" -Force
if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64-rocm") {
Move-Item -destination "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama\rocm"
Move-Item -destination "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama\rocm" -ErrorAction Stop
}
}
if (Test-Path -Path "${script:SRC_DIR}\dist\windows-arm64") {
write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-arm64.zip"
Write-Output "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-arm64.zip"
Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-arm64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-arm64.zip" -Force
}
}
@@ -375,8 +461,9 @@ try {
cpu
cuda12
cuda13
rocm
rocm6
vulkan
mlxCuda13
ollama
app
deps
@@ -385,13 +472,13 @@ try {
zip
} else {
for ( $i = 0; $i -lt $args.count; $i++ ) {
write-host "running build step $($args[$i])"
Write-Output "running build step $($args[$i])"
& $($args[$i])
}
}
} catch {
write-host "Build Failed"
write-host $_
Write-Error "Build Failed: $($_.Exception.Message)"
Write-Error "$($_.ScriptStackTrace)"
} finally {
set-location $script:SRC_DIR
$env:PKG_VERSION=""

View File

@@ -16,9 +16,16 @@ OLLAMA_COMMON_BUILD_ARGS="--build-arg=VERSION \
--build-arg=OLLAMA_FAST_BUILD \
--build-arg=CUSTOM_CPU_FLAGS \
--build-arg=GPU_RUNNER_CPU_FLAGS \
--build-arg=PARALLEL \
--build-arg=AMDGPU_TARGETS"
# Forward local MLX source overrides as Docker build contexts
if [ -n "${OLLAMA_MLX_SOURCE:-}" ]; then
OLLAMA_COMMON_BUILD_ARGS="$OLLAMA_COMMON_BUILD_ARGS --build-context local-mlx=$(cd "$OLLAMA_MLX_SOURCE" && pwd)"
fi
if [ -n "${OLLAMA_MLX_C_SOURCE:-}" ]; then
OLLAMA_COMMON_BUILD_ARGS="$OLLAMA_COMMON_BUILD_ARGS --build-context local-mlx-c=$(cd "$OLLAMA_MLX_C_SOURCE" && pwd)"
fi
echo "Building Ollama"
echo "VERSION=$VERSION"
echo "PLATFORM=$PLATFORM"

479
server/cloud_proxy.go Normal file
View File

@@ -0,0 +1,479 @@
package server
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
internalcloud "github.com/ollama/ollama/internal/cloud"
)
const (
defaultCloudProxyBaseURL = "https://ollama.com:443"
defaultCloudProxySigningHost = "ollama.com"
cloudProxyBaseURLEnv = "OLLAMA_CLOUD_BASE_URL"
legacyCloudAnthropicKey = "legacy_cloud_anthropic_web_search"
)
var (
cloudProxyBaseURL = defaultCloudProxyBaseURL
cloudProxySigningHost = defaultCloudProxySigningHost
cloudProxySignRequest = signCloudProxyRequest
cloudProxySigninURL = signinURL
)
var hopByHopHeaders = map[string]struct{}{
"connection": {},
"content-length": {},
"proxy-connection": {},
"keep-alive": {},
"proxy-authenticate": {},
"proxy-authorization": {},
"te": {},
"trailer": {},
"transfer-encoding": {},
"upgrade": {},
}
func init() {
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL(envconfig.Var(cloudProxyBaseURLEnv), mode)
if err != nil {
slog.Warn("ignoring cloud base URL override", "env", cloudProxyBaseURLEnv, "error", err)
return
}
cloudProxyBaseURL = baseURL
cloudProxySigningHost = signingHost
if overridden {
slog.Info("cloud base URL override enabled", "env", cloudProxyBaseURLEnv, "url", cloudProxyBaseURL, "mode", mode)
}
}
func cloudPassthroughMiddleware(disabledOperation string) gin.HandlerFunc {
return func(c *gin.Context) {
if c.Request.Method != http.MethodPost {
c.Next()
return
}
// TODO(drifkin): Avoid full-body buffering here for model detection.
// A future optimization can parse just enough JSON to read "model" (and
// optionally short-circuit cloud-disabled explicit-cloud requests) while
// preserving raw passthrough semantics.
body, err := readRequestBody(c.Request)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
c.Abort()
return
}
model, ok := extractModelField(body)
if !ok {
c.Next()
return
}
modelRef, err := parseAndValidateModelRef(model)
if err != nil || modelRef.Source != modelSourceCloud {
c.Next()
return
}
normalizedBody, err := replaceJSONModelField(body, modelRef.Base)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
c.Abort()
return
}
// TEMP(drifkin): keep Anthropic web search requests on the local middleware
// path so WebSearchAnthropicWriter can orchestrate follow-up calls.
if c.Request.URL.Path == "/v1/messages" {
if hasAnthropicWebSearchTool(body) {
c.Set(legacyCloudAnthropicKey, true)
c.Next()
return
}
}
proxyCloudRequest(c, normalizedBody, disabledOperation)
c.Abort()
}
}
func cloudModelPathPassthroughMiddleware(disabledOperation string) gin.HandlerFunc {
return func(c *gin.Context) {
modelName := strings.TrimSpace(c.Param("model"))
if modelName == "" {
c.Next()
return
}
modelRef, err := parseAndValidateModelRef(modelName)
if err != nil || modelRef.Source != modelSourceCloud {
c.Next()
return
}
proxyPath := "/v1/models/" + modelRef.Base
proxyCloudRequestWithPath(c, nil, proxyPath, disabledOperation)
c.Abort()
}
}
func proxyCloudJSONRequest(c *gin.Context, payload any, disabledOperation string) {
// TEMP(drifkin): we currently split out this `WithPath` method because we are
// mapping `/v1/messages` + web_search to `/api/chat` temporarily. Once we
// stop doing this, we can inline this method.
proxyCloudJSONRequestWithPath(c, payload, c.Request.URL.Path, disabledOperation)
}
func proxyCloudJSONRequestWithPath(c *gin.Context, payload any, path string, disabledOperation string) {
body, err := json.Marshal(payload)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
proxyCloudRequestWithPath(c, body, path, disabledOperation)
}
func proxyCloudRequest(c *gin.Context, body []byte, disabledOperation string) {
proxyCloudRequestWithPath(c, body, c.Request.URL.Path, disabledOperation)
}
func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disabledOperation string) {
if disabled, _ := internalcloud.Status(); disabled {
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(disabledOperation)})
return
}
baseURL, err := url.Parse(cloudProxyBaseURL)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
targetURL := baseURL.ResolveReference(&url.URL{
Path: path,
RawQuery: c.Request.URL.RawQuery,
})
outReq, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL.String(), bytes.NewReader(body))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
copyProxyRequestHeaders(outReq.Header, c.Request.Header)
if outReq.Header.Get("Content-Type") == "" && len(body) > 0 {
outReq.Header.Set("Content-Type", "application/json")
}
if err := cloudProxySignRequest(outReq.Context(), outReq); err != nil {
slog.Warn("cloud proxy signing failed", "error", err)
writeCloudUnauthorized(c)
return
}
// TODO(drifkin): Add phase-specific proxy timeouts.
// Connect/TLS/TTFB should have bounded timeouts, but once streaming starts
// we should not enforce a short total timeout for long-lived responses.
resp, err := http.DefaultClient.Do(outReq)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
}
defer resp.Body.Close()
copyProxyResponseHeaders(c.Writer.Header(), resp.Header)
c.Status(resp.StatusCode)
if err := copyProxyResponseBody(c.Writer, resp.Body); err != nil {
ctxErr := c.Request.Context().Err()
if errors.Is(err, context.Canceled) && errors.Is(ctxErr, context.Canceled) {
slog.Debug(
"cloud proxy response stream closed by client",
"path", c.Request.URL.Path,
"status", resp.StatusCode,
)
return
}
slog.Warn(
"cloud proxy response copy failed",
"path", c.Request.URL.Path,
"status", resp.StatusCode,
"request_context_canceled", ctxErr != nil,
"request_context_err", ctxErr,
"error", err,
)
return
}
}
func replaceJSONModelField(body []byte, model string) ([]byte, error) {
if len(body) == 0 {
return body, nil
}
var payload map[string]json.RawMessage
if err := json.Unmarshal(body, &payload); err != nil {
return nil, err
}
modelJSON, err := json.Marshal(model)
if err != nil {
return nil, err
}
payload["model"] = modelJSON
return json.Marshal(payload)
}
func readRequestBody(r *http.Request) ([]byte, error) {
if r.Body == nil {
return nil, nil
}
body, err := io.ReadAll(r.Body)
if err != nil {
return nil, err
}
r.Body = io.NopCloser(bytes.NewReader(body))
return body, nil
}
func extractModelField(body []byte) (string, bool) {
if len(body) == 0 {
return "", false
}
var payload map[string]json.RawMessage
if err := json.Unmarshal(body, &payload); err != nil {
return "", false
}
raw, ok := payload["model"]
if !ok {
return "", false
}
var model string
if err := json.Unmarshal(raw, &model); err != nil {
return "", false
}
model = strings.TrimSpace(model)
return model, model != ""
}
func hasAnthropicWebSearchTool(body []byte) bool {
if len(body) == 0 {
return false
}
var payload struct {
Tools []struct {
Type string `json:"type"`
} `json:"tools"`
}
if err := json.Unmarshal(body, &payload); err != nil {
return false
}
for _, tool := range payload.Tools {
if strings.HasPrefix(strings.TrimSpace(tool.Type), "web_search") {
return true
}
}
return false
}
func writeCloudUnauthorized(c *gin.Context) {
signinURL, err := cloudProxySigninURL()
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": signinURL})
}
func signCloudProxyRequest(ctx context.Context, req *http.Request) error {
if !strings.EqualFold(req.URL.Hostname(), cloudProxySigningHost) {
return nil
}
ts := strconv.FormatInt(time.Now().Unix(), 10)
challenge := buildCloudSignatureChallenge(req, ts)
signature, err := auth.Sign(ctx, []byte(challenge))
if err != nil {
return err
}
req.Header.Set("Authorization", signature)
return nil
}
func buildCloudSignatureChallenge(req *http.Request, ts string) string {
query := req.URL.Query()
query.Set("ts", ts)
req.URL.RawQuery = query.Encode()
return fmt.Sprintf("%s,%s", req.Method, req.URL.RequestURI())
}
func resolveCloudProxyBaseURL(rawOverride string, runMode string) (baseURL string, signingHost string, overridden bool, err error) {
baseURL = defaultCloudProxyBaseURL
signingHost = defaultCloudProxySigningHost
rawOverride = strings.TrimSpace(rawOverride)
if rawOverride == "" {
return baseURL, signingHost, false, nil
}
u, err := url.Parse(rawOverride)
if err != nil {
return "", "", false, fmt.Errorf("invalid URL: %w", err)
}
if u.Scheme == "" || u.Host == "" {
return "", "", false, fmt.Errorf("invalid URL: scheme and host are required")
}
if u.User != nil {
return "", "", false, fmt.Errorf("invalid URL: userinfo is not allowed")
}
if u.Path != "" && u.Path != "/" {
return "", "", false, fmt.Errorf("invalid URL: path is not allowed")
}
if u.RawQuery != "" || u.Fragment != "" {
return "", "", false, fmt.Errorf("invalid URL: query and fragment are not allowed")
}
host := u.Hostname()
if host == "" {
return "", "", false, fmt.Errorf("invalid URL: host is required")
}
loopback := isLoopbackHost(host)
if runMode == gin.ReleaseMode && !loopback {
return "", "", false, fmt.Errorf("non-loopback cloud override is not allowed in release mode")
}
if !loopback && !strings.EqualFold(u.Scheme, "https") {
return "", "", false, fmt.Errorf("non-loopback cloud override must use https")
}
u.Path = ""
u.RawPath = ""
u.RawQuery = ""
u.Fragment = ""
return u.String(), strings.ToLower(host), true, nil
}
func isLoopbackHost(host string) bool {
if strings.EqualFold(host, "localhost") {
return true
}
ip := net.ParseIP(host)
return ip != nil && ip.IsLoopback()
}
func copyProxyRequestHeaders(dst, src http.Header) {
connectionTokens := connectionHeaderTokens(src)
for key, values := range src {
if isHopByHopHeader(key) || isConnectionTokenHeader(key, connectionTokens) {
continue
}
dst.Del(key)
for _, value := range values {
dst.Add(key, value)
}
}
}
func copyProxyResponseHeaders(dst, src http.Header) {
connectionTokens := connectionHeaderTokens(src)
for key, values := range src {
if isHopByHopHeader(key) || isConnectionTokenHeader(key, connectionTokens) {
continue
}
dst.Del(key)
for _, value := range values {
dst.Add(key, value)
}
}
}
func copyProxyResponseBody(dst http.ResponseWriter, src io.Reader) error {
flusher, canFlush := dst.(http.Flusher)
buf := make([]byte, 32*1024)
for {
n, err := src.Read(buf)
if n > 0 {
if _, writeErr := dst.Write(buf[:n]); writeErr != nil {
return writeErr
}
if canFlush {
// TODO(drifkin): Consider conditional flushing so non-streaming
// responses don't flush every write and can optimize throughput.
flusher.Flush()
}
}
if err != nil {
if err == io.EOF {
return nil
}
return err
}
}
}
func isHopByHopHeader(name string) bool {
_, ok := hopByHopHeaders[strings.ToLower(name)]
return ok
}
func connectionHeaderTokens(header http.Header) map[string]struct{} {
tokens := map[string]struct{}{}
for _, raw := range header.Values("Connection") {
for _, token := range strings.Split(raw, ",") {
token = strings.TrimSpace(strings.ToLower(token))
if token == "" {
continue
}
tokens[token] = struct{}{}
}
}
return tokens
}
func isConnectionTokenHeader(name string, tokens map[string]struct{}) bool {
if len(tokens) == 0 {
return false
}
_, ok := tokens[strings.ToLower(name)]
return ok
}

154
server/cloud_proxy_test.go Normal file
View File

@@ -0,0 +1,154 @@
package server
import (
"net/http"
"testing"
"github.com/gin-gonic/gin"
)
func TestCopyProxyRequestHeaders_StripsConnectionTokenHeaders(t *testing.T) {
src := http.Header{}
src.Add("Connection", "keep-alive, X-Trace-Hop, x-alt-hop")
src.Add("X-Trace-Hop", "drop-me")
src.Add("X-Alt-Hop", "drop-me-too")
src.Add("Keep-Alive", "timeout=5")
src.Add("X-End-To-End", "keep-me")
dst := http.Header{}
copyProxyRequestHeaders(dst, src)
if got := dst.Get("Connection"); got != "" {
t.Fatalf("expected Connection to be stripped, got %q", got)
}
if got := dst.Get("Keep-Alive"); got != "" {
t.Fatalf("expected Keep-Alive to be stripped, got %q", got)
}
if got := dst.Get("X-Trace-Hop"); got != "" {
t.Fatalf("expected X-Trace-Hop to be stripped via Connection token, got %q", got)
}
if got := dst.Get("X-Alt-Hop"); got != "" {
t.Fatalf("expected X-Alt-Hop to be stripped via Connection token, got %q", got)
}
if got := dst.Get("X-End-To-End"); got != "keep-me" {
t.Fatalf("expected X-End-To-End to be forwarded, got %q", got)
}
}
func TestCopyProxyResponseHeaders_StripsConnectionTokenHeaders(t *testing.T) {
src := http.Header{}
src.Add("Connection", "X-Upstream-Hop")
src.Add("X-Upstream-Hop", "drop-me")
src.Add("Content-Type", "application/json")
src.Add("X-Server-Trace", "keep-me")
dst := http.Header{}
copyProxyResponseHeaders(dst, src)
if got := dst.Get("Connection"); got != "" {
t.Fatalf("expected Connection to be stripped, got %q", got)
}
if got := dst.Get("X-Upstream-Hop"); got != "" {
t.Fatalf("expected X-Upstream-Hop to be stripped via Connection token, got %q", got)
}
if got := dst.Get("Content-Type"); got != "application/json" {
t.Fatalf("expected Content-Type to be forwarded, got %q", got)
}
if got := dst.Get("X-Server-Trace"); got != "keep-me" {
t.Fatalf("expected X-Server-Trace to be forwarded, got %q", got)
}
}
func TestResolveCloudProxyBaseURL_Default(t *testing.T) {
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("", gin.ReleaseMode)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if overridden {
t.Fatal("expected override=false for empty input")
}
if baseURL != defaultCloudProxyBaseURL {
t.Fatalf("expected default base URL %q, got %q", defaultCloudProxyBaseURL, baseURL)
}
if signingHost != defaultCloudProxySigningHost {
t.Fatalf("expected default signing host %q, got %q", defaultCloudProxySigningHost, signingHost)
}
}
func TestResolveCloudProxyBaseURL_ReleaseAllowsLoopback(t *testing.T) {
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("http://localhost:8080", gin.ReleaseMode)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !overridden {
t.Fatal("expected override=true")
}
if baseURL != "http://localhost:8080" {
t.Fatalf("unexpected base URL: %q", baseURL)
}
if signingHost != "localhost" {
t.Fatalf("unexpected signing host: %q", signingHost)
}
}
func TestResolveCloudProxyBaseURL_ReleaseRejectsNonLoopback(t *testing.T) {
_, _, _, err := resolveCloudProxyBaseURL("https://example.com", gin.ReleaseMode)
if err == nil {
t.Fatal("expected error for non-loopback override in release mode")
}
}
func TestResolveCloudProxyBaseURL_DevAllowsNonLoopbackHTTPS(t *testing.T) {
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("https://example.com:8443", gin.DebugMode)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !overridden {
t.Fatal("expected override=true")
}
if baseURL != "https://example.com:8443" {
t.Fatalf("unexpected base URL: %q", baseURL)
}
if signingHost != "example.com" {
t.Fatalf("unexpected signing host: %q", signingHost)
}
}
func TestResolveCloudProxyBaseURL_DevRejectsNonLoopbackHTTP(t *testing.T) {
_, _, _, err := resolveCloudProxyBaseURL("http://example.com", gin.DebugMode)
if err == nil {
t.Fatal("expected error for non-loopback http override in dev mode")
}
}
func TestBuildCloudSignatureChallengeIncludesExistingQuery(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&foo=bar", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
got := buildCloudSignatureChallenge(req, "123")
want := "POST,/v1/messages?beta=true&foo=bar&ts=123"
if got != want {
t.Fatalf("challenge mismatch: got %q want %q", got, want)
}
if req.URL.RawQuery != "beta=true&foo=bar&ts=123" {
t.Fatalf("unexpected signed query: %q", req.URL.RawQuery)
}
}
func TestBuildCloudSignatureChallengeOverwritesExistingTimestamp(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&ts=999", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
got := buildCloudSignatureChallenge(req, "123")
want := "POST,/v1/messages?beta=true&ts=123"
if got != want {
t.Fatalf("challenge mismatch: got %q want %q", got, want)
}
if req.URL.RawQuery != "beta=true&ts=123" {
t.Fatalf("unexpected signed query: %q", req.URL.RawQuery)
}
}

View File

@@ -65,11 +65,22 @@ func (s *Server) CreateHandler(c *gin.Context) {
config.Parser = r.Parser
config.Requires = r.Requires
for v := range r.Files {
for v, digest := range r.Files {
if !fs.ValidPath(v) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
return
}
if digest == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": manifest.ErrInvalidDigestFormat.Error()})
return
}
}
for _, digest := range r.Adapters {
if digest == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": manifest.ErrInvalidDigestFormat.Error()})
return
}
}
name := model.ParseName(cmp.Or(r.Model, r.Name))
@@ -99,19 +110,26 @@ func (s *Server) CreateHandler(c *gin.Context) {
if r.From != "" {
slog.Debug("create model from model name", "from", r.From)
fromName := model.ParseName(r.From)
if !fromName.IsValid() {
fromRef, err := parseAndValidateModelRef(r.From)
if err != nil {
ch <- gin.H{"error": errtypes.InvalidModelNameErrMsg, "status": http.StatusBadRequest}
return
}
if r.RemoteHost != "" {
ru, err := remoteURL(r.RemoteHost)
fromName := fromRef.Name
remoteHost := r.RemoteHost
if fromRef.Source == modelSourceCloud && remoteHost == "" {
remoteHost = cloudProxyBaseURL
}
if remoteHost != "" {
ru, err := remoteURL(remoteHost)
if err != nil {
ch <- gin.H{"error": "bad remote", "status": http.StatusBadRequest}
return
}
config.RemoteModel = r.From
config.RemoteModel = fromRef.Base
config.RemoteHost = ru
remote = true
} else {

View File

@@ -71,6 +71,10 @@ type Model struct {
Template *template.Template
}
func (m *Model) IsMLX() bool {
return m.Config.ModelFormat == "safetensors"
}
// Capabilities returns the capabilities that the model supports
func (m *Model) Capabilities() []model.Capability {
capabilities := []model.Capability{}

81
server/model_resolver.go Normal file
View File

@@ -0,0 +1,81 @@
package server
import (
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/types/model"
)
type modelSource = modelref.ModelSource
const (
modelSourceUnspecified modelSource = modelref.ModelSourceUnspecified
modelSourceLocal modelSource = modelref.ModelSourceLocal
modelSourceCloud modelSource = modelref.ModelSourceCloud
)
var (
errConflictingModelSource = modelref.ErrConflictingSourceSuffix
errModelRequired = modelref.ErrModelRequired
)
type parsedModelRef struct {
// Original is the caller-provided model string before source parsing.
// Example: "gpt-oss:20b:cloud".
Original string
// Base is the model string after source suffix normalization.
// Example: "gpt-oss:20b:cloud" -> "gpt-oss:20b".
Base string
// Name is Base parsed as a fully-qualified model.Name with defaults applied.
// Example: "registry.ollama.ai/library/gpt-oss:20b".
Name model.Name
// Source captures explicit source intent from the original input.
// Example: "gpt-oss:20b:cloud" -> modelSourceCloud.
Source modelSource
}
func parseAndValidateModelRef(raw string) (parsedModelRef, error) {
var zero parsedModelRef
parsed, err := modelref.ParseRef(raw)
if err != nil {
return zero, err
}
name := model.ParseName(parsed.Base)
if !name.IsValid() {
return zero, model.Unqualified(name)
}
return parsedModelRef{
Original: parsed.Original,
Base: parsed.Base,
Name: name,
Source: parsed.Source,
}, nil
}
func parseNormalizePullModelRef(raw string) (parsedModelRef, error) {
var zero parsedModelRef
parsedRef, err := modelref.ParseRef(raw)
if err != nil {
return zero, err
}
normalizedName, _, err := modelref.NormalizePullName(raw)
if err != nil {
return zero, err
}
name := model.ParseName(normalizedName)
if !name.IsValid() {
return zero, model.Unqualified(name)
}
return parsedModelRef{
Original: parsedRef.Original,
Base: normalizedName,
Name: name,
Source: parsedRef.Source,
}, nil
}

View File

@@ -0,0 +1,170 @@
package server
import (
"errors"
"strings"
"testing"
)
func TestParseModelSelector(t *testing.T) {
t.Run("cloud suffix", func(t *testing.T) {
got, err := parseAndValidateModelRef("gpt-oss:20b:cloud")
if err != nil {
t.Fatalf("parseModelSelector returned error: %v", err)
}
if got.Source != modelSourceCloud {
t.Fatalf("expected source cloud, got %v", got.Source)
}
if got.Base != "gpt-oss:20b" {
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
}
if got.Name.String() != "registry.ollama.ai/library/gpt-oss:20b" {
t.Fatalf("unexpected resolved name: %q", got.Name.String())
}
})
t.Run("legacy cloud suffix", func(t *testing.T) {
got, err := parseAndValidateModelRef("gpt-oss:20b-cloud")
if err != nil {
t.Fatalf("parseModelSelector returned error: %v", err)
}
if got.Source != modelSourceCloud {
t.Fatalf("expected source cloud, got %v", got.Source)
}
if got.Base != "gpt-oss:20b" {
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
}
})
t.Run("bare dash cloud name is not explicit cloud", func(t *testing.T) {
got, err := parseAndValidateModelRef("my-cloud-model")
if err != nil {
t.Fatalf("parseModelSelector returned error: %v", err)
}
if got.Source != modelSourceUnspecified {
t.Fatalf("expected source unspecified, got %v", got.Source)
}
if got.Base != "my-cloud-model" {
t.Fatalf("expected base my-cloud-model, got %q", got.Base)
}
})
t.Run("local suffix", func(t *testing.T) {
got, err := parseAndValidateModelRef("qwen3:8b:local")
if err != nil {
t.Fatalf("parseModelSelector returned error: %v", err)
}
if got.Source != modelSourceLocal {
t.Fatalf("expected source local, got %v", got.Source)
}
if got.Base != "qwen3:8b" {
t.Fatalf("expected base qwen3:8b, got %q", got.Base)
}
})
t.Run("conflicting source suffixes fail", func(t *testing.T) {
_, err := parseAndValidateModelRef("foo:cloud:local")
if !errors.Is(err, errConflictingModelSource) {
t.Fatalf("expected errConflictingModelSource, got %v", err)
}
})
t.Run("unspecified source", func(t *testing.T) {
got, err := parseAndValidateModelRef("llama3")
if err != nil {
t.Fatalf("parseModelSelector returned error: %v", err)
}
if got.Source != modelSourceUnspecified {
t.Fatalf("expected source unspecified, got %v", got.Source)
}
if got.Name.Tag != "latest" {
t.Fatalf("expected default latest tag, got %q", got.Name.Tag)
}
})
t.Run("unknown suffix is treated as tag", func(t *testing.T) {
got, err := parseAndValidateModelRef("gpt-oss:clod")
if err != nil {
t.Fatalf("parseModelSelector returned error: %v", err)
}
if got.Source != modelSourceUnspecified {
t.Fatalf("expected source unspecified, got %v", got.Source)
}
if got.Name.Tag != "clod" {
t.Fatalf("expected tag clod, got %q", got.Name.Tag)
}
})
t.Run("empty model fails", func(t *testing.T) {
_, err := parseAndValidateModelRef("")
if !errors.Is(err, errModelRequired) {
t.Fatalf("expected errModelRequired, got %v", err)
}
})
t.Run("invalid model fails", func(t *testing.T) {
_, err := parseAndValidateModelRef("::cloud")
if err == nil {
t.Fatal("expected error for invalid model")
}
if !strings.Contains(err.Error(), "unqualified") {
t.Fatalf("expected unqualified model error, got %v", err)
}
})
}
func TestParsePullModelRef(t *testing.T) {
t.Run("explicit local is normalized", func(t *testing.T) {
got, err := parseNormalizePullModelRef("gpt-oss:20b:local")
if err != nil {
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
}
if got.Source != modelSourceLocal {
t.Fatalf("expected source local, got %v", got.Source)
}
if got.Base != "gpt-oss:20b" {
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
}
})
t.Run("explicit cloud with size maps to legacy cloud suffix", func(t *testing.T) {
got, err := parseNormalizePullModelRef("gpt-oss:20b:cloud")
if err != nil {
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
}
if got.Base != "gpt-oss:20b-cloud" {
t.Fatalf("expected base gpt-oss:20b-cloud, got %q", got.Base)
}
if got.Name.String() != "registry.ollama.ai/library/gpt-oss:20b-cloud" {
t.Fatalf("unexpected resolved name: %q", got.Name.String())
}
})
t.Run("explicit cloud without size maps to cloud tag", func(t *testing.T) {
got, err := parseNormalizePullModelRef("qwen3:cloud")
if err != nil {
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
}
if got.Base != "qwen3:cloud" {
t.Fatalf("expected base qwen3:cloud, got %q", got.Base)
}
if got.Name.String() != "registry.ollama.ai/library/qwen3:cloud" {
t.Fatalf("unexpected resolved name: %q", got.Name.String())
}
})
}

View File

@@ -30,42 +30,44 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
lastMsgIdx := len(msgs) - 1
currMsgIdx := 0
// Start with all messages and remove from the front until it fits in context
for i := 0; i <= lastMsgIdx; i++ {
// Collect system messages from the portion we're about to skip
system = make([]api.Message, 0)
for j := range i {
if msgs[j].Role == "system" {
system = append(system, msgs[j])
if truncate {
// Start with all messages and remove from the front until it fits in context
for i := 0; i <= lastMsgIdx; i++ {
// Collect system messages from the portion we're about to skip
system = make([]api.Message, 0)
for j := range i {
if msgs[j].Role == "system" {
system = append(system, msgs[j])
}
}
}
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
if err != nil {
return "", nil, err
}
s, err := tokenize(ctx, p)
if err != nil {
return "", nil, err
}
ctxLen := len(s)
if m.ProjectorPaths != nil {
for _, msg := range msgs[i:] {
ctxLen += imageNumTokens * len(msg.Images)
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
if err != nil {
return "", nil, err
}
}
if !truncate || ctxLen <= opts.NumCtx {
currMsgIdx = i
break
}
s, err := tokenize(ctx, p)
if err != nil {
return "", nil, err
}
// Must always include at least the last message
if i == lastMsgIdx {
currMsgIdx = lastMsgIdx
break
ctxLen := len(s)
if m.ProjectorPaths != nil {
for _, msg := range msgs[i:] {
ctxLen += imageNumTokens * len(msg.Images)
}
}
if ctxLen <= opts.NumCtx {
currMsgIdx = i
break
}
// Must always include at least the last message
if i == lastMsgIdx {
currMsgIdx = lastMsgIdx
break
}
}
}

View File

@@ -3,6 +3,7 @@ package server
import (
"bytes"
"context"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
@@ -366,3 +367,33 @@ func TestChatPromptRendererDoesNotRewriteMessageContent(t *testing.T) {
t.Fatal("prompt is empty")
}
}
func TestChatPromptGLMOcrRendererAddsImageTags(t *testing.T) {
msgs := []api.Message{
{
Role: "user",
Content: "extract text",
Images: []api.ImageData{[]byte("img-1"), []byte("img-2")},
},
}
m := Model{
Config: model.ConfigV2{Renderer: "glm-ocr"},
ProjectorPaths: []string{"vision"},
}
opts := api.Options{Runner: api.Runner{NumCtx: 8192}}
think := false
prompt, images, err := chatPrompt(t.Context(), &m, mockRunner{}.Tokenize, &opts, msgs, nil, &api.ThinkValue{Value: think}, true)
if err != nil {
t.Fatal(err)
}
if got, want := len(images), 2; got != want {
t.Fatalf("len(images) = %d, want %d", got, want)
}
if !strings.Contains(prompt, "<|user|>\n[img-0][img-1]extract text") {
t.Fatalf("prompt missing glm-ocr image tags, got: %q", prompt)
}
}

View File

@@ -62,8 +62,21 @@ const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
const (
cloudErrRemoteInferenceUnavailable = "remote model is unavailable"
cloudErrRemoteModelDetailsUnavailable = "remote model details are unavailable"
cloudErrWebSearchUnavailable = "web search is unavailable"
cloudErrWebFetchUnavailable = "web fetch is unavailable"
)
func writeModelRefParseError(c *gin.Context, err error, fallbackStatus int, fallbackMessage string) {
switch {
case errors.Is(err, errConflictingModelSource):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
case errors.Is(err, model.ErrUnqualifiedName):
c.JSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
default:
c.JSON(fallbackStatus, gin.H{"error": fallbackMessage})
}
}
func shouldUseHarmony(model *Model) bool {
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
// heuristic to check whether the template expects to be parsed via harmony:
@@ -150,7 +163,7 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
}
useImagegen, _ := requestOpts["use_imagegen_runner"].(bool)
// Deprecated runner override option; ignore if present.
delete(requestOpts, "use_imagegen_runner")
opts, err := s.modelOptions(model, requestOpts)
@@ -158,7 +171,7 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
return nil, nil, nil, err
}
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive, useImagegen)
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
var runner *runnerRef
select {
case runner = <-runnerCh:
@@ -196,14 +209,22 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
name := model.ParseName(req.Model)
if !name.IsValid() {
// Ideally this is "invalid model name" but we're keeping with
// what the API currently returns until we can change it.
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
modelRef, err := parseAndValidateModelRef(req.Model)
if err != nil {
writeModelRefParseError(c, err, http.StatusNotFound, fmt.Sprintf("model '%s' not found", req.Model))
return
}
if modelRef.Source == modelSourceCloud {
// TODO(drifkin): evaluate an `/api/*` passthrough for cloud where the
// original body (modulo model name normalization) is sent to cloud.
req.Model = modelRef.Base
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
return
}
name := modelRef.Name
resolvedName, _, err := s.resolveAlias(name)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -237,6 +258,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
if modelRef.Source == modelSourceLocal && m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
}
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
if disabled, _ := internalcloud.Status(); disabled {
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable)})
@@ -370,12 +396,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
}
// Validate Think value: string values currently only allowed for harmony/gptoss models
if req.Think != nil && req.Think.IsString() && m.Config.Parser != "harmony" {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())})
return
}
caps := []model.Capability{model.CapabilityCompletion}
if req.Suffix != "" {
caps = append(caps, model.CapabilityInsert)
@@ -484,7 +504,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
// the real chat handler, but doing this as a stopgap to get renderer
// support for generate
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
genTruncate := (req.Truncate == nil || *req.Truncate) && !m.IsMLX()
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, genTruncate)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -675,6 +696,18 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
modelRef, err := parseAndValidateModelRef(req.Model)
if err != nil {
writeModelRefParseError(c, err, http.StatusNotFound, fmt.Sprintf("model '%s' not found", req.Model))
return
}
if modelRef.Source == modelSourceCloud {
req.Model = modelRef.Base
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
return
}
var input []string
switch i := req.Input.(type) {
@@ -697,7 +730,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
}
}
name, err := getExistingName(model.ParseName(req.Model))
name, err := getExistingName(modelRef.Name)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
@@ -844,12 +877,20 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
name := model.ParseName(req.Model)
if !name.IsValid() {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
modelRef, err := parseAndValidateModelRef(req.Model)
if err != nil {
writeModelRefParseError(c, err, http.StatusBadRequest, "model is required")
return
}
if modelRef.Source == modelSourceCloud {
req.Model = modelRef.Base
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
return
}
name := modelRef.Name
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
@@ -891,12 +932,19 @@ func (s *Server) PullHandler(c *gin.Context) {
return
}
name := model.ParseName(cmp.Or(req.Model, req.Name))
if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
// TEMP(drifkin): we're temporarily allowing to continue pulling cloud model
// stub-files until we integrate cloud models into `/api/tags` (in which case
// this roundabout way of "adding" cloud models won't be needed anymore). So
// right here normalize any `:cloud` models into the legacy-style suffixes
// `:<tag>-cloud` and `:cloud`
modelRef, err := parseNormalizePullModelRef(cmp.Or(req.Model, req.Name))
if err != nil {
writeModelRefParseError(c, err, http.StatusBadRequest, errtypes.InvalidModelNameErrMsg)
return
}
name := modelRef.Name
name, err = getExistingName(name)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -1023,13 +1071,20 @@ func (s *Server) DeleteHandler(c *gin.Context) {
return
}
n := model.ParseName(cmp.Or(r.Model, r.Name))
if !n.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
modelRef, err := parseNormalizePullModelRef(cmp.Or(r.Model, r.Name))
if err != nil {
switch {
case errors.Is(err, errConflictingModelSource):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
case errors.Is(err, model.ErrUnqualifiedName):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
default:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
}
return
}
n, err := getExistingName(n)
n, err := getExistingName(modelRef.Name)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))})
return
@@ -1078,6 +1133,20 @@ func (s *Server) ShowHandler(c *gin.Context) {
return
}
modelRef, err := parseAndValidateModelRef(req.Model)
if err != nil {
writeModelRefParseError(c, err, http.StatusBadRequest, err.Error())
return
}
if modelRef.Source == modelSourceCloud {
req.Model = modelRef.Base
proxyCloudJSONRequest(c, req, cloudErrRemoteModelDetailsUnavailable)
return
}
req.Model = modelRef.Base
resp, err := GetModelInfo(req)
if err != nil {
var statusErr api.StatusError
@@ -1094,6 +1163,11 @@ func (s *Server) ShowHandler(c *gin.Context) {
return
}
if modelRef.Source == modelSourceLocal && resp.RemoteHost != "" {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", modelRef.Original)})
return
}
c.JSON(http.StatusOK, resp)
}
@@ -1621,6 +1695,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.GET("/api/experimental/aliases", s.ListAliasesHandler)
r.POST("/api/experimental/aliases", s.CreateAliasHandler)
r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler)
r.POST("/api/experimental/web_search", s.WebSearchExperimentalHandler)
r.POST("/api/experimental/web_fetch", s.WebFetchExperimentalHandler)
// Inference
r.GET("/api/ps", s.PsHandler)
@@ -1630,18 +1706,20 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.POST("/api/embeddings", s.EmbeddingsHandler)
// Inference (OpenAI compatibility)
r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler)
r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler)
// TODO(cloud-stage-a): apply Modelfile overlay deltas for local models with cloud
// parents on v1 request families while preserving this explicit :cloud passthrough.
r.POST("/v1/chat/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.CompletionsMiddleware(), s.GenerateHandler)
r.POST("/v1/embeddings", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.EmbeddingsMiddleware(), s.EmbedHandler)
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
r.GET("/v1/models/:model", cloudModelPathPassthroughMiddleware(cloudErrRemoteModelDetailsUnavailable), middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ResponsesMiddleware(), s.ChatHandler)
// OpenAI-compatible image generation endpoints
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
r.POST("/v1/images/edits", middleware.ImageEditsMiddleware(), s.GenerateHandler)
r.POST("/v1/images/generations", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
r.POST("/v1/images/edits", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageEditsMiddleware(), s.GenerateHandler)
// Inference (Anthropic compatibility)
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
r.POST("/v1/messages", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
if rc != nil {
// wrap old with new
@@ -1863,6 +1941,29 @@ func (s *Server) StatusHandler(c *gin.Context) {
})
}
func (s *Server) WebSearchExperimentalHandler(c *gin.Context) {
s.webExperimentalProxyHandler(c, "/api/web_search", cloudErrWebSearchUnavailable)
}
func (s *Server) WebFetchExperimentalHandler(c *gin.Context) {
s.webExperimentalProxyHandler(c, "/api/web_fetch", cloudErrWebFetchUnavailable)
}
func (s *Server) webExperimentalProxyHandler(c *gin.Context, proxyPath, disabledOperation string) {
body, err := readRequestBody(c.Request)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if len(bytes.TrimSpace(body)) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
}
proxyCloudRequestWithPath(c, body, proxyPath, disabledOperation)
}
func (s *Server) WhoamiHandler(c *gin.Context) {
// todo allow other hosts
u, err := url.Parse("https://ollama.com")
@@ -1951,6 +2052,9 @@ func (s *Server) PsHandler(c *gin.Context) {
}
if v.llama != nil {
mr.ContextLength = v.llama.ContextLength()
total, vram := v.llama.MemorySize()
mr.Size = int64(total)
mr.SizeVRAM = int64(vram)
}
// The scheduler waits to set expiresAt, so if a model is loading it's
// possible that it will be set to the unix epoch. For those cases, just
@@ -1997,12 +2101,24 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
name := model.ParseName(req.Model)
if !name.IsValid() {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
modelRef, err := parseAndValidateModelRef(req.Model)
if err != nil {
writeModelRefParseError(c, err, http.StatusBadRequest, "model is required")
return
}
if modelRef.Source == modelSourceCloud {
req.Model = modelRef.Base
if c.GetBool(legacyCloudAnthropicKey) {
proxyCloudJSONRequestWithPath(c, req, "/api/chat", cloudErrRemoteInferenceUnavailable)
return
}
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
return
}
name := modelRef.Name
resolvedName, _, err := s.resolveAlias(name)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -2034,6 +2150,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
if modelRef.Source == modelSourceLocal && m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
}
// expire the runner
if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
s.sched.expireRunner(m)
@@ -2213,6 +2334,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
truncate := req.Truncate == nil || *req.Truncate
if m.IsMLX() {
truncate = false
}
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
if err != nil {
slog.Error("chat prompt error", "error", err)
@@ -2233,12 +2357,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
// Validate Think value: string values currently only allowed for harmony/gptoss models
if req.Think != nil && req.Think.IsString() && m.Config.Parser != "harmony" {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())})
return
}
var thinkingState *thinking.Parser
openingTag, closingTag := thinking.InferTags(m.Template.Template)
if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" {

File diff suppressed because it is too large Load Diff

View File

@@ -144,6 +144,37 @@ func TestCreateFromBin(t *testing.T) {
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
})
t.Run("empty file digest", func(t *testing.T) {
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "my-gguf-model",
Files: map[string]string{"0.gguf": ""},
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 400, got %d", w.Code)
}
if !strings.Contains(w.Body.String(), "invalid digest format") {
t.Errorf("expected invalid digest format error, got:\n%s", w.Body.String())
}
})
t.Run("empty adapter digest", func(t *testing.T) {
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "my-gguf-model",
Files: map[string]string{"0.gguf": digest},
Adapters: map[string]string{"adapter.gguf": ""},
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 400, got %d", w.Code)
}
if !strings.Contains(w.Body.String(), "invalid digest format") {
t.Errorf("expected invalid digest format error, got:\n%s", w.Body.String())
}
})
}
func TestCreateFromModel(t *testing.T) {
@@ -763,6 +794,43 @@ func TestCreateAndShowRemoteModel(t *testing.T) {
fmt.Printf("resp = %#v\n", resp)
}
func TestCreateFromCloudSourceSuffix(t *testing.T) {
gin.SetMode(gin.TestMode)
var s Server
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-cloud-from-suffix",
From: "gpt-oss:20b:cloud",
Info: map[string]any{
"capabilities": []string{"completion"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, got %d", w.Code)
}
w = createRequest(t, s.ShowHandler, api.ShowRequest{Model: "test-cloud-from-suffix"})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, got %d", w.Code)
}
var resp api.ShowResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if resp.RemoteHost != "https://ollama.com:443" {
t.Fatalf("expected remote host https://ollama.com:443, got %q", resp.RemoteHost)
}
if resp.RemoteModel != "gpt-oss:20b" {
t.Fatalf("expected remote model gpt-oss:20b, got %q", resp.RemoteModel)
}
}
func TestCreateLicenses(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@@ -111,3 +111,32 @@ func TestDeleteDuplicateLayers(t *testing.T) {
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
}
func TestDeleteCloudSourceNormalizesToLegacyName(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
_, digest := createBinFile(t, nil, nil)
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "gpt-oss:20b-cloud",
Files: map[string]string{"test.gguf": digest},
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "gpt-oss", "20b-cloud"),
})
w = createRequest(t, s.DeleteHandler, api.DeleteRequest{Name: "gpt-oss:20b:cloud"})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d (%s)", w.Code, w.Body.String())
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
}

View File

@@ -0,0 +1,335 @@
package server
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
internalcloud "github.com/ollama/ollama/internal/cloud"
)
type webExperimentalUpstreamCapture struct {
path string
body string
header http.Header
}
func newWebExperimentalUpstream(t *testing.T, responseBody string) (*httptest.Server, *webExperimentalUpstreamCapture) {
t.Helper()
capture := &webExperimentalUpstreamCapture{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
payload, _ := io.ReadAll(r.Body)
capture.path = r.URL.Path
capture.body = string(payload)
capture.header = r.Header.Clone()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(responseBody))
}))
return srv, capture
}
func TestExperimentalWebEndpointsPassthrough(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
tests := []struct {
name string
localPath string
upstreamPath string
requestBody string
responseBody string
assertBody string
}{
{
name: "web_search",
localPath: "/api/experimental/web_search",
upstreamPath: "/api/web_search",
requestBody: `{"query":"what is ollama?","max_results":3}`,
responseBody: `{"results":[{"title":"Ollama","url":"https://ollama.com","content":"Cloud models are now available"}]}`,
assertBody: `"query":"what is ollama?"`,
},
{
name: "web_fetch",
localPath: "/api/experimental/web_fetch",
upstreamPath: "/api/web_fetch",
requestBody: `{"url":"https://ollama.com"}`,
responseBody: `{"title":"Ollama","content":"Cloud models are now available","links":["https://ollama.com/"]}`,
assertBody: `"url":"https://ollama.com"`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
upstream, capture := newWebExperimentalUpstream(t, tt.responseBody)
defer upstream.Close()
original := cloudProxyBaseURL
cloudProxyBaseURL = upstream.URL
t.Cleanup(func() { cloudProxyBaseURL = original })
s := &Server{}
router, err := s.GenerateRoutes(nil)
if err != nil {
t.Fatal(err)
}
local := httptest.NewServer(router)
defer local.Close()
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+tt.localPath, bytes.NewBufferString(tt.requestBody))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer should-forward")
req.Header.Set("X-Test-Header", "web-experimental")
resp, err := local.Client().Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body))
}
if capture.path != tt.upstreamPath {
t.Fatalf("expected upstream path %q, got %q", tt.upstreamPath, capture.path)
}
if !bytes.Contains([]byte(capture.body), []byte(tt.assertBody)) {
t.Fatalf("expected upstream body to contain %q, got %q", tt.assertBody, capture.body)
}
if got := capture.header.Get("Authorization"); got != "Bearer should-forward" {
t.Fatalf("expected forwarded Authorization header, got %q", got)
}
if got := capture.header.Get("X-Test-Header"); got != "web-experimental" {
t.Fatalf("expected forwarded X-Test-Header=web-experimental, got %q", got)
}
})
}
}
func TestExperimentalWebEndpointsMissingBody(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
s := &Server{}
router, err := s.GenerateRoutes(nil)
if err != nil {
t.Fatal(err)
}
local := httptest.NewServer(router)
defer local.Close()
tests := []string{
"/api/experimental/web_search",
"/api/experimental/web_fetch",
}
for _, path := range tests {
t.Run(path, func(t *testing.T) {
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+path, nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := local.Client().Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected status 400, got %d (%s)", resp.StatusCode, string(body))
}
if string(body) != `{"error":"missing request body"}` {
t.Fatalf("unexpected response body: %s", string(body))
}
})
}
}
func TestExperimentalWebEndpointsCloudDisabled(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
t.Setenv("OLLAMA_NO_CLOUD", "1")
s := &Server{}
router, err := s.GenerateRoutes(nil)
if err != nil {
t.Fatal(err)
}
local := httptest.NewServer(router)
defer local.Close()
tests := []struct {
name string
path string
request string
operation string
}{
{
name: "web_search",
path: "/api/experimental/web_search",
request: `{"query":"latest ollama release"}`,
operation: cloudErrWebSearchUnavailable,
},
{
name: "web_fetch",
path: "/api/experimental/web_fetch",
request: `{"url":"https://ollama.com"}`,
operation: cloudErrWebFetchUnavailable,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+tt.path, bytes.NewBufferString(tt.request))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := local.Client().Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusForbidden {
t.Fatalf("expected status 403, got %d (%s)", resp.StatusCode, string(body))
}
var got map[string]string
if err := json.Unmarshal(body, &got); err != nil {
t.Fatalf("expected json error body, got: %q", string(body))
}
if got["error"] != internalcloud.DisabledError(tt.operation) {
t.Fatalf("unexpected error message: %q", got["error"])
}
})
}
}
func TestExperimentalWebEndpointSigningFailureReturnsUnauthorized(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
origSignRequest := cloudProxySignRequest
origSigninURL := cloudProxySigninURL
cloudProxySignRequest = func(context.Context, *http.Request) error {
return errors.New("ssh: no key found")
}
cloudProxySigninURL = func() (string, error) {
return "https://ollama.com/signin/example", nil
}
t.Cleanup(func() {
cloudProxySignRequest = origSignRequest
cloudProxySigninURL = origSigninURL
})
s := &Server{}
router, err := s.GenerateRoutes(nil)
if err != nil {
t.Fatal(err)
}
local := httptest.NewServer(router)
defer local.Close()
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/experimental/web_search", bytes.NewBufferString(`{"query":"hello"}`))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := local.Client().Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("expected status 401, got %d (%s)", resp.StatusCode, string(body))
}
var got map[string]any
if err := json.Unmarshal(body, &got); err != nil {
t.Fatalf("expected json error body, got: %q", string(body))
}
if got["error"] != "unauthorized" {
t.Fatalf("unexpected error message: %v", got["error"])
}
if got["signin_url"] != "https://ollama.com/signin/example" {
t.Fatalf("unexpected signin_url: %v", got["signin_url"])
}
}
func TestExperimentalWebEndpointSigningFailureWithoutSigninURL(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
origSignRequest := cloudProxySignRequest
origSigninURL := cloudProxySigninURL
cloudProxySignRequest = func(context.Context, *http.Request) error {
return errors.New("ssh: no key found")
}
cloudProxySigninURL = func() (string, error) {
return "", errors.New("key missing")
}
t.Cleanup(func() {
cloudProxySignRequest = origSignRequest
cloudProxySigninURL = origSigninURL
})
s := &Server{}
router, err := s.GenerateRoutes(nil)
if err != nil {
t.Fatal(err)
}
local := httptest.NewServer(router)
defer local.Close()
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/experimental/web_fetch", bytes.NewBufferString(`{"url":"https://ollama.com"}`))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := local.Client().Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("expected status 401, got %d (%s)", resp.StatusCode, string(body))
}
var got map[string]any
if err := json.Unmarshal(body, &got); err != nil {
t.Fatalf("expected json error body, got: %q", string(body))
}
if got["error"] != "unauthorized" {
t.Fatalf("unexpected error message: %v", got["error"])
}
if _, ok := got["signin_url"]; ok {
t.Fatalf("did not expect signin_url when helper fails, got %v", got["signin_url"])
}
}

View File

@@ -33,7 +33,6 @@ type LlmRequest struct {
successCh chan *runnerRef
errCh chan error
schedAttempts uint
useImagegen bool
}
type Scheduler struct {
@@ -106,7 +105,7 @@ func schedulerModelKey(m *Model) string {
}
// context must be canceled to decrement ref count and release the runner
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration, useImagegen bool) (chan *runnerRef, chan error) {
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
if opts.NumCtx < 4 {
opts.NumCtx = 4
}
@@ -123,7 +122,6 @@ func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, ses
sessionDuration: sessionDuration,
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
useImagegen: useImagegen,
}
key := schedulerModelKey(req.model)
@@ -231,7 +229,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
}
// Check for experimental safetensors LLM models
if pending.model.Config.ModelFormat == "safetensors" {
if pending.model.IsMLX() {
if slices.Contains(pending.model.Config.Capabilities, "completion") {
// LLM model with safetensors format - use MLX runner
if s.loadMLX(pending) {
@@ -536,6 +534,7 @@ iGPUScan:
}
}
totalSize, vramSize := llama.MemorySize()
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
@@ -545,8 +544,8 @@ iGPUScan:
sessionDuration: sessionDuration,
gpus: gpuIDs,
discreteGPUs: discreteGPUs,
vramSize: llama.VRAMSize(),
totalSize: llama.TotalSize(),
totalSize: totalSize,
vramSize: vramSize,
loading: true,
pid: llama.Pid(),
}
@@ -592,20 +591,15 @@ iGPUScan:
return false
}
// loadMLX loads an experimental safetensors model using the unified MLX runner.
// This supports both LLM (completion) and image generation models.
// loadMLX loads an experimental safetensors model using MLX runners.
// Image models use x/imagegen; LLM models use x/mlxrunner.
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
modelName := req.model.ShortName
var server llm.LlamaServer
var err error
isImagegen := false
if slices.Contains(req.model.Config.Capabilities, "image") {
server, err = imagegen.NewServer(modelName, imagegen.ModeImageGen)
isImagegen = true
} else if req.useImagegen {
server, err = imagegen.NewServer(modelName, imagegen.ModeLLM)
isImagegen = true
server, err = imagegen.NewServer(modelName)
} else {
server, err = mlxrunner.NewClient(modelName)
}
@@ -619,6 +613,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
sessionDuration = req.sessionDuration.Duration
}
totalSize, vramSize := server.MemorySize()
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
@@ -626,10 +621,10 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
llama: server,
Options: &req.opts,
loading: false,
isImagegen: isImagegen,
isImagegen: slices.Contains(req.model.Config.Capabilities, "image"),
sessionDuration: sessionDuration,
totalSize: server.TotalSize(),
vramSize: server.VRAMSize(),
totalSize: totalSize,
vramSize: vramSize,
}
s.loadedMu.Lock()
@@ -735,8 +730,8 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
runner.refMu.Lock()
defer runner.refMu.Unlock()
// Check if runner type (imagegen vs mlxrunner) matches what's requested
wantImagegen := req.useImagegen || slices.Contains(req.model.Config.Capabilities, "image")
// Check if runner type (imagegen vs mlxrunner) matches what's requested.
wantImagegen := slices.Contains(req.model.Config.Capabilities, "image")
if runner.isImagegen != wantImagegen {
return true
}
@@ -762,7 +757,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
defer cancel()
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
(!runner.model.IsMLX() && !reflect.DeepEqual(optsExisting, optsNew)) || // have the runner options changed?
runner.llama.Ping(ctx) != nil {
return true
}

View File

@@ -408,10 +408,10 @@ func TestSchedGetRunner(t *testing.T) {
s.getSystemInfoFn = getSystemInfoFn
s.newServerFn = a.newServer
slog.Info("a")
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration, false)
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1)
slog.Info("b")
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration, false)
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1)
require.Empty(t, successCh1b)
require.Len(t, errCh1b, 1)
@@ -435,7 +435,7 @@ func TestSchedGetRunner(t *testing.T) {
c.req.model.ModelPath = "bad path"
slog.Info("c")
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration, false)
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration)
// Starts in pending channel, then should be quickly processed to return an error
time.Sleep(50 * time.Millisecond) // Long enough for the "a" model to expire and unload
require.Empty(t, successCh1c)
@@ -470,7 +470,7 @@ func TestSchedGetRunnerUsesDigestKeyWhenModelPathEmpty(t *testing.T) {
s.loadedMu.Unlock()
reqModel := &Model{Name: "safetensors-b", Digest: "sha-b"}
successCh, errCh := s.GetRunner(ctx, reqModel, opts, nil, false)
successCh, errCh := s.GetRunner(ctx, reqModel, opts, nil)
require.Empty(t, successCh)
require.Empty(t, errCh)
@@ -499,7 +499,7 @@ func TestSchedGetRunnerReusesSameDigestWhenModelPathEmpty(t *testing.T) {
s.loadedMu.Unlock()
reqCtx, cancelReq := context.WithCancel(ctx)
successCh, errCh := s.GetRunner(reqCtx, &Model{Name: "safetensors-a-copy", Digest: "sha-a"}, opts, nil, false)
successCh, errCh := s.GetRunner(reqCtx, &Model{Name: "safetensors-a-copy", Digest: "sha-a"}, opts, nil)
cancelReq()
select {
@@ -574,7 +574,7 @@ func TestSchedPrematureExpired(t *testing.T) {
s.getGpuFn = getGpuFn
s.getSystemInfoFn = getSystemInfoFn
s.newServerFn = scenario1a.newServer
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration, false)
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1)
s.Run(ctx)
select {
@@ -861,8 +861,7 @@ func (s *mockLlm) Close() error {
s.closeCalled = true
return s.closeResp
}
func (s *mockLlm) VRAMSize() uint64 { return s.vramSize }
func (s *mockLlm) TotalSize() uint64 { return s.totalSize }
func (s *mockLlm) MemorySize() (uint64, uint64) { return s.totalSize, s.vramSize }
func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] }
func (s *mockLlm) Pid() int { return -1 }
func (s *mockLlm) GetPort() int { return -1 }

View File

@@ -20,6 +20,7 @@ import (
"github.com/ollama/ollama/api"
internalcloud "github.com/ollama/ollama/internal/cloud"
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/model"
@@ -43,7 +44,7 @@ const (
// isLocalModel checks if the model is running locally (not a cloud model).
// TODO: Improve local/cloud model identification - could check model metadata
func isLocalModel(modelName string) bool {
return !strings.HasSuffix(modelName, "-cloud")
return !modelref.HasExplicitCloudSource(modelName)
}
// isLocalServer checks if connecting to a local Ollama server.

View File

@@ -22,12 +22,22 @@ func TestIsLocalModel(t *testing.T) {
},
{
name: "cloud model",
modelName: "gpt-4-cloud",
modelName: "gpt-oss:latest-cloud",
expected: false,
},
{
name: "cloud model with :cloud suffix",
modelName: "gpt-oss:cloud",
expected: false,
},
{
name: "cloud model with version",
modelName: "claude-3-cloud",
modelName: "gpt-oss:20b-cloud",
expected: false,
},
{
name: "cloud model with version and :cloud suffix",
modelName: "gpt-oss:20b:cloud",
expected: false,
},
{
@@ -134,7 +144,7 @@ func TestTruncateToolOutput(t *testing.T) {
{
name: "long output cloud model - uses 10k limit",
output: string(localLimitOutput), // 20k chars, under 10k token limit
modelName: "gpt-4-cloud",
modelName: "gpt-oss:latest-cloud",
host: "",
shouldTrim: false,
expectedLimit: defaultTokenLimit,
@@ -142,7 +152,7 @@ func TestTruncateToolOutput(t *testing.T) {
{
name: "very long output cloud model - trimmed at 10k",
output: string(defaultLimitOutput),
modelName: "gpt-4-cloud",
modelName: "gpt-oss:latest-cloud",
host: "",
shouldTrim: true,
expectedLimit: defaultTokenLimit,

View File

@@ -13,9 +13,12 @@ import (
"io"
"os"
"path/filepath"
"slices"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/create"
@@ -27,11 +30,79 @@ const MinOllamaVersion = "0.14.0"
// ModelfileConfig holds configuration extracted from a Modelfile.
type ModelfileConfig struct {
Template string
System string
License string
Parser string
Renderer string
Template string
System string
License string
Parser string
Renderer string
Parameters map[string]any
}
var ignoredModelfileParameters = []string{
"penalize_newline",
"low_vram",
"f16_kv",
"logits_all",
"vocab_only",
"use_mlock",
"mirostat",
"mirostat_tau",
"mirostat_eta",
}
// ConfigFromModelfile extracts the model directory and x/create-specific
// Modelfile configuration from a parsed Modelfile.
func ConfigFromModelfile(modelfile *parser.Modelfile) (string, *ModelfileConfig, error) {
var modelDir string
mfConfig := &ModelfileConfig{}
for _, cmd := range modelfile.Commands {
switch cmd.Name {
case "model":
modelDir = cmd.Args
case "template":
mfConfig.Template = cmd.Args
case "system":
mfConfig.System = cmd.Args
case "license":
mfConfig.License = cmd.Args
case "parser":
mfConfig.Parser = cmd.Args
case "renderer":
mfConfig.Renderer = cmd.Args
case "adapter", "message", "requires":
continue
default:
if slices.Contains(ignoredModelfileParameters, cmd.Name) {
continue
}
ps, err := api.FormatParams(map[string][]string{cmd.Name: {cmd.Args}})
if err != nil {
return "", nil, err
}
if mfConfig.Parameters == nil {
mfConfig.Parameters = make(map[string]any)
}
for k, v := range ps {
if ks, ok := mfConfig.Parameters[k].([]string); ok {
mfConfig.Parameters[k] = append(ks, v.([]string)...)
} else if vs, ok := v.([]string); ok {
mfConfig.Parameters[k] = vs
} else {
mfConfig.Parameters[k] = v
}
}
}
}
if modelDir == "" {
modelDir = "."
}
return modelDir, mfConfig, nil
}
// CreateOptions holds all options for model creation.
@@ -39,7 +110,7 @@ type CreateOptions struct {
ModelName string
ModelDir string
Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization
Modelfile *ModelfileConfig // template/system/license/parser/renderer from Modelfile
Modelfile *ModelfileConfig // template/system/license/parser/renderer/parameters from Modelfile
}
// CreateModel imports a model from a local directory.
@@ -351,6 +422,19 @@ func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
layers = append(layers, layer)
}
if len(mf.Parameters) > 0 {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(mf.Parameters); err != nil {
return nil, fmt.Errorf("failed to encode parameters: %w", err)
}
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
if err != nil {
return nil, fmt.Errorf("failed to create params layer: %w", err)
}
layers = append(layers, layer)
}
return layers, nil
}

View File

@@ -1,7 +1,13 @@
package client
import (
"encoding/json"
"os"
"strings"
"testing"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/parser"
)
func TestModelfileConfig(t *testing.T) {
@@ -31,6 +37,40 @@ func TestModelfileConfig(t *testing.T) {
}
}
func TestConfigFromModelfile(t *testing.T) {
modelfile, err := parser.ParseFile(strings.NewReader(`
FROM ./model
TEMPLATE {{ .Prompt }}
PARAMETER temperature 0.7
PARAMETER stop USER:
PARAMETER stop ASSISTANT:
`))
if err != nil {
t.Fatal(err)
}
modelDir, mfConfig, err := ConfigFromModelfile(modelfile)
if err != nil {
t.Fatal(err)
}
if modelDir != "./model" {
t.Fatalf("modelDir = %q, want %q", modelDir, "./model")
}
if mfConfig.Template != "{{ .Prompt }}" {
t.Fatalf("Template = %q, want %q", mfConfig.Template, "{{ .Prompt }}")
}
if got := mfConfig.Parameters["temperature"]; got != float32(0.7) {
t.Fatalf("temperature = %#v, want %v", got, float32(0.7))
}
if got := mfConfig.Parameters["stop"]; got == nil || len(got.([]string)) != 2 {
t.Fatalf("unexpected stop params: %#v", got)
}
}
func TestModelfileConfig_Empty(t *testing.T) {
config := &ModelfileConfig{}
@@ -120,6 +160,9 @@ func TestCreateOptions(t *testing.T) {
License: "MIT",
Parser: "qwen3-thinking",
Renderer: "qwen3",
Parameters: map[string]any{
"temperature": float32(0.7),
},
},
}
@@ -144,6 +187,9 @@ func TestCreateOptions(t *testing.T) {
if opts.Modelfile.Renderer != "qwen3" {
t.Errorf("Modelfile.Renderer = %q, want %q", opts.Modelfile.Renderer, "qwen3")
}
if opts.Modelfile.Parameters["temperature"] != float32(0.7) {
t.Errorf("Modelfile.Parameters[temperature] = %v, want %v", opts.Modelfile.Parameters["temperature"], float32(0.7))
}
}
func TestResolveParserName(t *testing.T) {
@@ -252,3 +298,44 @@ func TestQuantizeSupported(t *testing.T) {
// We can't easily test both cases, so just verify it returns something
_ = supported
}
func TestCreateModelfileLayersIncludesParameters(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
layers, err := createModelfileLayers(&ModelfileConfig{
Parameters: map[string]any{
"temperature": float32(0.7),
"stop": []string{"USER:", "ASSISTANT:"},
},
})
if err != nil {
t.Fatal(err)
}
if len(layers) != 1 {
t.Fatalf("len(layers) = %d, want 1", len(layers))
}
if layers[0].MediaType != "application/vnd.ollama.image.params" {
t.Fatalf("MediaType = %q, want %q", layers[0].MediaType, "application/vnd.ollama.image.params")
}
blobPath, err := manifest.BlobsPath(layers[0].Digest)
if err != nil {
t.Fatal(err)
}
data, err := os.ReadFile(blobPath)
if err != nil {
t.Fatal(err)
}
var got map[string]any
if err := json.Unmarshal(data, &got); err != nil {
t.Fatal(err)
}
if got["temperature"] != float64(0.7) {
t.Fatalf("temperature = %v, want %v", got["temperature"], float64(0.7))
}
}

View File

@@ -1,5 +1,3 @@
//go:build mlx
package client
import (
@@ -21,7 +19,7 @@ var quantizeParams = map[string]struct {
bits int
mode string
}{
"int4": {32, 4, "affine"},
"int4": {64, 4, "affine"},
"nvfp4": {16, 4, "nvfp4"},
"int8": {64, 8, "affine"},
"mxfp8": {32, 8, "mxfp8"},
@@ -194,9 +192,10 @@ func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
return blobData, nil
}
// QuantizeSupported returns true if quantization is supported (MLX build)
// QuantizeSupported returns true if quantization is supported (MLX library available)
func QuantizeSupported() bool {
return true
mlx.InitMLX()
return mlx.IsMLXAvailable()
}
// ensureTempDir creates the temp directory for quantization if it doesn't exist

View File

@@ -1,25 +0,0 @@
//go:build !mlx
package client
import (
"fmt"
"io"
"github.com/ollama/ollama/x/create"
)
// quantizeTensor is not available without MLX
func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) {
return nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
}
// quantizePackedGroup is not available without MLX
func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
return nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
}
// QuantizeSupported returns false when MLX is not available
func QuantizeSupported() bool {
return false
}

View File

@@ -288,6 +288,18 @@ func normalizeQuantType(quantize string) string {
}
}
func isStackedExpertWeight(name string) bool {
// Combined/stacked expert tensors may be emitted either as "...proj.weight" (per-expert)
// or "...proj" (pre-stacked packed tensor).
if strings.HasSuffix(name, ".bias") || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".qbias") {
return false
}
return strings.Contains(name, ".mlp.switch_mlp.") ||
strings.Contains(name, ".mlp.experts.") ||
strings.Contains(name, ".mlp.shared_experts.")
}
// GetTensorQuantization returns the appropriate quantization type for a tensor.
// Returns "" if the tensor should not be quantized.
// This implements mixed-precision quantization:
@@ -296,18 +308,25 @@ func normalizeQuantType(quantize string) string {
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
// - Norms, embeddings, biases, routing gates: no quantization
func GetTensorQuantization(name string, shape []int32, quantize string) string {
stackedExpert := isStackedExpertWeight(name)
// Use basic name-based check first
if !ShouldQuantize(name, "") {
if !stackedExpert && !ShouldQuantize(name, "") {
return ""
}
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
if len(shape) != 2 {
// Quantize standard linear weights (2D). Also allow stacked expert weights (3D),
// e.g. qwen switch_mlp / experts combined tensors.
if len(shape) != 2 && !(len(shape) == 3 && stackedExpert) {
return ""
}
// Skip small tensors (less than 1024 elements) - not worth quantizing
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
var elems int64 = 1
for _, d := range shape {
elems *= int64(d)
}
if elems < 1024 {
return ""
}
@@ -315,12 +334,12 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
quantNorm := normalizeQuantType(quantize)
// MLX quantization requires last dimension to be divisible by group size
// nvfp4: 16, int4/mxfp8: 32, int8: 64
// nvfp4: 16, mxfp8: 32, int4/int8: 64
groupSize := int32(32)
switch quantNorm {
case "nvfp4":
groupSize = 16
case "int8":
case "int4", "int8":
groupSize = 64
}
if shape[len(shape)-1]%groupSize != 0 {

View File

@@ -557,6 +557,10 @@ func TestShouldQuantizeTensor(t *testing.T) {
// 3D+ tensors should not be quantized
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
{"stacked expert switch_mlp gate_up 3D int8", "model.layers.1.mlp.switch_mlp.gate_up_proj.weight", []int32{64, 22016, 4096}, "int8", true},
{"stacked expert experts down_proj 3D int8", "model.layers.1.mlp.experts.down_proj.weight", []int32{64, 4096, 14336}, "int8", true},
{"stacked expert combined gate_up 3D int8", "model.language_model.layers.0.mlp.experts.gate_up_proj", []int32{256, 1024, 2048}, "int8", true},
{"stacked expert combined down_proj 3D int8", "model.language_model.layers.0.mlp.experts.down_proj", []int32{256, 2048, 512}, "int8", true},
// Embeddings should not be quantized regardless of shape
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
@@ -619,6 +623,44 @@ func TestExpertGroupPrefix(t *testing.T) {
}
}
func TestGetTensorQuantization_StackedExpert3D(t *testing.T) {
gateUp := GetTensorQuantization(
"model.layers.1.mlp.switch_mlp.gate_up_proj.weight",
[]int32{64, 22016, 4096},
"int4",
)
if gateUp != "int4" {
t.Fatalf("gate_up_proj quantization = %q, want %q", gateUp, "int4")
}
down := GetTensorQuantization(
"model.layers.1.mlp.experts.down_proj.weight",
[]int32{64, 4096, 14336},
"int4",
)
if down != "int8" {
t.Fatalf("down_proj quantization = %q, want %q", down, "int8")
}
combinedGateUp := GetTensorQuantization(
"model.language_model.layers.0.mlp.experts.gate_up_proj",
[]int32{256, 1024, 2048},
"int8",
)
if combinedGateUp != "int8" {
t.Fatalf("combined gate_up_proj quantization = %q, want %q", combinedGateUp, "int8")
}
combinedDown := GetTensorQuantization(
"model.language_model.layers.0.mlp.experts.down_proj",
[]int32{256, 2048, 512},
"int4",
)
if combinedDown != "int8" {
t.Fatalf("combined down_proj quantization = %q, want %q", combinedDown, "int8")
}
}
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
dir := t.TempDir()

View File

@@ -1,5 +1,3 @@
//go:build mlx
package cache
import "github.com/ollama/ollama/x/imagegen/mlx"

View File

@@ -1,5 +1,3 @@
//go:build mlx
package cache
import "github.com/ollama/ollama/x/imagegen/mlx"

View File

@@ -1,5 +1,3 @@
//go:build mlx
// Package cache provides caching mechanisms for diffusion model inference.
package cache

View File

@@ -5,22 +5,12 @@ Experimental MLX backend for running models on Apple Silicon and CUDA.
## Build
```bash
go build -tags mlx -o engine ./x/imagegen/cmd/engine
go build -o engine ./x/imagegen/cmd/engine
```
## Text Generation
```bash
./engine -model /path/to/model -prompt "Hello" -max-tokens 100
```
Options:
- `-temperature` - sampling temperature (default 0.7)
- `-top-p` - nucleus sampling (default 0.9)
- `-top-k` - top-k sampling (default 40)
Supports: Llama, Gemma3, GPT-OSS
Text generation models are no longer supported by this engine.
## Image Generation

View File

@@ -1,5 +1,3 @@
//go:build mlx
package main
import (

View File

@@ -1,5 +1,3 @@
//go:build mlx
package main
import (

View File

@@ -1,5 +1,3 @@
//go:build mlx
package main
import (
@@ -18,9 +16,6 @@ import (
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/flux2"
"github.com/ollama/ollama/x/imagegen/models/gemma3"
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
"github.com/ollama/ollama/x/imagegen/models/llama"
"github.com/ollama/ollama/x/imagegen/models/zimage"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
@@ -170,11 +165,11 @@ func main() {
log.Fatal(err)
}
// Load image if provided and model supports it
// Load image if provided and model supports it.
var image *mlx.Array
if *imagePath != "" {
if mm, ok := m.(interface{ ImageSize() int32 }); ok {
image, err = gemma3.ProcessImage(*imagePath, mm.ImageSize())
image, err = imagegen.ProcessImage(*imagePath, mm.ImageSize())
if err != nil {
log.Fatal("load image:", err)
}
@@ -236,14 +231,8 @@ func load(modelPath string) (Model, error) {
}
switch kind {
case "gpt_oss":
return gpt_oss.Load(modelPath)
case "gemma3":
return gemma3.Load(modelPath)
case "gemma3_text":
return gemma3.LoadText(modelPath)
default:
return llama.Load(modelPath)
return nil, fmt.Errorf("model type %q is not supported by x/imagegen/cmd/engine", kind)
}
}

View File

@@ -1,5 +1,3 @@
//go:build mlx
package main
import "github.com/ollama/ollama/x/imagegen/mlx"

View File

@@ -1,5 +1,3 @@
//go:build mlx
package imagegen
import (

View File

@@ -1,6 +1,4 @@
//go:build mlx
package gemma3
package imagegen
import (
"fmt"
@@ -13,8 +11,8 @@ import (
"golang.org/x/image/draw"
)
// ProcessImage loads and preprocesses an image for the vision tower
// Returns [1, H, W, C] tensor in NHWC format normalized for SigLIP
// ProcessImage loads and preprocesses an image for multimodal vision towers.
// Returns [1, H, W, C] tensor in NHWC format normalized for SigLIP.
func ProcessImage(path string, imageSize int32) (*mlx.Array, error) {
f, err := os.Open(path)
if err != nil {
@@ -30,20 +28,20 @@ func ProcessImage(path string, imageSize int32) (*mlx.Array, error) {
return ProcessImageData(img, imageSize)
}
// ProcessImageData preprocesses an image.Image for the vision tower
// ProcessImageData preprocesses an image.Image for multimodal vision towers.
func ProcessImageData(img image.Image, imageSize int32) (*mlx.Array, error) {
// Resize to target size using bilinear interpolation
// Resize to target size using bilinear interpolation.
resized := image.NewRGBA(image.Rect(0, 0, int(imageSize), int(imageSize)))
draw.BiLinear.Scale(resized, resized.Bounds(), img, img.Bounds(), draw.Over, nil)
// Convert to float32 array [H, W, C] and normalize
// SigLIP normalization: (pixel / 255.0 - 0.5) / 0.5 = pixel / 127.5 - 1.0
// Convert to float32 array [H, W, C] and normalize.
// SigLIP normalization: (pixel / 255.0 - 0.5) / 0.5 = pixel / 127.5 - 1.0.
data := make([]float32, imageSize*imageSize*3)
idx := 0
for y := int32(0); y < imageSize; y++ {
for x := int32(0); x < imageSize; x++ {
r, g, b, _ := resized.At(int(x), int(y)).RGBA()
// RGBA returns 16-bit values, convert to 8-bit
// RGBA returns 16-bit values, convert to 8-bit.
data[idx] = float32(r>>8)/127.5 - 1.0
data[idx+1] = float32(g>>8)/127.5 - 1.0
data[idx+2] = float32(b>>8)/127.5 - 1.0
@@ -51,8 +49,8 @@ func ProcessImageData(img image.Image, imageSize int32) (*mlx.Array, error) {
}
}
// Create MLX array [1, H, W, C] for NHWC layout
// Create MLX array [1, H, W, C] for NHWC layout.
arr := mlx.NewArrayFloat32(data, []int32{1, imageSize, imageSize, 3})
mlx.Eval(arr) // Materialize to prevent use-after-free
mlx.Eval(arr) // Materialize to prevent use-after-free.
return arr, nil
}

View File

@@ -1,5 +1,3 @@
//go:build mlx
package imagegen
import (

View File

@@ -1,420 +0,0 @@
//go:build mlx
package imagegen
import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"sync"
"time"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/glm4_moe_lite"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// TextModel is the interface for LLM text generation models.
type TextModel interface {
Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array
NewCache(maxSeqLen int32) []cache.Cache
Tokenizer() *tokenizer.Tokenizer
VocabSize() int32
MaxContextLength() int32
NumLayers() int
}
// llmState holds the state for LLM generation
type llmState struct {
model TextModel
}
var llmMu sync.Mutex
// Dedicated stream for generation (like mlx-lm's generation_stream)
var generationStream *mlx.Stream
// withStream runs fn with the generation stream as default
func withStream(fn func()) {
// Lazy initialization of generationStream
if generationStream == nil {
generationStream = mlx.NewStream()
}
orig := mlx.GetDefaultStream()
mlx.SetDefaultStream(generationStream)
fn()
mlx.SetDefaultStream(orig)
}
// Decoder wraps model + cache for autoregressive generation.
// This matches the pattern from cmd/engine/generate.go
type Decoder struct {
model TextModel
caches []cache.Cache
vocabSize int32
temp float32
token *mlx.Array // Current token (kept across iterations)
oldCacheState []*mlx.Array // Preallocated slice for old cache state
}
func NewDecoder(m TextModel, temp float32) *Decoder {
caches := m.NewCache(0)
return &Decoder{
model: m,
caches: caches,
vocabSize: m.VocabSize(),
temp: temp,
oldCacheState: make([]*mlx.Array, 0, len(caches)*2),
}
}
func (d *Decoder) prefill(inputIDs []int32) int {
processed := 0
// Track old cache state to free after each chunk
var oldCacheState []*mlx.Array
// Process all-but-1 tokens in chunks, eval cache state for memory management
for len(inputIDs) > 1 {
chunkSize := min(2048, len(inputIDs)-1)
if chunkSize <= 0 {
break
}
chunk := inputIDs[:chunkSize]
// Save old cache state before forward
oldCacheState = oldCacheState[:0]
for _, c := range d.caches {
oldCacheState = append(oldCacheState, c.State()...)
}
var cacheState []*mlx.Array
withStream(func() {
x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))})
d.model.Forward(x, d.caches)
for _, c := range d.caches {
cacheState = append(cacheState, c.State()...)
}
})
mlx.Eval(cacheState...)
// Free old cache state
for _, arr := range oldCacheState {
if arr != nil {
arr.Free()
}
}
inputIDs = inputIDs[chunkSize:]
processed += chunkSize
}
// Save old cache state before final step
oldCacheState = oldCacheState[:0]
for _, c := range d.caches {
oldCacheState = append(oldCacheState, c.State()...)
}
// Final token + sampling
withStream(func() {
x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))})
mlx.Eval(x) // Materialize before any other evals
logits := d.model.Forward(x, d.caches)
d.token = sample(logits, d.temp, d.vocabSize)
})
// Keep cache state (token auto-kept by AsyncEval)
for _, c := range d.caches {
mlx.Keep(c.State()...)
}
mlx.AsyncEval(d.token)
// Free old cache state from before final step
for _, arr := range oldCacheState {
if arr != nil {
arr.Free()
}
}
mlx.ClearCache()
return processed + len(inputIDs)
}
func (d *Decoder) step() int32 {
prevToken := d.token
// Save old cache state (reuse preallocated slice)
d.oldCacheState = d.oldCacheState[:0]
for _, c := range d.caches {
d.oldCacheState = append(d.oldCacheState, c.State()...)
}
withStream(func() {
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
d.token = sample(logits, d.temp, d.vocabSize)
})
// Keep token and new cache state so they survive cleanup
mlx.Keep(d.token)
for _, c := range d.caches {
mlx.Keep(c.State()...)
}
mlx.AsyncEval(d.token)
// Sync on previous token (GPU already working on next step)
val := prevToken.ItemInt32()
// Free old token and old cache state
prevToken.Free()
for _, arr := range d.oldCacheState {
arr.Free()
}
return val
}
// sample samples from logits using temperature scaling
func sample(logits *mlx.Array, temp float32, vocabSize int32) *mlx.Array {
// Get last position logits: [1, L, vocab] -> [vocab]
shape := logits.Shape()
seqLen := shape[1]
lastLogits := mlx.Slice(logits, []int32{0, seqLen - 1, 0}, []int32{1, seqLen, vocabSize})
lastLogits = mlx.Reshape(lastLogits, vocabSize)
if temp <= 0 || temp < 0.01 {
// Greedy decoding
return mlx.Argmax(lastLogits, -1, false)
}
// Apply temperature scaling
scaled := mlx.DivScalar(lastLogits, temp)
return mlx.RandomCategorical(scaled, -1, 1)
}
// loadLLMModel loads a safetensors LLM model and its tokenizer from manifest storage.
func (s *server) loadLLMModel() error {
// Load the manifest to get model information
modelManifest, err := manifest.LoadManifest(s.modelName)
if err != nil {
return fmt.Errorf("failed to load manifest: %w", err)
}
// Detect model architecture from config.json
configData, err := modelManifest.ReadConfig("config.json")
if err != nil {
return fmt.Errorf("failed to read config.json: %w", err)
}
var modelConfig struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(configData, &modelConfig); err != nil {
return fmt.Errorf("failed to parse config.json: %w", err)
}
arch := ""
if len(modelConfig.Architectures) > 0 {
arch = modelConfig.Architectures[0]
}
if arch == "" {
arch = modelConfig.ModelType
}
slog.Info("detected LLM architecture", "architecture", arch, "model_type", modelConfig.ModelType)
// Load the appropriate model based on architecture
var model TextModel
archLower := strings.ToLower(arch)
switch {
case strings.Contains(archLower, "glm4moelite"):
m, err := glm4_moe_lite.LoadFromManifest(modelManifest)
if err != nil {
return fmt.Errorf("failed to load glm4-moe-lite model: %w", err)
}
model = m
slog.Info("loaded glm4-moe-lite model", "vocab_size", m.VocabSize(), "layers", m.NumLayers())
default:
return fmt.Errorf("LLM architecture %q is not yet supported. "+
"Supported architectures: glm4-moe-lite. "+
"Please convert your model to GGUF format or use a supported architecture", arch)
}
s.llmModel = &llmState{
model: model,
}
return nil
}
// handleLLMCompletion handles LLM text generation requests.
func (s *server) handleLLMCompletion(w http.ResponseWriter, r *http.Request, req Request) {
if s.llmModel == nil {
http.Error(w, "LLM model not loaded", http.StatusInternalServerError)
return
}
// Serialize generation requests
llmMu.Lock()
defer llmMu.Unlock()
if err := s.llmGenerate(w, r, req); err != nil {
slog.Error("LLM generation failed", "error", err)
// Don't send error if we've already started streaming
}
}
// llmGenerate runs the generation loop using the Decoder pattern from cmd/engine
func (s *server) llmGenerate(w http.ResponseWriter, r *http.Request, req Request) error {
state := s.llmModel
// Set up streaming response
w.Header().Set("Content-Type", "application/x-ndjson")
w.Header().Set("Transfer-Encoding", "chunked")
flusher, ok := w.(http.Flusher)
if !ok {
return errors.New("streaming not supported")
}
tok := state.model.Tokenizer()
// The prompt is already formatted by the server using the model's renderer
// (see server/prompt.go renderPrompt), so we don't apply FormatPrompt here.
prompt := req.Prompt
// Tokenize the prompt
inputIDs := tok.Encode(prompt, true)
slog.Debug("tokenized prompt", "num_tokens", len(inputIDs))
// Generation parameters
maxTokens := int(state.model.MaxContextLength())
if maxTokens <= 0 {
maxTokens = 4096
}
if req.Options != nil && req.Options.NumPredict > 0 {
maxTokens = req.Options.NumPredict
}
temperature := float32(0.7)
if req.Options != nil && req.Options.Temperature > 0 {
temperature = float32(req.Options.Temperature)
}
// Enable MLX compilation for better performance
mlx.EnableCompile()
// Create decoder with fresh caches
dec := NewDecoder(state.model, temperature)
prefillStart := time.Now()
prefillTokens := dec.prefill(inputIDs)
// Prefill measurement includes time to first token
firstToken := dec.step()
prefillDuration := time.Since(prefillStart)
promptEvalDuration := prefillDuration
enc := json.NewEncoder(w)
ctx := r.Context()
generated := 0
stopReason := "max_tokens"
// Handle first token
generated++
if tok.IsEOS(firstToken) {
resp := Response{
Done: true,
StopReason: fmt.Sprintf("first_token_eos:%d", firstToken),
PromptEvalCount: prefillTokens,
PromptEvalDuration: int(promptEvalDuration.Nanoseconds()),
}
enc.Encode(resp)
flusher.Flush()
return nil
}
text := tok.Decode([]int32{firstToken})
resp := Response{Content: text}
enc.Encode(resp)
flusher.Flush()
genStart := time.Now()
// Generation loop
for n := 1; n < maxTokens; n++ {
// Check for cancellation
select {
case <-ctx.Done():
stopReason = fmt.Sprintf("context_cancelled:%d", generated)
break
default:
}
if stopReason != "max_tokens" {
break
}
token := dec.step()
generated++
if tok.IsEOS(token) {
stopReason = fmt.Sprintf("eos_token:%d", token)
break
}
text := tok.Decode([]int32{token})
// Check for stop sequences
if req.Options != nil && len(req.Options.Stop) > 0 {
shouldStop := false
var matchedStop string
for _, stop := range req.Options.Stop {
if strings.Contains(text, stop) {
text = strings.Split(text, stop)[0]
shouldStop = true
matchedStop = stop
break
}
}
if shouldStop {
if text != "" {
resp := Response{Content: text}
enc.Encode(resp)
flusher.Flush()
}
stopReason = fmt.Sprintf("stop_sequence:%s", matchedStop)
break
}
}
resp := Response{Content: text}
enc.Encode(resp)
flusher.Flush()
// Periodically clear MLX cache
if n%256 == 0 {
mlx.ClearCache()
}
}
// Clean up
mlx.ClearCache()
// Send final response with stats
evalDuration := time.Since(genStart)
resp = Response{
Done: true,
StopReason: fmt.Sprintf("%s:generated=%d", stopReason, generated),
PromptEvalCount: prefillTokens,
PromptEvalDuration: int(promptEvalDuration.Nanoseconds()),
EvalCount: generated,
EvalDuration: int(evalDuration.Nanoseconds()),
}
enc.Encode(resp)
flusher.Flush()
return nil
}

View File

@@ -1,5 +1,3 @@
//go:build mlx
package manifest
import (

Some files were not shown because too many files have changed in this diff Show More