mirror of
https://github.com/ollama/ollama.git
synced 2026-04-24 09:46:01 +02:00
Compare commits
43 Commits
pdevine/sa
...
jmorganca/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4fdeb59325 | ||
|
|
61086083eb | ||
|
|
62d1f01ab4 | ||
|
|
10e51c5177 | ||
|
|
3e06bde643 | ||
|
|
6be2de8214 | ||
|
|
ebb1b9ec14 | ||
|
|
d126467d5d | ||
|
|
afb4c62fbf | ||
|
|
e790dc435b | ||
|
|
288077c3a3 | ||
|
|
4425c54eda | ||
|
|
778899a5d2 | ||
|
|
4eab60c1e2 | ||
|
|
1af850e6e3 | ||
|
|
9b0c7cc7b9 | ||
|
|
6928630601 | ||
|
|
9896e3627f | ||
|
|
15732f0ea7 | ||
|
|
562c76d7cc | ||
|
|
122c68c151 | ||
|
|
82848a7806 | ||
|
|
39982a954e | ||
|
|
e9f6ea232f | ||
|
|
110eff01a9 | ||
|
|
799e51d419 | ||
|
|
e8fcb29586 | ||
|
|
97d2f05a6d | ||
|
|
8207e55ec7 | ||
|
|
ad16bffc7d | ||
|
|
c1e3ef4bcc | ||
|
|
a3093cd5e5 | ||
|
|
23d4cad1a2 | ||
|
|
86513cb697 | ||
|
|
3490e9590b | ||
|
|
8da09b1e7e | ||
|
|
a60b9adcce | ||
|
|
a16f96658b | ||
|
|
18ab09b431 | ||
|
|
638faeac54 | ||
|
|
dd5eb6337d | ||
|
|
79917cf80b | ||
|
|
cc90a035a0 |
50
.github/workflows/release.yaml
vendored
50
.github/workflows/release.yaml
vendored
@@ -117,6 +117,25 @@ jobs:
|
|||||||
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
||||||
flags: ''
|
flags: ''
|
||||||
runner_dir: 'vulkan'
|
runner_dir: 'vulkan'
|
||||||
|
- os: windows
|
||||||
|
arch: amd64
|
||||||
|
preset: 'MLX CUDA 13'
|
||||||
|
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
|
||||||
|
cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip
|
||||||
|
cuda-components:
|
||||||
|
- '"cudart"'
|
||||||
|
- '"nvcc"'
|
||||||
|
- '"cublas"'
|
||||||
|
- '"cublas_dev"'
|
||||||
|
- '"cufft"'
|
||||||
|
- '"cufft_dev"'
|
||||||
|
- '"nvrtc"'
|
||||||
|
- '"nvrtc_dev"'
|
||||||
|
- '"crt"'
|
||||||
|
- '"nvvm"'
|
||||||
|
- '"nvptxcompiler"'
|
||||||
|
cuda-version: '13.0'
|
||||||
|
flags: ''
|
||||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||||
environment: release
|
environment: release
|
||||||
env:
|
env:
|
||||||
@@ -125,8 +144,10 @@ jobs:
|
|||||||
- name: Install system dependencies
|
- name: Install system dependencies
|
||||||
run: |
|
run: |
|
||||||
choco install -y --no-progress ccache ninja
|
choco install -y --no-progress ccache ninja
|
||||||
|
if (Get-Command ccache -ErrorAction SilentlyContinue) {
|
||||||
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
||||||
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan')
|
}
|
||||||
|
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan') || startsWith(matrix.preset, 'MLX ')
|
||||||
id: cache-install
|
id: cache-install
|
||||||
uses: actions/cache/restore@v4
|
uses: actions/cache/restore@v4
|
||||||
with:
|
with:
|
||||||
@@ -134,8 +155,9 @@ jobs:
|
|||||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||||
C:\Program Files\AMD\ROCm
|
C:\Program Files\AMD\ROCm
|
||||||
C:\VulkanSDK
|
C:\VulkanSDK
|
||||||
key: ${{ matrix.install }}
|
C:\Program Files\NVIDIA\CUDNN
|
||||||
- if: startsWith(matrix.preset, 'CUDA ')
|
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||||
|
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'MLX ')
|
||||||
name: Install CUDA ${{ matrix.cuda-version }}
|
name: Install CUDA ${{ matrix.cuda-version }}
|
||||||
run: |
|
run: |
|
||||||
$ErrorActionPreference = "Stop"
|
$ErrorActionPreference = "Stop"
|
||||||
@@ -179,6 +201,23 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
- if: startsWith(matrix.preset, 'MLX ')
|
||||||
|
name: Install cuDNN for MLX
|
||||||
|
run: |
|
||||||
|
$ErrorActionPreference = "Stop"
|
||||||
|
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
|
||||||
|
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||||
|
Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip"
|
||||||
|
Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted
|
||||||
|
$cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName
|
||||||
|
New-Item -ItemType Directory -Force -Path $cudnnRoot
|
||||||
|
Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse
|
||||||
|
}
|
||||||
|
|
||||||
|
echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||||
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
||||||
uses: actions/cache/save@v4
|
uses: actions/cache/save@v4
|
||||||
with:
|
with:
|
||||||
@@ -186,7 +225,8 @@ jobs:
|
|||||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||||
C:\Program Files\AMD\ROCm
|
C:\Program Files\AMD\ROCm
|
||||||
C:\VulkanSDK
|
C:\VulkanSDK
|
||||||
key: ${{ matrix.install }}
|
C:\Program Files\NVIDIA\CUDNN
|
||||||
|
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/cache@v4
|
- uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
@@ -198,7 +238,7 @@ jobs:
|
|||||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}"
|
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}"
|
||||||
cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}"
|
cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}"
|
||||||
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
|
cmake --install build --component "${{ startsWith(matrix.preset, 'MLX ') && 'MLX' || startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
|
||||||
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
|
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
|
||||||
env:
|
env:
|
||||||
CMAKE_GENERATOR: Ninja
|
CMAKE_GENERATOR: Ninja
|
||||||
|
|||||||
62
.github/workflows/test.yaml
vendored
62
.github/workflows/test.yaml
vendored
@@ -37,7 +37,7 @@ jobs:
|
|||||||
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
|
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
|
||||||
}
|
}
|
||||||
|
|
||||||
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
|
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*' '.github/**/*') | tee -a $GITHUB_OUTPUT
|
||||||
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
|
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
|
||||||
|
|
||||||
linux:
|
linux:
|
||||||
@@ -51,7 +51,7 @@ jobs:
|
|||||||
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
|
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
|
||||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
||||||
- preset: ROCm
|
- preset: ROCm
|
||||||
container: rocm/dev-ubuntu-22.04:6.1.2
|
container: rocm/dev-ubuntu-22.04:7.2
|
||||||
extra-packages: rocm-libs
|
extra-packages: rocm-libs
|
||||||
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm'
|
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm'
|
||||||
- preset: Vulkan
|
- preset: Vulkan
|
||||||
@@ -60,6 +60,10 @@ jobs:
|
|||||||
mesa-vulkan-drivers vulkan-tools
|
mesa-vulkan-drivers vulkan-tools
|
||||||
libvulkan1 libvulkan-dev
|
libvulkan1 libvulkan-dev
|
||||||
vulkan-sdk cmake ccache g++ make
|
vulkan-sdk cmake ccache g++ make
|
||||||
|
- preset: 'MLX CUDA 13'
|
||||||
|
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
|
||||||
|
extra-packages: libcudnn9-dev-cuda-13 libopenblas-dev liblapack-dev liblapacke-dev git curl
|
||||||
|
flags: '-DCMAKE_CUDA_ARCHITECTURES=87 -DBLAS_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu -DLAPACK_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu'
|
||||||
runs-on: linux
|
runs-on: linux
|
||||||
container: ${{ matrix.container }}
|
container: ${{ matrix.container }}
|
||||||
steps:
|
steps:
|
||||||
@@ -76,6 +80,10 @@ jobs:
|
|||||||
$sudo apt-get update
|
$sudo apt-get update
|
||||||
fi
|
fi
|
||||||
$sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }}
|
$sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }}
|
||||||
|
# MLX requires CMake 3.25+, install from official releases
|
||||||
|
if [ "${{ matrix.preset }}" = "MLX CUDA 13" ]; then
|
||||||
|
curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.31.2/cmake-3.31.2-linux-$(uname -m).tar.gz | $sudo tar xz -C /usr/local --strip-components 1
|
||||||
|
fi
|
||||||
# Export VULKAN_SDK if provided by LunarG package (defensive)
|
# Export VULKAN_SDK if provided by LunarG package (defensive)
|
||||||
if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then
|
if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then
|
||||||
echo "VULKAN_SDK=/usr" >> $GITHUB_ENV
|
echo "VULKAN_SDK=/usr" >> $GITHUB_ENV
|
||||||
@@ -87,8 +95,8 @@ jobs:
|
|||||||
path: /github/home/.cache/ccache
|
path: /github/home/.cache/ccache
|
||||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
|
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
|
||||||
- run: |
|
- run: |
|
||||||
cmake --preset ${{ matrix.preset }} ${{ matrix.flags }}
|
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
|
||||||
cmake --build --preset ${{ matrix.preset }} --parallel
|
cmake --build --preset "${{ matrix.preset }}" --parallel
|
||||||
|
|
||||||
windows:
|
windows:
|
||||||
needs: [changes]
|
needs: [changes]
|
||||||
@@ -114,12 +122,31 @@ jobs:
|
|||||||
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||||
- preset: Vulkan
|
- preset: Vulkan
|
||||||
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
||||||
|
- preset: 'MLX CUDA 13'
|
||||||
|
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
|
||||||
|
cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip
|
||||||
|
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
|
||||||
|
cuda-components:
|
||||||
|
- '"cudart"'
|
||||||
|
- '"nvcc"'
|
||||||
|
- '"cublas"'
|
||||||
|
- '"cublas_dev"'
|
||||||
|
- '"cufft"'
|
||||||
|
- '"cufft_dev"'
|
||||||
|
- '"nvrtc"'
|
||||||
|
- '"nvrtc_dev"'
|
||||||
|
- '"crt"'
|
||||||
|
- '"nvvm"'
|
||||||
|
- '"nvptxcompiler"'
|
||||||
|
cuda-version: '13.0'
|
||||||
runs-on: windows
|
runs-on: windows
|
||||||
steps:
|
steps:
|
||||||
- run: |
|
- run: |
|
||||||
choco install -y --no-progress ccache ninja
|
choco install -y --no-progress ccache ninja
|
||||||
|
if (Get-Command ccache -ErrorAction SilentlyContinue) {
|
||||||
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
||||||
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan'
|
}
|
||||||
|
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan' || matrix.preset == 'MLX CUDA 13'
|
||||||
id: cache-install
|
id: cache-install
|
||||||
uses: actions/cache/restore@v4
|
uses: actions/cache/restore@v4
|
||||||
with:
|
with:
|
||||||
@@ -127,8 +154,9 @@ jobs:
|
|||||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||||
C:\Program Files\AMD\ROCm
|
C:\Program Files\AMD\ROCm
|
||||||
C:\VulkanSDK
|
C:\VulkanSDK
|
||||||
key: ${{ matrix.install }}
|
C:\Program Files\NVIDIA\CUDNN
|
||||||
- if: matrix.preset == 'CUDA'
|
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||||
|
- if: matrix.preset == 'CUDA' || matrix.preset == 'MLX CUDA 13'
|
||||||
name: Install CUDA ${{ matrix.cuda-version }}
|
name: Install CUDA ${{ matrix.cuda-version }}
|
||||||
run: |
|
run: |
|
||||||
$ErrorActionPreference = "Stop"
|
$ErrorActionPreference = "Stop"
|
||||||
@@ -168,6 +196,23 @@ jobs:
|
|||||||
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
|
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
|
||||||
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||||
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
|
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
|
||||||
|
- if: matrix.preset == 'MLX CUDA 13'
|
||||||
|
name: Install cuDNN for MLX
|
||||||
|
run: |
|
||||||
|
$ErrorActionPreference = "Stop"
|
||||||
|
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
|
||||||
|
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||||
|
Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip"
|
||||||
|
Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted
|
||||||
|
$cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName
|
||||||
|
New-Item -ItemType Directory -Force -Path $cudnnRoot
|
||||||
|
Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse
|
||||||
|
}
|
||||||
|
|
||||||
|
echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||||
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
||||||
uses: actions/cache/save@v4
|
uses: actions/cache/save@v4
|
||||||
with:
|
with:
|
||||||
@@ -175,7 +220,8 @@ jobs:
|
|||||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||||
C:\Program Files\AMD\ROCm
|
C:\Program Files\AMD\ROCm
|
||||||
C:\VulkanSDK
|
C:\VulkanSDK
|
||||||
key: ${{ matrix.install }}
|
C:\Program Files\NVIDIA\CUDNN
|
||||||
|
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/cache@v4
|
- uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
|
|||||||
108
CMakeLists.txt
108
CMakeLists.txt
@@ -64,10 +64,15 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR})
|
|||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
|
||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${OLLAMA_BUILD_DIR})
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${OLLAMA_BUILD_DIR})
|
||||||
|
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
# Store ggml include paths for use with target_include_directories later.
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include)
|
# We avoid global include_directories() to prevent polluting the include path
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
|
# for other projects like MLX (whose openblas dependency has its own common.h).
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
|
set(GGML_INCLUDE_DIRS
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx
|
||||||
|
)
|
||||||
|
|
||||||
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
|
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
|
||||||
|
|
||||||
@@ -87,6 +92,14 @@ if(NOT CPU_VARIANTS)
|
|||||||
set(CPU_VARIANTS "ggml-cpu")
|
set(CPU_VARIANTS "ggml-cpu")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# Apply ggml include directories to ggml targets only (not globally)
|
||||||
|
target_include_directories(ggml-base PRIVATE ${GGML_INCLUDE_DIRS})
|
||||||
|
foreach(variant ${CPU_VARIANTS})
|
||||||
|
if(TARGET ${variant})
|
||||||
|
target_include_directories(${variant} PRIVATE ${GGML_INCLUDE_DIRS})
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
|
||||||
install(TARGETS ggml-base ${CPU_VARIANTS}
|
install(TARGETS ggml-base ${CPU_VARIANTS}
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
PRE_EXCLUDE_REGEXES ".*"
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
@@ -103,6 +116,7 @@ if(CMAKE_CUDA_COMPILER)
|
|||||||
|
|
||||||
find_package(CUDAToolkit)
|
find_package(CUDAToolkit)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
||||||
|
target_include_directories(ggml-cuda PRIVATE ${GGML_INCLUDE_DIRS})
|
||||||
install(TARGETS ggml-cuda
|
install(TARGETS ggml-cuda
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||||
@@ -134,6 +148,7 @@ if(CMAKE_HIP_COMPILER)
|
|||||||
if(AMDGPU_TARGETS)
|
if(AMDGPU_TARGETS)
|
||||||
find_package(hip REQUIRED)
|
find_package(hip REQUIRED)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
||||||
|
target_include_directories(ggml-hip PRIVATE ${GGML_INCLUDE_DIRS})
|
||||||
|
|
||||||
if (WIN32)
|
if (WIN32)
|
||||||
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
|
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
|
||||||
@@ -148,7 +163,7 @@ if(CMAKE_HIP_COMPILER)
|
|||||||
)
|
)
|
||||||
install(RUNTIME_DEPENDENCY_SET rocm
|
install(RUNTIME_DEPENDENCY_SET rocm
|
||||||
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
|
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
|
||||||
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf
|
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register roctx64 rocroller drm drm_amdgpu numa elf
|
||||||
PRE_EXCLUDE_REGEXES ".*"
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
POST_EXCLUDE_REGEXES "system32"
|
POST_EXCLUDE_REGEXES "system32"
|
||||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
|
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
|
||||||
@@ -168,6 +183,7 @@ if(NOT APPLE)
|
|||||||
find_package(Vulkan)
|
find_package(Vulkan)
|
||||||
if(Vulkan_FOUND)
|
if(Vulkan_FOUND)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
||||||
|
target_include_directories(ggml-vulkan PRIVATE ${GGML_INCLUDE_DIRS})
|
||||||
install(TARGETS ggml-vulkan
|
install(TARGETS ggml-vulkan
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
PRE_INCLUDE_REGEXES vulkan
|
PRE_INCLUDE_REGEXES vulkan
|
||||||
@@ -179,7 +195,6 @@ if(NOT APPLE)
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
option(MLX_ENGINE "Enable MLX backend" OFF)
|
option(MLX_ENGINE "Enable MLX backend" OFF)
|
||||||
|
|
||||||
if(MLX_ENGINE)
|
if(MLX_ENGINE)
|
||||||
message(STATUS "Setting up MLX (this takes a while...)")
|
message(STATUS "Setting up MLX (this takes a while...)")
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx)
|
||||||
@@ -187,10 +202,36 @@ if(MLX_ENGINE)
|
|||||||
# Find CUDA toolkit if MLX is built with CUDA support
|
# Find CUDA toolkit if MLX is built with CUDA support
|
||||||
find_package(CUDAToolkit)
|
find_package(CUDAToolkit)
|
||||||
|
|
||||||
|
# Build list of directories for runtime dependency resolution
|
||||||
|
set(MLX_RUNTIME_DIRS ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR})
|
||||||
|
# Add cuDNN bin paths for DLLs (Windows MLX CUDA builds)
|
||||||
|
# CUDNN_ROOT_DIR is the standard CMake variable for cuDNN location
|
||||||
|
if(DEFINED ENV{CUDNN_ROOT_DIR})
|
||||||
|
# cuDNN 9.x has versioned subdirectories under bin/ (e.g., bin/13.0/)
|
||||||
|
file(GLOB CUDNN_BIN_SUBDIRS "$ENV{CUDNN_ROOT_DIR}/bin/*")
|
||||||
|
list(APPEND MLX_RUNTIME_DIRS ${CUDNN_BIN_SUBDIRS})
|
||||||
|
endif()
|
||||||
|
# Add build output directory and MLX dependency build directories
|
||||||
|
list(APPEND MLX_RUNTIME_DIRS ${OLLAMA_BUILD_DIR})
|
||||||
|
# OpenBLAS DLL location (pre-built zip extracts into openblas-src/bin/)
|
||||||
|
list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/openblas-src/bin)
|
||||||
|
# NCCL: on Linux, if real NCCL is found, cmake bundles libnccl.so via the
|
||||||
|
# regex below. If NCCL is not found, MLX links a static stub (OBJECT lib)
|
||||||
|
# so there is no runtime dependency. This path covers the stub build dir
|
||||||
|
# for windows so we include the DLL in our dependencies.
|
||||||
|
list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/distributed/nccl/nccl_stub-prefix/src/nccl_stub-build/Release)
|
||||||
|
|
||||||
|
# Base regexes for runtime dependencies (cross-platform)
|
||||||
|
set(MLX_INCLUDE_REGEXES cublas cublasLt cudart cufft nvrtc nvrtc-builtins cudnn nccl openblas gfortran)
|
||||||
|
# On Windows, also include dl.dll (dlfcn-win32 POSIX emulation layer)
|
||||||
|
if(WIN32)
|
||||||
|
list(APPEND MLX_INCLUDE_REGEXES "^dl\\.dll$")
|
||||||
|
endif()
|
||||||
|
|
||||||
install(TARGETS mlx mlxc
|
install(TARGETS mlx mlxc
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
DIRECTORIES ${MLX_RUNTIME_DIRS}
|
||||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
|
PRE_INCLUDE_REGEXES ${MLX_INCLUDE_REGEXES}
|
||||||
PRE_EXCLUDE_REGEXES ".*"
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||||
@@ -205,13 +246,54 @@ if(MLX_ENGINE)
|
|||||||
COMPONENT MLX)
|
COMPONENT MLX)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Manually install cudart and cublas since they might not be picked up as direct dependencies
|
# Install CCCL headers for NVRTC JIT compilation at runtime.
|
||||||
|
# MLX's own install rules use the default component so they get skipped by
|
||||||
|
# --component MLX. Headers are installed alongside libmlx in OLLAMA_INSTALL_DIR.
|
||||||
|
# On Linux, MLX's jit_module.cpp resolves CCCL via
|
||||||
|
# current_binary_dir().parent_path() / "include" / "cccl", so we create a
|
||||||
|
# symlink from lib/ollama/include -> ${OLLAMA_RUNNER_DIR}/include
|
||||||
|
# This will need refinement if we add multiple CUDA versions for MLX in the future.
|
||||||
|
if(EXISTS ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda)
|
||||||
|
install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda
|
||||||
|
DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl
|
||||||
|
COMPONENT MLX)
|
||||||
|
install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/nv
|
||||||
|
DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl
|
||||||
|
COMPONENT MLX)
|
||||||
|
if(NOT WIN32 AND NOT APPLE)
|
||||||
|
install(CODE "
|
||||||
|
set(_link \"${CMAKE_INSTALL_PREFIX}/lib/ollama/include\")
|
||||||
|
set(_target \"${OLLAMA_RUNNER_DIR}/include\")
|
||||||
|
if(NOT EXISTS \${_link})
|
||||||
|
execute_process(COMMAND \${CMAKE_COMMAND} -E create_symlink \${_target} \${_link})
|
||||||
|
endif()
|
||||||
|
" COMPONENT MLX)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# On Windows, explicitly install dl.dll (dlfcn-win32 POSIX dlopen emulation)
|
||||||
|
# RUNTIME_DEPENDENCIES auto-excludes it via POST_EXCLUDE_FILES_STRICT because
|
||||||
|
# dlfcn-win32 is a known CMake target with its own install rules (which install
|
||||||
|
# to the wrong destination). We must install it explicitly here.
|
||||||
|
if(WIN32)
|
||||||
|
install(FILES ${OLLAMA_BUILD_DIR}/dl.dll
|
||||||
|
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||||
|
COMPONENT MLX)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Manually install CUDA runtime libraries that MLX loads via dlopen
|
||||||
|
# (not detected by RUNTIME_DEPENDENCIES since they aren't link-time deps)
|
||||||
if(CUDAToolkit_FOUND)
|
if(CUDAToolkit_FOUND)
|
||||||
file(GLOB CUDART_LIBS
|
file(GLOB MLX_CUDA_LIBS
|
||||||
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
|
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
|
||||||
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
|
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*"
|
||||||
if(CUDART_LIBS)
|
"${CUDAToolkit_LIBRARY_DIR}/libcublasLt.so*"
|
||||||
install(FILES ${CUDART_LIBS}
|
"${CUDAToolkit_LIBRARY_DIR}/libnvrtc.so*"
|
||||||
|
"${CUDAToolkit_LIBRARY_DIR}/libnvrtc-builtins.so*"
|
||||||
|
"${CUDAToolkit_LIBRARY_DIR}/libcufft.so*"
|
||||||
|
"${CUDAToolkit_LIBRARY_DIR}/libcudnn.so*")
|
||||||
|
if(MLX_CUDA_LIBS)
|
||||||
|
install(FILES ${MLX_CUDA_LIBS}
|
||||||
DESTINATION ${OLLAMA_INSTALL_DIR}
|
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||||
COMPONENT MLX)
|
COMPONENT MLX)
|
||||||
endif()
|
endif()
|
||||||
|
|||||||
@@ -77,6 +77,15 @@
|
|||||||
"OLLAMA_RUNNER_DIR": "rocm"
|
"OLLAMA_RUNNER_DIR": "rocm"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "ROCm 7",
|
||||||
|
"inherits": [ "ROCm" ],
|
||||||
|
"cacheVariables": {
|
||||||
|
"CMAKE_HIP_FLAGS": "-parallel-jobs=4",
|
||||||
|
"AMDGPU_TARGETS": "gfx942;gfx950;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1151;gfx1200;gfx1201;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-",
|
||||||
|
"OLLAMA_RUNNER_DIR": "rocm"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "Vulkan",
|
"name": "Vulkan",
|
||||||
"inherits": [ "Default" ],
|
"inherits": [ "Default" ],
|
||||||
@@ -103,6 +112,7 @@
|
|||||||
"name": "MLX CUDA 13",
|
"name": "MLX CUDA 13",
|
||||||
"inherits": [ "MLX", "CUDA 13" ],
|
"inherits": [ "MLX", "CUDA 13" ],
|
||||||
"cacheVariables": {
|
"cacheVariables": {
|
||||||
|
"MLX_CUDA_ARCHITECTURES": "86;89;90;90a;100;103;75-virtual;80-virtual;110-virtual;120-virtual;121-virtual",
|
||||||
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
|
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -158,6 +168,11 @@
|
|||||||
"inherits": [ "ROCm" ],
|
"inherits": [ "ROCm" ],
|
||||||
"configurePreset": "ROCm 6"
|
"configurePreset": "ROCm 6"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "ROCm 7",
|
||||||
|
"inherits": [ "ROCm" ],
|
||||||
|
"configurePreset": "ROCm 7"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "Vulkan",
|
"name": "Vulkan",
|
||||||
"targets": [ "ggml-vulkan" ],
|
"targets": [ "ggml-vulkan" ],
|
||||||
|
|||||||
122
Dockerfile
122
Dockerfile
@@ -1,28 +1,23 @@
|
|||||||
# vim: filetype=dockerfile
|
# vim: filetype=dockerfile
|
||||||
|
|
||||||
ARG FLAVOR=${TARGETARCH}
|
ARG FLAVOR=${TARGETARCH}
|
||||||
ARG PARALLEL=8
|
|
||||||
|
|
||||||
ARG ROCMVERSION=6.3.3
|
ARG ROCMVERSION=7.2
|
||||||
ARG JETPACK5VERSION=r35.4.1
|
ARG JETPACK5VERSION=r35.4.1
|
||||||
ARG JETPACK6VERSION=r36.4.0
|
ARG JETPACK6VERSION=r36.4.0
|
||||||
ARG CMAKEVERSION=3.31.2
|
ARG CMAKEVERSION=3.31.2
|
||||||
|
ARG NINJAVERSION=1.12.1
|
||||||
ARG VULKANVERSION=1.4.321.1
|
ARG VULKANVERSION=1.4.321.1
|
||||||
|
|
||||||
|
# Default empty stages for local MLX source overrides.
|
||||||
|
# Override with: docker build --build-context local-mlx=../mlx --build-context local-mlx-c=../mlx-c
|
||||||
|
FROM scratch AS local-mlx
|
||||||
|
FROM scratch AS local-mlx-c
|
||||||
|
|
||||||
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
|
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
|
||||||
RUN dnf install -y yum-utils ccache gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ gcc-toolset-11-binutils \
|
RUN dnf install -y yum-utils ccache gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ gcc-toolset-11-binutils \
|
||||||
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
||||||
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
||||||
ARG VULKANVERSION
|
|
||||||
RUN wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
|
|
||||||
&& tar xvf /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
|
|
||||||
&& dnf -y install ninja-build \
|
|
||||||
&& ln -s /usr/bin/python3 /usr/bin/python \
|
|
||||||
&& /${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \
|
|
||||||
&& /${VULKANVERSION}/vulkansdk -j 8 shaderc
|
|
||||||
RUN cp -r /${VULKANVERSION}/x86_64/include/* /usr/local/include/ \
|
|
||||||
&& cp -r /${VULKANVERSION}/x86_64/lib/* /usr/local/lib
|
|
||||||
ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH
|
|
||||||
|
|
||||||
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
|
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
|
||||||
# install epel-release for ccache
|
# install epel-release for ccache
|
||||||
@@ -33,100 +28,119 @@ ENV CC=clang CXX=clang++
|
|||||||
|
|
||||||
FROM base-${TARGETARCH} AS base
|
FROM base-${TARGETARCH} AS base
|
||||||
ARG CMAKEVERSION
|
ARG CMAKEVERSION
|
||||||
|
ARG NINJAVERSION
|
||||||
RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||||
|
RUN dnf install -y unzip \
|
||||||
|
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux$([ "$(uname -m)" = "aarch64" ] && echo "-aarch64").zip \
|
||||||
|
&& unzip /tmp/ninja.zip -d /usr/local/bin \
|
||||||
|
&& rm /tmp/ninja.zip
|
||||||
|
ENV CMAKE_GENERATOR=Ninja
|
||||||
ENV LDFLAGS=-s
|
ENV LDFLAGS=-s
|
||||||
|
|
||||||
FROM base AS cpu
|
FROM base AS cpu
|
||||||
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
|
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
|
||||||
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
||||||
ARG PARALLEL
|
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CPU' \
|
cmake --preset 'CPU' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'CPU' \
|
&& cmake --build --preset 'CPU' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CPU --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CPU --strip
|
||||||
|
|
||||||
FROM base AS cuda-11
|
FROM base AS cuda-11
|
||||||
ARG CUDA11VERSION=11.8
|
ARG CUDA11VERSION=11.8
|
||||||
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
||||||
ARG PARALLEL
|
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 11' \
|
cmake --preset 'CUDA 11' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
|
&& cmake --build --preset 'CUDA 11' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CUDA --strip
|
||||||
|
|
||||||
FROM base AS cuda-12
|
FROM base AS cuda-12
|
||||||
ARG CUDA12VERSION=12.8
|
ARG CUDA12VERSION=12.8
|
||||||
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
||||||
ARG PARALLEL
|
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 12' \
|
cmake --preset 'CUDA 12' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
|
&& cmake --build --preset 'CUDA 12' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CUDA --strip
|
||||||
|
|
||||||
|
|
||||||
FROM base AS cuda-13
|
FROM base AS cuda-13
|
||||||
ARG CUDA13VERSION=13.0
|
ARG CUDA13VERSION=13.0
|
||||||
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
||||||
ARG PARALLEL
|
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 13' \
|
cmake --preset 'CUDA 13' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
|
&& cmake --build --preset 'CUDA 13' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CUDA --strip
|
||||||
|
|
||||||
|
|
||||||
FROM base AS rocm-6
|
FROM base AS rocm-7
|
||||||
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
||||||
ARG PARALLEL
|
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'ROCm 6' \
|
cmake --preset 'ROCm 7' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
|
&& cmake --build --preset 'ROCm 7' -- -l $(nproc) \
|
||||||
&& cmake --install build --component HIP --strip --parallel ${PARALLEL}
|
&& cmake --install build --component HIP --strip
|
||||||
RUN rm -f dist/lib/ollama/rocm/rocblas/library/*gfx90[06]*
|
RUN rm -f dist/lib/ollama/rocm/rocblas/library/*gfx90[06]*
|
||||||
|
|
||||||
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
|
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
|
||||||
ARG CMAKEVERSION
|
ARG CMAKEVERSION
|
||||||
RUN apt-get update && apt-get install -y curl ccache \
|
ARG NINJAVERSION
|
||||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
RUN apt-get update && apt-get install -y curl ccache unzip \
|
||||||
|
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 \
|
||||||
|
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux-aarch64.zip \
|
||||||
|
&& unzip /tmp/ninja.zip -d /usr/local/bin \
|
||||||
|
&& rm /tmp/ninja.zip
|
||||||
|
ENV CMAKE_GENERATOR=Ninja
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
ARG PARALLEL
|
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'JetPack 5' \
|
cmake --preset 'JetPack 5' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 5' \
|
&& cmake --build --preset 'JetPack 5' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CUDA --strip
|
||||||
|
|
||||||
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
|
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
|
||||||
ARG CMAKEVERSION
|
ARG CMAKEVERSION
|
||||||
RUN apt-get update && apt-get install -y curl ccache \
|
ARG NINJAVERSION
|
||||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
RUN apt-get update && apt-get install -y curl ccache unzip \
|
||||||
|
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 \
|
||||||
|
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux-aarch64.zip \
|
||||||
|
&& unzip /tmp/ninja.zip -d /usr/local/bin \
|
||||||
|
&& rm /tmp/ninja.zip
|
||||||
|
ENV CMAKE_GENERATOR=Ninja
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
ARG PARALLEL
|
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'JetPack 6' \
|
cmake --preset 'JetPack 6' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \
|
&& cmake --build --preset 'JetPack 6' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CUDA --strip
|
||||||
|
|
||||||
FROM base AS vulkan
|
FROM base AS vulkan
|
||||||
|
ARG VULKANVERSION
|
||||||
|
RUN ln -s /usr/bin/python3 /usr/bin/python \
|
||||||
|
&& wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk.tar.xz \
|
||||||
|
&& tar xvf /tmp/vulkansdk.tar.xz -C /tmp \
|
||||||
|
&& /tmp/${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \
|
||||||
|
&& /tmp/${VULKANVERSION}/vulkansdk -j 8 shaderc \
|
||||||
|
&& cp -r /tmp/${VULKANVERSION}/x86_64/include/* /usr/local/include/ \
|
||||||
|
&& cp -r /tmp/${VULKANVERSION}/x86_64/lib/* /usr/local/lib \
|
||||||
|
&& cp -r /tmp/${VULKANVERSION}/x86_64/bin/* /usr/local/bin/ \
|
||||||
|
&& rm -rf /tmp/${VULKANVERSION} /tmp/vulkansdk.tar.xz
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'Vulkan' \
|
cmake --preset 'Vulkan' \
|
||||||
&& cmake --build --parallel --preset 'Vulkan' \
|
&& cmake --build --preset 'Vulkan' -- -l $(nproc) \
|
||||||
&& cmake --install build --component Vulkan --strip --parallel 8
|
&& cmake --install build --component Vulkan --strip
|
||||||
|
|
||||||
FROM base AS mlx
|
FROM base AS mlx
|
||||||
ARG CUDA13VERSION=13.0
|
ARG CUDA13VERSION=13.0
|
||||||
@@ -138,20 +152,27 @@ ENV PATH=/usr/local/cuda-13/bin:$PATH
|
|||||||
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
|
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
|
||||||
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
|
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
|
||||||
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
|
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
|
||||||
ARG PARALLEL
|
|
||||||
WORKDIR /go/src/github.com/ollama/ollama
|
WORKDIR /go/src/github.com/ollama/ollama
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
COPY x/imagegen/mlx x/imagegen/mlx
|
COPY x/imagegen/mlx x/imagegen/mlx
|
||||||
COPY go.mod go.sum .
|
COPY go.mod go.sum .
|
||||||
COPY MLX_VERSION .
|
COPY MLX_VERSION MLX_CORE_VERSION .
|
||||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||||
ENV PATH=/usr/local/go/bin:$PATH
|
ENV PATH=/usr/local/go/bin:$PATH
|
||||||
RUN go mod download
|
RUN go mod download
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
--mount=type=bind,from=local-mlx,target=/tmp/local-mlx \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
|
--mount=type=bind,from=local-mlx-c,target=/tmp/local-mlx-c \
|
||||||
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
|
if [ -f /tmp/local-mlx/CMakeLists.txt ]; then \
|
||||||
|
export OLLAMA_MLX_SOURCE=/tmp/local-mlx; \
|
||||||
|
fi \
|
||||||
|
&& if [ -f /tmp/local-mlx-c/CMakeLists.txt ]; then \
|
||||||
|
export OLLAMA_MLX_C_SOURCE=/tmp/local-mlx-c; \
|
||||||
|
fi \
|
||||||
|
&& cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
||||||
|
&& cmake --build --preset 'MLX CUDA 13' -- -l $(nproc) \
|
||||||
|
&& cmake --install build --component MLX --strip
|
||||||
|
|
||||||
FROM base AS build
|
FROM base AS build
|
||||||
WORKDIR /go/src/github.com/ollama/ollama
|
WORKDIR /go/src/github.com/ollama/ollama
|
||||||
@@ -160,16 +181,14 @@ RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-
|
|||||||
ENV PATH=/usr/local/go/bin:$PATH
|
ENV PATH=/usr/local/go/bin:$PATH
|
||||||
RUN go mod download
|
RUN go mod download
|
||||||
COPY . .
|
COPY . .
|
||||||
# Clone mlx-c headers for CGO (version from MLX_VERSION file)
|
|
||||||
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
|
|
||||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||||
ENV CGO_ENABLED=1
|
ENV CGO_ENABLED=1
|
||||||
ARG CGO_CFLAGS
|
ARG CGO_CFLAGS
|
||||||
ARG CGO_CXXFLAGS
|
ARG CGO_CXXFLAGS
|
||||||
ENV CGO_CFLAGS="${CGO_CFLAGS} -I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
|
ENV CGO_CFLAGS="${CGO_CFLAGS}"
|
||||||
ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}"
|
ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}"
|
||||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||||
go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .
|
go build -trimpath -buildmode=pie -o /bin/ollama .
|
||||||
|
|
||||||
FROM --platform=linux/amd64 scratch AS amd64
|
FROM --platform=linux/amd64 scratch AS amd64
|
||||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||||
@@ -186,10 +205,9 @@ COPY --from=jetpack-5 dist/lib/ollama/ /lib/ollama/
|
|||||||
COPY --from=jetpack-6 dist/lib/ollama/ /lib/ollama/
|
COPY --from=jetpack-6 dist/lib/ollama/ /lib/ollama/
|
||||||
|
|
||||||
FROM scratch AS rocm
|
FROM scratch AS rocm
|
||||||
COPY --from=rocm-6 dist/lib/ollama /lib/ollama
|
COPY --from=rocm-7 dist/lib/ollama /lib/ollama
|
||||||
|
|
||||||
FROM ${FLAVOR} AS archive
|
FROM ${FLAVOR} AS archive
|
||||||
ARG VULKANVERSION
|
|
||||||
COPY --from=cpu dist/lib/ollama /lib/ollama
|
COPY --from=cpu dist/lib/ollama /lib/ollama
|
||||||
COPY --from=build /bin/ollama /bin/ollama
|
COPY --from=build /bin/ollama /bin/ollama
|
||||||
|
|
||||||
|
|||||||
1
MLX_CORE_VERSION
Normal file
1
MLX_CORE_VERSION
Normal file
@@ -0,0 +1 @@
|
|||||||
|
v0.30.6
|
||||||
@@ -1063,7 +1063,7 @@ func DefaultOptions() Options {
|
|||||||
TopP: 0.9,
|
TopP: 0.9,
|
||||||
TypicalP: 1.0,
|
TypicalP: 1.0,
|
||||||
RepeatLastN: 64,
|
RepeatLastN: 64,
|
||||||
RepeatPenalty: 1.1,
|
RepeatPenalty: 1.0,
|
||||||
PresencePenalty: 0.0,
|
PresencePenalty: 0.0,
|
||||||
FrequencyPenalty: 0.0,
|
FrequencyPenalty: 0.0,
|
||||||
Seed: -1,
|
Seed: -1,
|
||||||
|
|||||||
@@ -214,6 +214,7 @@ export default function Settings() {
|
|||||||
Agent: false,
|
Agent: false,
|
||||||
Tools: false,
|
Tools: false,
|
||||||
ContextLength: 0,
|
ContextLength: 0,
|
||||||
|
AutoUpdateEnabled: true,
|
||||||
});
|
});
|
||||||
updateSettingsMutation.mutate(defaultSettings);
|
updateSettingsMutation.mutate(defaultSettings);
|
||||||
}
|
}
|
||||||
|
|||||||
98
cmd/cmd.go
98
cmd/cmd.go
@@ -41,6 +41,7 @@ import (
|
|||||||
"github.com/ollama/ollama/cmd/tui"
|
"github.com/ollama/ollama/cmd/tui"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
|
"github.com/ollama/ollama/internal/modelref"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/progress"
|
"github.com/ollama/ollama/progress"
|
||||||
"github.com/ollama/ollama/readline"
|
"github.com/ollama/ollama/readline"
|
||||||
@@ -131,6 +132,17 @@ func getModelfileName(cmd *cobra.Command) (string, error) {
|
|||||||
return absName, nil
|
return absName, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isLocalhost returns true if the configured Ollama host is a loopback or unspecified address.
|
||||||
|
func isLocalhost() bool {
|
||||||
|
host := envconfig.Host()
|
||||||
|
h, _, _ := net.SplitHostPort(host.Host)
|
||||||
|
if h == "localhost" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
ip := net.ParseIP(h)
|
||||||
|
return ip != nil && (ip.IsLoopback() || ip.IsUnspecified())
|
||||||
|
}
|
||||||
|
|
||||||
func CreateHandler(cmd *cobra.Command, args []string) error {
|
func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||||
p := progress.NewProgress(os.Stderr)
|
p := progress.NewProgress(os.Stderr)
|
||||||
defer p.Stop()
|
defer p.Stop()
|
||||||
@@ -145,6 +157,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
// Check for --experimental flag for safetensors model creation
|
// Check for --experimental flag for safetensors model creation
|
||||||
experimental, _ := cmd.Flags().GetBool("experimental")
|
experimental, _ := cmd.Flags().GetBool("experimental")
|
||||||
if experimental {
|
if experimental {
|
||||||
|
if !isLocalhost() {
|
||||||
|
return errors.New("remote safetensor model creation not yet supported")
|
||||||
|
}
|
||||||
// Get Modelfile content - either from -f flag or default to "FROM ."
|
// Get Modelfile content - either from -f flag or default to "FROM ."
|
||||||
var reader io.Reader
|
var reader io.Reader
|
||||||
filename, err := getModelfileName(cmd)
|
filename, err := getModelfileName(cmd)
|
||||||
@@ -168,29 +183,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return fmt.Errorf("failed to parse Modelfile: %w", err)
|
return fmt.Errorf("failed to parse Modelfile: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract FROM path and configuration
|
modelDir, mfConfig, err := xcreateclient.ConfigFromModelfile(modelfile)
|
||||||
var modelDir string
|
if err != nil {
|
||||||
mfConfig := &xcreateclient.ModelfileConfig{}
|
return err
|
||||||
|
|
||||||
for _, cmd := range modelfile.Commands {
|
|
||||||
switch cmd.Name {
|
|
||||||
case "model":
|
|
||||||
modelDir = cmd.Args
|
|
||||||
case "template":
|
|
||||||
mfConfig.Template = cmd.Args
|
|
||||||
case "system":
|
|
||||||
mfConfig.System = cmd.Args
|
|
||||||
case "license":
|
|
||||||
mfConfig.License = cmd.Args
|
|
||||||
case "parser":
|
|
||||||
mfConfig.Parser = cmd.Args
|
|
||||||
case "renderer":
|
|
||||||
mfConfig.Renderer = cmd.Args
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if modelDir == "" {
|
|
||||||
modelDir = "."
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve relative paths based on Modelfile location
|
// Resolve relative paths based on Modelfile location
|
||||||
@@ -214,6 +209,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
if filename == "" {
|
if filename == "" {
|
||||||
// No Modelfile found - check if current directory is an image gen model
|
// No Modelfile found - check if current directory is an image gen model
|
||||||
if create.IsTensorModelDir(".") {
|
if create.IsTensorModelDir(".") {
|
||||||
|
if !isLocalhost() {
|
||||||
|
return errors.New("remote safetensor model creation not yet supported")
|
||||||
|
}
|
||||||
quantize, _ := cmd.Flags().GetString("quantize")
|
quantize, _ := cmd.Flags().GetString("quantize")
|
||||||
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
||||||
ModelName: modelName,
|
ModelName: modelName,
|
||||||
@@ -406,12 +404,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
requestedCloud := modelref.HasExplicitCloudSource(opts.Model)
|
||||||
|
|
||||||
if info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model}); err != nil {
|
if info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model}); err != nil {
|
||||||
return err
|
return err
|
||||||
} else if info.RemoteHost != "" {
|
} else if info.RemoteHost != "" || requestedCloud {
|
||||||
// Cloud model, no need to load/unload
|
// Cloud model, no need to load/unload
|
||||||
|
|
||||||
isCloud := strings.HasPrefix(info.RemoteHost, "https://ollama.com")
|
isCloud := requestedCloud || strings.HasPrefix(info.RemoteHost, "https://ollama.com")
|
||||||
|
|
||||||
// Check if user is signed in for ollama.com cloud models
|
// Check if user is signed in for ollama.com cloud models
|
||||||
if isCloud {
|
if isCloud {
|
||||||
@@ -422,10 +422,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
|||||||
|
|
||||||
if opts.ShowConnect {
|
if opts.ShowConnect {
|
||||||
p.StopAndClear()
|
p.StopAndClear()
|
||||||
|
remoteModel := info.RemoteModel
|
||||||
|
if remoteModel == "" {
|
||||||
|
remoteModel = opts.Model
|
||||||
|
}
|
||||||
if isCloud {
|
if isCloud {
|
||||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
|
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", remoteModel)
|
||||||
} else {
|
} else {
|
||||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
|
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", remoteModel, info.RemoteHost)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -497,6 +501,20 @@ func generateEmbedding(cmd *cobra.Command, modelName, input string, keepAlive *a
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(parthsareen): consolidate with TUI signin flow
|
||||||
|
func handleCloudAuthorizationError(err error) bool {
|
||||||
|
var authErr api.AuthorizationError
|
||||||
|
if errors.As(err, &authErr) && authErr.StatusCode == http.StatusUnauthorized {
|
||||||
|
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
|
||||||
|
if authErr.SigninURL != "" {
|
||||||
|
fmt.Printf(ConnectInstructions, authErr.SigninURL)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||||
interactive := true
|
interactive := true
|
||||||
|
|
||||||
@@ -585,17 +603,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
opts.WordWrap = !nowrap
|
opts.WordWrap = !nowrap
|
||||||
|
|
||||||
useImagegen := false
|
|
||||||
if cmd.Flags().Lookup("imagegen") != nil {
|
|
||||||
useImagegen, err = cmd.Flags().GetBool("imagegen")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if useImagegen {
|
|
||||||
opts.Options["use_imagegen_runner"] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fill out the rest of the options based on information about the
|
// Fill out the rest of the options based on information about the
|
||||||
// model.
|
// model.
|
||||||
client, err := api.ClientFromEnvironment()
|
client, err := api.ClientFromEnvironment()
|
||||||
@@ -604,12 +611,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
name := args[0]
|
name := args[0]
|
||||||
|
requestedCloud := modelref.HasExplicitCloudSource(name)
|
||||||
|
|
||||||
info, err := func() (*api.ShowResponse, error) {
|
info, err := func() (*api.ShowResponse, error) {
|
||||||
showReq := &api.ShowRequest{Name: name}
|
showReq := &api.ShowRequest{Name: name}
|
||||||
info, err := client.Show(cmd.Context(), showReq)
|
info, err := client.Show(cmd.Context(), showReq)
|
||||||
var se api.StatusError
|
var se api.StatusError
|
||||||
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
|
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
|
||||||
|
if requestedCloud {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if err := PullHandler(cmd, []string{name}); err != nil {
|
if err := PullHandler(cmd, []string{name}); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -618,6 +629,9 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return info, err
|
return info, err
|
||||||
}()
|
}()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if handleCloudAuthorizationError(err) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -712,7 +726,13 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
return generateInteractive(cmd, opts)
|
return generateInteractive(cmd, opts)
|
||||||
}
|
}
|
||||||
return generate(cmd, opts)
|
if err := generate(cmd, opts); err != nil {
|
||||||
|
if handleCloudAuthorizationError(err) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SigninHandler(cmd *cobra.Command, args []string) error {
|
func SigninHandler(cmd *cobra.Command, args []string) error {
|
||||||
|
|||||||
209
cmd/cmd_test.go
209
cmd/cmd_test.go
@@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/internal/modelref"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -705,6 +706,139 @@ func TestRunEmbeddingModelNoInput(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRunHandler_CloudAuthErrorOnShow_PrintsSigninMessage(t *testing.T) {
|
||||||
|
var generateCalled bool
|
||||||
|
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
if err := json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"error": "unauthorized",
|
||||||
|
"signin_url": "https://ollama.com/signin",
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||||
|
generateCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
t.Cleanup(mockServer.Close)
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(t.Context())
|
||||||
|
cmd.Flags().String("keepalive", "", "")
|
||||||
|
cmd.Flags().Bool("truncate", false, "")
|
||||||
|
cmd.Flags().Int("dimensions", 0, "")
|
||||||
|
cmd.Flags().Bool("verbose", false, "")
|
||||||
|
cmd.Flags().Bool("insecure", false, "")
|
||||||
|
cmd.Flags().Bool("nowordwrap", false, "")
|
||||||
|
cmd.Flags().String("format", "", "")
|
||||||
|
cmd.Flags().String("think", "", "")
|
||||||
|
cmd.Flags().Bool("hidethinking", false, "")
|
||||||
|
|
||||||
|
oldStdout := os.Stdout
|
||||||
|
readOut, writeOut, _ := os.Pipe()
|
||||||
|
os.Stdout = writeOut
|
||||||
|
t.Cleanup(func() { os.Stdout = oldStdout })
|
||||||
|
|
||||||
|
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||||
|
|
||||||
|
_ = writeOut.Close()
|
||||||
|
var out bytes.Buffer
|
||||||
|
_, _ = io.Copy(&out, readOut)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RunHandler returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if generateCalled {
|
||||||
|
t.Fatal("expected run to stop before /api/generate after unauthorized /api/show")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
|
||||||
|
t.Fatalf("expected sign-in guidance message, got %q", out.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(out.String(), "https://ollama.com/signin") {
|
||||||
|
t.Fatalf("expected signin_url in output, got %q", out.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunHandler_CloudAuthErrorOnGenerate_PrintsSigninMessage(t *testing.T) {
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||||
|
Capabilities: []model.Capability{model.CapabilityCompletion},
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
if err := json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"error": "unauthorized",
|
||||||
|
"signin_url": "https://ollama.com/signin",
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
t.Cleanup(mockServer.Close)
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(t.Context())
|
||||||
|
cmd.Flags().String("keepalive", "", "")
|
||||||
|
cmd.Flags().Bool("truncate", false, "")
|
||||||
|
cmd.Flags().Int("dimensions", 0, "")
|
||||||
|
cmd.Flags().Bool("verbose", false, "")
|
||||||
|
cmd.Flags().Bool("insecure", false, "")
|
||||||
|
cmd.Flags().Bool("nowordwrap", false, "")
|
||||||
|
cmd.Flags().String("format", "", "")
|
||||||
|
cmd.Flags().String("think", "", "")
|
||||||
|
cmd.Flags().Bool("hidethinking", false, "")
|
||||||
|
|
||||||
|
oldStdout := os.Stdout
|
||||||
|
readOut, writeOut, _ := os.Pipe()
|
||||||
|
os.Stdout = writeOut
|
||||||
|
t.Cleanup(func() { os.Stdout = oldStdout })
|
||||||
|
|
||||||
|
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||||
|
|
||||||
|
_ = writeOut.Close()
|
||||||
|
var out bytes.Buffer
|
||||||
|
_, _ = io.Copy(&out, readOut)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RunHandler returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
|
||||||
|
t.Fatalf("expected sign-in guidance message, got %q", out.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(out.String(), "https://ollama.com/signin") {
|
||||||
|
t.Fatalf("expected signin_url in output, got %q", out.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetModelfileName(t *testing.T) {
|
func TestGetModelfileName(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -1664,20 +1798,26 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
|
|||||||
func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
model string
|
||||||
remoteHost string
|
remoteHost string
|
||||||
|
remoteModel string
|
||||||
whoamiStatus int
|
whoamiStatus int
|
||||||
whoamiResp any
|
whoamiResp any
|
||||||
expectedError string
|
expectedError string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "ollama.com cloud model - user signed in",
|
name: "ollama.com cloud model - user signed in",
|
||||||
|
model: "test-cloud-model",
|
||||||
remoteHost: "https://ollama.com",
|
remoteHost: "https://ollama.com",
|
||||||
|
remoteModel: "test-model",
|
||||||
whoamiStatus: http.StatusOK,
|
whoamiStatus: http.StatusOK,
|
||||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "ollama.com cloud model - user not signed in",
|
name: "ollama.com cloud model - user not signed in",
|
||||||
|
model: "test-cloud-model",
|
||||||
remoteHost: "https://ollama.com",
|
remoteHost: "https://ollama.com",
|
||||||
|
remoteModel: "test-model",
|
||||||
whoamiStatus: http.StatusUnauthorized,
|
whoamiStatus: http.StatusUnauthorized,
|
||||||
whoamiResp: map[string]string{
|
whoamiResp: map[string]string{
|
||||||
"error": "unauthorized",
|
"error": "unauthorized",
|
||||||
@@ -1687,7 +1827,33 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "non-ollama.com remote - no auth check",
|
name: "non-ollama.com remote - no auth check",
|
||||||
|
model: "test-cloud-model",
|
||||||
remoteHost: "https://other-remote.com",
|
remoteHost: "https://other-remote.com",
|
||||||
|
remoteModel: "test-model",
|
||||||
|
whoamiStatus: http.StatusUnauthorized, // should not be called
|
||||||
|
whoamiResp: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit :cloud model - auth check without remote metadata",
|
||||||
|
model: "kimi-k2.5:cloud",
|
||||||
|
remoteHost: "",
|
||||||
|
remoteModel: "",
|
||||||
|
whoamiStatus: http.StatusOK,
|
||||||
|
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit -cloud model - auth check without remote metadata",
|
||||||
|
model: "kimi-k2.5:latest-cloud",
|
||||||
|
remoteHost: "",
|
||||||
|
remoteModel: "",
|
||||||
|
whoamiStatus: http.StatusOK,
|
||||||
|
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dash cloud-like name without explicit source does not require auth",
|
||||||
|
model: "test-cloud-model",
|
||||||
|
remoteHost: "",
|
||||||
|
remoteModel: "",
|
||||||
whoamiStatus: http.StatusUnauthorized, // should not be called
|
whoamiStatus: http.StatusUnauthorized, // should not be called
|
||||||
whoamiResp: nil,
|
whoamiResp: nil,
|
||||||
},
|
},
|
||||||
@@ -1702,7 +1868,7 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
|||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||||
RemoteHost: tt.remoteHost,
|
RemoteHost: tt.remoteHost,
|
||||||
RemoteModel: "test-model",
|
RemoteModel: tt.remoteModel,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
@@ -1715,6 +1881,8 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
|||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case "/api/generate":
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
default:
|
default:
|
||||||
http.NotFound(w, r)
|
http.NotFound(w, r)
|
||||||
}
|
}
|
||||||
@@ -1727,13 +1895,13 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
|||||||
cmd.SetContext(t.Context())
|
cmd.SetContext(t.Context())
|
||||||
|
|
||||||
opts := &runOptions{
|
opts := &runOptions{
|
||||||
Model: "test-cloud-model",
|
Model: tt.model,
|
||||||
ShowConnect: false,
|
ShowConnect: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := loadOrUnloadModel(cmd, opts)
|
err := loadOrUnloadModel(cmd, opts)
|
||||||
|
|
||||||
if strings.HasPrefix(tt.remoteHost, "https://ollama.com") {
|
if strings.HasPrefix(tt.remoteHost, "https://ollama.com") || modelref.HasExplicitCloudSource(tt.model) {
|
||||||
if !whoamiCalled {
|
if !whoamiCalled {
|
||||||
t.Error("expected whoami to be called for ollama.com cloud model")
|
t.Error("expected whoami to be called for ollama.com cloud model")
|
||||||
}
|
}
|
||||||
@@ -1760,3 +1928,38 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsLocalhost(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
host string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"default empty", "", true},
|
||||||
|
{"localhost no port", "localhost", true},
|
||||||
|
{"localhost with port", "localhost:11435", true},
|
||||||
|
{"127.0.0.1 no port", "127.0.0.1", true},
|
||||||
|
{"127.0.0.1 with port", "127.0.0.1:11434", true},
|
||||||
|
{"0.0.0.0 no port", "0.0.0.0", true},
|
||||||
|
{"0.0.0.0 with port", "0.0.0.0:11434", true},
|
||||||
|
{"::1 no port", "::1", true},
|
||||||
|
{"[::1] with port", "[::1]:11434", true},
|
||||||
|
{"loopback with scheme", "http://localhost:11434", true},
|
||||||
|
{"remote hostname", "example.com", false},
|
||||||
|
{"remote hostname with port", "example.com:11434", false},
|
||||||
|
{"remote IP", "192.168.1.1", false},
|
||||||
|
{"remote IP with port", "192.168.1.1:11434", false},
|
||||||
|
{"remote with scheme", "http://example.com:11434", false},
|
||||||
|
{"https remote", "https://example.com:443", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_HOST", tt.host)
|
||||||
|
got := isLocalhost()
|
||||||
|
if got != tt.expected {
|
||||||
|
t.Errorf("isLocalhost() with OLLAMA_HOST=%q = %v, want %v", tt.host, got, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -107,16 +107,13 @@ func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAli
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !force && aliases["primary"] != "" {
|
if !force && aliases["primary"] != "" {
|
||||||
client, _ := api.ClientFromEnvironment()
|
if isCloudModelName(aliases["primary"]) {
|
||||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
aliases["fast"] = aliases["primary"]
|
||||||
if isCloudModel(ctx, client, aliases["fast"]) {
|
|
||||||
return aliases, false, nil
|
return aliases, false, nil
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
delete(aliases, "fast")
|
delete(aliases, "fast")
|
||||||
return aliases, false, nil
|
return aliases, false, nil
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
items, existingModels, cloudModels, client, err := listModels(ctx)
|
items, existingModels, cloudModels, client, err := listModels(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -139,10 +136,8 @@ func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAli
|
|||||||
aliases["primary"] = primary
|
aliases["primary"] = primary
|
||||||
}
|
}
|
||||||
|
|
||||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
if isCloudModelName(aliases["primary"]) {
|
||||||
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
|
|
||||||
aliases["fast"] = aliases["primary"]
|
aliases["fast"] = aliases["primary"]
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
delete(aliases, "fast")
|
delete(aliases, "fast")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -233,6 +233,9 @@ func ModelExists(ctx context.Context, name string) bool {
|
|||||||
if name == "" {
|
if name == "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
if isCloudModelName(name) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
client, err := api.ClientFromEnvironment()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -125,13 +124,12 @@ func (d *Droid) Edit(models []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
|
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
|
||||||
client, _ := api.ClientFromEnvironment()
|
|
||||||
|
|
||||||
var newModels []any
|
var newModels []any
|
||||||
var defaultModelID string
|
var defaultModelID string
|
||||||
for i, model := range models {
|
for i, model := range models {
|
||||||
maxOutput := 64000
|
maxOutput := 64000
|
||||||
if isCloudModel(context.Background(), client, model) {
|
if isCloudModelName(model) {
|
||||||
if l, ok := lookupCloudModelLimit(model); ok {
|
if l, ok := lookupCloudModelLimit(model); ok {
|
||||||
maxOutput = l.Output
|
maxOutput = l.Output
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1276,25 +1276,17 @@ func TestDroidEdit_LocalModelDefaultMaxOutput(t *testing.T) {
|
|||||||
|
|
||||||
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
|
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
|
||||||
// Verify that every cloud model in cloudModelLimits has a valid output
|
// Verify that every cloud model in cloudModelLimits has a valid output
|
||||||
// value that would be used for maxOutputTokens when isCloudModel returns true.
|
// value that would be used for maxOutputTokens when the selected model uses
|
||||||
// :cloud suffix stripping must also work since that's how users specify them.
|
// the explicit :cloud source tag.
|
||||||
for name, expected := range cloudModelLimits {
|
for name, expected := range cloudModelLimits {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
l, ok := lookupCloudModelLimit(name)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("lookupCloudModelLimit(%q) returned false", name)
|
|
||||||
}
|
|
||||||
if l.Output != expected.Output {
|
|
||||||
t.Errorf("output = %d, want %d", l.Output, expected.Output)
|
|
||||||
}
|
|
||||||
// Also verify :cloud suffix lookup
|
|
||||||
cloudName := name + ":cloud"
|
cloudName := name + ":cloud"
|
||||||
l2, ok := lookupCloudModelLimit(cloudName)
|
l, ok := lookupCloudModelLimit(cloudName)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName)
|
t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName)
|
||||||
}
|
}
|
||||||
if l2.Output != expected.Output {
|
if l.Output != expected.Output {
|
||||||
t.Errorf(":cloud output = %d, want %d", l2.Output, expected.Output)
|
t.Errorf("output = %d, want %d", l.Output, expected.Output)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||||
|
"github.com/ollama/ollama/internal/modelref"
|
||||||
"github.com/ollama/ollama/progress"
|
"github.com/ollama/ollama/progress"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
@@ -81,6 +82,7 @@ var cloudModelLimits = map[string]cloudModelLimit{
|
|||||||
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
|
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
|
||||||
"glm-4.6": {Context: 202_752, Output: 131_072},
|
"glm-4.6": {Context: 202_752, Output: 131_072},
|
||||||
"glm-4.7": {Context: 202_752, Output: 131_072},
|
"glm-4.7": {Context: 202_752, Output: 131_072},
|
||||||
|
"glm-5": {Context: 202_752, Output: 131_072},
|
||||||
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
|
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
|
||||||
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
|
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
|
||||||
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
|
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
|
||||||
@@ -90,6 +92,7 @@ var cloudModelLimits = map[string]cloudModelLimit{
|
|||||||
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
|
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
|
||||||
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
|
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
|
||||||
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
|
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
|
||||||
|
"qwen3.5": {Context: 262_144, Output: 32_768},
|
||||||
}
|
}
|
||||||
|
|
||||||
// recommendedVRAM maps local recommended models to their approximate VRAM requirement.
|
// recommendedVRAM maps local recommended models to their approximate VRAM requirement.
|
||||||
@@ -324,12 +327,7 @@ func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (stri
|
|||||||
|
|
||||||
// If the selected model isn't installed, pull it first
|
// If the selected model isn't installed, pull it first
|
||||||
if !existingModels[selected] {
|
if !existingModels[selected] {
|
||||||
if cloudModels[selected] {
|
if !isCloudModelName(selected) {
|
||||||
// Cloud models only pull a small manifest; no confirmation needed
|
|
||||||
if err := pullModel(ctx, client, selected); err != nil {
|
|
||||||
return "", fmt.Errorf("failed to pull %s: %w", selected, err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
msg := fmt.Sprintf("Download %s?", selected)
|
msg := fmt.Sprintf("Download %s?", selected)
|
||||||
if ok, err := confirmPrompt(msg); err != nil {
|
if ok, err := confirmPrompt(msg); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -524,7 +522,7 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
|
|||||||
|
|
||||||
var toPull []string
|
var toPull []string
|
||||||
for _, m := range selected {
|
for _, m := range selected {
|
||||||
if !existingModels[m] {
|
if !existingModels[m] && !isCloudModelName(m) {
|
||||||
toPull = append(toPull, m)
|
toPull = append(toPull, m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -550,12 +548,28 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
|
|||||||
return selected, nil
|
return selected, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(parthsareen): consolidate pull logic from call sites
|
||||||
func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error {
|
func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error {
|
||||||
if existingModels[model] {
|
if isCloudModelName(model) || existingModels[model] {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
msg := fmt.Sprintf("Download %s?", model)
|
return confirmAndPull(ctx, client, model)
|
||||||
if ok, err := confirmPrompt(msg); err != nil {
|
}
|
||||||
|
|
||||||
|
// TODO(parthsareen): pull this out to tui package
|
||||||
|
// ShowOrPull checks if a model exists via client.Show and offers to pull it if not found.
|
||||||
|
func ShowOrPull(ctx context.Context, client *api.Client, model string) error {
|
||||||
|
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if isCloudModelName(model) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return confirmAndPull(ctx, client, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func confirmAndPull(ctx context.Context, client *api.Client, model string) error {
|
||||||
|
if ok, err := confirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
|
||||||
return err
|
return err
|
||||||
} else if !ok {
|
} else if !ok {
|
||||||
return errCancelled
|
return errCancelled
|
||||||
@@ -567,26 +581,6 @@ func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[st
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(parthsareen): pull this out to tui package
|
|
||||||
// ShowOrPull checks if a model exists via client.Show and offers to pull it if not found.
|
|
||||||
func ShowOrPull(ctx context.Context, client *api.Client, model string) error {
|
|
||||||
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
// Cloud models only pull a small manifest; skip the download confirmation
|
|
||||||
// TODO(parthsareen): consolidate with cloud config changes
|
|
||||||
if strings.HasSuffix(model, "cloud") {
|
|
||||||
return pullModel(ctx, client, model)
|
|
||||||
}
|
|
||||||
if ok, err := confirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
|
|
||||||
return err
|
|
||||||
} else if !ok {
|
|
||||||
return errCancelled
|
|
||||||
}
|
|
||||||
fmt.Fprintf(os.Stderr, "\n")
|
|
||||||
return pullModel(ctx, client, model)
|
|
||||||
}
|
|
||||||
|
|
||||||
func listModels(ctx context.Context) ([]ModelItem, map[string]bool, map[string]bool, *api.Client, error) {
|
func listModels(ctx context.Context) ([]ModelItem, map[string]bool, map[string]bool, *api.Client, error) {
|
||||||
client, err := api.ClientFromEnvironment()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -731,10 +725,8 @@ func syncAliases(ctx context.Context, client *api.Client, ac AliasConfigurer, na
|
|||||||
}
|
}
|
||||||
aliases["primary"] = model
|
aliases["primary"] = model
|
||||||
|
|
||||||
if isCloudModel(ctx, client, model) {
|
if isCloudModelName(model) {
|
||||||
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
|
|
||||||
aliases["fast"] = model
|
aliases["fast"] = model
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
delete(aliases, "fast")
|
delete(aliases, "fast")
|
||||||
}
|
}
|
||||||
@@ -1020,7 +1012,7 @@ Examples:
|
|||||||
existingAliases = aliases
|
existingAliases = aliases
|
||||||
|
|
||||||
// Ensure cloud models are authenticated
|
// Ensure cloud models are authenticated
|
||||||
if isCloudModel(cmd.Context(), client, model) {
|
if isCloudModelName(model) {
|
||||||
if err := ensureAuth(cmd.Context(), client, map[string]bool{model: true}, []string{model}); err != nil {
|
if err := ensureAuth(cmd.Context(), client, map[string]bool{model: true}, []string{model}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -1209,7 +1201,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
|||||||
// When user has no models, preserve recommended order.
|
// When user has no models, preserve recommended order.
|
||||||
notInstalled := make(map[string]bool)
|
notInstalled := make(map[string]bool)
|
||||||
for i := range items {
|
for i := range items {
|
||||||
if !existingModels[items[i].Name] {
|
if !existingModels[items[i].Name] && !cloudModels[items[i].Name] {
|
||||||
notInstalled[items[i].Name] = true
|
notInstalled[items[i].Name] = true
|
||||||
var parts []string
|
var parts []string
|
||||||
if items[i].Description != "" {
|
if items[i].Description != "" {
|
||||||
@@ -1303,7 +1295,8 @@ func IsCloudModelDisabled(ctx context.Context, name string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func isCloudModelName(name string) bool {
|
func isCloudModelName(name string) bool {
|
||||||
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
|
// TODO(drifkin): Replace this wrapper with inlining once things stabilize a bit
|
||||||
|
return modelref.HasExplicitCloudSource(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func filterCloudModels(existing []modelInfo) []modelInfo {
|
func filterCloudModels(existing []modelInfo) []modelInfo {
|
||||||
|
|||||||
@@ -426,11 +426,17 @@ func TestBuildModelList_NoExistingModels(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
|
if strings.HasSuffix(item.Name, ":cloud") {
|
||||||
|
if strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||||
|
t.Errorf("cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
if !strings.HasSuffix(item.Description, "(not downloaded)") {
|
if !strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||||
t.Errorf("item %q should have description ending with '(not downloaded)', got %q", item.Name, item.Description)
|
t.Errorf("item %q should have description ending with '(not downloaded)', got %q", item.Name, item.Description)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestBuildModelList_OnlyLocalModels_CloudRecsAtBottom(t *testing.T) {
|
func TestBuildModelList_OnlyLocalModels_CloudRecsAtBottom(t *testing.T) {
|
||||||
existing := []modelInfo{
|
existing := []modelInfo{
|
||||||
@@ -492,10 +498,14 @@ func TestBuildModelList_ExistingRecommendedMarked(t *testing.T) {
|
|||||||
if strings.HasSuffix(item.Description, "(not downloaded)") {
|
if strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||||
t.Errorf("installed recommended %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
t.Errorf("installed recommended %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||||
}
|
}
|
||||||
case "minimax-m2.5:cloud", "kimi-k2.5:cloud", "qwen3:8b":
|
case "qwen3:8b":
|
||||||
if !strings.HasSuffix(item.Description, "(not downloaded)") {
|
if !strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||||
t.Errorf("non-installed recommended %q should have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
t.Errorf("non-installed recommended %q should have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||||
}
|
}
|
||||||
|
case "minimax-m2.5:cloud", "kimi-k2.5:cloud":
|
||||||
|
if strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||||
|
t.Errorf("cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -536,7 +546,13 @@ func TestBuildModelList_HasRecommendedCloudModel_OnlyNonInstalledAtBottom(t *tes
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
if !slices.Contains([]string{"kimi-k2.5:cloud", "llama3.2"}, item.Name) {
|
isCloud := strings.HasSuffix(item.Name, ":cloud")
|
||||||
|
isInstalled := slices.Contains([]string{"kimi-k2.5:cloud", "llama3.2"}, item.Name)
|
||||||
|
if isInstalled || isCloud {
|
||||||
|
if strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||||
|
t.Errorf("installed or cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
if !strings.HasSuffix(item.Description, "(not downloaded)") {
|
if !strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||||
t.Errorf("non-installed %q should have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
t.Errorf("non-installed %q should have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||||
}
|
}
|
||||||
@@ -1000,8 +1016,8 @@ func TestShowOrPull_ModelNotFound_ConfirmNo_Cancelled(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShowOrPull_CloudModel_SkipsConfirmation(t *testing.T) {
|
func TestShowOrPull_CloudModel_DoesNotPull(t *testing.T) {
|
||||||
// Confirm prompt should NOT be called for cloud models
|
// Confirm prompt should NOT be called for explicit cloud models
|
||||||
oldHook := DefaultConfirmPrompt
|
oldHook := DefaultConfirmPrompt
|
||||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||||
t.Error("confirm prompt should not be called for cloud models")
|
t.Error("confirm prompt should not be called for cloud models")
|
||||||
@@ -1032,8 +1048,115 @@ func TestShowOrPull_CloudModel_SkipsConfirmation(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("ShowOrPull should succeed for cloud model, got: %v", err)
|
t.Errorf("ShowOrPull should succeed for cloud model, got: %v", err)
|
||||||
}
|
}
|
||||||
if !pullCalled {
|
if pullCalled {
|
||||||
t.Error("expected pull to be called for cloud model without confirmation")
|
t.Error("expected pull not to be called for cloud model")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShowOrPull_CloudLegacySuffix_DoesNotPull(t *testing.T) {
|
||||||
|
// Confirm prompt should NOT be called for explicit cloud models
|
||||||
|
oldHook := DefaultConfirmPrompt
|
||||||
|
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||||
|
t.Error("confirm prompt should not be called for cloud models")
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
defer func() { DefaultConfirmPrompt = oldHook }()
|
||||||
|
|
||||||
|
var pullCalled bool
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||||
|
case "/api/pull":
|
||||||
|
pullCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprintf(w, `{"status":"success"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, _ := url.Parse(srv.URL)
|
||||||
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
|
err := ShowOrPull(context.Background(), client, "gpt-oss:20b-cloud")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("ShowOrPull should succeed for cloud model, got: %v", err)
|
||||||
|
}
|
||||||
|
if pullCalled {
|
||||||
|
t.Error("expected pull not to be called for cloud model")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPullIfNeeded_CloudModel_DoesNotPull(t *testing.T) {
|
||||||
|
oldHook := DefaultConfirmPrompt
|
||||||
|
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||||
|
t.Error("confirm prompt should not be called for cloud models")
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
defer func() { DefaultConfirmPrompt = oldHook }()
|
||||||
|
|
||||||
|
err := pullIfNeeded(context.Background(), nil, map[string]bool{}, "glm-5:cloud")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error for cloud model, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pullIfNeeded(context.Background(), nil, map[string]bool{}, "gpt-oss:20b-cloud")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error for cloud model with legacy suffix, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectModelsWithSelectors_CloudSelection_DoesNotPull(t *testing.T) {
|
||||||
|
var pullCalled bool
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/status":
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprintf(w, `{"error":"not found"}`)
|
||||||
|
case "/api/tags":
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprintf(w, `{"models":[]}`)
|
||||||
|
case "/api/me":
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprintf(w, `{"name":"test-user"}`)
|
||||||
|
case "/api/pull":
|
||||||
|
pullCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprintf(w, `{"status":"success"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprintf(w, `{"error":"not found"}`)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
single := func(title string, items []ModelItem, current string) (string, error) {
|
||||||
|
for _, item := range items {
|
||||||
|
if item.Name == "glm-5:cloud" {
|
||||||
|
return item.Name, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Fatalf("expected glm-5:cloud in selector items, got %v", items)
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
multi := func(title string, items []ModelItem, preChecked []string) ([]string, error) {
|
||||||
|
return nil, fmt.Errorf("multi selector should not be called")
|
||||||
|
}
|
||||||
|
|
||||||
|
selected, err := selectModelsWithSelectors(context.Background(), "codex", "", single, multi)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("selectModelsWithSelectors returned error: %v", err)
|
||||||
|
}
|
||||||
|
if !slices.Equal(selected, []string{"glm-5:cloud"}) {
|
||||||
|
t.Fatalf("unexpected selected models: %v", selected)
|
||||||
|
}
|
||||||
|
if pullCalled {
|
||||||
|
t.Fatal("expected cloud selection to skip pull")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -502,7 +502,7 @@ func (c *Openclaw) Edit(models []string) error {
|
|||||||
ollama = make(map[string]any)
|
ollama = make(map[string]any)
|
||||||
}
|
}
|
||||||
|
|
||||||
ollama["baseUrl"] = envconfig.Host().String() + "/v1"
|
ollama["baseUrl"] = envconfig.Host().String()
|
||||||
// needed to register provider
|
// needed to register provider
|
||||||
ollama["apiKey"] = "ollama-local"
|
ollama["apiKey"] = "ollama-local"
|
||||||
ollama["api"] = "ollama"
|
ollama["api"] = "ollama"
|
||||||
|
|||||||
@@ -589,7 +589,7 @@ const testOpenclawFixture = `{
|
|||||||
"providers": {
|
"providers": {
|
||||||
"anthropic": {"apiKey": "xxx"},
|
"anthropic": {"apiKey": "xxx"},
|
||||||
"ollama": {
|
"ollama": {
|
||||||
"baseUrl": "http://127.0.0.1:11434/v1",
|
"baseUrl": "http://127.0.0.1:11434",
|
||||||
"models": [{"id": "old-model", "customField": "preserved"}]
|
"models": [{"id": "old-model", "customField": "preserved"}]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/internal/modelref"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OpenCode implements Runner and Editor for OpenCode integration
|
// OpenCode implements Runner and Editor for OpenCode integration
|
||||||
@@ -26,13 +26,10 @@ type cloudModelLimit struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// lookupCloudModelLimit returns the token limits for a cloud model.
|
// lookupCloudModelLimit returns the token limits for a cloud model.
|
||||||
// It tries the exact name first, then strips the ":cloud" suffix.
|
// It normalizes explicit cloud source suffixes before checking the shared limit map.
|
||||||
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
||||||
if l, ok := cloudModelLimits[name]; ok {
|
base, stripped := modelref.StripCloudSourceTag(name)
|
||||||
return l, true
|
if stripped {
|
||||||
}
|
|
||||||
base := strings.TrimSuffix(name, ":cloud")
|
|
||||||
if base != name {
|
|
||||||
if l, ok := cloudModelLimits[base]; ok {
|
if l, ok := cloudModelLimits[base]; ok {
|
||||||
return l, true
|
return l, true
|
||||||
}
|
}
|
||||||
@@ -122,13 +119,18 @@ func (o *OpenCode) Edit(modelList []string) error {
|
|||||||
if !ok {
|
if !ok {
|
||||||
ollama = map[string]any{
|
ollama = map[string]any{
|
||||||
"npm": "@ai-sdk/openai-compatible",
|
"npm": "@ai-sdk/openai-compatible",
|
||||||
"name": "Ollama (local)",
|
"name": "Ollama",
|
||||||
"options": map[string]any{
|
"options": map[string]any{
|
||||||
"baseURL": envconfig.Host().String() + "/v1",
|
"baseURL": envconfig.Host().String() + "/v1",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Migrate legacy provider name
|
||||||
|
if name, _ := ollama["name"].(string); name == "Ollama (local)" {
|
||||||
|
ollama["name"] = "Ollama"
|
||||||
|
}
|
||||||
|
|
||||||
models, ok := ollama["models"].(map[string]any)
|
models, ok := ollama["models"].(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
models = make(map[string]any)
|
models = make(map[string]any)
|
||||||
@@ -147,8 +149,6 @@ func (o *OpenCode) Edit(modelList []string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
client, _ := api.ClientFromEnvironment()
|
|
||||||
|
|
||||||
for _, model := range modelList {
|
for _, model := range modelList {
|
||||||
if existing, ok := models[model].(map[string]any); ok {
|
if existing, ok := models[model].(map[string]any); ok {
|
||||||
// migrate existing models without _launch marker
|
// migrate existing models without _launch marker
|
||||||
@@ -158,7 +158,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
|||||||
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if isCloudModel(context.Background(), client, model) {
|
if isCloudModelName(model) {
|
||||||
if l, ok := lookupCloudModelLimit(model); ok {
|
if l, ok := lookupCloudModelLimit(model); ok {
|
||||||
existing["limit"] = map[string]any{
|
existing["limit"] = map[string]any{
|
||||||
"context": l.Context,
|
"context": l.Context,
|
||||||
@@ -172,7 +172,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
|||||||
"name": model,
|
"name": model,
|
||||||
"_launch": true,
|
"_launch": true,
|
||||||
}
|
}
|
||||||
if isCloudModel(context.Background(), client, model) {
|
if isCloudModelName(model) {
|
||||||
if l, ok := lookupCloudModelLimit(model); ok {
|
if l, ok := lookupCloudModelLimit(model); ok {
|
||||||
entry["limit"] = map[string]any{
|
entry["limit"] = map[string]any{
|
||||||
"context": l.Context,
|
"context": l.Context,
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ package config
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -232,6 +234,44 @@ func TestOpenCodeEdit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("migrate Ollama (local) provider name", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"name":"Ollama (local)","npm":"@ai-sdk/openai-compatible","options":{"baseURL":"http://localhost:11434/v1"}}}}`), 0o644)
|
||||||
|
|
||||||
|
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := os.ReadFile(configPath)
|
||||||
|
var cfg map[string]any
|
||||||
|
json.Unmarshal(data, &cfg)
|
||||||
|
provider := cfg["provider"].(map[string]any)
|
||||||
|
ollama := provider["ollama"].(map[string]any)
|
||||||
|
if ollama["name"] != "Ollama" {
|
||||||
|
t.Errorf("provider name not migrated: got %q, want %q", ollama["name"], "Ollama")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserve custom provider name", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"name":"My Custom Ollama","npm":"@ai-sdk/openai-compatible","options":{"baseURL":"http://localhost:11434/v1"}}}}`), 0o644)
|
||||||
|
|
||||||
|
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := os.ReadFile(configPath)
|
||||||
|
var cfg map[string]any
|
||||||
|
json.Unmarshal(data, &cfg)
|
||||||
|
provider := cfg["provider"].(map[string]any)
|
||||||
|
ollama := provider["ollama"].(map[string]any)
|
||||||
|
if ollama["name"] != "My Custom Ollama" {
|
||||||
|
t.Errorf("custom provider name was changed: got %q, want %q", ollama["name"], "My Custom Ollama")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
|
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
|
||||||
cleanup()
|
cleanup()
|
||||||
os.MkdirAll(configDir, 0o755)
|
os.MkdirAll(configDir, 0o755)
|
||||||
@@ -619,6 +659,54 @@ func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenCodeEdit_BackfillsCloudModelLimitOnExistingEntry(t *testing.T) {
|
||||||
|
o := &OpenCode{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/api/show" {
|
||||||
|
fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"glm-5"}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||||
|
configPath := filepath.Join(configDir, "opencode.json")
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
os.WriteFile(configPath, []byte(`{
|
||||||
|
"provider": {
|
||||||
|
"ollama": {
|
||||||
|
"models": {
|
||||||
|
"glm-5:cloud": {
|
||||||
|
"name": "glm-5:cloud",
|
||||||
|
"_launch": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`), 0o644)
|
||||||
|
|
||||||
|
if err := o.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := readOpenCodeModel(t, configPath, "glm-5:cloud")
|
||||||
|
limit, ok := entry["limit"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("cloud model limit was not added on re-edit")
|
||||||
|
}
|
||||||
|
if limit["context"] != float64(202_752) {
|
||||||
|
t.Errorf("context = %v, want 202752", limit["context"])
|
||||||
|
}
|
||||||
|
if limit["output"] != float64(131_072) {
|
||||||
|
t.Errorf("output = %v, want 131072", limit["output"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestLookupCloudModelLimit(t *testing.T) {
|
func TestLookupCloudModelLimit(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -626,13 +714,17 @@ func TestLookupCloudModelLimit(t *testing.T) {
|
|||||||
wantContext int
|
wantContext int
|
||||||
wantOutput int
|
wantOutput int
|
||||||
}{
|
}{
|
||||||
{"glm-4.7", true, 202_752, 131_072},
|
{"glm-4.7", false, 0, 0},
|
||||||
{"glm-4.7:cloud", true, 202_752, 131_072},
|
{"glm-4.7:cloud", true, 202_752, 131_072},
|
||||||
{"kimi-k2.5", true, 262_144, 262_144},
|
{"glm-5:cloud", true, 202_752, 131_072},
|
||||||
|
{"gpt-oss:120b-cloud", true, 131_072, 131_072},
|
||||||
|
{"gpt-oss:20b-cloud", true, 131_072, 131_072},
|
||||||
|
{"kimi-k2.5", false, 0, 0},
|
||||||
{"kimi-k2.5:cloud", true, 262_144, 262_144},
|
{"kimi-k2.5:cloud", true, 262_144, 262_144},
|
||||||
{"deepseek-v3.2", true, 163_840, 65_536},
|
{"deepseek-v3.2", false, 0, 0},
|
||||||
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
|
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
|
||||||
{"qwen3-coder:480b", true, 262_144, 65_536},
|
{"qwen3-coder:480b", false, 0, 0},
|
||||||
|
{"qwen3-coder:480b:cloud", true, 262_144, 65_536},
|
||||||
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
|
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
|
||||||
{"llama3.2", false, 0, 0},
|
{"llama3.2", false, 0, 0},
|
||||||
{"unknown-model:cloud", false, 0, 0},
|
{"unknown-model:cloud", false, 0, 0},
|
||||||
|
|||||||
@@ -107,7 +107,8 @@ func (p *Pi) Edit(models []string) error {
|
|||||||
|
|
||||||
// Build new models list:
|
// Build new models list:
|
||||||
// 1. Keep user-managed models (no _launch marker) - untouched
|
// 1. Keep user-managed models (no _launch marker) - untouched
|
||||||
// 2. Keep ollama-managed models (_launch marker) that are still selected
|
// 2. Keep ollama-managed models (_launch marker) that are still selected,
|
||||||
|
// except stale cloud entries that should be rebuilt below
|
||||||
// 3. Add new ollama-managed models
|
// 3. Add new ollama-managed models
|
||||||
var newModels []any
|
var newModels []any
|
||||||
for _, m := range existingModels {
|
for _, m := range existingModels {
|
||||||
@@ -117,7 +118,13 @@ func (p *Pi) Edit(models []string) error {
|
|||||||
if !isPiOllamaModel(modelObj) {
|
if !isPiOllamaModel(modelObj) {
|
||||||
newModels = append(newModels, m)
|
newModels = append(newModels, m)
|
||||||
} else if selectedSet[id] {
|
} else if selectedSet[id] {
|
||||||
// Ollama-managed and still selected - keep it
|
// Rebuild stale managed cloud entries so createConfig refreshes
|
||||||
|
// the whole entry instead of patching it in place.
|
||||||
|
if !hasContextWindow(modelObj) {
|
||||||
|
if _, ok := lookupCloudModelLimit(id); ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
newModels = append(newModels, m)
|
newModels = append(newModels, m)
|
||||||
selectedSet[id] = false
|
selectedSet[id] = false
|
||||||
}
|
}
|
||||||
@@ -199,12 +206,28 @@ func isPiOllamaModel(cfg map[string]any) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hasContextWindow(cfg map[string]any) bool {
|
||||||
|
switch v := cfg["contextWindow"].(type) {
|
||||||
|
case float64:
|
||||||
|
return v > 0
|
||||||
|
case int:
|
||||||
|
return v > 0
|
||||||
|
case int64:
|
||||||
|
return v > 0
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// createConfig builds Pi model config with capability detection
|
// createConfig builds Pi model config with capability detection
|
||||||
func createConfig(ctx context.Context, client *api.Client, modelID string) map[string]any {
|
func createConfig(ctx context.Context, client *api.Client, modelID string) map[string]any {
|
||||||
cfg := map[string]any{
|
cfg := map[string]any{
|
||||||
"id": modelID,
|
"id": modelID,
|
||||||
"_launch": true,
|
"_launch": true,
|
||||||
}
|
}
|
||||||
|
if l, ok := lookupCloudModelLimit(modelID); ok {
|
||||||
|
cfg["contextWindow"] = l.Context
|
||||||
|
}
|
||||||
|
|
||||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID})
|
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -223,7 +246,8 @@ func createConfig(ctx context.Context, client *api.Client, modelID string) map[s
|
|||||||
cfg["reasoning"] = true
|
cfg["reasoning"] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract context window from ModelInfo
|
// Extract context window from ModelInfo. For known cloud models, the
|
||||||
|
// pre-filled shared limit remains unless the server provides a positive value.
|
||||||
for key, val := range resp.ModelInfo {
|
for key, val := range resp.ModelInfo {
|
||||||
if strings.HasSuffix(key, ".context_length") {
|
if strings.HasSuffix(key, ".context_length") {
|
||||||
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
|
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
|
||||||
|
|||||||
@@ -192,6 +192,48 @@ func TestPiEdit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("rebuilds stale existing managed cloud model", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
|
||||||
|
existingConfig := `{
|
||||||
|
"providers": {
|
||||||
|
"ollama": {
|
||||||
|
"baseUrl": "http://localhost:11434/v1",
|
||||||
|
"api": "openai-completions",
|
||||||
|
"apiKey": "ollama",
|
||||||
|
"models": [
|
||||||
|
{"id": "glm-5:cloud", "_launch": true, "legacyField": "stale"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := pi.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||||
|
t.Fatalf("Edit() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := readConfig()
|
||||||
|
providers := cfg["providers"].(map[string]any)
|
||||||
|
ollama := providers["ollama"].(map[string]any)
|
||||||
|
modelsArray := ollama["models"].([]any)
|
||||||
|
modelEntry := modelsArray[0].(map[string]any)
|
||||||
|
|
||||||
|
if modelEntry["contextWindow"] != float64(202_752) {
|
||||||
|
t.Errorf("contextWindow = %v, want 202752", modelEntry["contextWindow"])
|
||||||
|
}
|
||||||
|
input, ok := modelEntry["input"].([]any)
|
||||||
|
if !ok || len(input) != 1 || input[0] != "text" {
|
||||||
|
t.Errorf("input = %v, want [text]", modelEntry["input"])
|
||||||
|
}
|
||||||
|
if _, ok := modelEntry["legacyField"]; ok {
|
||||||
|
t.Error("legacyField should be removed when stale managed cloud entry is rebuilt")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("replaces old models with new ones", func(t *testing.T) {
|
t.Run("replaces old models with new ones", func(t *testing.T) {
|
||||||
cleanup()
|
cleanup()
|
||||||
os.MkdirAll(configDir, 0o755)
|
os.MkdirAll(configDir, 0o755)
|
||||||
@@ -798,6 +840,60 @@ func TestCreateConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("falls back to cloud context when show fails", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, _ := url.Parse(srv.URL)
|
||||||
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
|
cfg := createConfig(context.Background(), client, "kimi-k2.5:cloud")
|
||||||
|
|
||||||
|
if cfg["contextWindow"] != 262_144 {
|
||||||
|
t.Errorf("contextWindow = %v, want 262144", cfg["contextWindow"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("falls back to cloud context when model info is empty", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/api/show" {
|
||||||
|
fmt.Fprintf(w, `{"capabilities":[],"model_info":{}}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, _ := url.Parse(srv.URL)
|
||||||
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
|
cfg := createConfig(context.Background(), client, "glm-5:cloud")
|
||||||
|
|
||||||
|
if cfg["contextWindow"] != 202_752 {
|
||||||
|
t.Errorf("contextWindow = %v, want 202752", cfg["contextWindow"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("falls back to cloud context for dash cloud suffix", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, _ := url.Parse(srv.URL)
|
||||||
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
|
cfg := createConfig(context.Background(), client, "gpt-oss:120b-cloud")
|
||||||
|
|
||||||
|
if cfg["contextWindow"] != 131_072 {
|
||||||
|
t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("skips zero context length", func(t *testing.T) {
|
t.Run("skips zero context length", func(t *testing.T) {
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.URL.Path == "/api/show" {
|
if r.URL.Path == "/api/show" {
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/charmbracelet/lipgloss"
|
"github.com/charmbracelet/lipgloss"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/cmd/config"
|
"github.com/ollama/ollama/cmd/config"
|
||||||
|
"github.com/ollama/ollama/internal/modelref"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -147,7 +148,13 @@ type signInCheckMsg struct {
|
|||||||
type clearStatusMsg struct{}
|
type clearStatusMsg struct{}
|
||||||
|
|
||||||
func (m *model) modelExists(name string) bool {
|
func (m *model) modelExists(name string) bool {
|
||||||
if m.availableModels == nil || name == "" {
|
if name == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if modelref.HasExplicitCloudSource(name) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if m.availableModels == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if m.availableModels[name] {
|
if m.availableModels[name] {
|
||||||
@@ -209,7 +216,7 @@ func (m *model) openMultiModelModal(integration string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func isCloudModel(name string) bool {
|
func isCloudModel(name string) bool {
|
||||||
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
|
return modelref.HasExplicitCloudSource(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func cloudStatusDisabled(client *api.Client) bool {
|
func cloudStatusDisabled(client *api.Client) bool {
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ type nemotronHModel struct {
|
|||||||
NGroups uint32 `json:"n_groups"`
|
NGroups uint32 `json:"n_groups"`
|
||||||
IntermediateSize uint32 `json:"intermediate_size"`
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
HybridOverridePattern hybridPattern `json:"hybrid_override_pattern"`
|
HybridOverridePattern hybridPattern `json:"hybrid_override_pattern"`
|
||||||
|
LayersBlockType []string `json:"layers_block_type"`
|
||||||
|
|
||||||
// MoE
|
// MoE
|
||||||
NumExperts uint32 `json:"num_experts"`
|
NumExperts uint32 `json:"num_experts"`
|
||||||
@@ -162,8 +163,27 @@ func (n *nemotronHModel) denseIntermediateSize() uint32 {
|
|||||||
|
|
||||||
func (n *nemotronHModel) layerArrays() (headCountKV []uint32, ffnLengths []uint32, err error) {
|
func (n *nemotronHModel) layerArrays() (headCountKV []uint32, ffnLengths []uint32, err error) {
|
||||||
pattern := strings.TrimSpace(string(n.HybridOverridePattern))
|
pattern := strings.TrimSpace(string(n.HybridOverridePattern))
|
||||||
|
|
||||||
|
// Convert layers_block_type array to pattern string if hybrid_override_pattern is not set
|
||||||
|
if pattern == "" && len(n.LayersBlockType) > 0 {
|
||||||
|
var sb strings.Builder
|
||||||
|
for _, blockType := range n.LayersBlockType {
|
||||||
|
switch strings.ToLower(blockType) {
|
||||||
|
case "mamba":
|
||||||
|
sb.WriteRune('M')
|
||||||
|
case "moe":
|
||||||
|
sb.WriteRune('E')
|
||||||
|
case "attention":
|
||||||
|
sb.WriteRune('A')
|
||||||
|
default:
|
||||||
|
return nil, nil, fmt.Errorf("nemotron_h: unsupported block type %q in layers_block_type", blockType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pattern = sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
if pattern == "" {
|
if pattern == "" {
|
||||||
return nil, nil, fmt.Errorf("nemotron_h: hybrid_override_pattern must be set")
|
return nil, nil, fmt.Errorf("nemotron_h: hybrid_override_pattern or layers_block_type must be set")
|
||||||
}
|
}
|
||||||
|
|
||||||
runes := []rune(pattern)
|
runes := []rune(pattern)
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ To use Ollama with tools that expect the Anthropic API (like Claude Code), set t
|
|||||||
|
|
||||||
```shell
|
```shell
|
||||||
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
|
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
|
||||||
export ANTHROPIC_API_KEY="" # required but ignored
|
|
||||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -269,7 +268,7 @@ ollama launch claude --config
|
|||||||
Set the environment variables and run Claude Code:
|
Set the environment variables and run Claude Code:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder
|
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 claude --model qwen3-coder
|
||||||
```
|
```
|
||||||
|
|
||||||
Or set the environment variables in your shell profile:
|
Or set the environment variables in your shell profile:
|
||||||
@@ -277,7 +276,6 @@ Or set the environment variables in your shell profile:
|
|||||||
```shell
|
```shell
|
||||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||||
export ANTHROPIC_API_KEY=""
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Then run Claude Code with any Ollama model:
|
Then run Claude Code with any Ollama model:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ Ollama provides compatibility with parts of the [OpenAI API](https://platform.op
|
|||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
### Simple `v1/chat/completions` example
|
### Simple `/v1/chat/completions` example
|
||||||
|
|
||||||
<CodeGroup dropdown>
|
<CodeGroup dropdown>
|
||||||
|
|
||||||
@@ -57,7 +57,7 @@ curl -X POST http://localhost:11434/v1/chat/completions \
|
|||||||
|
|
||||||
</CodeGroup>
|
</CodeGroup>
|
||||||
|
|
||||||
### Simple `v1/responses` example
|
### Simple `/v1/responses` example
|
||||||
|
|
||||||
<CodeGroup dropdown>
|
<CodeGroup dropdown>
|
||||||
|
|
||||||
@@ -103,7 +103,7 @@ curl -X POST http://localhost:11434/v1/responses \
|
|||||||
|
|
||||||
</CodeGroup>
|
</CodeGroup>
|
||||||
|
|
||||||
### v1/chat/completions with vision example
|
### `/v1/chat/completions` with vision example
|
||||||
|
|
||||||
<CodeGroup dropdown>
|
<CodeGroup dropdown>
|
||||||
|
|
||||||
|
|||||||
@@ -51,6 +51,9 @@ Install prerequisites:
|
|||||||
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
|
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
|
||||||
- (Optional) VULKAN GPU support
|
- (Optional) VULKAN GPU support
|
||||||
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
|
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
|
||||||
|
- (Optional) MLX engine support
|
||||||
|
- [CUDA 13+ SDK](https://developer.nvidia.com/cuda-downloads)
|
||||||
|
- [cuDNN 9+](https://developer.nvidia.com/cudnn)
|
||||||
|
|
||||||
Then, configure and build the project:
|
Then, configure and build the project:
|
||||||
|
|
||||||
@@ -101,6 +104,10 @@ Install prerequisites:
|
|||||||
- (Optional) VULKAN GPU support
|
- (Optional) VULKAN GPU support
|
||||||
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
|
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
|
||||||
- Or install via package manager: `sudo apt install vulkan-sdk` (Ubuntu/Debian) or `sudo dnf install vulkan-sdk` (Fedora/CentOS)
|
- Or install via package manager: `sudo apt install vulkan-sdk` (Ubuntu/Debian) or `sudo dnf install vulkan-sdk` (Fedora/CentOS)
|
||||||
|
- (Optional) MLX engine support
|
||||||
|
- [CUDA 13+ SDK](https://developer.nvidia.com/cuda-downloads)
|
||||||
|
- [cuDNN 9+](https://developer.nvidia.com/cudnn)
|
||||||
|
- OpenBLAS/LAPACK: `sudo apt install libopenblas-dev liblapack-dev liblapacke-dev` (Ubuntu/Debian)
|
||||||
> [!IMPORTANT]
|
> [!IMPORTANT]
|
||||||
> Ensure prerequisites are in `PATH` before running CMake.
|
> Ensure prerequisites are in `PATH` before running CMake.
|
||||||
|
|
||||||
@@ -118,6 +125,67 @@ Lastly, run Ollama:
|
|||||||
go run . serve
|
go run . serve
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## MLX Engine (Optional)
|
||||||
|
|
||||||
|
The MLX engine enables running safetensor based models. It requires building the [MLX](https://github.com/ml-explore/mlx) and [MLX-C](https://github.com/ml-explore/mlx-c) shared libraries separately via CMake. On MacOS, MLX leverages the Metal library to run on the GPU, and on Windows and Linux, runs on NVIDIA GPUs via CUDA v13.
|
||||||
|
|
||||||
|
### macOS (Apple Silicon)
|
||||||
|
|
||||||
|
Requires the Metal toolchain. Install [Xcode](https://developer.apple.com/xcode/) first, then:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
xcodebuild -downloadComponent MetalToolchain
|
||||||
|
```
|
||||||
|
|
||||||
|
Verify it's installed correctly (should print "no input files"):
|
||||||
|
|
||||||
|
```shell
|
||||||
|
xcrun metal
|
||||||
|
```
|
||||||
|
|
||||||
|
Then build:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
cmake -B build --preset MLX
|
||||||
|
cmake --build build --preset MLX --parallel
|
||||||
|
cmake --install build --component MLX
|
||||||
|
```
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Without the Metal toolchain, cmake will silently complete with Metal disabled. Check the cmake output for `Setting MLX_BUILD_METAL=OFF` which indicates the toolchain is missing.
|
||||||
|
|
||||||
|
### Windows / Linux (CUDA)
|
||||||
|
|
||||||
|
Requires CUDA 13+ and [cuDNN](https://developer.nvidia.com/cudnn) 9+.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
cmake -B build --preset "MLX CUDA 13"
|
||||||
|
cmake --build build --target mlx --target mlxc --config Release --parallel
|
||||||
|
cmake --install build --component MLX --strip
|
||||||
|
```
|
||||||
|
|
||||||
|
### Local MLX source overrides
|
||||||
|
|
||||||
|
To build against a local checkout of MLX and/or MLX-C (useful for development), set environment variables before running CMake:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
export OLLAMA_MLX_SOURCE=/path/to/mlx
|
||||||
|
export OLLAMA_MLX_C_SOURCE=/path/to/mlx-c
|
||||||
|
```
|
||||||
|
|
||||||
|
For example, using the helper scripts with local mlx and mlx-c repos:
|
||||||
|
```shell
|
||||||
|
OLLAMA_MLX_SOURCE=../mlx OLLAMA_MLX_C_SOURCE=../mlx-c ./scripts/build_linux.sh
|
||||||
|
|
||||||
|
OLLAMA_MLX_SOURCE=../mlx OLLAMA_MLX_C_SOURCE=../mlx-c ./scripts/build_darwin.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
```powershell
|
||||||
|
$env:OLLAMA_MLX_SOURCE="../mlx"
|
||||||
|
$env:OLLAMA_MLX_C_SOURCE="../mlx-c"
|
||||||
|
./scripts/build_darwin.ps1
|
||||||
|
```
|
||||||
|
|
||||||
## Docker
|
## Docker
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
|||||||
27
docs/gpu.mdx
27
docs/gpu.mdx
@@ -62,10 +62,12 @@ Ollama supports the following AMD GPUs via the ROCm library:
|
|||||||
### Linux Support
|
### Linux Support
|
||||||
|
|
||||||
| Family | Cards and accelerators |
|
| Family | Cards and accelerators |
|
||||||
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` |
|
| AMD Radeon RX | `9070 XT` `9070 GRE` `9070` `9060 XT` `9060 XT LP` `9060` `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7700` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `5700 XT` `5700` `5600 XT` `5500 XT` |
|
||||||
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `SSG` |
|
| AMD Radeon AI PRO | `R9700` `R9600D` |
|
||||||
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` |
|
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` |
|
||||||
|
| AMD Ryzen AI | `Ryzen AI Max+ 395` `Ryzen AI Max 390` `Ryzen AI Max 385` `Ryzen AI 9 HX 475` `Ryzen AI 9 HX 470` `Ryzen AI 9 465` `Ryzen AI 9 HX 375` `Ryzen AI 9 HX 370` `Ryzen AI 9 365` |
|
||||||
|
| AMD Instinct | `MI350X` `MI300X` `MI300A` `MI250X` `MI250` `MI210` `MI100` |
|
||||||
|
|
||||||
### Windows Support
|
### Windows Support
|
||||||
|
|
||||||
@@ -97,17 +99,20 @@ This table shows some example GPUs that map to these LLVM targets:
|
|||||||
| **LLVM Target** | **An Example GPU** |
|
| **LLVM Target** | **An Example GPU** |
|
||||||
|-----------------|---------------------|
|
|-----------------|---------------------|
|
||||||
| gfx908 | Radeon Instinct MI100 |
|
| gfx908 | Radeon Instinct MI100 |
|
||||||
| gfx90a | Radeon Instinct MI210 |
|
| gfx90a | Radeon Instinct MI210/MI250 |
|
||||||
| gfx940 | Radeon Instinct MI300 |
|
| gfx942 | Radeon Instinct MI300X/MI300A |
|
||||||
| gfx941 | |
|
| gfx950 | Radeon Instinct MI350X |
|
||||||
| gfx942 | |
|
| gfx1010 | Radeon RX 5700 XT |
|
||||||
|
| gfx1012 | Radeon RX 5500 XT |
|
||||||
| gfx1030 | Radeon PRO V620 |
|
| gfx1030 | Radeon PRO V620 |
|
||||||
| gfx1100 | Radeon PRO W7900 |
|
| gfx1100 | Radeon PRO W7900 |
|
||||||
| gfx1101 | Radeon PRO W7700 |
|
| gfx1101 | Radeon PRO W7700 |
|
||||||
| gfx1102 | Radeon RX 7600 |
|
| gfx1102 | Radeon RX 7600 |
|
||||||
|
| gfx1103 | Radeon 780M |
|
||||||
AMD is working on enhancing ROCm v6 to broaden support for families of GPUs in a
|
| gfx1150 | Ryzen AI 9 HX 375 |
|
||||||
future release which should increase support for more GPUs.
|
| gfx1151 | Ryzen AI Max+ 395 |
|
||||||
|
| gfx1200 | Radeon RX 9070 |
|
||||||
|
| gfx1201 | Radeon RX 9070 XT |
|
||||||
|
|
||||||
Reach out on [Discord](https://discord.gg/ollama) or file an
|
Reach out on [Discord](https://discord.gg/ollama) or file an
|
||||||
[issue](https://github.com/ollama/ollama/issues) for additional help.
|
[issue](https://github.com/ollama/ollama/issues) for additional help.
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ nvidia-smi
|
|||||||
|
|
||||||
### Install AMD ROCm drivers (optional)
|
### Install AMD ROCm drivers (optional)
|
||||||
|
|
||||||
[Download and Install](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/tutorial/quick-start.html) ROCm v6.
|
[Download and Install](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/tutorial/quick-start.html) ROCm v7.
|
||||||
|
|
||||||
### Start Ollama
|
### Start Ollama
|
||||||
|
|
||||||
|
|||||||
@@ -152,7 +152,9 @@ PARAMETER <parameter> <parametervalue>
|
|||||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
|
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
|
||||||
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
|
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
|
||||||
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
||||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.0) | float | repeat_penalty 1.0 |
|
||||||
|
| presence_penalty | Penalizes tokens that have already appeared in the generated text to reduce repetition. (Default: 0.0) | float | presence_penalty 1.5 |
|
||||||
|
| frequency_penalty | Penalizes tokens based on how often they have appeared in the generated text. (Default: 0.0) | float | frequency_penalty 1.0 |
|
||||||
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
|
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
|
||||||
| seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | int | seed 42 |
|
| seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | int | seed 42 |
|
||||||
| stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate `stop` parameters in a modelfile. | string | stop "AI assistant:" |
|
| stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate `stop` parameters in a modelfile. | string | stop "AI assistant:" |
|
||||||
|
|||||||
115
internal/modelref/modelref.go
Normal file
115
internal/modelref/modelref.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package modelref
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ModelSource uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
ModelSourceUnspecified ModelSource = iota
|
||||||
|
ModelSourceLocal
|
||||||
|
ModelSourceCloud
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrConflictingSourceSuffix = errors.New("use either :local or :cloud, not both")
|
||||||
|
ErrModelRequired = errors.New("model is required")
|
||||||
|
)
|
||||||
|
|
||||||
|
type ParsedRef struct {
|
||||||
|
Original string
|
||||||
|
Base string
|
||||||
|
Source ModelSource
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseRef(raw string) (ParsedRef, error) {
|
||||||
|
var zero ParsedRef
|
||||||
|
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
return zero, ErrModelRequired
|
||||||
|
}
|
||||||
|
|
||||||
|
base, source, explicit := parseSourceSuffix(raw)
|
||||||
|
if explicit {
|
||||||
|
if _, _, nested := parseSourceSuffix(base); nested {
|
||||||
|
return zero, fmt.Errorf("%w: %q", ErrConflictingSourceSuffix, raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ParsedRef{
|
||||||
|
Original: raw,
|
||||||
|
Base: base,
|
||||||
|
Source: source,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func HasExplicitCloudSource(raw string) bool {
|
||||||
|
parsedRef, err := ParseRef(raw)
|
||||||
|
return err == nil && parsedRef.Source == ModelSourceCloud
|
||||||
|
}
|
||||||
|
|
||||||
|
func HasExplicitLocalSource(raw string) bool {
|
||||||
|
parsedRef, err := ParseRef(raw)
|
||||||
|
return err == nil && parsedRef.Source == ModelSourceLocal
|
||||||
|
}
|
||||||
|
|
||||||
|
func StripCloudSourceTag(raw string) (string, bool) {
|
||||||
|
parsedRef, err := ParseRef(raw)
|
||||||
|
if err != nil || parsedRef.Source != ModelSourceCloud {
|
||||||
|
return strings.TrimSpace(raw), false
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsedRef.Base, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func NormalizePullName(raw string) (string, bool, error) {
|
||||||
|
parsedRef, err := ParseRef(raw)
|
||||||
|
if err != nil {
|
||||||
|
return "", false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if parsedRef.Source != ModelSourceCloud {
|
||||||
|
return parsedRef.Base, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return toLegacyCloudPullName(parsedRef.Base), true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func toLegacyCloudPullName(base string) string {
|
||||||
|
if hasExplicitTag(base) {
|
||||||
|
return base + "-cloud"
|
||||||
|
}
|
||||||
|
|
||||||
|
return base + ":cloud"
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasExplicitTag(name string) bool {
|
||||||
|
lastSlash := strings.LastIndex(name, "/")
|
||||||
|
lastColon := strings.LastIndex(name, ":")
|
||||||
|
return lastColon > lastSlash
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSourceSuffix(raw string) (string, ModelSource, bool) {
|
||||||
|
idx := strings.LastIndex(raw, ":")
|
||||||
|
if idx >= 0 {
|
||||||
|
suffixRaw := strings.TrimSpace(raw[idx+1:])
|
||||||
|
suffix := strings.ToLower(suffixRaw)
|
||||||
|
|
||||||
|
switch suffix {
|
||||||
|
case "cloud":
|
||||||
|
return raw[:idx], ModelSourceCloud, true
|
||||||
|
case "local":
|
||||||
|
return raw[:idx], ModelSourceLocal, true
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(suffixRaw, "/") && strings.HasSuffix(suffix, "-cloud") {
|
||||||
|
return raw[:idx+1] + suffixRaw[:len(suffixRaw)-len("-cloud")], ModelSourceCloud, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return raw, ModelSourceUnspecified, false
|
||||||
|
}
|
||||||
268
internal/modelref/modelref_test.go
Normal file
268
internal/modelref/modelref_test.go
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
package modelref
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseRef(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
wantBase string
|
||||||
|
wantSource ModelSource
|
||||||
|
wantErr error
|
||||||
|
wantCloud bool
|
||||||
|
wantLocal bool
|
||||||
|
wantStripped string
|
||||||
|
wantStripOK bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "cloud suffix",
|
||||||
|
input: "gpt-oss:20b:cloud",
|
||||||
|
wantBase: "gpt-oss:20b",
|
||||||
|
wantSource: ModelSourceCloud,
|
||||||
|
wantCloud: true,
|
||||||
|
wantStripped: "gpt-oss:20b",
|
||||||
|
wantStripOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "legacy cloud suffix",
|
||||||
|
input: "gpt-oss:20b-cloud",
|
||||||
|
wantBase: "gpt-oss:20b",
|
||||||
|
wantSource: ModelSourceCloud,
|
||||||
|
wantCloud: true,
|
||||||
|
wantStripped: "gpt-oss:20b",
|
||||||
|
wantStripOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "local suffix",
|
||||||
|
input: "qwen3:8b:local",
|
||||||
|
wantBase: "qwen3:8b",
|
||||||
|
wantSource: ModelSourceLocal,
|
||||||
|
wantLocal: true,
|
||||||
|
wantStripped: "qwen3:8b:local",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no source suffix",
|
||||||
|
input: "llama3.2",
|
||||||
|
wantBase: "llama3.2",
|
||||||
|
wantSource: ModelSourceUnspecified,
|
||||||
|
wantStripped: "llama3.2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bare cloud name is not explicit cloud",
|
||||||
|
input: "my-cloud-model",
|
||||||
|
wantBase: "my-cloud-model",
|
||||||
|
wantSource: ModelSourceUnspecified,
|
||||||
|
wantStripped: "my-cloud-model",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "slash in suffix blocks legacy cloud parsing",
|
||||||
|
input: "foo:bar-cloud/baz",
|
||||||
|
wantBase: "foo:bar-cloud/baz",
|
||||||
|
wantSource: ModelSourceUnspecified,
|
||||||
|
wantStripped: "foo:bar-cloud/baz",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "conflicting source suffixes",
|
||||||
|
input: "foo:cloud:local",
|
||||||
|
wantErr: ErrConflictingSourceSuffix,
|
||||||
|
wantSource: ModelSourceUnspecified,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty input",
|
||||||
|
input: " ",
|
||||||
|
wantErr: ErrModelRequired,
|
||||||
|
wantSource: ModelSourceUnspecified,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := ParseRef(tt.input)
|
||||||
|
if tt.wantErr != nil {
|
||||||
|
if !errors.Is(err, tt.wantErr) {
|
||||||
|
t.Fatalf("ParseRef(%q) error = %v, want %v", tt.input, err, tt.wantErr)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseRef(%q) returned error: %v", tt.input, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.Base != tt.wantBase {
|
||||||
|
t.Fatalf("base = %q, want %q", got.Base, tt.wantBase)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.Source != tt.wantSource {
|
||||||
|
t.Fatalf("source = %v, want %v", got.Source, tt.wantSource)
|
||||||
|
}
|
||||||
|
|
||||||
|
if HasExplicitCloudSource(tt.input) != tt.wantCloud {
|
||||||
|
t.Fatalf("HasExplicitCloudSource(%q) = %v, want %v", tt.input, HasExplicitCloudSource(tt.input), tt.wantCloud)
|
||||||
|
}
|
||||||
|
|
||||||
|
if HasExplicitLocalSource(tt.input) != tt.wantLocal {
|
||||||
|
t.Fatalf("HasExplicitLocalSource(%q) = %v, want %v", tt.input, HasExplicitLocalSource(tt.input), tt.wantLocal)
|
||||||
|
}
|
||||||
|
|
||||||
|
stripped, ok := StripCloudSourceTag(tt.input)
|
||||||
|
if ok != tt.wantStripOK {
|
||||||
|
t.Fatalf("StripCloudSourceTag(%q) ok = %v, want %v", tt.input, ok, tt.wantStripOK)
|
||||||
|
}
|
||||||
|
if stripped != tt.wantStripped {
|
||||||
|
t.Fatalf("StripCloudSourceTag(%q) base = %q, want %q", tt.input, stripped, tt.wantStripped)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizePullName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
wantName string
|
||||||
|
wantCloud bool
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "explicit local strips source",
|
||||||
|
input: "gpt-oss:20b:local",
|
||||||
|
wantName: "gpt-oss:20b",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit cloud with size maps to legacy dash cloud tag",
|
||||||
|
input: "gpt-oss:20b:cloud",
|
||||||
|
wantName: "gpt-oss:20b-cloud",
|
||||||
|
wantCloud: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "legacy cloud with size remains stable",
|
||||||
|
input: "gpt-oss:20b-cloud",
|
||||||
|
wantName: "gpt-oss:20b-cloud",
|
||||||
|
wantCloud: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit cloud without tag maps to cloud tag",
|
||||||
|
input: "qwen3:cloud",
|
||||||
|
wantName: "qwen3:cloud",
|
||||||
|
wantCloud: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "host port without tag keeps host port and appends cloud tag",
|
||||||
|
input: "localhost:11434/library/foo:cloud",
|
||||||
|
wantName: "localhost:11434/library/foo:cloud",
|
||||||
|
wantCloud: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "conflicting source suffixes fail",
|
||||||
|
input: "foo:cloud:local",
|
||||||
|
wantErr: ErrConflictingSourceSuffix,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotName, gotCloud, err := NormalizePullName(tt.input)
|
||||||
|
if tt.wantErr != nil {
|
||||||
|
if !errors.Is(err, tt.wantErr) {
|
||||||
|
t.Fatalf("NormalizePullName(%q) error = %v, want %v", tt.input, err, tt.wantErr)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NormalizePullName(%q) returned error: %v", tt.input, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotName != tt.wantName {
|
||||||
|
t.Fatalf("normalized name = %q, want %q", gotName, tt.wantName)
|
||||||
|
}
|
||||||
|
if gotCloud != tt.wantCloud {
|
||||||
|
t.Fatalf("cloud = %v, want %v", gotCloud, tt.wantCloud)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseSourceSuffix(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
wantBase string
|
||||||
|
wantSource ModelSource
|
||||||
|
wantExplicit bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "explicit cloud suffix",
|
||||||
|
input: "gpt-oss:20b:cloud",
|
||||||
|
wantBase: "gpt-oss:20b",
|
||||||
|
wantSource: ModelSourceCloud,
|
||||||
|
wantExplicit: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit local suffix",
|
||||||
|
input: "qwen3:8b:local",
|
||||||
|
wantBase: "qwen3:8b",
|
||||||
|
wantSource: ModelSourceLocal,
|
||||||
|
wantExplicit: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "legacy cloud suffix on tag",
|
||||||
|
input: "gpt-oss:20b-cloud",
|
||||||
|
wantBase: "gpt-oss:20b",
|
||||||
|
wantSource: ModelSourceCloud,
|
||||||
|
wantExplicit: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "legacy cloud suffix does not match model segment",
|
||||||
|
input: "my-cloud-model",
|
||||||
|
wantBase: "my-cloud-model",
|
||||||
|
wantSource: ModelSourceUnspecified,
|
||||||
|
wantExplicit: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "legacy cloud suffix blocked when suffix includes slash",
|
||||||
|
input: "foo:bar-cloud/baz",
|
||||||
|
wantBase: "foo:bar-cloud/baz",
|
||||||
|
wantSource: ModelSourceUnspecified,
|
||||||
|
wantExplicit: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown suffix is not explicit source",
|
||||||
|
input: "gpt-oss:clod",
|
||||||
|
wantBase: "gpt-oss:clod",
|
||||||
|
wantSource: ModelSourceUnspecified,
|
||||||
|
wantExplicit: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "uppercase suffix is accepted",
|
||||||
|
input: "gpt-oss:20b:CLOUD",
|
||||||
|
wantBase: "gpt-oss:20b",
|
||||||
|
wantSource: ModelSourceCloud,
|
||||||
|
wantExplicit: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no suffix",
|
||||||
|
input: "llama3.2",
|
||||||
|
wantBase: "llama3.2",
|
||||||
|
wantSource: ModelSourceUnspecified,
|
||||||
|
wantExplicit: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotBase, gotSource, gotExplicit := parseSourceSuffix(tt.input)
|
||||||
|
if gotBase != tt.wantBase {
|
||||||
|
t.Fatalf("base = %q, want %q", gotBase, tt.wantBase)
|
||||||
|
}
|
||||||
|
if gotSource != tt.wantSource {
|
||||||
|
t.Fatalf("source = %v, want %v", gotSource, tt.wantSource)
|
||||||
|
}
|
||||||
|
if gotExplicit != tt.wantExplicit {
|
||||||
|
t.Fatalf("explicit = %v, want %v", gotExplicit, tt.wantExplicit)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -74,8 +74,7 @@ type LlamaServer interface {
|
|||||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||||
Close() error
|
Close() error
|
||||||
VRAMSize() uint64 // Total VRAM across all GPUs
|
MemorySize() (total, vram uint64)
|
||||||
TotalSize() uint64
|
|
||||||
VRAMByGPU(id ml.DeviceID) uint64
|
VRAMByGPU(id ml.DeviceID) uint64
|
||||||
Pid() int
|
Pid() int
|
||||||
GetPort() int
|
GetPort() int
|
||||||
@@ -685,8 +684,9 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
|
|||||||
// Windows CUDA should not use mmap for best performance
|
// Windows CUDA should not use mmap for best performance
|
||||||
// Linux with a model larger than free space, mmap leads to thrashing
|
// Linux with a model larger than free space, mmap leads to thrashing
|
||||||
// For CPU loads we want the memory to be allocated, not FS cache
|
// For CPU loads we want the memory to be allocated, not FS cache
|
||||||
|
totalSize, _ := s.MemorySize()
|
||||||
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
|
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
|
||||||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < s.TotalSize() && s.options.UseMMap == nil) ||
|
(runtime.GOOS == "linux" && systemInfo.FreeMemory < totalSize && s.options.UseMMap == nil) ||
|
||||||
(len(gpus) == 0 && s.options.UseMMap == nil) ||
|
(len(gpus) == 0 && s.options.UseMMap == nil) ||
|
||||||
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
|
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
|
||||||
(s.options.UseMMap != nil && !*s.options.UseMMap) {
|
(s.options.UseMMap != nil && !*s.options.UseMMap) {
|
||||||
@@ -1848,17 +1848,17 @@ func (s *llamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) VRAMSize() uint64 {
|
func (s *llmServer) MemorySize() (total, vram uint64) {
|
||||||
if s.mem == nil {
|
if s.mem == nil {
|
||||||
return 0
|
return 0, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
var mem uint64
|
|
||||||
|
|
||||||
for _, g := range s.mem.GPUs {
|
for _, g := range s.mem.GPUs {
|
||||||
mem += g.Size()
|
vram += g.Size()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
total = s.mem.InputWeights + s.mem.CPU.Size() + vram
|
||||||
|
|
||||||
// Some elements are always on CPU. However, if we have allocated all layers
|
// Some elements are always on CPU. However, if we have allocated all layers
|
||||||
// on the GPU then include the CPU components as well, to represent complete offloading.
|
// on the GPU then include the CPU components as well, to represent complete offloading.
|
||||||
noCPULayers := true
|
noCPULayers := true
|
||||||
@@ -1869,25 +1869,11 @@ func (s *llmServer) VRAMSize() uint64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if noCPULayers {
|
if noCPULayers {
|
||||||
mem += s.mem.InputWeights
|
vram += s.mem.InputWeights
|
||||||
mem += s.mem.CPU.Graph
|
vram += s.mem.CPU.Graph
|
||||||
}
|
}
|
||||||
|
|
||||||
return mem
|
return total, vram
|
||||||
}
|
|
||||||
|
|
||||||
func (s *llmServer) TotalSize() uint64 {
|
|
||||||
if s.mem == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
mem := s.mem.InputWeights
|
|
||||||
mem += s.mem.CPU.Size()
|
|
||||||
for _, g := range s.mem.GPUs {
|
|
||||||
mem += g.Size()
|
|
||||||
}
|
|
||||||
|
|
||||||
return mem
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
|
func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||||
|
"github.com/ollama/ollama/internal/modelref"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -919,7 +920,7 @@ func hasWebSearchTool(tools []anthropic.Tool) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func isCloudModelName(name string) bool {
|
func isCloudModelName(name string) bool {
|
||||||
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
|
return modelref.HasExplicitCloudSource(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractQueryFromToolCall extracts the search query from a web_search tool call
|
// extractQueryFromToolCall extracts the search query from a web_search tool call
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ type GatedDeltaNet struct {
|
|||||||
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
|
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
|
||||||
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
|
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
|
||||||
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
|
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
|
||||||
SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias
|
SSMDT ml.Tensor `gguf:"ssm_dt,alt:ssm_dt.bias"` // alpha bias
|
||||||
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
|
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
|
||||||
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
|
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
|
||||||
SSMOut *nn.Linear `gguf:"ssm_out"`
|
SSMOut *nn.Linear `gguf:"ssm_out"`
|
||||||
@@ -135,6 +135,18 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
|||||||
default:
|
default:
|
||||||
return nil, errors.New("qwen3next: missing linear attention beta/alpha projections")
|
return nil, errors.New("qwen3next: missing linear attention beta/alpha projections")
|
||||||
}
|
}
|
||||||
|
if gdn.SSMDT == nil {
|
||||||
|
return nil, errors.New("qwen3next: missing linear attention ssm_dt tensor")
|
||||||
|
}
|
||||||
|
if gdn.SSMA == nil {
|
||||||
|
return nil, errors.New("qwen3next: missing linear attention ssm_a tensor")
|
||||||
|
}
|
||||||
|
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
|
||||||
|
return nil, errors.New("qwen3next: missing linear attention ssm_conv1d tensor")
|
||||||
|
}
|
||||||
|
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
|
||||||
|
return nil, errors.New("qwen3next: missing linear attention ssm_norm/ssm_out projections")
|
||||||
|
}
|
||||||
|
|
||||||
// Compute gate: softplus(alpha + dt_bias) * -A
|
// Compute gate: softplus(alpha + dt_bias) * -A
|
||||||
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
|
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
|
||||||
@@ -442,6 +454,10 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
|||||||
vT := v.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, headVDim, nChunks, numVHeads*nSeqs)
|
vT := v.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, headVDim, nChunks, numVHeads*nSeqs)
|
||||||
stateT := state.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
|
stateT := state.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
|
||||||
|
|
||||||
|
// Collect chunk outputs and concatenate at the end.
|
||||||
|
// Avoids SET on buffer-less intermediates under partial offload.
|
||||||
|
chunks := make([]ml.Tensor, nChunks)
|
||||||
|
|
||||||
for chunk := range nChunks {
|
for chunk := range nChunks {
|
||||||
qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1)
|
qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||||
vTChunk := vT.Slice(ctx, 2, chunk, chunk+1, 1)
|
vTChunk := vT.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||||
@@ -463,14 +479,7 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
|||||||
vAttn := vTNewChunk.Mulmat(ctx, attnChunk)
|
vAttn := vTNewChunk.Mulmat(ctx, attnChunk)
|
||||||
coreAttnOutChunk := attnInter.Add(ctx, vAttn)
|
coreAttnOutChunk := attnInter.Add(ctx, vAttn)
|
||||||
|
|
||||||
v = v.SetInplace(
|
chunks[chunk] = coreAttnOutChunk
|
||||||
ctx,
|
|
||||||
coreAttnOutChunk,
|
|
||||||
v.Stride(1),
|
|
||||||
v.Stride(2),
|
|
||||||
v.Stride(3),
|
|
||||||
chunk*v.Stride(2),
|
|
||||||
)
|
|
||||||
|
|
||||||
// Update state for next chunk
|
// Update state for next chunk
|
||||||
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||||
@@ -483,6 +492,20 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
|||||||
stateT = stateT.Add(ctx, kgdMulVNew)
|
stateT = stateT.Add(ctx, kgdMulVNew)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Use a balanced concat tree so concat work does not balloon on long prompts.
|
||||||
|
for len(chunks) > 1 {
|
||||||
|
merged := make([]ml.Tensor, 0, (len(chunks)+1)/2)
|
||||||
|
for i := 0; i < len(chunks); i += 2 {
|
||||||
|
if i+1 < len(chunks) {
|
||||||
|
merged = append(merged, chunks[i].Concat(ctx, chunks[i+1], 2))
|
||||||
|
} else {
|
||||||
|
merged = append(merged, chunks[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
chunks = merged
|
||||||
|
}
|
||||||
|
v = chunks[0]
|
||||||
|
|
||||||
// Final reshape
|
// Final reshape
|
||||||
coreAttnOut := v.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
|
coreAttnOut := v.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
|
||||||
|
|
||||||
|
|||||||
@@ -437,6 +437,46 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
return m.Output.Forward(ctx, hiddenStates), nil
|
return m.Output.Forward(ctx, hiddenStates), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) Validate() error {
|
||||||
|
if m.Options == nil {
|
||||||
|
return fmt.Errorf("qwen3next: missing model options")
|
||||||
|
}
|
||||||
|
if len(m.Layers) != len(m.Options.isRecurrent) {
|
||||||
|
return fmt.Errorf("qwen3next: layer config mismatch: have %d layers, %d recurrent flags", len(m.Layers), len(m.Options.isRecurrent))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, layer := range m.Layers {
|
||||||
|
if !m.Options.isRecurrent[i] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
gdn, ok := layer.Operator.(*GatedDeltaNet)
|
||||||
|
if !ok || gdn == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d expected recurrent operator", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing attn_qkv/attn_gate projections", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMBetaAlpha == nil && (gdn.SSMBeta == nil || gdn.SSMAlpha == nil) {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing linear attention beta/alpha projections", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMDT == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing ssm_dt tensor", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMA == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing ssm_a tensor", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing ssm_conv1d tensor", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing ssm_norm/ssm_out projections", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
m.positionCache = nil
|
m.positionCache = nil
|
||||||
if len(m.mropeSections) > 0 {
|
if len(m.mropeSections) > 0 {
|
||||||
@@ -450,6 +490,64 @@ var (
|
|||||||
_ model.MultimodalProcessor = (*Model)(nil)
|
_ model.MultimodalProcessor = (*Model)(nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func defaultVHeadReordered(arch string) bool {
|
||||||
|
return arch == "qwen35" || arch == "qwen35moe"
|
||||||
|
}
|
||||||
|
|
||||||
|
func inferRecurrentLayers(headCountKV []uint64, numLayers int, fullAttentionInterval uint32) ([]bool, error) {
|
||||||
|
isRecurrent := make([]bool, numLayers)
|
||||||
|
|
||||||
|
hasZero := false
|
||||||
|
hasFull := false
|
||||||
|
for i := range numLayers {
|
||||||
|
if i >= len(headCountKV) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if headCountKV[i] == 0 {
|
||||||
|
isRecurrent[i] = true
|
||||||
|
hasZero = true
|
||||||
|
} else {
|
||||||
|
hasFull = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hasZero && hasFull {
|
||||||
|
return isRecurrent, nil
|
||||||
|
}
|
||||||
|
if !hasFull {
|
||||||
|
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compatibility path: older imports store a scalar KV head count and omit
|
||||||
|
// per-layer recurrent flags. Derive the hybrid layout from the interval.
|
||||||
|
interval := int(fullAttentionInterval)
|
||||||
|
if interval == 0 {
|
||||||
|
interval = min(4, numLayers)
|
||||||
|
}
|
||||||
|
if interval <= 0 {
|
||||||
|
return nil, fmt.Errorf("qwen3next: invalid block_count (%d)", numLayers)
|
||||||
|
}
|
||||||
|
if interval > numLayers {
|
||||||
|
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds block_count (%d)", interval, numLayers)
|
||||||
|
}
|
||||||
|
|
||||||
|
hasZero = false
|
||||||
|
hasFull = false
|
||||||
|
for i := range numLayers {
|
||||||
|
isRecurrent[i] = (i+1)%interval != 0
|
||||||
|
if isRecurrent[i] {
|
||||||
|
hasZero = true
|
||||||
|
} else {
|
||||||
|
hasFull = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasZero || !hasFull {
|
||||||
|
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) does not produce a mixed recurrent/full layout", interval)
|
||||||
|
}
|
||||||
|
|
||||||
|
return isRecurrent, nil
|
||||||
|
}
|
||||||
|
|
||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
numLayers := int(c.Uint("block_count"))
|
numLayers := int(c.Uint("block_count"))
|
||||||
layers := make([]Layer, numLayers)
|
layers := make([]Layer, numLayers)
|
||||||
@@ -460,26 +558,14 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
HeadCountKV() []uint64
|
HeadCountKV() []uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
var isRecurrent []bool
|
|
||||||
var headCountKV []uint64
|
var headCountKV []uint64
|
||||||
if hc, ok := c.(headCounts); ok {
|
if hc, ok := c.(headCounts); ok {
|
||||||
headCountKV = hc.HeadCountKV()
|
headCountKV = hc.HeadCountKV()
|
||||||
}
|
}
|
||||||
|
|
||||||
isRecurrent = make([]bool, numLayers)
|
isRecurrent, err := inferRecurrentLayers(headCountKV, numLayers, c.Uint("full_attention_interval"))
|
||||||
hasZero := false
|
if err != nil {
|
||||||
hasFull := false
|
return nil, err
|
||||||
for i := range numLayers {
|
|
||||||
// If KV head count is 0, it's a recurrent layer
|
|
||||||
if i < len(headCountKV) && headCountKV[i] == 0 {
|
|
||||||
isRecurrent[i] = true
|
|
||||||
hasZero = true
|
|
||||||
} else if i < len(headCountKV) && headCountKV[i] > 0 {
|
|
||||||
hasFull = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !hasZero || !hasFull {
|
|
||||||
return nil, fmt.Errorf("qwen3next: invalid attention.head_count_kv array; expected mix of zero and non-zero values")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine if MoE
|
// Determine if MoE
|
||||||
@@ -543,7 +629,7 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
ssmNGroup: int(c.Uint("ssm.group_count")),
|
ssmNGroup: int(c.Uint("ssm.group_count")),
|
||||||
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
|
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
|
||||||
convKernelSize: int(c.Uint("ssm.conv_kernel")),
|
convKernelSize: int(c.Uint("ssm.conv_kernel")),
|
||||||
vHeadReordered: c.Bool("ssm.v_head_reordered", false),
|
vHeadReordered: c.Bool("ssm.v_head_reordered", defaultVHeadReordered(c.Architecture())),
|
||||||
isRecurrent: isRecurrent,
|
isRecurrent: isRecurrent,
|
||||||
mropeSections: slices.Collect(func(yield func(int) bool) {
|
mropeSections: slices.Collect(func(yield func(int) bool) {
|
||||||
for _, section := range mropeSections {
|
for _, section := range mropeSections {
|
||||||
@@ -555,7 +641,7 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)),
|
mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)),
|
||||||
}
|
}
|
||||||
if opts.numKVHeads == 0 {
|
if opts.numKVHeads == 0 {
|
||||||
return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value")
|
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate cache dimensions
|
// Calculate cache dimensions
|
||||||
|
|||||||
65
model/models/qwen3next/model_new_test.go
Normal file
65
model/models/qwen3next/model_new_test.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
package qwen3next
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInferRecurrentLayersMixedKVArray(t *testing.T) {
|
||||||
|
got, err := inferRecurrentLayers([]uint64{0, 2, 0, 2}, 4, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []bool{true, false, true, false}
|
||||||
|
if !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInferRecurrentLayersScalarKVDefaultInterval(t *testing.T) {
|
||||||
|
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2, 2, 2}, 8, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []bool{true, true, true, false, true, true, true, false}
|
||||||
|
if !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInferRecurrentLayersScalarKVConfiguredInterval(t *testing.T) {
|
||||||
|
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2}, 6, 3)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []bool{true, true, false, true, true, false}
|
||||||
|
if !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInferRecurrentLayersAllZeroRejects(t *testing.T) {
|
||||||
|
_, err := inferRecurrentLayers([]uint64{0, 0, 0, 0}, 4, 0)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("inferRecurrentLayers() expected error, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "must include at least one non-zero value") {
|
||||||
|
t.Fatalf("unexpected error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultVHeadReordered(t *testing.T) {
|
||||||
|
if !defaultVHeadReordered("qwen35") {
|
||||||
|
t.Fatal("defaultVHeadReordered(qwen35) = false, want true")
|
||||||
|
}
|
||||||
|
if !defaultVHeadReordered("qwen35moe") {
|
||||||
|
t.Fatal("defaultVHeadReordered(qwen35moe) = false, want true")
|
||||||
|
}
|
||||||
|
if defaultVHeadReordered("qwen3next") {
|
||||||
|
t.Fatal("defaultVHeadReordered(qwen3next) = true, want false")
|
||||||
|
}
|
||||||
|
}
|
||||||
45
model/models/qwen3next/model_validate_test.go
Normal file
45
model/models/qwen3next/model_validate_test.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package qwen3next
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidateRecurrentLayerRequiresSSMDT(t *testing.T) {
|
||||||
|
m := &Model{
|
||||||
|
Layers: []Layer{{
|
||||||
|
Operator: &GatedDeltaNet{
|
||||||
|
SSMQKV: &nn.Linear{},
|
||||||
|
SSMQKVGate: &nn.Linear{},
|
||||||
|
SSMBeta: &nn.Linear{},
|
||||||
|
SSMAlpha: &nn.Linear{},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
Options: &Options{
|
||||||
|
isRecurrent: []bool{true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := m.Validate()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Validate() expected error, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "missing ssm_dt") {
|
||||||
|
t.Fatalf("unexpected error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateNonRecurrentSkipsLinearChecks(t *testing.T) {
|
||||||
|
m := &Model{
|
||||||
|
Layers: []Layer{{Operator: &FullAttention{}}},
|
||||||
|
Options: &Options{
|
||||||
|
isRecurrent: []bool{false},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.Validate(); err != nil {
|
||||||
|
t.Fatalf("Validate() error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -35,6 +35,7 @@ type GLM46Parser struct {
|
|||||||
state glm46ParserState
|
state glm46ParserState
|
||||||
buffer strings.Builder
|
buffer strings.Builder
|
||||||
tools []api.Tool
|
tools []api.Tool
|
||||||
|
callIndex int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GLM46Parser) HasToolSupport() bool {
|
func (p *GLM46Parser) HasToolSupport() bool {
|
||||||
@@ -48,6 +49,7 @@ func (p *GLM46Parser) HasThinkingSupport() bool {
|
|||||||
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||||
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
p.tools = tools
|
p.tools = tools
|
||||||
|
p.callIndex = 0
|
||||||
return tools
|
return tools
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -89,6 +91,8 @@ func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string,
|
|||||||
slog.Warn("glm-4.6 tool call parsing failed", "error", err)
|
slog.Warn("glm-4.6 tool call parsing failed", "error", err)
|
||||||
return "", "", nil, err
|
return "", "", nil, err
|
||||||
}
|
}
|
||||||
|
toolCall.Function.Index = p.callIndex
|
||||||
|
p.callIndex++
|
||||||
toolCalls = append(toolCalls, toolCall)
|
toolCalls = append(toolCalls, toolCall)
|
||||||
case glm46EventThinkingContent:
|
case glm46EventThinkingContent:
|
||||||
thinkingSb.WriteString(event.content)
|
thinkingSb.WriteString(event.content)
|
||||||
@@ -341,6 +345,47 @@ func escapeGLM46Content(s string) string {
|
|||||||
return result.String()
|
return result.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// repairUnclosedArgValues inserts missing </arg_value> closing tags.
|
||||||
|
// GLM models sometimes omit the closing tag, producing XML like:
|
||||||
|
//
|
||||||
|
// <arg_value>value</tool_call>
|
||||||
|
//
|
||||||
|
// instead of:
|
||||||
|
//
|
||||||
|
// <arg_value>value</arg_value></tool_call>
|
||||||
|
func repairUnclosedArgValues(s string) string {
|
||||||
|
var result strings.Builder
|
||||||
|
for {
|
||||||
|
openIdx := strings.Index(s, "<arg_value>")
|
||||||
|
if openIdx == -1 {
|
||||||
|
result.WriteString(s)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
afterOpen := openIdx + len("<arg_value>")
|
||||||
|
closeIdx := strings.Index(s[afterOpen:], "</arg_value>")
|
||||||
|
nextKeyIdx := strings.Index(s[afterOpen:], "<arg_key>")
|
||||||
|
// Check if properly closed before the next <arg_key> (or no next key)
|
||||||
|
if closeIdx != -1 && (nextKeyIdx == -1 || closeIdx < nextKeyIdx) {
|
||||||
|
end := afterOpen + closeIdx + len("</arg_value>")
|
||||||
|
result.WriteString(s[:end])
|
||||||
|
s = s[end:]
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Unclosed — insert </arg_value> before the next <arg_key> or at end
|
||||||
|
if nextKeyIdx != -1 {
|
||||||
|
insertAt := afterOpen + nextKeyIdx
|
||||||
|
result.WriteString(s[:insertAt])
|
||||||
|
result.WriteString("</arg_value>")
|
||||||
|
s = s[insertAt:]
|
||||||
|
} else {
|
||||||
|
result.WriteString(s)
|
||||||
|
result.WriteString("</arg_value>")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result.String()
|
||||||
|
}
|
||||||
|
|
||||||
func parseGLM46ToolCall(raw glm46EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
func parseGLM46ToolCall(raw glm46EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
||||||
// Escape any unescaped entities in text content
|
// Escape any unescaped entities in text content
|
||||||
// We need to escape text between tags, but not the tags themselves
|
// We need to escape text between tags, but not the tags themselves
|
||||||
@@ -349,11 +394,15 @@ func parseGLM46ToolCall(raw glm46EventRawToolCall, tools []api.Tool) (api.ToolCa
|
|||||||
// Wrap the content in a root element to make it valid XML
|
// Wrap the content in a root element to make it valid XML
|
||||||
xmlString := "<tool_call>" + escaped + "</tool_call>"
|
xmlString := "<tool_call>" + escaped + "</tool_call>"
|
||||||
|
|
||||||
// Parse XML into struct
|
// Parse XML into struct, retrying once with repaired XML if it fails
|
||||||
var parsed GLMToolCallXML
|
var parsed GLMToolCallXML
|
||||||
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
|
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
|
||||||
|
parsed = GLMToolCallXML{}
|
||||||
|
repaired := "<tool_call>" + repairUnclosedArgValues(escaped) + "</tool_call>"
|
||||||
|
if err2 := xml.Unmarshal([]byte(repaired), &parsed); err2 != nil {
|
||||||
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
|
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Extract and trim function name
|
// Extract and trim function name
|
||||||
functionName := strings.TrimSpace(parsed.Content)
|
functionName := strings.TrimSpace(parsed.Content)
|
||||||
|
|||||||
@@ -846,6 +846,47 @@ line3</arg_value>`,
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "unclosed arg_value at end",
|
||||||
|
tools: []api.Tool{},
|
||||||
|
rawToolCall: `get-weather
|
||||||
|
<arg_key>city</arg_key>
|
||||||
|
<arg_value>Paris`,
|
||||||
|
wantToolCall: api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get-weather",
|
||||||
|
Arguments: args(`{"city": "Paris"}`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unclosed arg_value before next arg_key",
|
||||||
|
tools: []api.Tool{},
|
||||||
|
rawToolCall: `get-weather
|
||||||
|
<arg_key>city</arg_key>
|
||||||
|
<arg_value>Paris<arg_key>unit</arg_key>
|
||||||
|
<arg_value>celsius</arg_value>`,
|
||||||
|
wantToolCall: api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get-weather",
|
||||||
|
Arguments: args(`{"city": "Paris", "unit": "celsius"}`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple unclosed arg_values",
|
||||||
|
tools: []api.Tool{},
|
||||||
|
rawToolCall: `get-weather
|
||||||
|
<arg_key>city</arg_key>
|
||||||
|
<arg_value>Paris<arg_key>unit</arg_key>
|
||||||
|
<arg_value>celsius`,
|
||||||
|
wantToolCall: api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get-weather",
|
||||||
|
Arguments: args(`{"city": "Paris", "unit": "celsius"}`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tc := range cases {
|
for i, tc := range cases {
|
||||||
@@ -860,3 +901,45 @@ line3</arg_value>`,
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRepairUnclosedArgValues(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "already valid",
|
||||||
|
input: `<arg_key>k</arg_key><arg_value>v</arg_value>`,
|
||||||
|
want: `<arg_key>k</arg_key><arg_value>v</arg_value>`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unclosed at end",
|
||||||
|
input: `<arg_key>k</arg_key><arg_value>v`,
|
||||||
|
want: `<arg_key>k</arg_key><arg_value>v</arg_value>`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unclosed before next arg_key",
|
||||||
|
input: `<arg_key>a</arg_key><arg_value>1<arg_key>b</arg_key><arg_value>2</arg_value>`,
|
||||||
|
want: `<arg_key>a</arg_key><arg_value>1</arg_value><arg_key>b</arg_key><arg_value>2</arg_value>`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no arg_value tags",
|
||||||
|
input: `just plain text`,
|
||||||
|
want: `just plain text`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple unclosed",
|
||||||
|
input: `<arg_key>a</arg_key><arg_value>1<arg_key>b</arg_key><arg_value>2`,
|
||||||
|
want: `<arg_key>a</arg_key><arg_value>1</arg_value><arg_key>b</arg_key><arg_value>2</arg_value>`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := repairUnclosedArgValues(tc.input)
|
||||||
|
if got != tc.want {
|
||||||
|
t.Errorf("got %q, want %q", got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ type GLM47Parser struct {
|
|||||||
|
|
||||||
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
p.tools = tools
|
p.tools = tools
|
||||||
|
p.callIndex = 0
|
||||||
// When thinking is enabled (nil or true), the prompt ends with <think>,
|
// When thinking is enabled (nil or true), the prompt ends with <think>,
|
||||||
// so model output starts directly with thinking content (no opening tag).
|
// so model output starts directly with thinking content (no opening tag).
|
||||||
if thinkValue == nil || thinkValue.Bool() {
|
if thinkValue == nil || thinkValue.Bool() {
|
||||||
|
|||||||
@@ -97,3 +97,91 @@ func TestGLM47ParserToolCallEscaping(t *testing.T) {
|
|||||||
t.Fatalf("expected %#v, got %#v", expected, toolCall)
|
t.Fatalf("expected %#v, got %#v", expected, toolCall)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGLM47ParserToolCallIndexing(t *testing.T) {
|
||||||
|
parser := GLM47Parser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
input := `plan</think>
|
||||||
|
<tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>
|
||||||
|
<tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>
|
||||||
|
<tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>`
|
||||||
|
|
||||||
|
_, _, calls, err := parser.Add(input, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(calls) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(calls[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGLM47ParserToolCallIndexingStreaming(t *testing.T) {
|
||||||
|
parser := GLM47Parser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
var all []api.ToolCall
|
||||||
|
|
||||||
|
_, _, calls, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call><tool_call>second<arg_key>b</arg_key>", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 1 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
_, _, calls, err = parser.Add("<arg_value>2</arg_value></tool_call><tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 2 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(all) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(all[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGLM47ParserToolCallIndexResetOnInit(t *testing.T) {
|
||||||
|
parser := GLM47Parser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
_, _, _, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
_, _, calls, err := parser.Add("plan</think><tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
|
||||||
|
}
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if !toolCallEqual(calls[0], want) {
|
||||||
|
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ func ParserForName(name string) Parser {
|
|||||||
case "qwen3-thinking":
|
case "qwen3-thinking":
|
||||||
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||||
case "qwen3.5":
|
case "qwen3.5":
|
||||||
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
p = &Qwen35Parser{}
|
||||||
case "qwen3-coder":
|
case "qwen3-coder":
|
||||||
p = &Qwen3CoderParser{}
|
p = &Qwen3CoderParser{}
|
||||||
case "qwen3-vl-instruct":
|
case "qwen3-vl-instruct":
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ type Qwen3Parser struct {
|
|||||||
state qwen3ParserState
|
state qwen3ParserState
|
||||||
buffer strings.Builder
|
buffer strings.Builder
|
||||||
tools []api.Tool
|
tools []api.Tool
|
||||||
|
callIndex int
|
||||||
hasThinkingSupport bool
|
hasThinkingSupport bool
|
||||||
defaultThinking bool
|
defaultThinking bool
|
||||||
maybeThinkingOpenAtBOL bool
|
maybeThinkingOpenAtBOL bool
|
||||||
@@ -54,6 +55,7 @@ func (p *Qwen3Parser) HasThinkingSupport() bool {
|
|||||||
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
p.tools = tools
|
p.tools = tools
|
||||||
p.buffer.Reset()
|
p.buffer.Reset()
|
||||||
|
p.callIndex = 0
|
||||||
|
|
||||||
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
||||||
if thinkValue == nil {
|
if thinkValue == nil {
|
||||||
@@ -106,6 +108,8 @@ func (p *Qwen3Parser) Add(s string, done bool) (content string, thinking string,
|
|||||||
slog.Warn("qwen3 tool call parsing failed", "error", err)
|
slog.Warn("qwen3 tool call parsing failed", "error", err)
|
||||||
return "", "", nil, err
|
return "", "", nil, err
|
||||||
}
|
}
|
||||||
|
toolCall.Function.Index = p.callIndex
|
||||||
|
p.callIndex++
|
||||||
calls = append(calls, toolCall)
|
calls = append(calls, toolCall)
|
||||||
case qwen3EventThinkingContent:
|
case qwen3EventThinkingContent:
|
||||||
thinkingSb.WriteString(event.content)
|
thinkingSb.WriteString(event.content)
|
||||||
|
|||||||
238
model/parsers/qwen35.go
Normal file
238
model/parsers/qwen35.go
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
package parsers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/logutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
type qwen35ParserState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
qwen35ParserStateCollectingThinking qwen35ParserState = iota
|
||||||
|
qwen35ParserStateThinkingDoneEatingWhitespace
|
||||||
|
qwen35ParserStateCollectingContent
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
qwen35ThinkingOpenTag = "<think>"
|
||||||
|
qwen35ThinkingCloseTag = "</think>"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Qwen35Parser handles qwen3.5 reasoning extraction and delegates post-thinking
|
||||||
|
// content (including XML tool calls) to Qwen3CoderParser.
|
||||||
|
type Qwen35Parser struct {
|
||||||
|
toolParser Qwen3CoderParser
|
||||||
|
|
||||||
|
state qwen35ParserState
|
||||||
|
buffer strings.Builder
|
||||||
|
// Some checkpoints may emit an explicit leading <think> even when the
|
||||||
|
// prompt already opened thinking. Strip at most one such tag.
|
||||||
|
allowLeadingThinkOpenTag bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Qwen35Parser) HasToolSupport() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Qwen35Parser) HasThinkingSupport() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Qwen35Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.toolParser = Qwen3CoderParser{}
|
||||||
|
p.toolParser.Init(tools, nil, nil)
|
||||||
|
|
||||||
|
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
||||||
|
if thinkValue == nil {
|
||||||
|
thinkingEnabled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
assistantPrefill := lastMessage != nil && lastMessage.Role == "assistant" && lastMessage.Content != ""
|
||||||
|
if thinkingEnabled && !assistantPrefill {
|
||||||
|
p.state = qwen35ParserStateCollectingThinking
|
||||||
|
p.allowLeadingThinkOpenTag = true
|
||||||
|
} else {
|
||||||
|
p.state = qwen35ParserStateCollectingContent
|
||||||
|
p.allowLeadingThinkOpenTag = false
|
||||||
|
}
|
||||||
|
|
||||||
|
return tools
|
||||||
|
}
|
||||||
|
|
||||||
|
type qwen35Event interface {
|
||||||
|
isQwen35Event()
|
||||||
|
}
|
||||||
|
|
||||||
|
type qwen35EventContent struct {
|
||||||
|
content string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (qwen35EventContent) isQwen35Event() {}
|
||||||
|
|
||||||
|
type qwen35EventThinkingContent struct {
|
||||||
|
content string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (qwen35EventThinkingContent) isQwen35Event() {}
|
||||||
|
|
||||||
|
func (p *Qwen35Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||||
|
p.buffer.WriteString(s)
|
||||||
|
events := p.parseEvents()
|
||||||
|
|
||||||
|
var contentSb strings.Builder
|
||||||
|
var thinkingSb strings.Builder
|
||||||
|
for _, event := range events {
|
||||||
|
switch event := event.(type) {
|
||||||
|
case qwen35EventContent:
|
||||||
|
parsedContent, _, parsedCalls, err := p.toolParser.Add(event.content, done)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("qwen3.5 tool call parsing failed", "error", err)
|
||||||
|
return "", "", nil, err
|
||||||
|
}
|
||||||
|
contentSb.WriteString(parsedContent)
|
||||||
|
calls = append(calls, parsedCalls...)
|
||||||
|
case qwen35EventThinkingContent:
|
||||||
|
thinkingSb.WriteString(event.content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return contentSb.String(), thinkingSb.String(), calls, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Qwen35Parser) parseEvents() []qwen35Event {
|
||||||
|
var all []qwen35Event
|
||||||
|
|
||||||
|
keepLooping := true
|
||||||
|
for keepLooping {
|
||||||
|
var events []qwen35Event
|
||||||
|
events, keepLooping = p.eat()
|
||||||
|
if len(events) > 0 {
|
||||||
|
all = append(all, events...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(all) > 0 {
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "qwen3.5 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return all
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Qwen35Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
|
||||||
|
return splitAtTag(&p.buffer, tag, trimAfter)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Qwen35Parser) eatLeadingWhitespaceAndTransitionTo(nextState qwen35ParserState) ([]qwen35Event, bool) {
|
||||||
|
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||||
|
p.buffer.Reset()
|
||||||
|
if trimmed == "" {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
p.state = nextState
|
||||||
|
p.buffer.WriteString(trimmed)
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// maybeConsumeLeadingThinkOpenTag handles a single optional leading <think> tag.
|
||||||
|
// Returns (handled, shouldContinueParsingNow).
|
||||||
|
func (p *Qwen35Parser) maybeConsumeLeadingThinkOpenTag(acc string) (bool, bool) {
|
||||||
|
if !p.allowLeadingThinkOpenTag {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
trimmed := strings.TrimLeftFunc(acc, unicode.IsSpace)
|
||||||
|
if strings.HasPrefix(trimmed, qwen35ThinkingOpenTag) {
|
||||||
|
after := strings.TrimPrefix(trimmed, qwen35ThinkingOpenTag)
|
||||||
|
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(after)
|
||||||
|
if after == "" {
|
||||||
|
return true, false
|
||||||
|
}
|
||||||
|
p.allowLeadingThinkOpenTag = false
|
||||||
|
return true, true
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(qwen35ThinkingOpenTag, trimmed) {
|
||||||
|
return true, false
|
||||||
|
}
|
||||||
|
|
||||||
|
p.allowLeadingThinkOpenTag = false
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Qwen35Parser) eat() ([]qwen35Event, bool) {
|
||||||
|
var events []qwen35Event
|
||||||
|
|
||||||
|
switch p.state {
|
||||||
|
case qwen35ParserStateCollectingThinking:
|
||||||
|
acc := p.buffer.String()
|
||||||
|
|
||||||
|
if handled, continueNow := p.maybeConsumeLeadingThinkOpenTag(acc); handled {
|
||||||
|
return events, continueNow
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(acc, qwen35ThinkingCloseTag) {
|
||||||
|
thinking, remaining := p.splitAtTag(qwen35ThinkingCloseTag, true)
|
||||||
|
if len(thinking) > 0 {
|
||||||
|
events = append(events, qwen35EventThinkingContent{content: thinking})
|
||||||
|
}
|
||||||
|
if remaining == "" {
|
||||||
|
p.state = qwen35ParserStateThinkingDoneEatingWhitespace
|
||||||
|
} else {
|
||||||
|
p.state = qwen35ParserStateCollectingContent
|
||||||
|
}
|
||||||
|
return events, true
|
||||||
|
} else if overlapLen := overlap(acc, qwen35ThinkingCloseTag); overlapLen > 0 {
|
||||||
|
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||||
|
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||||
|
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||||
|
|
||||||
|
unambiguous := acc[:ambiguousStart]
|
||||||
|
ambiguous := acc[ambiguousStart:]
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(ambiguous)
|
||||||
|
if len(unambiguous) > 0 {
|
||||||
|
events = append(events, qwen35EventThinkingContent{content: unambiguous})
|
||||||
|
}
|
||||||
|
return events, false
|
||||||
|
}
|
||||||
|
|
||||||
|
whitespaceLen := trailingWhitespaceLen(acc)
|
||||||
|
ambiguousStart := len(acc) - whitespaceLen
|
||||||
|
unambiguous := acc[:ambiguousStart]
|
||||||
|
ambiguous := acc[ambiguousStart:]
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(ambiguous)
|
||||||
|
if len(unambiguous) > 0 {
|
||||||
|
events = append(events, qwen35EventThinkingContent{content: unambiguous})
|
||||||
|
}
|
||||||
|
return events, false
|
||||||
|
|
||||||
|
case qwen35ParserStateThinkingDoneEatingWhitespace:
|
||||||
|
return p.eatLeadingWhitespaceAndTransitionTo(qwen35ParserStateCollectingContent)
|
||||||
|
|
||||||
|
case qwen35ParserStateCollectingContent:
|
||||||
|
if p.buffer.Len() == 0 {
|
||||||
|
return events, false
|
||||||
|
}
|
||||||
|
|
||||||
|
content := p.buffer.String()
|
||||||
|
p.buffer.Reset()
|
||||||
|
if len(content) > 0 {
|
||||||
|
events = append(events, qwen35EventContent{content: content})
|
||||||
|
}
|
||||||
|
return events, false
|
||||||
|
|
||||||
|
default:
|
||||||
|
slog.Warn("qwen3.5 parser entered unknown state; resetting to content mode", "state", p.state)
|
||||||
|
p.state = qwen35ParserStateCollectingContent
|
||||||
|
return events, false
|
||||||
|
}
|
||||||
|
}
|
||||||
382
model/parsers/qwen35_test.go
Normal file
382
model/parsers/qwen35_test.go
Normal file
@@ -0,0 +1,382 @@
|
|||||||
|
package parsers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestQwen35ParserXMLToolCall(t *testing.T) {
|
||||||
|
parser := ParserForName("qwen3.5")
|
||||||
|
if parser == nil {
|
||||||
|
t.Fatal("expected qwen3.5 parser")
|
||||||
|
}
|
||||||
|
|
||||||
|
tools := []api.Tool{
|
||||||
|
{
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Properties: func() *api.ToolPropertiesMap {
|
||||||
|
props := api.NewToolPropertiesMap()
|
||||||
|
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||||
|
props.Set("days", api.ToolProperty{Type: api.PropertyType{"integer"}})
|
||||||
|
return props
|
||||||
|
}(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(tools, nil, &api.ThinkValue{Value: false})
|
||||||
|
input := "<tool_call><function=get_weather><parameter=location>\nSan Francisco\n</parameter><parameter=days>\n3\n</parameter></function></tool_call>"
|
||||||
|
content, thinking, calls, err := parser.Add(input, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if content != "" {
|
||||||
|
t.Fatalf("expected empty content, got %q", content)
|
||||||
|
}
|
||||||
|
if thinking != "" {
|
||||||
|
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||||
|
}
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||||
|
}
|
||||||
|
|
||||||
|
if calls[0].Function.Name != "get_weather" {
|
||||||
|
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
location, ok := calls[0].Function.Arguments.Get("location")
|
||||||
|
if !ok || location != "San Francisco" {
|
||||||
|
t.Fatalf("expected location %q, got %v", "San Francisco", location)
|
||||||
|
}
|
||||||
|
|
||||||
|
days, ok := calls[0].Function.Arguments.Get("days")
|
||||||
|
if !ok || days != 3 {
|
||||||
|
t.Fatalf("expected days %d, got %v", 3, days)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ParserThinkingWithExplicitOpeningTag(t *testing.T) {
|
||||||
|
parser := ParserForName("qwen3.5")
|
||||||
|
if parser == nil {
|
||||||
|
t.Fatal("expected qwen3.5 parser")
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||||
|
content, thinking, calls, err := parser.Add("<think>\nLet me think...</think>Answer.", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if thinking != "Let me think..." {
|
||||||
|
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
|
||||||
|
}
|
||||||
|
if content != "Answer." {
|
||||||
|
t.Fatalf("expected content %q, got %q", "Answer.", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ParserAssistantPrefillStartsInContent(t *testing.T) {
|
||||||
|
parser := ParserForName("qwen3.5")
|
||||||
|
if parser == nil {
|
||||||
|
t.Fatal("expected qwen3.5 parser")
|
||||||
|
}
|
||||||
|
|
||||||
|
last := &api.Message{Role: "assistant", Content: "Prefilled response start"}
|
||||||
|
parser.Init(nil, last, nil)
|
||||||
|
|
||||||
|
content, thinking, calls, err := parser.Add(" and continued", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if thinking != "" {
|
||||||
|
t.Fatalf("expected no thinking for assistant prefill continuation, got %q", thinking)
|
||||||
|
}
|
||||||
|
if content != " and continued" {
|
||||||
|
t.Fatalf("expected content %q, got %q", " and continued", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ParserToolCallEmittedInThinkingIsNotParsed(t *testing.T) {
|
||||||
|
parser := ParserForName("qwen3.5")
|
||||||
|
if parser == nil {
|
||||||
|
t.Fatal("expected qwen3.5 parser")
|
||||||
|
}
|
||||||
|
|
||||||
|
tools := []api.Tool{
|
||||||
|
{
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Properties: func() *api.ToolPropertiesMap {
|
||||||
|
props := api.NewToolPropertiesMap()
|
||||||
|
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||||
|
return props
|
||||||
|
}(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(tools, nil, &api.ThinkValue{Value: true})
|
||||||
|
input := `Need weather lookup<tool_call><function=get_weather><parameter=location>
|
||||||
|
SF
|
||||||
|
</parameter></function></tool_call>`
|
||||||
|
content, thinking, calls, err := parser.Add(input, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if content != "" {
|
||||||
|
t.Fatalf("expected empty content, got %q", content)
|
||||||
|
}
|
||||||
|
expectedThinking := `Need weather lookup<tool_call><function=get_weather><parameter=location>
|
||||||
|
SF
|
||||||
|
</parameter></function></tool_call>`
|
||||||
|
if thinking != expectedThinking {
|
||||||
|
t.Fatalf("expected thinking %q, got %q", expectedThinking, thinking)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls before </think>, got %d", len(calls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ParserToolCallAfterThinkingCloseIsParsed(t *testing.T) {
|
||||||
|
parser := ParserForName("qwen3.5")
|
||||||
|
if parser == nil {
|
||||||
|
t.Fatal("expected qwen3.5 parser")
|
||||||
|
}
|
||||||
|
|
||||||
|
tools := []api.Tool{
|
||||||
|
{
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Properties: func() *api.ToolPropertiesMap {
|
||||||
|
props := api.NewToolPropertiesMap()
|
||||||
|
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||||
|
return props
|
||||||
|
}(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(tools, nil, &api.ThinkValue{Value: true})
|
||||||
|
input := `Need weather lookup</think><tool_call><function=get_weather><parameter=location>
|
||||||
|
SF
|
||||||
|
</parameter></function></tool_call>`
|
||||||
|
content, thinking, calls, err := parser.Add(input, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if content != "" {
|
||||||
|
t.Fatalf("expected empty content, got %q", content)
|
||||||
|
}
|
||||||
|
if thinking != "Need weather lookup" {
|
||||||
|
t.Fatalf("expected thinking %q, got %q", "Need weather lookup", thinking)
|
||||||
|
}
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 tool call after </think>, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if calls[0].Function.Name != "get_weather" {
|
||||||
|
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
location, ok := calls[0].Function.Arguments.Get("location")
|
||||||
|
if !ok || location != "SF" {
|
||||||
|
t.Fatalf("expected location %q, got %v", "SF", location)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ParserThinkingDisabledPassesContentThrough(t *testing.T) {
|
||||||
|
parser := ParserForName("qwen3.5")
|
||||||
|
if parser == nil {
|
||||||
|
t.Fatal("expected qwen3.5 parser")
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
content, thinking, calls, err := parser.Add("Plain answer without think close tag.", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if thinking != "" {
|
||||||
|
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||||
|
}
|
||||||
|
if content != "Plain answer without think close tag." {
|
||||||
|
t.Fatalf("expected content %q, got %q", "Plain answer without think close tag.", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ParserThinkingDisabledWithCloseTagTreatsAsContent(t *testing.T) {
|
||||||
|
parser := ParserForName("qwen3.5")
|
||||||
|
if parser == nil {
|
||||||
|
t.Fatal("expected qwen3.5 parser")
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
content, thinking, calls, err := parser.Add("</think>Some content after spurious tag.", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if thinking != "" {
|
||||||
|
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||||
|
}
|
||||||
|
if content != "</think>Some content after spurious tag." {
|
||||||
|
t.Fatalf("expected content %q, got %q", "</think>Some content after spurious tag.", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ParserLeadingThinkCloseProducesContent(t *testing.T) {
|
||||||
|
parser := ParserForName("qwen3.5")
|
||||||
|
if parser == nil {
|
||||||
|
t.Fatal("expected qwen3.5 parser")
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||||
|
content, thinking, calls, err := parser.Add("</think>The final answer.", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if thinking != "" {
|
||||||
|
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||||
|
}
|
||||||
|
if content != "The final answer." {
|
||||||
|
t.Fatalf("expected content %q, got %q", "The final answer.", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ParserStreamingSplitThinkCloseTag(t *testing.T) {
|
||||||
|
parser := ParserForName("qwen3.5")
|
||||||
|
if parser == nil {
|
||||||
|
t.Fatal("expected qwen3.5 parser")
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||||
|
|
||||||
|
content, thinking, calls, err := parser.Add("Reasoning text</thi", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed on first chunk: %v", err)
|
||||||
|
}
|
||||||
|
if thinking != "Reasoning text" {
|
||||||
|
t.Fatalf("expected thinking %q, got %q", "Reasoning text", thinking)
|
||||||
|
}
|
||||||
|
if content != "" {
|
||||||
|
t.Fatalf("expected empty content, got %q", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
|
||||||
|
content, thinking, calls, err = parser.Add("nk>The final answer.", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed on second chunk: %v", err)
|
||||||
|
}
|
||||||
|
if thinking != "" {
|
||||||
|
t.Fatalf("expected no additional thinking on second chunk, got %q", thinking)
|
||||||
|
}
|
||||||
|
if content != "The final answer." {
|
||||||
|
t.Fatalf("expected content %q, got %q", "The final answer.", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ParserStreamingEatsWhitespaceAfterThinkClose(t *testing.T) {
|
||||||
|
parser := ParserForName("qwen3.5")
|
||||||
|
if parser == nil {
|
||||||
|
t.Fatal("expected qwen3.5 parser")
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||||
|
|
||||||
|
content, thinking, calls, err := parser.Add("Reasoning</think>", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed on first chunk: %v", err)
|
||||||
|
}
|
||||||
|
if thinking != "Reasoning" {
|
||||||
|
t.Fatalf("expected thinking %q, got %q", "Reasoning", thinking)
|
||||||
|
}
|
||||||
|
if content != "" {
|
||||||
|
t.Fatalf("expected empty content, got %q", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
|
||||||
|
content, thinking, calls, err = parser.Add("\n \t", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed on whitespace chunk: %v", err)
|
||||||
|
}
|
||||||
|
if thinking != "" {
|
||||||
|
t.Fatalf("expected no thinking on whitespace chunk, got %q", thinking)
|
||||||
|
}
|
||||||
|
if content != "" {
|
||||||
|
t.Fatalf("expected whitespace after </think> to be eaten, got content %q", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
|
||||||
|
content, thinking, calls, err = parser.Add("The final answer.", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed on content chunk: %v", err)
|
||||||
|
}
|
||||||
|
if thinking != "" {
|
||||||
|
t.Fatalf("expected no additional thinking, got %q", thinking)
|
||||||
|
}
|
||||||
|
if content != "The final answer." {
|
||||||
|
t.Fatalf("expected content %q, got %q", "The final answer.", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ParserThinkingTruncatedWithoutCloseTag(t *testing.T) {
|
||||||
|
parser := ParserForName("qwen3.5")
|
||||||
|
if parser == nil {
|
||||||
|
t.Fatal("expected qwen3.5 parser")
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||||
|
content, thinking, calls, err := parser.Add("Reasoning that never closes", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if thinking != "Reasoning that never closes" {
|
||||||
|
t.Fatalf("expected thinking %q, got %q", "Reasoning that never closes", thinking)
|
||||||
|
}
|
||||||
|
if content != "" {
|
||||||
|
t.Fatalf("expected empty content, got %q", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -230,3 +230,89 @@ func TestQwen35ParserRespectsNoThink(t *testing.T) {
|
|||||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQwen3ParserToolCallIndexing(t *testing.T) {
|
||||||
|
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
|
||||||
|
input := `<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>
|
||||||
|
<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>
|
||||||
|
<tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`
|
||||||
|
_, _, calls, err := parser.Add(input, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(calls) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(calls[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3ParserToolCallIndexingStreaming(t *testing.T) {
|
||||||
|
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
|
||||||
|
var all []api.ToolCall
|
||||||
|
|
||||||
|
_, _, calls, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call><tool_call>{"name":"second","arguments":{"b":"2"}`, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 1 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
_, _, calls, err = parser.Add(`}</tool_call><tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 2 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(all) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(all[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3ParserToolCallIndexResetOnInit(t *testing.T) {
|
||||||
|
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
|
||||||
|
_, _, _, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>`, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
_, _, calls, err := parser.Add(`<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>`, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
|
||||||
|
}
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if !toolCallEqual(calls[0], want) {
|
||||||
|
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ type Qwen3CoderParser struct {
|
|||||||
state qwenParserState
|
state qwenParserState
|
||||||
acc strings.Builder
|
acc strings.Builder
|
||||||
tools []api.Tool
|
tools []api.Tool
|
||||||
|
callIndex int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Qwen3CoderParser) HasToolSupport() bool {
|
func (p *Qwen3CoderParser) HasToolSupport() bool {
|
||||||
@@ -44,6 +45,7 @@ func (p *Qwen3CoderParser) HasThinkingSupport() bool {
|
|||||||
|
|
||||||
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
p.tools = tools
|
p.tools = tools
|
||||||
|
p.callIndex = 0
|
||||||
return tools // Qwen doesn't modify tools
|
return tools // Qwen doesn't modify tools
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,6 +64,8 @@ func (p *Qwen3CoderParser) Add(s string, done bool) (content string, thinking st
|
|||||||
slog.Warn("qwen tool call parsing failed", "error", err)
|
slog.Warn("qwen tool call parsing failed", "error", err)
|
||||||
return "", "", nil, err
|
return "", "", nil, err
|
||||||
}
|
}
|
||||||
|
toolCall.Function.Index = p.callIndex
|
||||||
|
p.callIndex++
|
||||||
toolCalls = append(toolCalls, toolCall)
|
toolCalls = append(toolCalls, toolCall)
|
||||||
case qwenEventContent:
|
case qwenEventContent:
|
||||||
// TODO(drifkin): if the same turn contains multiple interleaved content
|
// TODO(drifkin): if the same turn contains multiple interleaved content
|
||||||
|
|||||||
@@ -1035,6 +1035,92 @@ func TestQwenToolCallValueParsing(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQwen3CoderParserToolCallIndexing(t *testing.T) {
|
||||||
|
parser := Qwen3CoderParser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
input := `<tool_call><function=first><parameter=a>1</parameter></function></tool_call>
|
||||||
|
<tool_call><function=second><parameter=b>2</parameter></function></tool_call>
|
||||||
|
<tool_call><function=third><parameter=c>3</parameter></function></tool_call>`
|
||||||
|
_, _, calls, err := parser.Add(input, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(calls) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(calls[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3CoderParserToolCallIndexingStreaming(t *testing.T) {
|
||||||
|
parser := Qwen3CoderParser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
var all []api.ToolCall
|
||||||
|
|
||||||
|
_, _, calls, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call><tool_call><function=second>", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 1 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
_, _, calls, err = parser.Add("<parameter=b>2</parameter></function></tool_call><tool_call><function=third><parameter=c>3</parameter></function></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 2 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(all) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(all[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3CoderParserToolCallIndexResetOnInit(t *testing.T) {
|
||||||
|
parser := Qwen3CoderParser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
_, _, _, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
_, _, calls, err := parser.Add("<tool_call><function=second><parameter=b>2</parameter></function></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 0},
|
||||||
|
}
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if !toolCallEqual(calls[0], want) {
|
||||||
|
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestQwenXMLTransform(t *testing.T) {
|
func TestQwenXMLTransform(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
desc string
|
desc string
|
||||||
|
|||||||
@@ -8,7 +8,21 @@ import (
|
|||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GlmOcrRenderer struct{}
|
type GlmOcrRenderer struct {
|
||||||
|
useImgTags bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GlmOcrRenderer) renderContent(message api.Message, imageOffset int) (string, int) {
|
||||||
|
var sb strings.Builder
|
||||||
|
for range message.Images {
|
||||||
|
if r.useImgTags {
|
||||||
|
sb.WriteString(fmt.Sprintf("[img-%d]", imageOffset))
|
||||||
|
imageOffset++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sb.WriteString(message.Content)
|
||||||
|
return sb.String(), imageOffset
|
||||||
|
}
|
||||||
|
|
||||||
func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
@@ -38,11 +52,14 @@ func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkV
|
|||||||
thinkingExplicitlySet = true
|
thinkingExplicitlySet = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
imageOffset := 0
|
||||||
for i, message := range messages {
|
for i, message := range messages {
|
||||||
switch message.Role {
|
switch message.Role {
|
||||||
case "user":
|
case "user":
|
||||||
sb.WriteString("<|user|>\n")
|
sb.WriteString("<|user|>\n")
|
||||||
sb.WriteString(message.Content)
|
content, nextOffset := r.renderContent(message, imageOffset)
|
||||||
|
imageOffset = nextOffset
|
||||||
|
sb.WriteString(content)
|
||||||
if thinkingExplicitlySet && !enableThinking && !strings.HasSuffix(message.Content, "/nothink") {
|
if thinkingExplicitlySet && !enableThinking && !strings.HasSuffix(message.Content, "/nothink") {
|
||||||
sb.WriteString("/nothink")
|
sb.WriteString("/nothink")
|
||||||
}
|
}
|
||||||
|
|||||||
99
model/renderers/glmocr_test.go
Normal file
99
model/renderers/glmocr_test.go
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
package renderers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGlmOcrRenderer_Images(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
renderer *GlmOcrRenderer
|
||||||
|
messages []api.Message
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "use_img_tags_single_image",
|
||||||
|
renderer: &GlmOcrRenderer{useImgTags: true},
|
||||||
|
messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "Describe this image.",
|
||||||
|
Images: []api.ImageData{api.ImageData("img1")},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "[gMASK]<sop><|user|>\n[img-0]Describe this image.<|assistant|>\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "use_img_tags_multiple_images",
|
||||||
|
renderer: &GlmOcrRenderer{useImgTags: true},
|
||||||
|
messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "Describe these images.",
|
||||||
|
Images: []api.ImageData{api.ImageData("img1"), api.ImageData("img2")},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "[gMASK]<sop><|user|>\n[img-0][img-1]Describe these images.<|assistant|>\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multi_turn_increments_image_offset",
|
||||||
|
renderer: &GlmOcrRenderer{useImgTags: true},
|
||||||
|
messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "First image",
|
||||||
|
Images: []api.ImageData{api.ImageData("img1")},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "Processed.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "Second image",
|
||||||
|
Images: []api.ImageData{api.ImageData("img2")},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "[gMASK]<sop><|user|>\n[img-0]First image<|assistant|>\n<think></think>\nProcessed.\n<|user|>\n[img-1]Second image<|assistant|>\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "default_no_img_tags",
|
||||||
|
renderer: &GlmOcrRenderer{},
|
||||||
|
messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "No image tags expected.",
|
||||||
|
Images: []api.ImageData{api.ImageData("img1")},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "[gMASK]<sop><|user|>\nNo image tags expected.<|assistant|>\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no_images_content_unchanged",
|
||||||
|
renderer: &GlmOcrRenderer{useImgTags: true},
|
||||||
|
messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "Text only message.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "[gMASK]<sop><|user|>\nText only message.<|assistant|>\n",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := tt.renderer.Render(tt.messages, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Render() error = %v", err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(tt.expected, got); diff != "" {
|
||||||
|
t.Fatalf("Render() mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
194
model/renderers/qwen35.go
Normal file
194
model/renderers/qwen35.go
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
package renderers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
qwen35ThinkOpenTag = "<think>"
|
||||||
|
qwen35ThinkCloseTag = "</think>"
|
||||||
|
qwen35ToolPostamble = `
|
||||||
|
</tools>
|
||||||
|
|
||||||
|
If you choose to call a function ONLY reply in the following format with NO suffix:
|
||||||
|
|
||||||
|
<tool_call>
|
||||||
|
<function=example_function_name>
|
||||||
|
<parameter=example_parameter_1>
|
||||||
|
value_1
|
||||||
|
</parameter>
|
||||||
|
<parameter=example_parameter_2>
|
||||||
|
This is the value for the second parameter
|
||||||
|
that can span
|
||||||
|
multiple lines
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call>
|
||||||
|
|
||||||
|
<IMPORTANT>
|
||||||
|
Reminder:
|
||||||
|
- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags
|
||||||
|
- Required parameters MUST be specified
|
||||||
|
- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
|
||||||
|
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
|
||||||
|
</IMPORTANT>`
|
||||||
|
)
|
||||||
|
|
||||||
|
type Qwen35Renderer struct {
|
||||||
|
isThinking bool
|
||||||
|
|
||||||
|
emitEmptyThinkOnNoThink bool
|
||||||
|
useImgTags bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Qwen35Renderer) renderContent(content api.Message, imageOffset int) (string, int) {
|
||||||
|
// This assumes all images are at the front of the message - same assumption as ollama/ollama/runner.go
|
||||||
|
var subSb strings.Builder
|
||||||
|
for range content.Images {
|
||||||
|
if r.useImgTags {
|
||||||
|
subSb.WriteString(fmt.Sprintf("[img-%d]", imageOffset))
|
||||||
|
imageOffset++
|
||||||
|
} else {
|
||||||
|
subSb.WriteString("<|vision_start|><|image_pad|><|vision_end|>")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// TODO: support videos
|
||||||
|
|
||||||
|
subSb.WriteString(content.Content)
|
||||||
|
return subSb.String(), imageOffset
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitQwen35ReasoningContent(content, messageThinking string, isThinking bool) (reasoning string, remaining string) {
|
||||||
|
if isThinking && messageThinking != "" {
|
||||||
|
return strings.TrimSpace(messageThinking), content
|
||||||
|
}
|
||||||
|
|
||||||
|
if idx := strings.Index(content, qwen35ThinkCloseTag); idx != -1 {
|
||||||
|
before := content[:idx]
|
||||||
|
if open := strings.LastIndex(before, qwen35ThinkOpenTag); open != -1 {
|
||||||
|
reasoning = before[open+len(qwen35ThinkOpenTag):]
|
||||||
|
} else {
|
||||||
|
reasoning = before
|
||||||
|
}
|
||||||
|
content = strings.TrimLeft(content[idx+len(qwen35ThinkCloseTag):], "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.TrimSpace(reasoning), content
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Qwen35Renderer) Render(messages []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
isThinking := r.isThinking
|
||||||
|
if think != nil {
|
||||||
|
isThinking = think.Bool()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tools) > 0 {
|
||||||
|
sb.WriteString(imStartTag + "system\n")
|
||||||
|
sb.WriteString("# Tools\n\nYou have access to the following functions:\n\n<tools>")
|
||||||
|
for _, tool := range tools {
|
||||||
|
sb.WriteString("\n")
|
||||||
|
if b, err := marshalWithSpaces(tool); err == nil {
|
||||||
|
sb.Write(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sb.WriteString(qwen35ToolPostamble)
|
||||||
|
if len(messages) > 0 && messages[0].Role == "system" {
|
||||||
|
systemContent, _ := r.renderContent(messages[0], 0)
|
||||||
|
systemContent = strings.TrimSpace(systemContent)
|
||||||
|
if systemContent != "" {
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
sb.WriteString(systemContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sb.WriteString(imEndTag + "\n")
|
||||||
|
} else if len(messages) > 0 && messages[0].Role == "system" {
|
||||||
|
systemContent, _ := r.renderContent(messages[0], 0)
|
||||||
|
sb.WriteString(imStartTag + "system\n" + strings.TrimSpace(systemContent) + imEndTag + "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
multiStepTool := true
|
||||||
|
lastQueryIndex := len(messages) - 1 // so this is the last user message
|
||||||
|
|
||||||
|
for i := len(messages) - 1; i >= 0; i-- {
|
||||||
|
message := messages[i]
|
||||||
|
if multiStepTool && message.Role == "user" {
|
||||||
|
content, _ := r.renderContent(message, 0)
|
||||||
|
content = strings.TrimSpace(content)
|
||||||
|
if !(strings.HasPrefix(content, "<tool_response>") && strings.HasSuffix(content, "</tool_response>")) {
|
||||||
|
multiStepTool = false
|
||||||
|
lastQueryIndex = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
imageOffset := 0
|
||||||
|
for i, message := range messages {
|
||||||
|
content, nextImageOffset := r.renderContent(message, imageOffset)
|
||||||
|
imageOffset = nextImageOffset
|
||||||
|
content = strings.TrimSpace(content)
|
||||||
|
|
||||||
|
lastMessage := i == len(messages)-1
|
||||||
|
prefill := lastMessage && message.Role == "assistant"
|
||||||
|
|
||||||
|
if message.Role == "user" || (message.Role == "system" && i != 0) {
|
||||||
|
sb.WriteString(imStartTag + message.Role + "\n" + content + imEndTag + "\n")
|
||||||
|
} else if message.Role == "assistant" {
|
||||||
|
contentReasoning, content := splitQwen35ReasoningContent(content, message.Thinking, isThinking)
|
||||||
|
|
||||||
|
if isThinking && i > lastQueryIndex {
|
||||||
|
sb.WriteString(imStartTag + message.Role + "\n<think>\n" + contentReasoning + "\n</think>\n\n" + content)
|
||||||
|
} else {
|
||||||
|
sb.WriteString(imStartTag + message.Role + "\n" + content)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(message.ToolCalls) > 0 {
|
||||||
|
for j, toolCall := range message.ToolCalls {
|
||||||
|
if j == 0 {
|
||||||
|
if strings.TrimSpace(content) != "" {
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString("<tool_call>\n<function=" + toolCall.Function.Name + ">\n")
|
||||||
|
for name, value := range toolCall.Function.Arguments.All() {
|
||||||
|
sb.WriteString("<parameter=" + name + ">\n")
|
||||||
|
sb.WriteString(formatToolCallArgument(value))
|
||||||
|
sb.WriteString("\n</parameter>\n")
|
||||||
|
}
|
||||||
|
sb.WriteString("</function>\n</tool_call>")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !prefill {
|
||||||
|
sb.WriteString(imEndTag + "\n")
|
||||||
|
}
|
||||||
|
} else if message.Role == "tool" {
|
||||||
|
if i == 0 || messages[i-1].Role != "tool" {
|
||||||
|
sb.WriteString(imStartTag + "user")
|
||||||
|
}
|
||||||
|
sb.WriteString("\n<tool_response>\n" + content + "\n</tool_response>")
|
||||||
|
if i == len(messages)-1 || messages[i+1].Role != "tool" {
|
||||||
|
sb.WriteString(imEndTag + "\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// prefill at the end
|
||||||
|
if lastMessage && !prefill {
|
||||||
|
sb.WriteString(imStartTag + "assistant\n")
|
||||||
|
if isThinking {
|
||||||
|
sb.WriteString("<think>\n")
|
||||||
|
} else if r.emitEmptyThinkOnNoThink {
|
||||||
|
sb.WriteString("<think>\n\n</think>\n\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String(), nil
|
||||||
|
}
|
||||||
389
model/renderers/qwen35_test.go
Normal file
389
model/renderers/qwen35_test.go
Normal file
@@ -0,0 +1,389 @@
|
|||||||
|
package renderers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestQwen35RendererUsesXMLToolCallingFormat(t *testing.T) {
|
||||||
|
renderer := &Qwen35Renderer{isThinking: true}
|
||||||
|
msgs := []api.Message{
|
||||||
|
{Role: "system", Content: "You are a helpful assistant."},
|
||||||
|
{Role: "user", Content: "What's the weather in Paris?"},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "I'll check.",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: testArgsOrdered([]orderedArg{
|
||||||
|
{Key: "location", Value: "Paris"},
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{Role: "tool", Content: "22C"},
|
||||||
|
{Role: "user", Content: "Thanks"},
|
||||||
|
}
|
||||||
|
tools := []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: testPropsOrdered([]orderedProp{
|
||||||
|
{
|
||||||
|
Key: "location",
|
||||||
|
Value: api.ToolProperty{
|
||||||
|
Type: api.PropertyType{"string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
Required: []string{"location"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := renderer.Render(msgs, tools, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("render failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(got, "<tools>") {
|
||||||
|
t.Fatalf("expected tools section in prompt, got:\n%s", got)
|
||||||
|
}
|
||||||
|
if !strings.Contains(got, "<function=example_function_name>") {
|
||||||
|
t.Fatalf("expected xml-style tool call instructions, got:\n%s", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
wantToolCall := "<tool_call>\n<function=get_weather>\n<parameter=location>\nParis\n</parameter>\n</function>\n</tool_call>"
|
||||||
|
if !strings.Contains(got, wantToolCall) {
|
||||||
|
t.Fatalf("expected xml tool call payload, got:\n%s", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
toolsIdx := strings.Index(got, "# Tools")
|
||||||
|
systemIdx := strings.Index(got, "You are a helpful assistant.")
|
||||||
|
if toolsIdx == -1 || systemIdx == -1 || systemIdx < toolsIdx {
|
||||||
|
t.Fatalf("expected system prompt appended after tool instructions, got:\n%s", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35RendererNoThinkPrefill(t *testing.T) {
|
||||||
|
renderer := &Qwen35Renderer{isThinking: true, emitEmptyThinkOnNoThink: true}
|
||||||
|
msgs := []api.Message{
|
||||||
|
{Role: "user", Content: "hello"},
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := renderer.Render(msgs, nil, &api.ThinkValue{Value: false})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("render failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasSuffix(got, "<|im_start|>assistant\n<think>\n\n</think>\n\n") {
|
||||||
|
t.Fatalf("expected explicit no-think prefill, got:\n%s", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35RendererBackToBackToolCallsAndResponses(t *testing.T) {
|
||||||
|
renderer := &Qwen35Renderer{isThinking: true}
|
||||||
|
|
||||||
|
msgs := []api.Message{
|
||||||
|
{Role: "system", Content: "You are a helpful assistant."},
|
||||||
|
{Role: "user", Content: "Run add and multiply."},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "I'll run both now.",
|
||||||
|
Thinking: "Need to call add and multiply.",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "add",
|
||||||
|
Arguments: testArgsOrdered([]orderedArg{
|
||||||
|
{Key: "a", Value: 2},
|
||||||
|
{Key: "b", Value: 3},
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "multiply",
|
||||||
|
Arguments: testArgsOrdered([]orderedArg{
|
||||||
|
{Key: "x", Value: 4},
|
||||||
|
{Key: "y", Value: 5},
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{Role: "tool", Content: "5"},
|
||||||
|
{Role: "tool", Content: "20"},
|
||||||
|
{Role: "user", Content: "Summarize the results."},
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := renderer.Render(msgs, qwen35MathTools(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("render failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(got, "Need to call add and multiply.") {
|
||||||
|
t.Fatalf("did not expect historical reasoning block in this sequence, got:\n%s", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
wantToolCalls := `<tool_call>
|
||||||
|
<function=add>
|
||||||
|
<parameter=a>
|
||||||
|
2
|
||||||
|
</parameter>
|
||||||
|
<parameter=b>
|
||||||
|
3
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call>
|
||||||
|
<tool_call>
|
||||||
|
<function=multiply>
|
||||||
|
<parameter=x>
|
||||||
|
4
|
||||||
|
</parameter>
|
||||||
|
<parameter=y>
|
||||||
|
5
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call>`
|
||||||
|
if !strings.Contains(got, wantToolCalls) {
|
||||||
|
t.Fatalf("expected back-to-back tool calls, got:\n%s", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
wantToolResponses := `<|im_start|>user
|
||||||
|
<tool_response>
|
||||||
|
5
|
||||||
|
</tool_response>
|
||||||
|
<tool_response>
|
||||||
|
20
|
||||||
|
</tool_response><|im_end|>`
|
||||||
|
if !strings.Contains(got, wantToolResponses) {
|
||||||
|
t.Fatalf("expected grouped back-to-back tool responses, got:\n%s", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasSuffix(got, "<|im_start|>assistant\n<think>\n") {
|
||||||
|
t.Fatalf("expected assistant thinking prefill at end, got:\n%s", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35RendererInterleavedThinkingAndTools(t *testing.T) {
|
||||||
|
renderer := &Qwen35Renderer{isThinking: true}
|
||||||
|
|
||||||
|
msgs := []api.Message{
|
||||||
|
{Role: "system", Content: "You are a helpful assistant."},
|
||||||
|
{Role: "user", Content: "Plan a picnic in Paris."},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "Checking weather first.",
|
||||||
|
Thinking: "Need weather before giving advice.",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: testArgsOrdered([]orderedArg{
|
||||||
|
{Key: "location", Value: "Paris"},
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{Role: "tool", Content: "22C"},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "Checking UV too.",
|
||||||
|
Thinking: "Need UV index for sunscreen advice.",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_uv",
|
||||||
|
Arguments: testArgsOrdered([]orderedArg{
|
||||||
|
{Key: "location", Value: "Paris"},
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{Role: "tool", Content: "5"},
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := renderer.Render(msgs, qwen35WeatherUVTools(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("render failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
wantFirstTurn := `<|im_start|>assistant
|
||||||
|
<think>
|
||||||
|
Need weather before giving advice.
|
||||||
|
</think>
|
||||||
|
|
||||||
|
Checking weather first.
|
||||||
|
|
||||||
|
<tool_call>
|
||||||
|
<function=get_weather>
|
||||||
|
<parameter=location>
|
||||||
|
Paris
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call><|im_end|>`
|
||||||
|
if !strings.Contains(got, wantFirstTurn) {
|
||||||
|
t.Fatalf("expected first assistant thinking/tool sequence, got:\n%s", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
wantSecondTurn := `<|im_start|>assistant
|
||||||
|
<think>
|
||||||
|
Need UV index for sunscreen advice.
|
||||||
|
</think>
|
||||||
|
|
||||||
|
Checking UV too.
|
||||||
|
|
||||||
|
<tool_call>
|
||||||
|
<function=get_uv>
|
||||||
|
<parameter=location>
|
||||||
|
Paris
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call><|im_end|>`
|
||||||
|
if !strings.Contains(got, wantSecondTurn) {
|
||||||
|
t.Fatalf("expected second assistant thinking/tool sequence, got:\n%s", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasSuffix(got, "<|im_start|>assistant\n<think>\n") {
|
||||||
|
t.Fatalf("expected assistant thinking prefill at end, got:\n%s", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35RendererAssistantPrefillWithThinking(t *testing.T) {
|
||||||
|
renderer := &Qwen35Renderer{isThinking: true}
|
||||||
|
msgs := []api.Message{
|
||||||
|
{Role: "user", Content: "Write two words."},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Thinking: "Keep it short.",
|
||||||
|
Content: "Hello world",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := renderer.Render(msgs, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("render failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := `<|im_start|>user
|
||||||
|
Write two words.<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
<think>
|
||||||
|
Keep it short.
|
||||||
|
</think>
|
||||||
|
|
||||||
|
Hello world`
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("unexpected prefill output\n--- got ---\n%s\n--- want ---\n%s", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func qwen35MathTools() []api.Tool {
|
||||||
|
return []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "add",
|
||||||
|
Description: "Add two numbers",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: testPropsOrdered([]orderedProp{
|
||||||
|
{
|
||||||
|
Key: "a",
|
||||||
|
Value: api.ToolProperty{
|
||||||
|
Type: api.PropertyType{"integer"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: "b",
|
||||||
|
Value: api.ToolProperty{
|
||||||
|
Type: api.PropertyType{"integer"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
Required: []string{"a", "b"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "multiply",
|
||||||
|
Description: "Multiply two numbers",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: testPropsOrdered([]orderedProp{
|
||||||
|
{
|
||||||
|
Key: "x",
|
||||||
|
Value: api.ToolProperty{
|
||||||
|
Type: api.PropertyType{"integer"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: "y",
|
||||||
|
Value: api.ToolProperty{
|
||||||
|
Type: api.PropertyType{"integer"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
Required: []string{"x", "y"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func qwen35WeatherUVTools() []api.Tool {
|
||||||
|
return []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get weather for a location",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: testPropsOrdered([]orderedProp{
|
||||||
|
{
|
||||||
|
Key: "location",
|
||||||
|
Value: api.ToolProperty{
|
||||||
|
Type: api.PropertyType{"string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
Required: []string{"location"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_uv",
|
||||||
|
Description: "Get UV index for a location",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: testPropsOrdered([]orderedProp{
|
||||||
|
{
|
||||||
|
Key: "location",
|
||||||
|
Value: api.ToolProperty{
|
||||||
|
Type: api.PropertyType{"string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
Required: []string{"location"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -57,7 +57,7 @@ func rendererForName(name string) Renderer {
|
|||||||
renderer := &Qwen3VLRenderer{isThinking: true, useImgTags: RenderImgTags}
|
renderer := &Qwen3VLRenderer{isThinking: true, useImgTags: RenderImgTags}
|
||||||
return renderer
|
return renderer
|
||||||
case "qwen3.5":
|
case "qwen3.5":
|
||||||
renderer := &Qwen3VLRenderer{isThinking: true, emitEmptyThinkOnNoThink: true, useImgTags: RenderImgTags}
|
renderer := &Qwen35Renderer{isThinking: true, emitEmptyThinkOnNoThink: true, useImgTags: RenderImgTags}
|
||||||
return renderer
|
return renderer
|
||||||
case "cogito":
|
case "cogito":
|
||||||
renderer := &CogitoRenderer{isThinking: true}
|
renderer := &CogitoRenderer{isThinking: true}
|
||||||
@@ -86,7 +86,7 @@ func rendererForName(name string) Renderer {
|
|||||||
case "glm-4.7":
|
case "glm-4.7":
|
||||||
return &GLM47Renderer{}
|
return &GLM47Renderer{}
|
||||||
case "glm-ocr":
|
case "glm-ocr":
|
||||||
return &GlmOcrRenderer{}
|
return &GlmOcrRenderer{useImgTags: RenderImgTags}
|
||||||
case "lfm2":
|
case "lfm2":
|
||||||
return &LFM2Renderer{IsThinking: false, useImgTags: RenderImgTags}
|
return &LFM2Renderer{IsThinking: false, useImgTags: RenderImgTags}
|
||||||
case "lfm2-thinking":
|
case "lfm2-thinking":
|
||||||
|
|||||||
@@ -181,6 +181,9 @@ func fileDigestMap(path string) (map[string]string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !filepath.IsLocal(rel) {
|
if !filepath.IsLocal(rel) {
|
||||||
|
if strings.Contains(rel, ".cache") {
|
||||||
|
return nil, fmt.Errorf("insecure path: %s\n\nUse --local-dir <dir> when downloading model to disable caching", rel)
|
||||||
|
}
|
||||||
return nil, fmt.Errorf("insecure path: %s", rel)
|
return nil, fmt.Errorf("insecure path: %s", rel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -562,6 +562,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
|||||||
if errors.As(err, &reprocess) {
|
if errors.As(err, &reprocess) {
|
||||||
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||||
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||||
|
seq.sampler.Reset()
|
||||||
// Skip this sequence but continue processing the rest
|
// Skip this sequence but continue processing the rest
|
||||||
nextBatch.seqs[seqIdx] = nil // clear this sequence for this batch
|
nextBatch.seqs[seqIdx] = nil // clear this sequence for this batch
|
||||||
err = nil
|
err = nil
|
||||||
@@ -692,6 +693,12 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
|||||||
// (unless we take down the whole runner).
|
// (unless we take down the whole runner).
|
||||||
if len(seq.pendingInputs) > 0 {
|
if len(seq.pendingInputs) > 0 {
|
||||||
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
||||||
|
for _, inp := range seq.pendingInputs {
|
||||||
|
if len(inp.Multimodal) != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seq.sampler.Accept(inp.Token)
|
||||||
|
}
|
||||||
seq.pendingInputs = []*input.Input{}
|
seq.pendingInputs = []*input.Input{}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -892,6 +899,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
req.Options.TopK,
|
req.Options.TopK,
|
||||||
req.Options.TopP,
|
req.Options.TopP,
|
||||||
req.Options.MinP,
|
req.Options.MinP,
|
||||||
|
req.Options.RepeatPenalty,
|
||||||
|
req.Options.PresencePenalty,
|
||||||
|
req.Options.FrequencyPenalty,
|
||||||
req.Options.Seed,
|
req.Options.Seed,
|
||||||
grammar,
|
grammar,
|
||||||
)
|
)
|
||||||
@@ -938,6 +948,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
seq.sampler.Reset()
|
||||||
|
for _, inp := range seq.cache.Inputs {
|
||||||
|
if len(inp.Multimodal) != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seq.sampler.Accept(inp.Token)
|
||||||
|
}
|
||||||
|
|
||||||
s.seqs[i] = seq
|
s.seqs[i] = seq
|
||||||
s.cond.Signal()
|
s.cond.Signal()
|
||||||
found = true
|
found = true
|
||||||
|
|||||||
@@ -16,24 +16,49 @@ type token struct {
|
|||||||
value float32 // The raw logit or probability from the model
|
value float32 // The raw logit or probability from the model
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const DefaultPenaltyLookback = 64
|
||||||
|
|
||||||
type Sampler struct {
|
type Sampler struct {
|
||||||
rng *rand.Rand
|
rng *rand.Rand
|
||||||
topK int
|
topK int
|
||||||
topP float32
|
topP float32
|
||||||
minP float32
|
minP float32
|
||||||
temperature float32
|
temperature float32
|
||||||
|
repeat float32
|
||||||
|
presence float32
|
||||||
|
frequency float32
|
||||||
|
history []int32
|
||||||
grammar *GrammarSampler
|
grammar *GrammarSampler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Sampler) Reset() {
|
||||||
|
s.history = s.history[:0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Sampler) Accept(token int32) {
|
||||||
|
s.history = append(s.history, token)
|
||||||
|
if len(s.history) > DefaultPenaltyLookback {
|
||||||
|
copy(s.history, s.history[len(s.history)-DefaultPenaltyLookback:])
|
||||||
|
s.history = s.history[:DefaultPenaltyLookback]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||||
if len(logits) == 0 {
|
if len(logits) == 0 {
|
||||||
return -1, errors.New("sample: no logits provided to sample")
|
return -1, errors.New("sample: no logits provided to sample")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
counts := tokenCounts(s.history, len(logits))
|
||||||
|
|
||||||
tokens := make([]token, len(logits))
|
tokens := make([]token, len(logits))
|
||||||
for i := range logits {
|
for i := range logits {
|
||||||
|
value := logits[i]
|
||||||
|
if count := counts[int32(i)]; count > 0 {
|
||||||
|
value = applyPenalty(value, count, s.repeat, s.presence, s.frequency)
|
||||||
|
}
|
||||||
|
|
||||||
tokens[i].id = int32(i)
|
tokens[i].id = int32(i)
|
||||||
tokens[i].value = logits[i]
|
tokens[i].value = value
|
||||||
}
|
}
|
||||||
|
|
||||||
t, err := s.sample(tokens)
|
t, err := s.sample(tokens)
|
||||||
@@ -55,8 +80,12 @@ func (s *Sampler) Sample(logits []float32) (int32, error) {
|
|||||||
// we need to reset them before applying the grammar and
|
// we need to reset them before applying the grammar and
|
||||||
// sampling again
|
// sampling again
|
||||||
for i := range logits {
|
for i := range logits {
|
||||||
|
value := logits[i]
|
||||||
|
if count := counts[int32(i)]; count > 0 {
|
||||||
|
value = applyPenalty(value, count, s.repeat, s.presence, s.frequency)
|
||||||
|
}
|
||||||
tokens[i].id = int32(i)
|
tokens[i].id = int32(i)
|
||||||
tokens[i].value = logits[i]
|
tokens[i].value = value
|
||||||
}
|
}
|
||||||
s.grammar.Apply(tokens)
|
s.grammar.Apply(tokens)
|
||||||
t, err = s.sample(tokens)
|
t, err = s.sample(tokens)
|
||||||
@@ -127,7 +156,7 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
||||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *GrammarSampler) Sampler {
|
func NewSampler(temperature float32, topK int, topP float32, minP float32, repeatPenalty float32, presencePenalty float32, frequencyPenalty float32, seed int, grammar *GrammarSampler) Sampler {
|
||||||
var rng *rand.Rand
|
var rng *rand.Rand
|
||||||
if seed != -1 {
|
if seed != -1 {
|
||||||
// PCG requires two parameters: sequence and stream
|
// PCG requires two parameters: sequence and stream
|
||||||
@@ -154,12 +183,19 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
|
|||||||
minP = 1.0
|
minP = 1.0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if repeatPenalty <= 0 {
|
||||||
|
repeatPenalty = 1.0
|
||||||
|
}
|
||||||
|
|
||||||
return Sampler{
|
return Sampler{
|
||||||
rng: rng,
|
rng: rng,
|
||||||
topK: topK,
|
topK: topK,
|
||||||
topP: topP,
|
topP: topP,
|
||||||
minP: minP,
|
minP: minP,
|
||||||
temperature: temperature,
|
temperature: temperature,
|
||||||
|
repeat: repeatPenalty,
|
||||||
|
presence: presencePenalty,
|
||||||
|
frequency: frequencyPenalty,
|
||||||
grammar: grammar,
|
grammar: grammar,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
|||||||
logits[i] = float32(rand.Float64()*10 - 5)
|
logits[i] = float32(rand.Float64()*10 - 5)
|
||||||
}
|
}
|
||||||
|
|
||||||
sampler := NewSampler(0.8, 0, 0, 0, 42, nil)
|
sampler := NewSampler(0.8, 0, 0, 0, 1, 0, 0, 42, nil)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
sampler.Sample(logits)
|
sampler.Sample(logits)
|
||||||
@@ -49,7 +49,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
|||||||
|
|
||||||
for _, tc := range configs {
|
for _, tc := range configs {
|
||||||
b.Run("Config"+tc.name, func(b *testing.B) {
|
b.Run("Config"+tc.name, func(b *testing.B) {
|
||||||
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil)
|
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, 1, 0, 0, tc.seed, nil)
|
||||||
sampler.Sample(logits)
|
sampler.Sample(logits)
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
@@ -62,7 +62,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
|||||||
|
|
||||||
// Test with combined transforms separately - topK influences performance greatly
|
// Test with combined transforms separately - topK influences performance greatly
|
||||||
b.Run("TransformCombined", func(b *testing.B) {
|
b.Run("TransformCombined", func(b *testing.B) {
|
||||||
sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil)
|
sampler := NewSampler(0.8, 50, 0.9, 0.05, 1, 0, 0, 42, nil)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
@@ -81,7 +81,7 @@ func BenchmarkGreedySampler(b *testing.B) {
|
|||||||
logits[i] = float32(rand.Float64()*10 - 5)
|
logits[i] = float32(rand.Float64()*10 - 5)
|
||||||
}
|
}
|
||||||
|
|
||||||
sampler := NewSampler(0, -1, 0, 0, -1, nil)
|
sampler := NewSampler(0, -1, 0, 0, 1, 0, 0, -1, nil)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
|
|
||||||
func TestWeighted(t *testing.T) {
|
func TestWeighted(t *testing.T) {
|
||||||
logits := []float32{-10, 3, -10, -10}
|
logits := []float32{-10, 3, -10, -10}
|
||||||
sampler := NewSampler(0, 0, 0, 0, 0, nil)
|
sampler := NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil)
|
||||||
got, err := sampler.Sample(logits)
|
got, err := sampler.Sample(logits)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
@@ -25,7 +25,7 @@ func TestWeighted(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
logits = []float32{-100, -10, 0, 10}
|
logits = []float32{-100, -10, 0, 10}
|
||||||
sampler = NewSampler(0, 0, 0, 0, 0, nil)
|
sampler = NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil)
|
||||||
got, err = sampler.Sample(logits)
|
got, err = sampler.Sample(logits)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
@@ -39,7 +39,7 @@ func TestWeighted(t *testing.T) {
|
|||||||
// Test very high p
|
// Test very high p
|
||||||
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
|
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
|
||||||
// Use extremely small topP to filter out all tokens
|
// Use extremely small topP to filter out all tokens
|
||||||
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
|
sampler = NewSampler(1.0, 0, 1e-10, 0, 1, 0, 0, 0, nil)
|
||||||
got, err = sampler.Sample(logits)
|
got, err = sampler.Sample(logits)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
@@ -52,7 +52,7 @@ func TestWeighted(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
|
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
|
||||||
sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
|
sampler = NewSampler(1, 0, 0.95, 0.05, 1, 0, 0, 0, nil)
|
||||||
got, err = sampler.Sample(logits)
|
got, err = sampler.Sample(logits)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("expected error, got %d", got)
|
t.Errorf("expected error, got %d", got)
|
||||||
@@ -151,8 +151,8 @@ func TestGrammar(t *testing.T) {
|
|||||||
|
|
||||||
func BenchmarkSample(b *testing.B) {
|
func BenchmarkSample(b *testing.B) {
|
||||||
samplers := map[string]Sampler{
|
samplers := map[string]Sampler{
|
||||||
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
|
"Greedy": NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
|
||||||
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
|
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, 1, 0, 0, -1, nil),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate random logits for benchmarking
|
// Generate random logits for benchmarking
|
||||||
|
|||||||
@@ -25,6 +25,48 @@ func (h *tokenHeap) Pop() any {
|
|||||||
return x
|
return x
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func tokenCounts(history []int32, vocabSize int) map[int32]int {
|
||||||
|
if len(history) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
start := 0
|
||||||
|
if len(history) > DefaultPenaltyLookback {
|
||||||
|
start = len(history) - DefaultPenaltyLookback
|
||||||
|
}
|
||||||
|
|
||||||
|
counts := make(map[int32]int, len(history)-start)
|
||||||
|
for _, token := range history[start:] {
|
||||||
|
if token < 0 || int(token) >= vocabSize {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
counts[token]++
|
||||||
|
}
|
||||||
|
|
||||||
|
return counts
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyPenalty(logit float32, count int, repeatPenalty float32, presencePenalty float32, frequencyPenalty float32) float32 {
|
||||||
|
if repeatPenalty != 1.0 {
|
||||||
|
// Preserve ordering for negative logits when applying repeat penalty.
|
||||||
|
if logit < 0 {
|
||||||
|
logit *= repeatPenalty
|
||||||
|
} else {
|
||||||
|
logit /= repeatPenalty
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if frequencyPenalty != 0 {
|
||||||
|
logit -= float32(count) * frequencyPenalty
|
||||||
|
}
|
||||||
|
|
||||||
|
if presencePenalty != 0 {
|
||||||
|
logit -= presencePenalty
|
||||||
|
}
|
||||||
|
|
||||||
|
return logit
|
||||||
|
}
|
||||||
|
|
||||||
// temperature applies scaling to the logits
|
// temperature applies scaling to the logits
|
||||||
func temperature(ts []token, temp float32) {
|
func temperature(ts []token, temp float32) {
|
||||||
// Ensure temperature clipping near 0 to avoid numerical instability
|
// Ensure temperature clipping near 0 to avoid numerical instability
|
||||||
|
|||||||
@@ -295,6 +295,86 @@ func TestMinP(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTokenCounts(t *testing.T) {
|
||||||
|
history := make([]int32, 70)
|
||||||
|
history[0] = 7
|
||||||
|
history[69] = 7
|
||||||
|
|
||||||
|
counts := tokenCounts(history, 8)
|
||||||
|
if got := counts[7]; got != 1 {
|
||||||
|
t.Fatalf("lookback mismatch: got %d want %d", got, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyPenalty(t *testing.T) {
|
||||||
|
logit := applyPenalty(5.0, 3, 1.0, 1.5, 0.5)
|
||||||
|
if math.Abs(float64(logit-2.0)) > 1e-6 {
|
||||||
|
t.Fatalf("unexpected penalty result: got %f want %f", logit, 2.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
logit = applyPenalty(4.0, 1, 2.0, 0, 0)
|
||||||
|
if math.Abs(float64(logit-2.0)) > 1e-6 {
|
||||||
|
t.Fatalf("unexpected repeat penalty result for positive logits: got %f want %f", logit, 2.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
logit = applyPenalty(-4.0, 1, 2.0, 0, 0)
|
||||||
|
if math.Abs(float64(logit-(-8.0))) > 1e-6 {
|
||||||
|
t.Fatalf("unexpected repeat penalty result for negative logits: got %f want %f", logit, -8.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSamplerPresencePenalty(t *testing.T) {
|
||||||
|
logits := []float32{0.0, 5.0, 0.0}
|
||||||
|
|
||||||
|
baseline := NewSampler(0, 0, 1, 0, 1, 0, 0, -1, nil)
|
||||||
|
baseline.Accept(1)
|
||||||
|
got, err := baseline.Sample(logits)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if got != 1 {
|
||||||
|
t.Fatalf("unexpected baseline token: got %d want %d", got, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
presence := NewSampler(0, 0, 1, 0, 1, 6, 0, -1, nil)
|
||||||
|
presence.Accept(1)
|
||||||
|
got, err = presence.Sample(logits)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if got == 1 {
|
||||||
|
t.Fatalf("presence penalty did not change repeated token selection")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSamplerFrequencyPenalty(t *testing.T) {
|
||||||
|
logits := []float32{0.0, 5.0, 4.0}
|
||||||
|
|
||||||
|
baseline := NewSampler(0, 0, 1, 0, 1, 0, 0, -1, nil)
|
||||||
|
baseline.Accept(1)
|
||||||
|
baseline.Accept(1)
|
||||||
|
baseline.Accept(1)
|
||||||
|
got, err := baseline.Sample(logits)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if got != 1 {
|
||||||
|
t.Fatalf("unexpected baseline token: got %d want %d", got, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
frequency := NewSampler(0, 0, 1, 0, 1, 0, 1.0, -1, nil)
|
||||||
|
frequency.Accept(1)
|
||||||
|
frequency.Accept(1)
|
||||||
|
frequency.Accept(1)
|
||||||
|
got, err = frequency.Sample(logits)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if got != 2 {
|
||||||
|
t.Fatalf("frequency penalty did not demote repeated token as expected: got %d want %d", got, 2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkTransforms(b *testing.B) {
|
func BenchmarkTransforms(b *testing.B) {
|
||||||
// Generate random logits
|
// Generate random logits
|
||||||
tokens := make([]token, 1<<16)
|
tokens := make([]token, 1<<16)
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ _build_darwin() {
|
|||||||
cmake --install $BUILD_DIR --component CPU
|
cmake --install $BUILD_DIR --component CPU
|
||||||
cmake --install $BUILD_DIR --component MLX
|
cmake --install $BUILD_DIR --component MLX
|
||||||
# Override CGO flags to point to the amd64 build directory
|
# Override CGO flags to point to the amd64 build directory
|
||||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
MLX_CGO_CFLAGS="-O3 -mmacosx-version-min=14.0"
|
||||||
MLX_CGO_LDFLAGS="-ldl -lc++ -framework Accelerate -mmacosx-version-min=14.0"
|
MLX_CGO_LDFLAGS="-ldl -lc++ -framework Accelerate -mmacosx-version-min=14.0"
|
||||||
else
|
else
|
||||||
BUILD_DIR=build
|
BUILD_DIR=build
|
||||||
@@ -70,10 +70,10 @@ _build_darwin() {
|
|||||||
cmake --build --preset MLX --parallel
|
cmake --build --preset MLX --parallel
|
||||||
cmake --install $BUILD_DIR --component MLX
|
cmake --install $BUILD_DIR --component MLX
|
||||||
# Use default CGO flags from mlx.go for arm64
|
# Use default CGO flags from mlx.go for arm64
|
||||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
MLX_CGO_CFLAGS="-O3 -mmacosx-version-min=14.0"
|
||||||
MLX_CGO_LDFLAGS="-lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
|
MLX_CGO_LDFLAGS="-lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
|
||||||
fi
|
fi
|
||||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX .
|
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -o $INSTALL_PREFIX .
|
||||||
# Copy MLX libraries to same directory as executable for dlopen
|
# Copy MLX libraries to same directory as executable for dlopen
|
||||||
cp $INSTALL_PREFIX/lib/ollama/libmlxc.dylib $INSTALL_PREFIX/
|
cp $INSTALL_PREFIX/lib/ollama/libmlxc.dylib $INSTALL_PREFIX/
|
||||||
cp $INSTALL_PREFIX/lib/ollama/libmlx.dylib $INSTALL_PREFIX/
|
cp $INSTALL_PREFIX/lib/ollama/libmlx.dylib $INSTALL_PREFIX/
|
||||||
|
|||||||
@@ -4,7 +4,10 @@
|
|||||||
#
|
#
|
||||||
# gcloud auth application-default login
|
# gcloud auth application-default login
|
||||||
|
|
||||||
$ErrorActionPreference = "Stop"
|
# Use "Continue" so that stderr output from native commands (e.g. CGo warnings)
|
||||||
|
# is not promoted to a terminating exception by the try/catch block.
|
||||||
|
# All native commands already check $LASTEXITCODE explicitly.
|
||||||
|
$ErrorActionPreference = "Continue"
|
||||||
|
|
||||||
mkdir -Force -path .\dist | Out-Null
|
mkdir -Force -path .\dist | Out-Null
|
||||||
|
|
||||||
@@ -16,13 +19,13 @@ function checkEnv {
|
|||||||
if ($null -ne $arch) {
|
if ($null -ne $arch) {
|
||||||
$script:ARCH = ($arch.ToString().ToLower()).Replace("x64", "amd64")
|
$script:ARCH = ($arch.ToString().ToLower()).Replace("x64", "amd64")
|
||||||
} else {
|
} else {
|
||||||
write-host "WARNING: old powershell detected, assuming amd64 architecture - set `$env:ARCH to override"
|
Write-Output "WARNING: old powershell detected, assuming amd64 architecture - set `$env:ARCH to override"
|
||||||
$script:ARCH="amd64"
|
$script:ARCH="amd64"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
$script:TARGET_ARCH=$script:ARCH
|
$script:TARGET_ARCH=$script:ARCH
|
||||||
Write-host "Building for ${script:TARGET_ARCH}"
|
Write-host "Building for ${script:TARGET_ARCH}"
|
||||||
write-host "Locating required tools and paths"
|
Write-Output "Locating required tools and paths"
|
||||||
$script:SRC_DIR=$PWD
|
$script:SRC_DIR=$PWD
|
||||||
|
|
||||||
# Locate CUDA versions
|
# Locate CUDA versions
|
||||||
@@ -37,16 +40,17 @@ function checkEnv {
|
|||||||
$script:CUDA_DIRS=($cudaList | sort-object -Descending)
|
$script:CUDA_DIRS=($cudaList | sort-object -Descending)
|
||||||
}
|
}
|
||||||
if ($script:CUDA_DIRS.length -gt 0) {
|
if ($script:CUDA_DIRS.length -gt 0) {
|
||||||
write-host "Available CUDA Versions: $script:CUDA_DIRS"
|
Write-Output "Available CUDA Versions: $script:CUDA_DIRS"
|
||||||
} else {
|
} else {
|
||||||
write-host "No CUDA versions detected"
|
Write-Output "No CUDA versions detected"
|
||||||
}
|
}
|
||||||
|
|
||||||
# Locate ROCm version
|
# Locate ROCm v6
|
||||||
if ($null -ne $env:HIP_PATH) {
|
$rocmDir=(get-item "C:\Program Files\AMD\ROCm\6.*" -ea 'silentlycontinue' | sort-object -Descending | select-object -First 1)
|
||||||
|
if ($null -ne $rocmDir) {
|
||||||
|
$script:HIP_PATH=$rocmDir.FullName
|
||||||
|
} elseif ($null -ne $env:HIP_PATH -and $env:HIP_PATH -match '[/\\]6\.') {
|
||||||
$script:HIP_PATH=$env:HIP_PATH
|
$script:HIP_PATH=$env:HIP_PATH
|
||||||
} else {
|
|
||||||
$script:HIP_PATH=(get-item "C:\Program Files\AMD\ROCm\*\bin\" -ea 'silentlycontinue' | sort-object -Descending)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
$inoSetup=(get-item "C:\Program Files*\Inno Setup*\")
|
$inoSetup=(get-item "C:\Program Files*\Inno Setup*\")
|
||||||
@@ -78,7 +82,7 @@ function checkEnv {
|
|||||||
} else {
|
} else {
|
||||||
$script:PKG_VERSION="0.0.0"
|
$script:PKG_VERSION="0.0.0"
|
||||||
}
|
}
|
||||||
write-host "Building Ollama $script:VERSION with package version $script:PKG_VERSION"
|
Write-Output "Building Ollama $script:VERSION with package version $script:PKG_VERSION"
|
||||||
|
|
||||||
# Note: Windows Kits 10 signtool crashes with GCP's plugin
|
# Note: Windows Kits 10 signtool crashes with GCP's plugin
|
||||||
if ($null -eq $env:SIGN_TOOL) {
|
if ($null -eq $env:SIGN_TOOL) {
|
||||||
@@ -87,12 +91,32 @@ function checkEnv {
|
|||||||
${script:SignTool}=${env:SIGN_TOOL}
|
${script:SignTool}=${env:SIGN_TOOL}
|
||||||
}
|
}
|
||||||
if ("${env:KEY_CONTAINER}") {
|
if ("${env:KEY_CONTAINER}") {
|
||||||
|
if (Test-Path "${script:SRC_DIR}\ollama_inc.crt") {
|
||||||
${script:OLLAMA_CERT}=$(resolve-path "${script:SRC_DIR}\ollama_inc.crt")
|
${script:OLLAMA_CERT}=$(resolve-path "${script:SRC_DIR}\ollama_inc.crt")
|
||||||
Write-host "Code signing enabled"
|
Write-host "Code signing enabled"
|
||||||
} else {
|
} else {
|
||||||
write-host "Code signing disabled - please set KEY_CONTAINERS to sign and copy ollama_inc.crt to the top of the source tree"
|
Write-Output "WARNING: KEY_CONTAINER is set but ollama_inc.crt not found at ${script:SRC_DIR}\ollama_inc.crt - code signing disabled"
|
||||||
}
|
}
|
||||||
$script:JOBS=([Environment]::ProcessorCount)
|
} else {
|
||||||
|
Write-Output "Code signing disabled - please set KEY_CONTAINERS to sign and copy ollama_inc.crt to the top of the source tree"
|
||||||
|
}
|
||||||
|
if ($env:OLLAMA_BUILD_PARALLEL) {
|
||||||
|
$script:JOBS=[int]$env:OLLAMA_BUILD_PARALLEL
|
||||||
|
} else {
|
||||||
|
# Use physical core count rather than logical processors (hyperthreads)
|
||||||
|
# to avoid saturating the system during builds
|
||||||
|
try {
|
||||||
|
$cores = (Get-CimInstance Win32_Processor | Measure-Object -Property NumberOfCores -Sum).Sum
|
||||||
|
} catch {
|
||||||
|
$cores = 0
|
||||||
|
}
|
||||||
|
if ($cores -gt 0) {
|
||||||
|
$script:JOBS = $cores
|
||||||
|
} else {
|
||||||
|
$script:JOBS = [Environment]::ProcessorCount
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Write-Output "Build parallelism: $script:JOBS (set OLLAMA_BUILD_PARALLEL to override)"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -127,7 +151,7 @@ function cuda11 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
write-host "Building CUDA v$cudaMajorVer backend libraries $cuda"
|
Write-Output "Building CUDA v$cudaMajorVer backend libraries $cuda"
|
||||||
$env:CUDAToolkit_ROOT=$cuda
|
$env:CUDAToolkit_ROOT=$cuda
|
||||||
& cmake -B build\cuda_v$cudaMajorVer --preset "CUDA $cudaMajorVer" -T cuda="$cuda" -DCMAKE_CUDA_COMPILER="$cuda\bin\nvcc.exe" -G "Visual Studio 16 2019" --install-prefix "$script:DIST_DIR"
|
& cmake -B build\cuda_v$cudaMajorVer --preset "CUDA $cudaMajorVer" -T cuda="$cuda" -DCMAKE_CUDA_COMPILER="$cuda\bin\nvcc.exe" -G "Visual Studio 16 2019" --install-prefix "$script:DIST_DIR"
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
@@ -136,12 +160,12 @@ function cuda11 {
|
|||||||
& cmake --install build\cuda_v$cudaMajorVer --component "CUDA" --strip
|
& cmake --install build\cuda_v$cudaMajorVer --component "CUDA" --strip
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
} else {
|
} else {
|
||||||
write-host "CUDA v$cudaMajorVer not detected, skipping"
|
Write-Output "CUDA v$cudaMajorVer not detected, skipping"
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
write-host "not arch we wanted"
|
Write-Output "not arch we wanted"
|
||||||
}
|
}
|
||||||
write-host "done"
|
Write-Output "done"
|
||||||
}
|
}
|
||||||
|
|
||||||
function cudaCommon {
|
function cudaCommon {
|
||||||
@@ -159,7 +183,7 @@ function cudaCommon {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
write-host "Building CUDA v$cudaMajorVer backend libraries $cuda"
|
Write-Output "Building CUDA v$cudaMajorVer backend libraries $cuda"
|
||||||
$env:CUDAToolkit_ROOT=$cuda
|
$env:CUDAToolkit_ROOT=$cuda
|
||||||
& cmake -B build\cuda_v$cudaMajorVer --preset "CUDA $cudaMajorVer" -T cuda="$cuda" --install-prefix "$script:DIST_DIR"
|
& cmake -B build\cuda_v$cudaMajorVer --preset "CUDA $cudaMajorVer" -T cuda="$cuda" --install-prefix "$script:DIST_DIR"
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
@@ -168,7 +192,7 @@ function cudaCommon {
|
|||||||
& cmake --install build\cuda_v$cudaMajorVer --component "CUDA" --strip
|
& cmake --install build\cuda_v$cudaMajorVer --component "CUDA" --strip
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
} else {
|
} else {
|
||||||
write-host "CUDA v$cudaMajorVer not detected, skipping"
|
Write-Output "CUDA v$cudaMajorVer not detected, skipping"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -181,11 +205,11 @@ function cuda13 {
|
|||||||
cudaCommon("13")
|
cudaCommon("13")
|
||||||
}
|
}
|
||||||
|
|
||||||
function rocm {
|
function rocm6 {
|
||||||
mkdir -Force -path "${script:DIST_DIR}\" | Out-Null
|
mkdir -Force -path "${script:DIST_DIR}\" | Out-Null
|
||||||
if ($script:ARCH -ne "arm64") {
|
if ($script:ARCH -ne "arm64") {
|
||||||
if ($script:HIP_PATH) {
|
if ($script:HIP_PATH) {
|
||||||
write-host "Building ROCm backend libraries $script:HIP_PATH"
|
Write-Output "Building ROCm backend libraries $script:HIP_PATH"
|
||||||
if (-Not (get-command -ErrorAction silent ninja)) {
|
if (-Not (get-command -ErrorAction silent ninja)) {
|
||||||
$NINJA_DIR=(gci -path (Get-CimInstance MSFT_VSInstance -Namespace root/cimv2/vs)[0].InstallLocation -r -fi ninja.exe).Directory.FullName
|
$NINJA_DIR=(gci -path (Get-CimInstance MSFT_VSInstance -Namespace root/cimv2/vs)[0].InstallLocation -r -fi ninja.exe).Directory.FullName
|
||||||
$env:PATH="$NINJA_DIR;$env:PATH"
|
$env:PATH="$NINJA_DIR;$env:PATH"
|
||||||
@@ -193,9 +217,11 @@ function rocm {
|
|||||||
$env:HIPCXX="${script:HIP_PATH}\bin\clang++.exe"
|
$env:HIPCXX="${script:HIP_PATH}\bin\clang++.exe"
|
||||||
$env:HIP_PLATFORM="amd"
|
$env:HIP_PLATFORM="amd"
|
||||||
$env:CMAKE_PREFIX_PATH="${script:HIP_PATH}"
|
$env:CMAKE_PREFIX_PATH="${script:HIP_PATH}"
|
||||||
|
# Set CC/CXX via environment instead of -D flags to avoid triggering
|
||||||
|
# spurious compiler-change reconfigures that reset CMAKE_INSTALL_PREFIX
|
||||||
|
$env:CC="${script:HIP_PATH}\bin\clang.exe"
|
||||||
|
$env:CXX="${script:HIP_PATH}\bin\clang++.exe"
|
||||||
& cmake -B build\rocm --preset "ROCm 6" -G Ninja `
|
& cmake -B build\rocm --preset "ROCm 6" -G Ninja `
|
||||||
-DCMAKE_C_COMPILER=clang `
|
|
||||||
-DCMAKE_CXX_COMPILER=clang++ `
|
|
||||||
-DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" `
|
-DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" `
|
||||||
-DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" `
|
-DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" `
|
||||||
--install-prefix $script:DIST_DIR
|
--install-prefix $script:DIST_DIR
|
||||||
@@ -203,20 +229,22 @@ function rocm {
|
|||||||
$env:HIPCXX=""
|
$env:HIPCXX=""
|
||||||
$env:HIP_PLATFORM=""
|
$env:HIP_PLATFORM=""
|
||||||
$env:CMAKE_PREFIX_PATH=""
|
$env:CMAKE_PREFIX_PATH=""
|
||||||
|
$env:CC=""
|
||||||
|
$env:CXX=""
|
||||||
& cmake --build build\rocm --target ggml-hip --config Release --parallel $script:JOBS
|
& cmake --build build\rocm --target ggml-hip --config Release --parallel $script:JOBS
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
& cmake --install build\rocm --component "HIP" --strip
|
& cmake --install build\rocm --component "HIP" --strip
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
Remove-Item -Path $script:DIST_DIR\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
|
Remove-Item -Path $script:DIST_DIR\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
|
||||||
} else {
|
} else {
|
||||||
write-host "ROCm not detected, skipping"
|
Write-Output "ROCm not detected, skipping"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function vulkan {
|
function vulkan {
|
||||||
if ($env:VULKAN_SDK) {
|
if ($env:VULKAN_SDK) {
|
||||||
write-host "Building Vulkan backend libraries"
|
Write-Output "Building Vulkan backend libraries"
|
||||||
& cmake -B build\vulkan --preset Vulkan --install-prefix $script:DIST_DIR
|
& cmake -B build\vulkan --preset Vulkan --install-prefix $script:DIST_DIR
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
& cmake --build build\vulkan --target ggml-vulkan --config Release --parallel $script:JOBS
|
& cmake --build build\vulkan --target ggml-vulkan --config Release --parallel $script:JOBS
|
||||||
@@ -224,33 +252,91 @@ function vulkan {
|
|||||||
& cmake --install build\vulkan --component Vulkan --strip
|
& cmake --install build\vulkan --component Vulkan --strip
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
} else {
|
} else {
|
||||||
write-host "Vulkan not detected, skipping"
|
Write-Output "Vulkan not detected, skipping"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function mlxCuda13 {
|
||||||
|
mkdir -Force -path "${script:DIST_DIR}\" | Out-Null
|
||||||
|
$cudaMajorVer="13"
|
||||||
|
if ($script:ARCH -ne "arm64") {
|
||||||
|
if ("$script:CUDA_DIRS".Contains("v$cudaMajorVer")) {
|
||||||
|
foreach ($d in $Script:CUDA_DIRS){
|
||||||
|
if ($d.FullName.Contains("v$cudaMajorVer")) {
|
||||||
|
if (test-path -literalpath (join-path -path $d -childpath "nvcc.exe" ) ) {
|
||||||
|
$cuda=($d.FullName|split-path -parent)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check for cuDNN - required for MLX CUDA backend
|
||||||
|
# Supports two layouts:
|
||||||
|
# 1. CI/zip extract: CUDNN\include\cudnn.h, lib\x64\, bin\x64\
|
||||||
|
# 2. Official installer: CUDNN\v*\include\{cuda-ver}\cudnn.h, lib\{cuda-ver}\x64\, bin\{cuda-ver}\
|
||||||
|
if ($env:CUDNN_INCLUDE_PATH -and $env:CUDNN_LIBRARY_PATH) {
|
||||||
|
Write-Output "Using cuDNN from environment: $env:CUDNN_INCLUDE_PATH"
|
||||||
|
} elseif (Test-Path "C:\Program Files\NVIDIA\CUDNN\include\cudnn.h") {
|
||||||
|
# CI/zip layout (flat)
|
||||||
|
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
|
||||||
|
$env:CUDNN_ROOT_DIR = $cudnnRoot
|
||||||
|
$env:CUDNN_INCLUDE_PATH = "$cudnnRoot\include"
|
||||||
|
$env:CUDNN_LIBRARY_PATH = "$cudnnRoot\lib\x64"
|
||||||
|
Write-Output "Found cuDNN at $cudnnRoot (flat layout)"
|
||||||
|
} else {
|
||||||
|
# Official installer layout (versioned)
|
||||||
|
$cudnnRoot = $null
|
||||||
|
$resolved = Resolve-Path -Path "C:\Program Files\NVIDIA\CUDNN\v*" -ErrorAction SilentlyContinue | Sort-Object -Descending | Select-Object -First 1
|
||||||
|
if ($resolved -and (Test-Path "$($resolved.Path)\include\$cudaMajorVer.0\cudnn.h")) {
|
||||||
|
$cudnnRoot = $resolved.Path
|
||||||
|
$env:CUDNN_ROOT_DIR = $cudnnRoot
|
||||||
|
$env:CUDNN_INCLUDE_PATH = "$cudnnRoot\include\$cudaMajorVer.0"
|
||||||
|
$env:CUDNN_LIBRARY_PATH = "$cudnnRoot\lib\$cudaMajorVer.0\x64"
|
||||||
|
Write-Output "Found cuDNN at $cudnnRoot (official installer, CUDA $cudaMajorVer.0)"
|
||||||
|
} else {
|
||||||
|
Write-Output "cuDNN not found - set CUDNN_INCLUDE_PATH and CUDNN_LIBRARY_PATH environment variables"
|
||||||
|
Write-Output "Skipping MLX build"
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Write-Output "Building MLX CUDA v$cudaMajorVer backend libraries $cuda"
|
||||||
|
$env:CUDAToolkit_ROOT=$cuda
|
||||||
|
& cmake -B build\mlx_cuda_v$cudaMajorVer --preset "MLX CUDA $cudaMajorVer" -T cuda="$cuda" --install-prefix "$script:DIST_DIR"
|
||||||
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
|
& cmake --build build\mlx_cuda_v$cudaMajorVer --target mlx --target mlxc --config Release --parallel $script:JOBS -- /nodeReuse:false
|
||||||
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
|
& cmake --install build\mlx_cuda_v$cudaMajorVer --component "MLX" --strip
|
||||||
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
|
} else {
|
||||||
|
Write-Output "CUDA v$cudaMajorVer not detected, skipping MLX build"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function ollama {
|
function ollama {
|
||||||
mkdir -Force -path "${script:DIST_DIR}\" | Out-Null
|
mkdir -Force -path "${script:DIST_DIR}\" | Out-Null
|
||||||
write-host "Building ollama CLI"
|
Write-Output "Building ollama CLI"
|
||||||
& go build -trimpath -ldflags "-s -w -X=github.com/ollama/ollama/version.Version=$script:VERSION -X=github.com/ollama/ollama/server.mode=release" .
|
& go build -trimpath -ldflags "-s -w -X=github.com/ollama/ollama/version.Version=$script:VERSION -X=github.com/ollama/ollama/server.mode=release" .
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
cp .\ollama.exe "${script:DIST_DIR}\"
|
cp .\ollama.exe "${script:DIST_DIR}\"
|
||||||
}
|
}
|
||||||
|
|
||||||
function app {
|
function app {
|
||||||
write-host "Building Ollama App $script:VERSION with package version $script:PKG_VERSION"
|
Write-Output "Building Ollama App $script:VERSION with package version $script:PKG_VERSION"
|
||||||
|
|
||||||
if (!(Get-Command npm -ErrorAction SilentlyContinue)) {
|
if (!(Get-Command npm -ErrorAction SilentlyContinue)) {
|
||||||
write-host "npm is not installed. Please install Node.js and npm first:"
|
Write-Output "npm is not installed. Please install Node.js and npm first:"
|
||||||
write-host " Visit: https://nodejs.org/"
|
Write-Output " Visit: https://nodejs.org/"
|
||||||
exit 1
|
exit 1
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!(Get-Command tsc -ErrorAction SilentlyContinue)) {
|
if (!(Get-Command tsc -ErrorAction SilentlyContinue)) {
|
||||||
write-host "Installing TypeScript compiler..."
|
Write-Output "Installing TypeScript compiler..."
|
||||||
npm install -g typescript
|
npm install -g typescript
|
||||||
}
|
}
|
||||||
if (!(Get-Command tscriptify -ErrorAction SilentlyContinue)) {
|
if (!(Get-Command tscriptify -ErrorAction SilentlyContinue)) {
|
||||||
write-host "Installing tscriptify..."
|
Write-Output "Installing tscriptify..."
|
||||||
go install github.com/tkrajina/typescriptify-golang-structs/tscriptify@latest
|
go install github.com/tkrajina/typescriptify-golang-structs/tscriptify@latest
|
||||||
}
|
}
|
||||||
if (!(Get-Command tscriptify -ErrorAction SilentlyContinue)) {
|
if (!(Get-Command tscriptify -ErrorAction SilentlyContinue)) {
|
||||||
@@ -260,32 +346,32 @@ function app {
|
|||||||
Push-Location app/ui/app
|
Push-Location app/ui/app
|
||||||
npm install
|
npm install
|
||||||
if ($LASTEXITCODE -ne 0) {
|
if ($LASTEXITCODE -ne 0) {
|
||||||
write-host "ERROR: npm install failed with exit code $LASTEXITCODE"
|
Write-Output "ERROR: npm install failed with exit code $LASTEXITCODE"
|
||||||
exit $LASTEXITCODE
|
exit $LASTEXITCODE
|
||||||
}
|
}
|
||||||
|
|
||||||
write-host "Building React application..."
|
Write-Output "Building React application..."
|
||||||
npm run build
|
npm run build
|
||||||
if ($LASTEXITCODE -ne 0) {
|
if ($LASTEXITCODE -ne 0) {
|
||||||
write-host "ERROR: npm run build failed with exit code $LASTEXITCODE"
|
Write-Output "ERROR: npm run build failed with exit code $LASTEXITCODE"
|
||||||
exit $LASTEXITCODE
|
exit $LASTEXITCODE
|
||||||
}
|
}
|
||||||
|
|
||||||
# Check if dist directory exists and has content
|
# Check if dist directory exists and has content
|
||||||
if (!(Test-Path "dist")) {
|
if (!(Test-Path "dist")) {
|
||||||
write-host "ERROR: dist directory was not created by npm run build"
|
Write-Output "ERROR: dist directory was not created by npm run build"
|
||||||
exit 1
|
exit 1
|
||||||
}
|
}
|
||||||
|
|
||||||
$distFiles = Get-ChildItem "dist" -Recurse
|
$distFiles = Get-ChildItem "dist" -Recurse
|
||||||
if ($distFiles.Count -eq 0) {
|
if ($distFiles.Count -eq 0) {
|
||||||
write-host "ERROR: dist directory is empty after npm run build"
|
Write-Output "ERROR: dist directory is empty after npm run build"
|
||||||
exit 1
|
exit 1
|
||||||
}
|
}
|
||||||
|
|
||||||
Pop-Location
|
Pop-Location
|
||||||
|
|
||||||
write-host "Running go generate"
|
Write-Output "Running go generate"
|
||||||
& go generate ./...
|
& go generate ./...
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
& go build -trimpath -ldflags "-s -w -H windowsgui -X=github.com/ollama/ollama/app/version.Version=$script:VERSION" -o .\dist\windows-ollama-app-${script:ARCH}.exe ./app/cmd/app/
|
& go build -trimpath -ldflags "-s -w -H windowsgui -X=github.com/ollama/ollama/app/version.Version=$script:VERSION" -o .\dist\windows-ollama-app-${script:ARCH}.exe ./app/cmd/app/
|
||||||
@@ -293,42 +379,42 @@ function app {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function deps {
|
function deps {
|
||||||
write-host "Download MSVC Redistributables"
|
Write-Output "Download MSVC Redistributables"
|
||||||
mkdir -Force -path "${script:SRC_DIR}\dist\\windows-arm64" | Out-Null
|
mkdir -Force -path "${script:SRC_DIR}\dist\\windows-arm64" | Out-Null
|
||||||
mkdir -Force -path "${script:SRC_DIR}\dist\\windows-amd64" | Out-Null
|
mkdir -Force -path "${script:SRC_DIR}\dist\\windows-amd64" | Out-Null
|
||||||
invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.arm64.exe" -OutFile "${script:SRC_DIR}\dist\windows-arm64\vc_redist.arm64.exe"
|
invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.arm64.exe" -OutFile "${script:SRC_DIR}\dist\windows-arm64\vc_redist.arm64.exe" -ErrorAction Stop
|
||||||
invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.x64.exe" -OutFile "${script:SRC_DIR}\dist\windows-amd64\vc_redist.x64.exe"
|
invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.x64.exe" -OutFile "${script:SRC_DIR}\dist\windows-amd64\vc_redist.x64.exe" -ErrorAction Stop
|
||||||
write-host "Done."
|
Write-Output "Done."
|
||||||
}
|
}
|
||||||
|
|
||||||
function sign {
|
function sign {
|
||||||
# Copy install.ps1 to dist for release packaging
|
# Copy install.ps1 to dist for release packaging
|
||||||
write-host "Copying install.ps1 to dist"
|
Write-Output "Copying install.ps1 to dist"
|
||||||
Copy-Item -Path "${script:SRC_DIR}\scripts\install.ps1" -Destination "${script:SRC_DIR}\dist\install.ps1"
|
Copy-Item -Path "${script:SRC_DIR}\scripts\install.ps1" -Destination "${script:SRC_DIR}\dist\install.ps1" -ErrorAction Stop
|
||||||
|
|
||||||
if ("${env:KEY_CONTAINER}") {
|
if ("${env:KEY_CONTAINER}") {
|
||||||
write-host "Signing Ollama executables, scripts and libraries"
|
Write-Output "Signing Ollama executables, scripts and libraries"
|
||||||
& "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
|
& "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
|
||||||
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} `
|
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} `
|
||||||
$(get-childitem -path "${script:SRC_DIR}\dist\windows-*" -r -include @('*.exe', '*.dll'))
|
$(get-childitem -path "${script:SRC_DIR}\dist\windows-*" -r -include @('*.exe', '*.dll'))
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
|
|
||||||
write-host "Signing install.ps1"
|
Write-Output "Signing install.ps1"
|
||||||
& "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
|
& "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
|
||||||
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} `
|
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} `
|
||||||
"${script:SRC_DIR}\dist\install.ps1"
|
"${script:SRC_DIR}\dist\install.ps1"
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
} else {
|
} else {
|
||||||
write-host "Signing not enabled"
|
Write-Output "Signing not enabled"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function installer {
|
function installer {
|
||||||
if ($null -eq ${script:INNO_SETUP_DIR}) {
|
if ($null -eq ${script:INNO_SETUP_DIR}) {
|
||||||
write-host "ERROR: missing Inno Setup installation directory - install from https://jrsoftware.org/isdl.php"
|
Write-Output "ERROR: missing Inno Setup installation directory - install from https://jrsoftware.org/isdl.php"
|
||||||
exit 1
|
exit 1
|
||||||
}
|
}
|
||||||
write-host "Building Ollama Installer"
|
Write-Output "Building Ollama Installer"
|
||||||
cd "${script:SRC_DIR}\app"
|
cd "${script:SRC_DIR}\app"
|
||||||
$env:PKG_VERSION=$script:PKG_VERSION
|
$env:PKG_VERSION=$script:PKG_VERSION
|
||||||
if ("${env:KEY_CONTAINER}") {
|
if ("${env:KEY_CONTAINER}") {
|
||||||
@@ -342,24 +428,24 @@ function installer {
|
|||||||
function zip {
|
function zip {
|
||||||
if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64") {
|
if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64") {
|
||||||
if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm") {
|
if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm") {
|
||||||
write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64-rocm.zip"
|
Write-Output "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64-rocm.zip"
|
||||||
# Temporarily adjust paths so we can retain the same directory structure
|
# Temporarily adjust paths so we can retain the same directory structure
|
||||||
Remove-Item -ea 0 -r "${script:SRC_DIR}\dist\windows-amd64-rocm"
|
Remove-Item -ea 0 -r "${script:SRC_DIR}\dist\windows-amd64-rocm"
|
||||||
mkdir -Force -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama"
|
mkdir -Force -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama"
|
||||||
Write-Output "Extract this ROCm zip file to the same location where you extracted ollama-windows-amd64.zip" > "${script:SRC_DIR}\dist\windows-amd64-rocm\README.txt"
|
Write-Output "Extract this ROCm zip file to the same location where you extracted ollama-windows-amd64.zip" > "${script:SRC_DIR}\dist\windows-amd64-rocm\README.txt"
|
||||||
Move-Item -path "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -destination "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama"
|
Move-Item -path "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -destination "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama" -ErrorAction Stop
|
||||||
Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-amd64-rocm\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64-rocm.zip" -Force
|
Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-amd64-rocm\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64-rocm.zip" -Force
|
||||||
}
|
}
|
||||||
|
|
||||||
write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64.zip"
|
Write-Output "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64.zip"
|
||||||
Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-amd64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64.zip" -Force
|
Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-amd64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64.zip" -Force
|
||||||
if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64-rocm") {
|
if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64-rocm") {
|
||||||
Move-Item -destination "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama\rocm"
|
Move-Item -destination "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama\rocm" -ErrorAction Stop
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Test-Path -Path "${script:SRC_DIR}\dist\windows-arm64") {
|
if (Test-Path -Path "${script:SRC_DIR}\dist\windows-arm64") {
|
||||||
write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-arm64.zip"
|
Write-Output "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-arm64.zip"
|
||||||
Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-arm64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-arm64.zip" -Force
|
Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-arm64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-arm64.zip" -Force
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -375,8 +461,9 @@ try {
|
|||||||
cpu
|
cpu
|
||||||
cuda12
|
cuda12
|
||||||
cuda13
|
cuda13
|
||||||
rocm
|
rocm6
|
||||||
vulkan
|
vulkan
|
||||||
|
mlxCuda13
|
||||||
ollama
|
ollama
|
||||||
app
|
app
|
||||||
deps
|
deps
|
||||||
@@ -385,13 +472,13 @@ try {
|
|||||||
zip
|
zip
|
||||||
} else {
|
} else {
|
||||||
for ( $i = 0; $i -lt $args.count; $i++ ) {
|
for ( $i = 0; $i -lt $args.count; $i++ ) {
|
||||||
write-host "running build step $($args[$i])"
|
Write-Output "running build step $($args[$i])"
|
||||||
& $($args[$i])
|
& $($args[$i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch {
|
} catch {
|
||||||
write-host "Build Failed"
|
Write-Error "Build Failed: $($_.Exception.Message)"
|
||||||
write-host $_
|
Write-Error "$($_.ScriptStackTrace)"
|
||||||
} finally {
|
} finally {
|
||||||
set-location $script:SRC_DIR
|
set-location $script:SRC_DIR
|
||||||
$env:PKG_VERSION=""
|
$env:PKG_VERSION=""
|
||||||
|
|||||||
@@ -16,9 +16,16 @@ OLLAMA_COMMON_BUILD_ARGS="--build-arg=VERSION \
|
|||||||
--build-arg=OLLAMA_FAST_BUILD \
|
--build-arg=OLLAMA_FAST_BUILD \
|
||||||
--build-arg=CUSTOM_CPU_FLAGS \
|
--build-arg=CUSTOM_CPU_FLAGS \
|
||||||
--build-arg=GPU_RUNNER_CPU_FLAGS \
|
--build-arg=GPU_RUNNER_CPU_FLAGS \
|
||||||
--build-arg=PARALLEL \
|
|
||||||
--build-arg=AMDGPU_TARGETS"
|
--build-arg=AMDGPU_TARGETS"
|
||||||
|
|
||||||
|
# Forward local MLX source overrides as Docker build contexts
|
||||||
|
if [ -n "${OLLAMA_MLX_SOURCE:-}" ]; then
|
||||||
|
OLLAMA_COMMON_BUILD_ARGS="$OLLAMA_COMMON_BUILD_ARGS --build-context local-mlx=$(cd "$OLLAMA_MLX_SOURCE" && pwd)"
|
||||||
|
fi
|
||||||
|
if [ -n "${OLLAMA_MLX_C_SOURCE:-}" ]; then
|
||||||
|
OLLAMA_COMMON_BUILD_ARGS="$OLLAMA_COMMON_BUILD_ARGS --build-context local-mlx-c=$(cd "$OLLAMA_MLX_C_SOURCE" && pwd)"
|
||||||
|
fi
|
||||||
|
|
||||||
echo "Building Ollama"
|
echo "Building Ollama"
|
||||||
echo "VERSION=$VERSION"
|
echo "VERSION=$VERSION"
|
||||||
echo "PLATFORM=$PLATFORM"
|
echo "PLATFORM=$PLATFORM"
|
||||||
479
server/cloud_proxy.go
Normal file
479
server/cloud_proxy.go
Normal file
@@ -0,0 +1,479 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/auth"
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultCloudProxyBaseURL = "https://ollama.com:443"
|
||||||
|
defaultCloudProxySigningHost = "ollama.com"
|
||||||
|
cloudProxyBaseURLEnv = "OLLAMA_CLOUD_BASE_URL"
|
||||||
|
legacyCloudAnthropicKey = "legacy_cloud_anthropic_web_search"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
cloudProxyBaseURL = defaultCloudProxyBaseURL
|
||||||
|
cloudProxySigningHost = defaultCloudProxySigningHost
|
||||||
|
cloudProxySignRequest = signCloudProxyRequest
|
||||||
|
cloudProxySigninURL = signinURL
|
||||||
|
)
|
||||||
|
|
||||||
|
var hopByHopHeaders = map[string]struct{}{
|
||||||
|
"connection": {},
|
||||||
|
"content-length": {},
|
||||||
|
"proxy-connection": {},
|
||||||
|
"keep-alive": {},
|
||||||
|
"proxy-authenticate": {},
|
||||||
|
"proxy-authorization": {},
|
||||||
|
"te": {},
|
||||||
|
"trailer": {},
|
||||||
|
"transfer-encoding": {},
|
||||||
|
"upgrade": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL(envconfig.Var(cloudProxyBaseURLEnv), mode)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("ignoring cloud base URL override", "env", cloudProxyBaseURLEnv, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cloudProxyBaseURL = baseURL
|
||||||
|
cloudProxySigningHost = signingHost
|
||||||
|
|
||||||
|
if overridden {
|
||||||
|
slog.Info("cloud base URL override enabled", "env", cloudProxyBaseURLEnv, "url", cloudProxyBaseURL, "mode", mode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloudPassthroughMiddleware(disabledOperation string) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if c.Request.Method != http.MethodPost {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(drifkin): Avoid full-body buffering here for model detection.
|
||||||
|
// A future optimization can parse just enough JSON to read "model" (and
|
||||||
|
// optionally short-circuit cloud-disabled explicit-cloud requests) while
|
||||||
|
// preserving raw passthrough semantics.
|
||||||
|
body, err := readRequestBody(c.Request)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
model, ok := extractModelField(body)
|
||||||
|
if !ok {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
modelRef, err := parseAndValidateModelRef(model)
|
||||||
|
if err != nil || modelRef.Source != modelSourceCloud {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizedBody, err := replaceJSONModelField(body, modelRef.Base)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TEMP(drifkin): keep Anthropic web search requests on the local middleware
|
||||||
|
// path so WebSearchAnthropicWriter can orchestrate follow-up calls.
|
||||||
|
if c.Request.URL.Path == "/v1/messages" {
|
||||||
|
if hasAnthropicWebSearchTool(body) {
|
||||||
|
c.Set(legacyCloudAnthropicKey, true)
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyCloudRequest(c, normalizedBody, disabledOperation)
|
||||||
|
c.Abort()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloudModelPathPassthroughMiddleware(disabledOperation string) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
modelName := strings.TrimSpace(c.Param("model"))
|
||||||
|
if modelName == "" {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
modelRef, err := parseAndValidateModelRef(modelName)
|
||||||
|
if err != nil || modelRef.Source != modelSourceCloud {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyPath := "/v1/models/" + modelRef.Base
|
||||||
|
proxyCloudRequestWithPath(c, nil, proxyPath, disabledOperation)
|
||||||
|
c.Abort()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func proxyCloudJSONRequest(c *gin.Context, payload any, disabledOperation string) {
|
||||||
|
// TEMP(drifkin): we currently split out this `WithPath` method because we are
|
||||||
|
// mapping `/v1/messages` + web_search to `/api/chat` temporarily. Once we
|
||||||
|
// stop doing this, we can inline this method.
|
||||||
|
proxyCloudJSONRequestWithPath(c, payload, c.Request.URL.Path, disabledOperation)
|
||||||
|
}
|
||||||
|
|
||||||
|
func proxyCloudJSONRequestWithPath(c *gin.Context, payload any, path string, disabledOperation string) {
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyCloudRequestWithPath(c, body, path, disabledOperation)
|
||||||
|
}
|
||||||
|
|
||||||
|
func proxyCloudRequest(c *gin.Context, body []byte, disabledOperation string) {
|
||||||
|
proxyCloudRequestWithPath(c, body, c.Request.URL.Path, disabledOperation)
|
||||||
|
}
|
||||||
|
|
||||||
|
func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disabledOperation string) {
|
||||||
|
if disabled, _ := internalcloud.Status(); disabled {
|
||||||
|
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(disabledOperation)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL, err := url.Parse(cloudProxyBaseURL)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
targetURL := baseURL.ResolveReference(&url.URL{
|
||||||
|
Path: path,
|
||||||
|
RawQuery: c.Request.URL.RawQuery,
|
||||||
|
})
|
||||||
|
|
||||||
|
outReq, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL.String(), bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
copyProxyRequestHeaders(outReq.Header, c.Request.Header)
|
||||||
|
if outReq.Header.Get("Content-Type") == "" && len(body) > 0 {
|
||||||
|
outReq.Header.Set("Content-Type", "application/json")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cloudProxySignRequest(outReq.Context(), outReq); err != nil {
|
||||||
|
slog.Warn("cloud proxy signing failed", "error", err)
|
||||||
|
writeCloudUnauthorized(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(drifkin): Add phase-specific proxy timeouts.
|
||||||
|
// Connect/TLS/TTFB should have bounded timeouts, but once streaming starts
|
||||||
|
// we should not enforce a short total timeout for long-lived responses.
|
||||||
|
resp, err := http.DefaultClient.Do(outReq)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
copyProxyResponseHeaders(c.Writer.Header(), resp.Header)
|
||||||
|
c.Status(resp.StatusCode)
|
||||||
|
|
||||||
|
if err := copyProxyResponseBody(c.Writer, resp.Body); err != nil {
|
||||||
|
ctxErr := c.Request.Context().Err()
|
||||||
|
if errors.Is(err, context.Canceled) && errors.Is(ctxErr, context.Canceled) {
|
||||||
|
slog.Debug(
|
||||||
|
"cloud proxy response stream closed by client",
|
||||||
|
"path", c.Request.URL.Path,
|
||||||
|
"status", resp.StatusCode,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Warn(
|
||||||
|
"cloud proxy response copy failed",
|
||||||
|
"path", c.Request.URL.Path,
|
||||||
|
"status", resp.StatusCode,
|
||||||
|
"request_context_canceled", ctxErr != nil,
|
||||||
|
"request_context_err", ctxErr,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func replaceJSONModelField(body []byte, model string) ([]byte, error) {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload map[string]json.RawMessage
|
||||||
|
if err := json.Unmarshal(body, &payload); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
modelJSON, err := json.Marshal(model)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
payload["model"] = modelJSON
|
||||||
|
|
||||||
|
return json.Marshal(payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func readRequestBody(r *http.Request) ([]byte, error) {
|
||||||
|
if r.Body == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractModelField(body []byte) (string, bool) {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload map[string]json.RawMessage
|
||||||
|
if err := json.Unmarshal(body, &payload); err != nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
raw, ok := payload["model"]
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
var model string
|
||||||
|
if err := json.Unmarshal(raw, &model); err != nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
model = strings.TrimSpace(model)
|
||||||
|
return model, model != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasAnthropicWebSearchTool(body []byte) bool {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload struct {
|
||||||
|
Tools []struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
} `json:"tools"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &payload); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tool := range payload.Tools {
|
||||||
|
if strings.HasPrefix(strings.TrimSpace(tool.Type), "web_search") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeCloudUnauthorized(c *gin.Context) {
|
||||||
|
signinURL, err := cloudProxySigninURL()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": signinURL})
|
||||||
|
}
|
||||||
|
|
||||||
|
func signCloudProxyRequest(ctx context.Context, req *http.Request) error {
|
||||||
|
if !strings.EqualFold(req.URL.Hostname(), cloudProxySigningHost) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ts := strconv.FormatInt(time.Now().Unix(), 10)
|
||||||
|
challenge := buildCloudSignatureChallenge(req, ts)
|
||||||
|
signature, err := auth.Sign(ctx, []byte(challenge))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Authorization", signature)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildCloudSignatureChallenge(req *http.Request, ts string) string {
|
||||||
|
query := req.URL.Query()
|
||||||
|
query.Set("ts", ts)
|
||||||
|
req.URL.RawQuery = query.Encode()
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s,%s", req.Method, req.URL.RequestURI())
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveCloudProxyBaseURL(rawOverride string, runMode string) (baseURL string, signingHost string, overridden bool, err error) {
|
||||||
|
baseURL = defaultCloudProxyBaseURL
|
||||||
|
signingHost = defaultCloudProxySigningHost
|
||||||
|
|
||||||
|
rawOverride = strings.TrimSpace(rawOverride)
|
||||||
|
if rawOverride == "" {
|
||||||
|
return baseURL, signingHost, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := url.Parse(rawOverride)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", false, fmt.Errorf("invalid URL: %w", err)
|
||||||
|
}
|
||||||
|
if u.Scheme == "" || u.Host == "" {
|
||||||
|
return "", "", false, fmt.Errorf("invalid URL: scheme and host are required")
|
||||||
|
}
|
||||||
|
if u.User != nil {
|
||||||
|
return "", "", false, fmt.Errorf("invalid URL: userinfo is not allowed")
|
||||||
|
}
|
||||||
|
if u.Path != "" && u.Path != "/" {
|
||||||
|
return "", "", false, fmt.Errorf("invalid URL: path is not allowed")
|
||||||
|
}
|
||||||
|
if u.RawQuery != "" || u.Fragment != "" {
|
||||||
|
return "", "", false, fmt.Errorf("invalid URL: query and fragment are not allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
host := u.Hostname()
|
||||||
|
if host == "" {
|
||||||
|
return "", "", false, fmt.Errorf("invalid URL: host is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
loopback := isLoopbackHost(host)
|
||||||
|
if runMode == gin.ReleaseMode && !loopback {
|
||||||
|
return "", "", false, fmt.Errorf("non-loopback cloud override is not allowed in release mode")
|
||||||
|
}
|
||||||
|
if !loopback && !strings.EqualFold(u.Scheme, "https") {
|
||||||
|
return "", "", false, fmt.Errorf("non-loopback cloud override must use https")
|
||||||
|
}
|
||||||
|
|
||||||
|
u.Path = ""
|
||||||
|
u.RawPath = ""
|
||||||
|
u.RawQuery = ""
|
||||||
|
u.Fragment = ""
|
||||||
|
|
||||||
|
return u.String(), strings.ToLower(host), true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isLoopbackHost(host string) bool {
|
||||||
|
if strings.EqualFold(host, "localhost") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
ip := net.ParseIP(host)
|
||||||
|
return ip != nil && ip.IsLoopback()
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyProxyRequestHeaders(dst, src http.Header) {
|
||||||
|
connectionTokens := connectionHeaderTokens(src)
|
||||||
|
for key, values := range src {
|
||||||
|
if isHopByHopHeader(key) || isConnectionTokenHeader(key, connectionTokens) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Del(key)
|
||||||
|
for _, value := range values {
|
||||||
|
dst.Add(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyProxyResponseHeaders(dst, src http.Header) {
|
||||||
|
connectionTokens := connectionHeaderTokens(src)
|
||||||
|
for key, values := range src {
|
||||||
|
if isHopByHopHeader(key) || isConnectionTokenHeader(key, connectionTokens) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Del(key)
|
||||||
|
for _, value := range values {
|
||||||
|
dst.Add(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyProxyResponseBody(dst http.ResponseWriter, src io.Reader) error {
|
||||||
|
flusher, canFlush := dst.(http.Flusher)
|
||||||
|
buf := make([]byte, 32*1024)
|
||||||
|
|
||||||
|
for {
|
||||||
|
n, err := src.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
if _, writeErr := dst.Write(buf[:n]); writeErr != nil {
|
||||||
|
return writeErr
|
||||||
|
}
|
||||||
|
if canFlush {
|
||||||
|
// TODO(drifkin): Consider conditional flushing so non-streaming
|
||||||
|
// responses don't flush every write and can optimize throughput.
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isHopByHopHeader(name string) bool {
|
||||||
|
_, ok := hopByHopHeaders[strings.ToLower(name)]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func connectionHeaderTokens(header http.Header) map[string]struct{} {
|
||||||
|
tokens := map[string]struct{}{}
|
||||||
|
for _, raw := range header.Values("Connection") {
|
||||||
|
for _, token := range strings.Split(raw, ",") {
|
||||||
|
token = strings.TrimSpace(strings.ToLower(token))
|
||||||
|
if token == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
tokens[token] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
func isConnectionTokenHeader(name string, tokens map[string]struct{}) bool {
|
||||||
|
if len(tokens) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, ok := tokens[strings.ToLower(name)]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
154
server/cloud_proxy_test.go
Normal file
154
server/cloud_proxy_test.go
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCopyProxyRequestHeaders_StripsConnectionTokenHeaders(t *testing.T) {
|
||||||
|
src := http.Header{}
|
||||||
|
src.Add("Connection", "keep-alive, X-Trace-Hop, x-alt-hop")
|
||||||
|
src.Add("X-Trace-Hop", "drop-me")
|
||||||
|
src.Add("X-Alt-Hop", "drop-me-too")
|
||||||
|
src.Add("Keep-Alive", "timeout=5")
|
||||||
|
src.Add("X-End-To-End", "keep-me")
|
||||||
|
|
||||||
|
dst := http.Header{}
|
||||||
|
copyProxyRequestHeaders(dst, src)
|
||||||
|
|
||||||
|
if got := dst.Get("Connection"); got != "" {
|
||||||
|
t.Fatalf("expected Connection to be stripped, got %q", got)
|
||||||
|
}
|
||||||
|
if got := dst.Get("Keep-Alive"); got != "" {
|
||||||
|
t.Fatalf("expected Keep-Alive to be stripped, got %q", got)
|
||||||
|
}
|
||||||
|
if got := dst.Get("X-Trace-Hop"); got != "" {
|
||||||
|
t.Fatalf("expected X-Trace-Hop to be stripped via Connection token, got %q", got)
|
||||||
|
}
|
||||||
|
if got := dst.Get("X-Alt-Hop"); got != "" {
|
||||||
|
t.Fatalf("expected X-Alt-Hop to be stripped via Connection token, got %q", got)
|
||||||
|
}
|
||||||
|
if got := dst.Get("X-End-To-End"); got != "keep-me" {
|
||||||
|
t.Fatalf("expected X-End-To-End to be forwarded, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopyProxyResponseHeaders_StripsConnectionTokenHeaders(t *testing.T) {
|
||||||
|
src := http.Header{}
|
||||||
|
src.Add("Connection", "X-Upstream-Hop")
|
||||||
|
src.Add("X-Upstream-Hop", "drop-me")
|
||||||
|
src.Add("Content-Type", "application/json")
|
||||||
|
src.Add("X-Server-Trace", "keep-me")
|
||||||
|
|
||||||
|
dst := http.Header{}
|
||||||
|
copyProxyResponseHeaders(dst, src)
|
||||||
|
|
||||||
|
if got := dst.Get("Connection"); got != "" {
|
||||||
|
t.Fatalf("expected Connection to be stripped, got %q", got)
|
||||||
|
}
|
||||||
|
if got := dst.Get("X-Upstream-Hop"); got != "" {
|
||||||
|
t.Fatalf("expected X-Upstream-Hop to be stripped via Connection token, got %q", got)
|
||||||
|
}
|
||||||
|
if got := dst.Get("Content-Type"); got != "application/json" {
|
||||||
|
t.Fatalf("expected Content-Type to be forwarded, got %q", got)
|
||||||
|
}
|
||||||
|
if got := dst.Get("X-Server-Trace"); got != "keep-me" {
|
||||||
|
t.Fatalf("expected X-Server-Trace to be forwarded, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveCloudProxyBaseURL_Default(t *testing.T) {
|
||||||
|
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("", gin.ReleaseMode)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if overridden {
|
||||||
|
t.Fatal("expected override=false for empty input")
|
||||||
|
}
|
||||||
|
if baseURL != defaultCloudProxyBaseURL {
|
||||||
|
t.Fatalf("expected default base URL %q, got %q", defaultCloudProxyBaseURL, baseURL)
|
||||||
|
}
|
||||||
|
if signingHost != defaultCloudProxySigningHost {
|
||||||
|
t.Fatalf("expected default signing host %q, got %q", defaultCloudProxySigningHost, signingHost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveCloudProxyBaseURL_ReleaseAllowsLoopback(t *testing.T) {
|
||||||
|
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("http://localhost:8080", gin.ReleaseMode)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if !overridden {
|
||||||
|
t.Fatal("expected override=true")
|
||||||
|
}
|
||||||
|
if baseURL != "http://localhost:8080" {
|
||||||
|
t.Fatalf("unexpected base URL: %q", baseURL)
|
||||||
|
}
|
||||||
|
if signingHost != "localhost" {
|
||||||
|
t.Fatalf("unexpected signing host: %q", signingHost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveCloudProxyBaseURL_ReleaseRejectsNonLoopback(t *testing.T) {
|
||||||
|
_, _, _, err := resolveCloudProxyBaseURL("https://example.com", gin.ReleaseMode)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for non-loopback override in release mode")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveCloudProxyBaseURL_DevAllowsNonLoopbackHTTPS(t *testing.T) {
|
||||||
|
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("https://example.com:8443", gin.DebugMode)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if !overridden {
|
||||||
|
t.Fatal("expected override=true")
|
||||||
|
}
|
||||||
|
if baseURL != "https://example.com:8443" {
|
||||||
|
t.Fatalf("unexpected base URL: %q", baseURL)
|
||||||
|
}
|
||||||
|
if signingHost != "example.com" {
|
||||||
|
t.Fatalf("unexpected signing host: %q", signingHost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveCloudProxyBaseURL_DevRejectsNonLoopbackHTTP(t *testing.T) {
|
||||||
|
_, _, _, err := resolveCloudProxyBaseURL("http://example.com", gin.DebugMode)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for non-loopback http override in dev mode")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildCloudSignatureChallengeIncludesExistingQuery(t *testing.T) {
|
||||||
|
req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&foo=bar", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := buildCloudSignatureChallenge(req, "123")
|
||||||
|
want := "POST,/v1/messages?beta=true&foo=bar&ts=123"
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("challenge mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if req.URL.RawQuery != "beta=true&foo=bar&ts=123" {
|
||||||
|
t.Fatalf("unexpected signed query: %q", req.URL.RawQuery)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildCloudSignatureChallengeOverwritesExistingTimestamp(t *testing.T) {
|
||||||
|
req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&ts=999", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := buildCloudSignatureChallenge(req, "123")
|
||||||
|
want := "POST,/v1/messages?beta=true&ts=123"
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("challenge mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if req.URL.RawQuery != "beta=true&ts=123" {
|
||||||
|
t.Fatalf("unexpected signed query: %q", req.URL.RawQuery)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -65,11 +65,22 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
|||||||
config.Parser = r.Parser
|
config.Parser = r.Parser
|
||||||
config.Requires = r.Requires
|
config.Requires = r.Requires
|
||||||
|
|
||||||
for v := range r.Files {
|
for v, digest := range r.Files {
|
||||||
if !fs.ValidPath(v) {
|
if !fs.ValidPath(v) {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if digest == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": manifest.ErrInvalidDigestFormat.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, digest := range r.Adapters {
|
||||||
|
if digest == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": manifest.ErrInvalidDigestFormat.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
name := model.ParseName(cmp.Or(r.Model, r.Name))
|
name := model.ParseName(cmp.Or(r.Model, r.Name))
|
||||||
@@ -99,19 +110,26 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
|||||||
|
|
||||||
if r.From != "" {
|
if r.From != "" {
|
||||||
slog.Debug("create model from model name", "from", r.From)
|
slog.Debug("create model from model name", "from", r.From)
|
||||||
fromName := model.ParseName(r.From)
|
fromRef, err := parseAndValidateModelRef(r.From)
|
||||||
if !fromName.IsValid() {
|
if err != nil {
|
||||||
ch <- gin.H{"error": errtypes.InvalidModelNameErrMsg, "status": http.StatusBadRequest}
|
ch <- gin.H{"error": errtypes.InvalidModelNameErrMsg, "status": http.StatusBadRequest}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if r.RemoteHost != "" {
|
|
||||||
ru, err := remoteURL(r.RemoteHost)
|
fromName := fromRef.Name
|
||||||
|
remoteHost := r.RemoteHost
|
||||||
|
if fromRef.Source == modelSourceCloud && remoteHost == "" {
|
||||||
|
remoteHost = cloudProxyBaseURL
|
||||||
|
}
|
||||||
|
|
||||||
|
if remoteHost != "" {
|
||||||
|
ru, err := remoteURL(remoteHost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ch <- gin.H{"error": "bad remote", "status": http.StatusBadRequest}
|
ch <- gin.H{"error": "bad remote", "status": http.StatusBadRequest}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
config.RemoteModel = r.From
|
config.RemoteModel = fromRef.Base
|
||||||
config.RemoteHost = ru
|
config.RemoteHost = ru
|
||||||
remote = true
|
remote = true
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -71,6 +71,10 @@ type Model struct {
|
|||||||
Template *template.Template
|
Template *template.Template
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) IsMLX() bool {
|
||||||
|
return m.Config.ModelFormat == "safetensors"
|
||||||
|
}
|
||||||
|
|
||||||
// Capabilities returns the capabilities that the model supports
|
// Capabilities returns the capabilities that the model supports
|
||||||
func (m *Model) Capabilities() []model.Capability {
|
func (m *Model) Capabilities() []model.Capability {
|
||||||
capabilities := []model.Capability{}
|
capabilities := []model.Capability{}
|
||||||
|
|||||||
81
server/model_resolver.go
Normal file
81
server/model_resolver.go
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/ollama/ollama/internal/modelref"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
type modelSource = modelref.ModelSource
|
||||||
|
|
||||||
|
const (
|
||||||
|
modelSourceUnspecified modelSource = modelref.ModelSourceUnspecified
|
||||||
|
modelSourceLocal modelSource = modelref.ModelSourceLocal
|
||||||
|
modelSourceCloud modelSource = modelref.ModelSourceCloud
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errConflictingModelSource = modelref.ErrConflictingSourceSuffix
|
||||||
|
errModelRequired = modelref.ErrModelRequired
|
||||||
|
)
|
||||||
|
|
||||||
|
type parsedModelRef struct {
|
||||||
|
// Original is the caller-provided model string before source parsing.
|
||||||
|
// Example: "gpt-oss:20b:cloud".
|
||||||
|
Original string
|
||||||
|
// Base is the model string after source suffix normalization.
|
||||||
|
// Example: "gpt-oss:20b:cloud" -> "gpt-oss:20b".
|
||||||
|
Base string
|
||||||
|
// Name is Base parsed as a fully-qualified model.Name with defaults applied.
|
||||||
|
// Example: "registry.ollama.ai/library/gpt-oss:20b".
|
||||||
|
Name model.Name
|
||||||
|
// Source captures explicit source intent from the original input.
|
||||||
|
// Example: "gpt-oss:20b:cloud" -> modelSourceCloud.
|
||||||
|
Source modelSource
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAndValidateModelRef(raw string) (parsedModelRef, error) {
|
||||||
|
var zero parsedModelRef
|
||||||
|
|
||||||
|
parsed, err := modelref.ParseRef(raw)
|
||||||
|
if err != nil {
|
||||||
|
return zero, err
|
||||||
|
}
|
||||||
|
|
||||||
|
name := model.ParseName(parsed.Base)
|
||||||
|
if !name.IsValid() {
|
||||||
|
return zero, model.Unqualified(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsedModelRef{
|
||||||
|
Original: parsed.Original,
|
||||||
|
Base: parsed.Base,
|
||||||
|
Name: name,
|
||||||
|
Source: parsed.Source,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseNormalizePullModelRef(raw string) (parsedModelRef, error) {
|
||||||
|
var zero parsedModelRef
|
||||||
|
|
||||||
|
parsedRef, err := modelref.ParseRef(raw)
|
||||||
|
if err != nil {
|
||||||
|
return zero, err
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizedName, _, err := modelref.NormalizePullName(raw)
|
||||||
|
if err != nil {
|
||||||
|
return zero, err
|
||||||
|
}
|
||||||
|
|
||||||
|
name := model.ParseName(normalizedName)
|
||||||
|
if !name.IsValid() {
|
||||||
|
return zero, model.Unqualified(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsedModelRef{
|
||||||
|
Original: parsedRef.Original,
|
||||||
|
Base: normalizedName,
|
||||||
|
Name: name,
|
||||||
|
Source: parsedRef.Source,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
170
server/model_resolver_test.go
Normal file
170
server/model_resolver_test.go
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseModelSelector(t *testing.T) {
|
||||||
|
t.Run("cloud suffix", func(t *testing.T) {
|
||||||
|
got, err := parseAndValidateModelRef("gpt-oss:20b:cloud")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.Source != modelSourceCloud {
|
||||||
|
t.Fatalf("expected source cloud, got %v", got.Source)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.Base != "gpt-oss:20b" {
|
||||||
|
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.Name.String() != "registry.ollama.ai/library/gpt-oss:20b" {
|
||||||
|
t.Fatalf("unexpected resolved name: %q", got.Name.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("legacy cloud suffix", func(t *testing.T) {
|
||||||
|
got, err := parseAndValidateModelRef("gpt-oss:20b-cloud")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.Source != modelSourceCloud {
|
||||||
|
t.Fatalf("expected source cloud, got %v", got.Source)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.Base != "gpt-oss:20b" {
|
||||||
|
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("bare dash cloud name is not explicit cloud", func(t *testing.T) {
|
||||||
|
got, err := parseAndValidateModelRef("my-cloud-model")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.Source != modelSourceUnspecified {
|
||||||
|
t.Fatalf("expected source unspecified, got %v", got.Source)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.Base != "my-cloud-model" {
|
||||||
|
t.Fatalf("expected base my-cloud-model, got %q", got.Base)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("local suffix", func(t *testing.T) {
|
||||||
|
got, err := parseAndValidateModelRef("qwen3:8b:local")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.Source != modelSourceLocal {
|
||||||
|
t.Fatalf("expected source local, got %v", got.Source)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.Base != "qwen3:8b" {
|
||||||
|
t.Fatalf("expected base qwen3:8b, got %q", got.Base)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("conflicting source suffixes fail", func(t *testing.T) {
|
||||||
|
_, err := parseAndValidateModelRef("foo:cloud:local")
|
||||||
|
if !errors.Is(err, errConflictingModelSource) {
|
||||||
|
t.Fatalf("expected errConflictingModelSource, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unspecified source", func(t *testing.T) {
|
||||||
|
got, err := parseAndValidateModelRef("llama3")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.Source != modelSourceUnspecified {
|
||||||
|
t.Fatalf("expected source unspecified, got %v", got.Source)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.Name.Tag != "latest" {
|
||||||
|
t.Fatalf("expected default latest tag, got %q", got.Name.Tag)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unknown suffix is treated as tag", func(t *testing.T) {
|
||||||
|
got, err := parseAndValidateModelRef("gpt-oss:clod")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.Source != modelSourceUnspecified {
|
||||||
|
t.Fatalf("expected source unspecified, got %v", got.Source)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.Name.Tag != "clod" {
|
||||||
|
t.Fatalf("expected tag clod, got %q", got.Name.Tag)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty model fails", func(t *testing.T) {
|
||||||
|
_, err := parseAndValidateModelRef("")
|
||||||
|
if !errors.Is(err, errModelRequired) {
|
||||||
|
t.Fatalf("expected errModelRequired, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid model fails", func(t *testing.T) {
|
||||||
|
_, err := parseAndValidateModelRef("::cloud")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for invalid model")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "unqualified") {
|
||||||
|
t.Fatalf("expected unqualified model error, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePullModelRef(t *testing.T) {
|
||||||
|
t.Run("explicit local is normalized", func(t *testing.T) {
|
||||||
|
got, err := parseNormalizePullModelRef("gpt-oss:20b:local")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.Source != modelSourceLocal {
|
||||||
|
t.Fatalf("expected source local, got %v", got.Source)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.Base != "gpt-oss:20b" {
|
||||||
|
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("explicit cloud with size maps to legacy cloud suffix", func(t *testing.T) {
|
||||||
|
got, err := parseNormalizePullModelRef("gpt-oss:20b:cloud")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
|
||||||
|
}
|
||||||
|
if got.Base != "gpt-oss:20b-cloud" {
|
||||||
|
t.Fatalf("expected base gpt-oss:20b-cloud, got %q", got.Base)
|
||||||
|
}
|
||||||
|
if got.Name.String() != "registry.ollama.ai/library/gpt-oss:20b-cloud" {
|
||||||
|
t.Fatalf("unexpected resolved name: %q", got.Name.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("explicit cloud without size maps to cloud tag", func(t *testing.T) {
|
||||||
|
got, err := parseNormalizePullModelRef("qwen3:cloud")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
|
||||||
|
}
|
||||||
|
if got.Base != "qwen3:cloud" {
|
||||||
|
t.Fatalf("expected base qwen3:cloud, got %q", got.Base)
|
||||||
|
}
|
||||||
|
if got.Name.String() != "registry.ollama.ai/library/qwen3:cloud" {
|
||||||
|
t.Fatalf("unexpected resolved name: %q", got.Name.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -30,6 +30,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
lastMsgIdx := len(msgs) - 1
|
lastMsgIdx := len(msgs) - 1
|
||||||
currMsgIdx := 0
|
currMsgIdx := 0
|
||||||
|
|
||||||
|
if truncate {
|
||||||
// Start with all messages and remove from the front until it fits in context
|
// Start with all messages and remove from the front until it fits in context
|
||||||
for i := 0; i <= lastMsgIdx; i++ {
|
for i := 0; i <= lastMsgIdx; i++ {
|
||||||
// Collect system messages from the portion we're about to skip
|
// Collect system messages from the portion we're about to skip
|
||||||
@@ -57,7 +58,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !truncate || ctxLen <= opts.NumCtx {
|
if ctxLen <= opts.NumCtx {
|
||||||
currMsgIdx = i
|
currMsgIdx = i
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -68,6 +69,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if currMsgIdx > 0 {
|
if currMsgIdx > 0 {
|
||||||
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[currMsgIdx:]))
|
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[currMsgIdx:]))
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package server
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
@@ -366,3 +367,33 @@ func TestChatPromptRendererDoesNotRewriteMessageContent(t *testing.T) {
|
|||||||
t.Fatal("prompt is empty")
|
t.Fatal("prompt is empty")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestChatPromptGLMOcrRendererAddsImageTags(t *testing.T) {
|
||||||
|
msgs := []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "extract text",
|
||||||
|
Images: []api.ImageData{[]byte("img-1"), []byte("img-2")},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
m := Model{
|
||||||
|
Config: model.ConfigV2{Renderer: "glm-ocr"},
|
||||||
|
ProjectorPaths: []string{"vision"},
|
||||||
|
}
|
||||||
|
opts := api.Options{Runner: api.Runner{NumCtx: 8192}}
|
||||||
|
think := false
|
||||||
|
|
||||||
|
prompt, images, err := chatPrompt(t.Context(), &m, mockRunner{}.Tokenize, &opts, msgs, nil, &api.ThinkValue{Value: think}, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := len(images), 2; got != want {
|
||||||
|
t.Fatalf("len(images) = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(prompt, "<|user|>\n[img-0][img-1]extract text") {
|
||||||
|
t.Fatalf("prompt missing glm-ocr image tags, got: %q", prompt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
200
server/routes.go
200
server/routes.go
@@ -62,8 +62,21 @@ const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
|
|||||||
const (
|
const (
|
||||||
cloudErrRemoteInferenceUnavailable = "remote model is unavailable"
|
cloudErrRemoteInferenceUnavailable = "remote model is unavailable"
|
||||||
cloudErrRemoteModelDetailsUnavailable = "remote model details are unavailable"
|
cloudErrRemoteModelDetailsUnavailable = "remote model details are unavailable"
|
||||||
|
cloudErrWebSearchUnavailable = "web search is unavailable"
|
||||||
|
cloudErrWebFetchUnavailable = "web fetch is unavailable"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func writeModelRefParseError(c *gin.Context, err error, fallbackStatus int, fallbackMessage string) {
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, errConflictingModelSource):
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
case errors.Is(err, model.ErrUnqualifiedName):
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
||||||
|
default:
|
||||||
|
c.JSON(fallbackStatus, gin.H{"error": fallbackMessage})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func shouldUseHarmony(model *Model) bool {
|
func shouldUseHarmony(model *Model) bool {
|
||||||
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
|
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
|
||||||
// heuristic to check whether the template expects to be parsed via harmony:
|
// heuristic to check whether the template expects to be parsed via harmony:
|
||||||
@@ -150,7 +163,7 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
|||||||
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
|
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
useImagegen, _ := requestOpts["use_imagegen_runner"].(bool)
|
// Deprecated runner override option; ignore if present.
|
||||||
delete(requestOpts, "use_imagegen_runner")
|
delete(requestOpts, "use_imagegen_runner")
|
||||||
|
|
||||||
opts, err := s.modelOptions(model, requestOpts)
|
opts, err := s.modelOptions(model, requestOpts)
|
||||||
@@ -158,7 +171,7 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
|||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive, useImagegen)
|
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
|
||||||
var runner *runnerRef
|
var runner *runnerRef
|
||||||
select {
|
select {
|
||||||
case runner = <-runnerCh:
|
case runner = <-runnerCh:
|
||||||
@@ -196,14 +209,22 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
name := model.ParseName(req.Model)
|
modelRef, err := parseAndValidateModelRef(req.Model)
|
||||||
if !name.IsValid() {
|
if err != nil {
|
||||||
// Ideally this is "invalid model name" but we're keeping with
|
writeModelRefParseError(c, err, http.StatusNotFound, fmt.Sprintf("model '%s' not found", req.Model))
|
||||||
// what the API currently returns until we can change it.
|
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if modelRef.Source == modelSourceCloud {
|
||||||
|
// TODO(drifkin): evaluate an `/api/*` passthrough for cloud where the
|
||||||
|
// original body (modulo model name normalization) is sent to cloud.
|
||||||
|
req.Model = modelRef.Base
|
||||||
|
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
name := modelRef.Name
|
||||||
|
|
||||||
resolvedName, _, err := s.resolveAlias(name)
|
resolvedName, _, err := s.resolveAlias(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
@@ -237,6 +258,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if modelRef.Source == modelSourceLocal && m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||||
if disabled, _ := internalcloud.Status(); disabled {
|
if disabled, _ := internalcloud.Status(); disabled {
|
||||||
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable)})
|
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable)})
|
||||||
@@ -370,12 +396,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate Think value: string values currently only allowed for harmony/gptoss models
|
|
||||||
if req.Think != nil && req.Think.IsString() && m.Config.Parser != "harmony" {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
caps := []model.Capability{model.CapabilityCompletion}
|
caps := []model.Capability{model.CapabilityCompletion}
|
||||||
if req.Suffix != "" {
|
if req.Suffix != "" {
|
||||||
caps = append(caps, model.CapabilityInsert)
|
caps = append(caps, model.CapabilityInsert)
|
||||||
@@ -484,7 +504,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
// the real chat handler, but doing this as a stopgap to get renderer
|
// the real chat handler, but doing this as a stopgap to get renderer
|
||||||
// support for generate
|
// support for generate
|
||||||
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
|
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
|
||||||
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
|
genTruncate := (req.Truncate == nil || *req.Truncate) && !m.IsMLX()
|
||||||
|
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, genTruncate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
@@ -675,6 +696,18 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
modelRef, err := parseAndValidateModelRef(req.Model)
|
||||||
|
if err != nil {
|
||||||
|
writeModelRefParseError(c, err, http.StatusNotFound, fmt.Sprintf("model '%s' not found", req.Model))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelRef.Source == modelSourceCloud {
|
||||||
|
req.Model = modelRef.Base
|
||||||
|
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var input []string
|
var input []string
|
||||||
|
|
||||||
switch i := req.Input.(type) {
|
switch i := req.Input.(type) {
|
||||||
@@ -697,7 +730,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
name, err := getExistingName(model.ParseName(req.Model))
|
name, err := getExistingName(modelRef.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||||
return
|
return
|
||||||
@@ -844,12 +877,20 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
name := model.ParseName(req.Model)
|
modelRef, err := parseAndValidateModelRef(req.Model)
|
||||||
if !name.IsValid() {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
writeModelRefParseError(c, err, http.StatusBadRequest, "model is required")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if modelRef.Source == modelSourceCloud {
|
||||||
|
req.Model = modelRef.Base
|
||||||
|
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
name := modelRef.Name
|
||||||
|
|
||||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleScheduleError(c, req.Model, err)
|
handleScheduleError(c, req.Model, err)
|
||||||
@@ -891,12 +932,19 @@ func (s *Server) PullHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
name := model.ParseName(cmp.Or(req.Model, req.Name))
|
// TEMP(drifkin): we're temporarily allowing to continue pulling cloud model
|
||||||
if !name.IsValid() {
|
// stub-files until we integrate cloud models into `/api/tags` (in which case
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
// this roundabout way of "adding" cloud models won't be needed anymore). So
|
||||||
|
// right here normalize any `:cloud` models into the legacy-style suffixes
|
||||||
|
// `:<tag>-cloud` and `:cloud`
|
||||||
|
modelRef, err := parseNormalizePullModelRef(cmp.Or(req.Model, req.Name))
|
||||||
|
if err != nil {
|
||||||
|
writeModelRefParseError(c, err, http.StatusBadRequest, errtypes.InvalidModelNameErrMsg)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
name := modelRef.Name
|
||||||
|
|
||||||
name, err = getExistingName(name)
|
name, err = getExistingName(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
@@ -1023,13 +1071,20 @@ func (s *Server) DeleteHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
n := model.ParseName(cmp.Or(r.Model, r.Name))
|
modelRef, err := parseNormalizePullModelRef(cmp.Or(r.Model, r.Name))
|
||||||
if !n.IsValid() {
|
if err != nil {
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, errConflictingModelSource):
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
case errors.Is(err, model.ErrUnqualifiedName):
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
|
||||||
|
default:
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err := getExistingName(n)
|
n, err := getExistingName(modelRef.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))})
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))})
|
||||||
return
|
return
|
||||||
@@ -1078,6 +1133,20 @@ func (s *Server) ShowHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
modelRef, err := parseAndValidateModelRef(req.Model)
|
||||||
|
if err != nil {
|
||||||
|
writeModelRefParseError(c, err, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelRef.Source == modelSourceCloud {
|
||||||
|
req.Model = modelRef.Base
|
||||||
|
proxyCloudJSONRequest(c, req, cloudErrRemoteModelDetailsUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Model = modelRef.Base
|
||||||
|
|
||||||
resp, err := GetModelInfo(req)
|
resp, err := GetModelInfo(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var statusErr api.StatusError
|
var statusErr api.StatusError
|
||||||
@@ -1094,6 +1163,11 @@ func (s *Server) ShowHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if modelRef.Source == modelSourceLocal && resp.RemoteHost != "" {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", modelRef.Original)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, resp)
|
c.JSON(http.StatusOK, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1621,6 +1695,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
|||||||
r.GET("/api/experimental/aliases", s.ListAliasesHandler)
|
r.GET("/api/experimental/aliases", s.ListAliasesHandler)
|
||||||
r.POST("/api/experimental/aliases", s.CreateAliasHandler)
|
r.POST("/api/experimental/aliases", s.CreateAliasHandler)
|
||||||
r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler)
|
r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler)
|
||||||
|
r.POST("/api/experimental/web_search", s.WebSearchExperimentalHandler)
|
||||||
|
r.POST("/api/experimental/web_fetch", s.WebFetchExperimentalHandler)
|
||||||
|
|
||||||
// Inference
|
// Inference
|
||||||
r.GET("/api/ps", s.PsHandler)
|
r.GET("/api/ps", s.PsHandler)
|
||||||
@@ -1630,18 +1706,20 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
|||||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||||
|
|
||||||
// Inference (OpenAI compatibility)
|
// Inference (OpenAI compatibility)
|
||||||
r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)
|
// TODO(cloud-stage-a): apply Modelfile overlay deltas for local models with cloud
|
||||||
r.POST("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler)
|
// parents on v1 request families while preserving this explicit :cloud passthrough.
|
||||||
r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler)
|
r.POST("/v1/chat/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ChatMiddleware(), s.ChatHandler)
|
||||||
|
r.POST("/v1/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.CompletionsMiddleware(), s.GenerateHandler)
|
||||||
|
r.POST("/v1/embeddings", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
r.GET("/v1/models/:model", cloudModelPathPassthroughMiddleware(cloudErrRemoteModelDetailsUnavailable), middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
r.POST("/v1/responses", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||||
// OpenAI-compatible image generation endpoints
|
// OpenAI-compatible image generation endpoints
|
||||||
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
r.POST("/v1/images/generations", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||||
r.POST("/v1/images/edits", middleware.ImageEditsMiddleware(), s.GenerateHandler)
|
r.POST("/v1/images/edits", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageEditsMiddleware(), s.GenerateHandler)
|
||||||
|
|
||||||
// Inference (Anthropic compatibility)
|
// Inference (Anthropic compatibility)
|
||||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
r.POST("/v1/messages", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||||
|
|
||||||
if rc != nil {
|
if rc != nil {
|
||||||
// wrap old with new
|
// wrap old with new
|
||||||
@@ -1863,6 +1941,29 @@ func (s *Server) StatusHandler(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) WebSearchExperimentalHandler(c *gin.Context) {
|
||||||
|
s.webExperimentalProxyHandler(c, "/api/web_search", cloudErrWebSearchUnavailable)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) WebFetchExperimentalHandler(c *gin.Context) {
|
||||||
|
s.webExperimentalProxyHandler(c, "/api/web_fetch", cloudErrWebFetchUnavailable)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) webExperimentalProxyHandler(c *gin.Context, proxyPath, disabledOperation string) {
|
||||||
|
body, err := readRequestBody(c.Request)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(bytes.TrimSpace(body)) == 0 {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyCloudRequestWithPath(c, body, proxyPath, disabledOperation)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) WhoamiHandler(c *gin.Context) {
|
func (s *Server) WhoamiHandler(c *gin.Context) {
|
||||||
// todo allow other hosts
|
// todo allow other hosts
|
||||||
u, err := url.Parse("https://ollama.com")
|
u, err := url.Parse("https://ollama.com")
|
||||||
@@ -1951,6 +2052,9 @@ func (s *Server) PsHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
if v.llama != nil {
|
if v.llama != nil {
|
||||||
mr.ContextLength = v.llama.ContextLength()
|
mr.ContextLength = v.llama.ContextLength()
|
||||||
|
total, vram := v.llama.MemorySize()
|
||||||
|
mr.Size = int64(total)
|
||||||
|
mr.SizeVRAM = int64(vram)
|
||||||
}
|
}
|
||||||
// The scheduler waits to set expiresAt, so if a model is loading it's
|
// The scheduler waits to set expiresAt, so if a model is loading it's
|
||||||
// possible that it will be set to the unix epoch. For those cases, just
|
// possible that it will be set to the unix epoch. For those cases, just
|
||||||
@@ -1997,12 +2101,24 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
name := model.ParseName(req.Model)
|
modelRef, err := parseAndValidateModelRef(req.Model)
|
||||||
if !name.IsValid() {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
writeModelRefParseError(c, err, http.StatusBadRequest, "model is required")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if modelRef.Source == modelSourceCloud {
|
||||||
|
req.Model = modelRef.Base
|
||||||
|
if c.GetBool(legacyCloudAnthropicKey) {
|
||||||
|
proxyCloudJSONRequestWithPath(c, req, "/api/chat", cloudErrRemoteInferenceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
name := modelRef.Name
|
||||||
|
|
||||||
resolvedName, _, err := s.resolveAlias(name)
|
resolvedName, _, err := s.resolveAlias(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
@@ -2034,6 +2150,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if modelRef.Source == modelSourceLocal && m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// expire the runner
|
// expire the runner
|
||||||
if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
|
if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
|
||||||
s.sched.expireRunner(m)
|
s.sched.expireRunner(m)
|
||||||
@@ -2213,6 +2334,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
truncate := req.Truncate == nil || *req.Truncate
|
truncate := req.Truncate == nil || *req.Truncate
|
||||||
|
if m.IsMLX() {
|
||||||
|
truncate = false
|
||||||
|
}
|
||||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("chat prompt error", "error", err)
|
slog.Error("chat prompt error", "error", err)
|
||||||
@@ -2233,12 +2357,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate Think value: string values currently only allowed for harmony/gptoss models
|
|
||||||
if req.Think != nil && req.Think.IsString() && m.Config.Parser != "harmony" {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var thinkingState *thinking.Parser
|
var thinkingState *thinking.Parser
|
||||||
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||||
if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" {
|
if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" {
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -144,6 +144,37 @@ func TestCreateFromBin(t *testing.T) {
|
|||||||
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
|
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
|
||||||
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("empty file digest", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||||
|
Name: "my-gguf-model",
|
||||||
|
Files: map[string]string{"0.gguf": ""},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
if !strings.Contains(w.Body.String(), "invalid digest format") {
|
||||||
|
t.Errorf("expected invalid digest format error, got:\n%s", w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty adapter digest", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||||
|
Name: "my-gguf-model",
|
||||||
|
Files: map[string]string{"0.gguf": digest},
|
||||||
|
Adapters: map[string]string{"adapter.gguf": ""},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
if !strings.Contains(w.Body.String(), "invalid digest format") {
|
||||||
|
t.Errorf("expected invalid digest format error, got:\n%s", w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateFromModel(t *testing.T) {
|
func TestCreateFromModel(t *testing.T) {
|
||||||
@@ -763,6 +794,43 @@ func TestCreateAndShowRemoteModel(t *testing.T) {
|
|||||||
fmt.Printf("resp = %#v\n", resp)
|
fmt.Printf("resp = %#v\n", resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateFromCloudSourceSuffix(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
var s Server
|
||||||
|
|
||||||
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||||
|
Model: "test-cloud-from-suffix",
|
||||||
|
From: "gpt-oss:20b:cloud",
|
||||||
|
Info: map[string]any{
|
||||||
|
"capabilities": []string{"completion"},
|
||||||
|
},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
w = createRequest(t, s.ShowHandler, api.ShowRequest{Model: "test-cloud-from-suffix"})
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp api.ShowResponse
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.RemoteHost != "https://ollama.com:443" {
|
||||||
|
t.Fatalf("expected remote host https://ollama.com:443, got %q", resp.RemoteHost)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.RemoteModel != "gpt-oss:20b" {
|
||||||
|
t.Fatalf("expected remote model gpt-oss:20b, got %q", resp.RemoteModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCreateLicenses(t *testing.T) {
|
func TestCreateLicenses(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
@@ -111,3 +111,32 @@ func TestDeleteDuplicateLayers(t *testing.T) {
|
|||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDeleteCloudSourceNormalizesToLegacyName(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
p := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
|
|
||||||
|
var s Server
|
||||||
|
|
||||||
|
_, digest := createBinFile(t, nil, nil)
|
||||||
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||||
|
Name: "gpt-oss:20b-cloud",
|
||||||
|
Files: map[string]string{"test.gguf": digest},
|
||||||
|
})
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "gpt-oss", "20b-cloud"),
|
||||||
|
})
|
||||||
|
|
||||||
|
w = createRequest(t, s.DeleteHandler, api.DeleteRequest{Name: "gpt-oss:20b:cloud"})
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d (%s)", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
|
||||||
|
}
|
||||||
|
|||||||
335
server/routes_web_experimental_test.go
Normal file
335
server/routes_web_experimental_test.go
Normal file
@@ -0,0 +1,335 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||||
|
)
|
||||||
|
|
||||||
|
type webExperimentalUpstreamCapture struct {
|
||||||
|
path string
|
||||||
|
body string
|
||||||
|
header http.Header
|
||||||
|
}
|
||||||
|
|
||||||
|
func newWebExperimentalUpstream(t *testing.T, responseBody string) (*httptest.Server, *webExperimentalUpstreamCapture) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
capture := &webExperimentalUpstreamCapture{}
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
payload, _ := io.ReadAll(r.Body)
|
||||||
|
capture.path = r.URL.Path
|
||||||
|
capture.body = string(payload)
|
||||||
|
capture.header = r.Header.Clone()
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte(responseBody))
|
||||||
|
}))
|
||||||
|
|
||||||
|
return srv, capture
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExperimentalWebEndpointsPassthrough(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
setTestHome(t, t.TempDir())
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
localPath string
|
||||||
|
upstreamPath string
|
||||||
|
requestBody string
|
||||||
|
responseBody string
|
||||||
|
assertBody string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "web_search",
|
||||||
|
localPath: "/api/experimental/web_search",
|
||||||
|
upstreamPath: "/api/web_search",
|
||||||
|
requestBody: `{"query":"what is ollama?","max_results":3}`,
|
||||||
|
responseBody: `{"results":[{"title":"Ollama","url":"https://ollama.com","content":"Cloud models are now available"}]}`,
|
||||||
|
assertBody: `"query":"what is ollama?"`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "web_fetch",
|
||||||
|
localPath: "/api/experimental/web_fetch",
|
||||||
|
upstreamPath: "/api/web_fetch",
|
||||||
|
requestBody: `{"url":"https://ollama.com"}`,
|
||||||
|
responseBody: `{"title":"Ollama","content":"Cloud models are now available","links":["https://ollama.com/"]}`,
|
||||||
|
assertBody: `"url":"https://ollama.com"`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
upstream, capture := newWebExperimentalUpstream(t, tt.responseBody)
|
||||||
|
defer upstream.Close()
|
||||||
|
|
||||||
|
original := cloudProxyBaseURL
|
||||||
|
cloudProxyBaseURL = upstream.URL
|
||||||
|
t.Cleanup(func() { cloudProxyBaseURL = original })
|
||||||
|
|
||||||
|
s := &Server{}
|
||||||
|
router, err := s.GenerateRoutes(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
local := httptest.NewServer(router)
|
||||||
|
defer local.Close()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+tt.localPath, bytes.NewBufferString(tt.requestBody))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer should-forward")
|
||||||
|
req.Header.Set("X-Test-Header", "web-experimental")
|
||||||
|
|
||||||
|
resp, err := local.Client().Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
if capture.path != tt.upstreamPath {
|
||||||
|
t.Fatalf("expected upstream path %q, got %q", tt.upstreamPath, capture.path)
|
||||||
|
}
|
||||||
|
if !bytes.Contains([]byte(capture.body), []byte(tt.assertBody)) {
|
||||||
|
t.Fatalf("expected upstream body to contain %q, got %q", tt.assertBody, capture.body)
|
||||||
|
}
|
||||||
|
if got := capture.header.Get("Authorization"); got != "Bearer should-forward" {
|
||||||
|
t.Fatalf("expected forwarded Authorization header, got %q", got)
|
||||||
|
}
|
||||||
|
if got := capture.header.Get("X-Test-Header"); got != "web-experimental" {
|
||||||
|
t.Fatalf("expected forwarded X-Test-Header=web-experimental, got %q", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExperimentalWebEndpointsMissingBody(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
setTestHome(t, t.TempDir())
|
||||||
|
|
||||||
|
s := &Server{}
|
||||||
|
router, err := s.GenerateRoutes(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
local := httptest.NewServer(router)
|
||||||
|
defer local.Close()
|
||||||
|
|
||||||
|
tests := []string{
|
||||||
|
"/api/experimental/web_search",
|
||||||
|
"/api/experimental/web_fetch",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, path := range tests {
|
||||||
|
t.Run(path, func(t *testing.T) {
|
||||||
|
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+path, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := local.Client().Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status 400, got %d (%s)", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
if string(body) != `{"error":"missing request body"}` {
|
||||||
|
t.Fatalf("unexpected response body: %s", string(body))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExperimentalWebEndpointsCloudDisabled(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
setTestHome(t, t.TempDir())
|
||||||
|
t.Setenv("OLLAMA_NO_CLOUD", "1")
|
||||||
|
|
||||||
|
s := &Server{}
|
||||||
|
router, err := s.GenerateRoutes(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
local := httptest.NewServer(router)
|
||||||
|
defer local.Close()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
request string
|
||||||
|
operation string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "web_search",
|
||||||
|
path: "/api/experimental/web_search",
|
||||||
|
request: `{"query":"latest ollama release"}`,
|
||||||
|
operation: cloudErrWebSearchUnavailable,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "web_fetch",
|
||||||
|
path: "/api/experimental/web_fetch",
|
||||||
|
request: `{"url":"https://ollama.com"}`,
|
||||||
|
operation: cloudErrWebFetchUnavailable,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+tt.path, bytes.NewBufferString(tt.request))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := local.Client().Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
if resp.StatusCode != http.StatusForbidden {
|
||||||
|
t.Fatalf("expected status 403, got %d (%s)", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var got map[string]string
|
||||||
|
if err := json.Unmarshal(body, &got); err != nil {
|
||||||
|
t.Fatalf("expected json error body, got: %q", string(body))
|
||||||
|
}
|
||||||
|
if got["error"] != internalcloud.DisabledError(tt.operation) {
|
||||||
|
t.Fatalf("unexpected error message: %q", got["error"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExperimentalWebEndpointSigningFailureReturnsUnauthorized(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
setTestHome(t, t.TempDir())
|
||||||
|
|
||||||
|
origSignRequest := cloudProxySignRequest
|
||||||
|
origSigninURL := cloudProxySigninURL
|
||||||
|
cloudProxySignRequest = func(context.Context, *http.Request) error {
|
||||||
|
return errors.New("ssh: no key found")
|
||||||
|
}
|
||||||
|
cloudProxySigninURL = func() (string, error) {
|
||||||
|
return "https://ollama.com/signin/example", nil
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cloudProxySignRequest = origSignRequest
|
||||||
|
cloudProxySigninURL = origSigninURL
|
||||||
|
})
|
||||||
|
|
||||||
|
s := &Server{}
|
||||||
|
router, err := s.GenerateRoutes(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
local := httptest.NewServer(router)
|
||||||
|
defer local.Close()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/experimental/web_search", bytes.NewBufferString(`{"query":"hello"}`))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := local.Client().Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized {
|
||||||
|
t.Fatalf("expected status 401, got %d (%s)", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var got map[string]any
|
||||||
|
if err := json.Unmarshal(body, &got); err != nil {
|
||||||
|
t.Fatalf("expected json error body, got: %q", string(body))
|
||||||
|
}
|
||||||
|
if got["error"] != "unauthorized" {
|
||||||
|
t.Fatalf("unexpected error message: %v", got["error"])
|
||||||
|
}
|
||||||
|
if got["signin_url"] != "https://ollama.com/signin/example" {
|
||||||
|
t.Fatalf("unexpected signin_url: %v", got["signin_url"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExperimentalWebEndpointSigningFailureWithoutSigninURL(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
setTestHome(t, t.TempDir())
|
||||||
|
|
||||||
|
origSignRequest := cloudProxySignRequest
|
||||||
|
origSigninURL := cloudProxySigninURL
|
||||||
|
cloudProxySignRequest = func(context.Context, *http.Request) error {
|
||||||
|
return errors.New("ssh: no key found")
|
||||||
|
}
|
||||||
|
cloudProxySigninURL = func() (string, error) {
|
||||||
|
return "", errors.New("key missing")
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cloudProxySignRequest = origSignRequest
|
||||||
|
cloudProxySigninURL = origSigninURL
|
||||||
|
})
|
||||||
|
|
||||||
|
s := &Server{}
|
||||||
|
router, err := s.GenerateRoutes(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
local := httptest.NewServer(router)
|
||||||
|
defer local.Close()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/experimental/web_fetch", bytes.NewBufferString(`{"url":"https://ollama.com"}`))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := local.Client().Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized {
|
||||||
|
t.Fatalf("expected status 401, got %d (%s)", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var got map[string]any
|
||||||
|
if err := json.Unmarshal(body, &got); err != nil {
|
||||||
|
t.Fatalf("expected json error body, got: %q", string(body))
|
||||||
|
}
|
||||||
|
if got["error"] != "unauthorized" {
|
||||||
|
t.Fatalf("unexpected error message: %v", got["error"])
|
||||||
|
}
|
||||||
|
if _, ok := got["signin_url"]; ok {
|
||||||
|
t.Fatalf("did not expect signin_url when helper fails, got %v", got["signin_url"])
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -33,7 +33,6 @@ type LlmRequest struct {
|
|||||||
successCh chan *runnerRef
|
successCh chan *runnerRef
|
||||||
errCh chan error
|
errCh chan error
|
||||||
schedAttempts uint
|
schedAttempts uint
|
||||||
useImagegen bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Scheduler struct {
|
type Scheduler struct {
|
||||||
@@ -106,7 +105,7 @@ func schedulerModelKey(m *Model) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// context must be canceled to decrement ref count and release the runner
|
// context must be canceled to decrement ref count and release the runner
|
||||||
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration, useImagegen bool) (chan *runnerRef, chan error) {
|
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
|
||||||
if opts.NumCtx < 4 {
|
if opts.NumCtx < 4 {
|
||||||
opts.NumCtx = 4
|
opts.NumCtx = 4
|
||||||
}
|
}
|
||||||
@@ -123,7 +122,6 @@ func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, ses
|
|||||||
sessionDuration: sessionDuration,
|
sessionDuration: sessionDuration,
|
||||||
successCh: make(chan *runnerRef, 1),
|
successCh: make(chan *runnerRef, 1),
|
||||||
errCh: make(chan error, 1),
|
errCh: make(chan error, 1),
|
||||||
useImagegen: useImagegen,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
key := schedulerModelKey(req.model)
|
key := schedulerModelKey(req.model)
|
||||||
@@ -231,7 +229,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check for experimental safetensors LLM models
|
// Check for experimental safetensors LLM models
|
||||||
if pending.model.Config.ModelFormat == "safetensors" {
|
if pending.model.IsMLX() {
|
||||||
if slices.Contains(pending.model.Config.Capabilities, "completion") {
|
if slices.Contains(pending.model.Config.Capabilities, "completion") {
|
||||||
// LLM model with safetensors format - use MLX runner
|
// LLM model with safetensors format - use MLX runner
|
||||||
if s.loadMLX(pending) {
|
if s.loadMLX(pending) {
|
||||||
@@ -536,6 +534,7 @@ iGPUScan:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
totalSize, vramSize := llama.MemorySize()
|
||||||
runner := &runnerRef{
|
runner := &runnerRef{
|
||||||
model: req.model,
|
model: req.model,
|
||||||
modelPath: req.model.ModelPath,
|
modelPath: req.model.ModelPath,
|
||||||
@@ -545,8 +544,8 @@ iGPUScan:
|
|||||||
sessionDuration: sessionDuration,
|
sessionDuration: sessionDuration,
|
||||||
gpus: gpuIDs,
|
gpus: gpuIDs,
|
||||||
discreteGPUs: discreteGPUs,
|
discreteGPUs: discreteGPUs,
|
||||||
vramSize: llama.VRAMSize(),
|
totalSize: totalSize,
|
||||||
totalSize: llama.TotalSize(),
|
vramSize: vramSize,
|
||||||
loading: true,
|
loading: true,
|
||||||
pid: llama.Pid(),
|
pid: llama.Pid(),
|
||||||
}
|
}
|
||||||
@@ -592,20 +591,15 @@ iGPUScan:
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadMLX loads an experimental safetensors model using the unified MLX runner.
|
// loadMLX loads an experimental safetensors model using MLX runners.
|
||||||
// This supports both LLM (completion) and image generation models.
|
// Image models use x/imagegen; LLM models use x/mlxrunner.
|
||||||
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||||
modelName := req.model.ShortName
|
modelName := req.model.ShortName
|
||||||
var server llm.LlamaServer
|
var server llm.LlamaServer
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
isImagegen := false
|
|
||||||
if slices.Contains(req.model.Config.Capabilities, "image") {
|
if slices.Contains(req.model.Config.Capabilities, "image") {
|
||||||
server, err = imagegen.NewServer(modelName, imagegen.ModeImageGen)
|
server, err = imagegen.NewServer(modelName)
|
||||||
isImagegen = true
|
|
||||||
} else if req.useImagegen {
|
|
||||||
server, err = imagegen.NewServer(modelName, imagegen.ModeLLM)
|
|
||||||
isImagegen = true
|
|
||||||
} else {
|
} else {
|
||||||
server, err = mlxrunner.NewClient(modelName)
|
server, err = mlxrunner.NewClient(modelName)
|
||||||
}
|
}
|
||||||
@@ -619,6 +613,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
|||||||
sessionDuration = req.sessionDuration.Duration
|
sessionDuration = req.sessionDuration.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
totalSize, vramSize := server.MemorySize()
|
||||||
runner := &runnerRef{
|
runner := &runnerRef{
|
||||||
model: req.model,
|
model: req.model,
|
||||||
modelPath: req.model.ModelPath,
|
modelPath: req.model.ModelPath,
|
||||||
@@ -626,10 +621,10 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
|||||||
llama: server,
|
llama: server,
|
||||||
Options: &req.opts,
|
Options: &req.opts,
|
||||||
loading: false,
|
loading: false,
|
||||||
isImagegen: isImagegen,
|
isImagegen: slices.Contains(req.model.Config.Capabilities, "image"),
|
||||||
sessionDuration: sessionDuration,
|
sessionDuration: sessionDuration,
|
||||||
totalSize: server.TotalSize(),
|
totalSize: totalSize,
|
||||||
vramSize: server.VRAMSize(),
|
vramSize: vramSize,
|
||||||
}
|
}
|
||||||
|
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
@@ -735,8 +730,8 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
|||||||
runner.refMu.Lock()
|
runner.refMu.Lock()
|
||||||
defer runner.refMu.Unlock()
|
defer runner.refMu.Unlock()
|
||||||
|
|
||||||
// Check if runner type (imagegen vs mlxrunner) matches what's requested
|
// Check if runner type (imagegen vs mlxrunner) matches what's requested.
|
||||||
wantImagegen := req.useImagegen || slices.Contains(req.model.Config.Capabilities, "image")
|
wantImagegen := slices.Contains(req.model.Config.Capabilities, "image")
|
||||||
if runner.isImagegen != wantImagegen {
|
if runner.isImagegen != wantImagegen {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -762,7 +757,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
|
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
|
||||||
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
|
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
|
||||||
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
|
(!runner.model.IsMLX() && !reflect.DeepEqual(optsExisting, optsNew)) || // have the runner options changed?
|
||||||
runner.llama.Ping(ctx) != nil {
|
runner.llama.Ping(ctx) != nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -408,10 +408,10 @@ func TestSchedGetRunner(t *testing.T) {
|
|||||||
s.getSystemInfoFn = getSystemInfoFn
|
s.getSystemInfoFn = getSystemInfoFn
|
||||||
s.newServerFn = a.newServer
|
s.newServerFn = a.newServer
|
||||||
slog.Info("a")
|
slog.Info("a")
|
||||||
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration, false)
|
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration)
|
||||||
require.Len(t, s.pendingReqCh, 1)
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
slog.Info("b")
|
slog.Info("b")
|
||||||
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration, false)
|
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration)
|
||||||
require.Len(t, s.pendingReqCh, 1)
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
require.Empty(t, successCh1b)
|
require.Empty(t, successCh1b)
|
||||||
require.Len(t, errCh1b, 1)
|
require.Len(t, errCh1b, 1)
|
||||||
@@ -435,7 +435,7 @@ func TestSchedGetRunner(t *testing.T) {
|
|||||||
|
|
||||||
c.req.model.ModelPath = "bad path"
|
c.req.model.ModelPath = "bad path"
|
||||||
slog.Info("c")
|
slog.Info("c")
|
||||||
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration, false)
|
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration)
|
||||||
// Starts in pending channel, then should be quickly processed to return an error
|
// Starts in pending channel, then should be quickly processed to return an error
|
||||||
time.Sleep(50 * time.Millisecond) // Long enough for the "a" model to expire and unload
|
time.Sleep(50 * time.Millisecond) // Long enough for the "a" model to expire and unload
|
||||||
require.Empty(t, successCh1c)
|
require.Empty(t, successCh1c)
|
||||||
@@ -470,7 +470,7 @@ func TestSchedGetRunnerUsesDigestKeyWhenModelPathEmpty(t *testing.T) {
|
|||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
reqModel := &Model{Name: "safetensors-b", Digest: "sha-b"}
|
reqModel := &Model{Name: "safetensors-b", Digest: "sha-b"}
|
||||||
successCh, errCh := s.GetRunner(ctx, reqModel, opts, nil, false)
|
successCh, errCh := s.GetRunner(ctx, reqModel, opts, nil)
|
||||||
|
|
||||||
require.Empty(t, successCh)
|
require.Empty(t, successCh)
|
||||||
require.Empty(t, errCh)
|
require.Empty(t, errCh)
|
||||||
@@ -499,7 +499,7 @@ func TestSchedGetRunnerReusesSameDigestWhenModelPathEmpty(t *testing.T) {
|
|||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
reqCtx, cancelReq := context.WithCancel(ctx)
|
reqCtx, cancelReq := context.WithCancel(ctx)
|
||||||
successCh, errCh := s.GetRunner(reqCtx, &Model{Name: "safetensors-a-copy", Digest: "sha-a"}, opts, nil, false)
|
successCh, errCh := s.GetRunner(reqCtx, &Model{Name: "safetensors-a-copy", Digest: "sha-a"}, opts, nil)
|
||||||
cancelReq()
|
cancelReq()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@@ -574,7 +574,7 @@ func TestSchedPrematureExpired(t *testing.T) {
|
|||||||
s.getGpuFn = getGpuFn
|
s.getGpuFn = getGpuFn
|
||||||
s.getSystemInfoFn = getSystemInfoFn
|
s.getSystemInfoFn = getSystemInfoFn
|
||||||
s.newServerFn = scenario1a.newServer
|
s.newServerFn = scenario1a.newServer
|
||||||
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration, false)
|
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
|
||||||
require.Len(t, s.pendingReqCh, 1)
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
s.Run(ctx)
|
s.Run(ctx)
|
||||||
select {
|
select {
|
||||||
@@ -861,8 +861,7 @@ func (s *mockLlm) Close() error {
|
|||||||
s.closeCalled = true
|
s.closeCalled = true
|
||||||
return s.closeResp
|
return s.closeResp
|
||||||
}
|
}
|
||||||
func (s *mockLlm) VRAMSize() uint64 { return s.vramSize }
|
func (s *mockLlm) MemorySize() (uint64, uint64) { return s.totalSize, s.vramSize }
|
||||||
func (s *mockLlm) TotalSize() uint64 { return s.totalSize }
|
|
||||||
func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] }
|
func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] }
|
||||||
func (s *mockLlm) Pid() int { return -1 }
|
func (s *mockLlm) Pid() int { return -1 }
|
||||||
func (s *mockLlm) GetPort() int { return -1 }
|
func (s *mockLlm) GetPort() int { return -1 }
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||||
|
"github.com/ollama/ollama/internal/modelref"
|
||||||
"github.com/ollama/ollama/progress"
|
"github.com/ollama/ollama/progress"
|
||||||
"github.com/ollama/ollama/readline"
|
"github.com/ollama/ollama/readline"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
@@ -43,7 +44,7 @@ const (
|
|||||||
// isLocalModel checks if the model is running locally (not a cloud model).
|
// isLocalModel checks if the model is running locally (not a cloud model).
|
||||||
// TODO: Improve local/cloud model identification - could check model metadata
|
// TODO: Improve local/cloud model identification - could check model metadata
|
||||||
func isLocalModel(modelName string) bool {
|
func isLocalModel(modelName string) bool {
|
||||||
return !strings.HasSuffix(modelName, "-cloud")
|
return !modelref.HasExplicitCloudSource(modelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// isLocalServer checks if connecting to a local Ollama server.
|
// isLocalServer checks if connecting to a local Ollama server.
|
||||||
|
|||||||
@@ -22,12 +22,22 @@ func TestIsLocalModel(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "cloud model",
|
name: "cloud model",
|
||||||
modelName: "gpt-4-cloud",
|
modelName: "gpt-oss:latest-cloud",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cloud model with :cloud suffix",
|
||||||
|
modelName: "gpt-oss:cloud",
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "cloud model with version",
|
name: "cloud model with version",
|
||||||
modelName: "claude-3-cloud",
|
modelName: "gpt-oss:20b-cloud",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cloud model with version and :cloud suffix",
|
||||||
|
modelName: "gpt-oss:20b:cloud",
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -134,7 +144,7 @@ func TestTruncateToolOutput(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "long output cloud model - uses 10k limit",
|
name: "long output cloud model - uses 10k limit",
|
||||||
output: string(localLimitOutput), // 20k chars, under 10k token limit
|
output: string(localLimitOutput), // 20k chars, under 10k token limit
|
||||||
modelName: "gpt-4-cloud",
|
modelName: "gpt-oss:latest-cloud",
|
||||||
host: "",
|
host: "",
|
||||||
shouldTrim: false,
|
shouldTrim: false,
|
||||||
expectedLimit: defaultTokenLimit,
|
expectedLimit: defaultTokenLimit,
|
||||||
@@ -142,7 +152,7 @@ func TestTruncateToolOutput(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "very long output cloud model - trimmed at 10k",
|
name: "very long output cloud model - trimmed at 10k",
|
||||||
output: string(defaultLimitOutput),
|
output: string(defaultLimitOutput),
|
||||||
modelName: "gpt-4-cloud",
|
modelName: "gpt-oss:latest-cloud",
|
||||||
host: "",
|
host: "",
|
||||||
shouldTrim: true,
|
shouldTrim: true,
|
||||||
expectedLimit: defaultTokenLimit,
|
expectedLimit: defaultTokenLimit,
|
||||||
|
|||||||
@@ -13,9 +13,12 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/manifest"
|
"github.com/ollama/ollama/manifest"
|
||||||
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/progress"
|
"github.com/ollama/ollama/progress"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/x/create"
|
"github.com/ollama/ollama/x/create"
|
||||||
@@ -32,6 +35,74 @@ type ModelfileConfig struct {
|
|||||||
License string
|
License string
|
||||||
Parser string
|
Parser string
|
||||||
Renderer string
|
Renderer string
|
||||||
|
Parameters map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
var ignoredModelfileParameters = []string{
|
||||||
|
"penalize_newline",
|
||||||
|
"low_vram",
|
||||||
|
"f16_kv",
|
||||||
|
"logits_all",
|
||||||
|
"vocab_only",
|
||||||
|
"use_mlock",
|
||||||
|
"mirostat",
|
||||||
|
"mirostat_tau",
|
||||||
|
"mirostat_eta",
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfigFromModelfile extracts the model directory and x/create-specific
|
||||||
|
// Modelfile configuration from a parsed Modelfile.
|
||||||
|
func ConfigFromModelfile(modelfile *parser.Modelfile) (string, *ModelfileConfig, error) {
|
||||||
|
var modelDir string
|
||||||
|
mfConfig := &ModelfileConfig{}
|
||||||
|
|
||||||
|
for _, cmd := range modelfile.Commands {
|
||||||
|
switch cmd.Name {
|
||||||
|
case "model":
|
||||||
|
modelDir = cmd.Args
|
||||||
|
case "template":
|
||||||
|
mfConfig.Template = cmd.Args
|
||||||
|
case "system":
|
||||||
|
mfConfig.System = cmd.Args
|
||||||
|
case "license":
|
||||||
|
mfConfig.License = cmd.Args
|
||||||
|
case "parser":
|
||||||
|
mfConfig.Parser = cmd.Args
|
||||||
|
case "renderer":
|
||||||
|
mfConfig.Renderer = cmd.Args
|
||||||
|
case "adapter", "message", "requires":
|
||||||
|
continue
|
||||||
|
default:
|
||||||
|
if slices.Contains(ignoredModelfileParameters, cmd.Name) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ps, err := api.FormatParams(map[string][]string{cmd.Name: {cmd.Args}})
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if mfConfig.Parameters == nil {
|
||||||
|
mfConfig.Parameters = make(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range ps {
|
||||||
|
if ks, ok := mfConfig.Parameters[k].([]string); ok {
|
||||||
|
mfConfig.Parameters[k] = append(ks, v.([]string)...)
|
||||||
|
} else if vs, ok := v.([]string); ok {
|
||||||
|
mfConfig.Parameters[k] = vs
|
||||||
|
} else {
|
||||||
|
mfConfig.Parameters[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelDir == "" {
|
||||||
|
modelDir = "."
|
||||||
|
}
|
||||||
|
|
||||||
|
return modelDir, mfConfig, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateOptions holds all options for model creation.
|
// CreateOptions holds all options for model creation.
|
||||||
@@ -39,7 +110,7 @@ type CreateOptions struct {
|
|||||||
ModelName string
|
ModelName string
|
||||||
ModelDir string
|
ModelDir string
|
||||||
Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization
|
Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization
|
||||||
Modelfile *ModelfileConfig // template/system/license/parser/renderer from Modelfile
|
Modelfile *ModelfileConfig // template/system/license/parser/renderer/parameters from Modelfile
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateModel imports a model from a local directory.
|
// CreateModel imports a model from a local directory.
|
||||||
@@ -351,6 +422,19 @@ func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
|
|||||||
layers = append(layers, layer)
|
layers = append(layers, layer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(mf.Parameters) > 0 {
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := json.NewEncoder(&b).Encode(mf.Parameters); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to encode parameters: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create params layer: %w", err)
|
||||||
|
}
|
||||||
|
layers = append(layers, layer)
|
||||||
|
}
|
||||||
|
|
||||||
return layers, nil
|
return layers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,13 @@
|
|||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/manifest"
|
||||||
|
"github.com/ollama/ollama/parser"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestModelfileConfig(t *testing.T) {
|
func TestModelfileConfig(t *testing.T) {
|
||||||
@@ -31,6 +37,40 @@ func TestModelfileConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConfigFromModelfile(t *testing.T) {
|
||||||
|
modelfile, err := parser.ParseFile(strings.NewReader(`
|
||||||
|
FROM ./model
|
||||||
|
TEMPLATE {{ .Prompt }}
|
||||||
|
PARAMETER temperature 0.7
|
||||||
|
PARAMETER stop USER:
|
||||||
|
PARAMETER stop ASSISTANT:
|
||||||
|
`))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
modelDir, mfConfig, err := ConfigFromModelfile(modelfile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelDir != "./model" {
|
||||||
|
t.Fatalf("modelDir = %q, want %q", modelDir, "./model")
|
||||||
|
}
|
||||||
|
|
||||||
|
if mfConfig.Template != "{{ .Prompt }}" {
|
||||||
|
t.Fatalf("Template = %q, want %q", mfConfig.Template, "{{ .Prompt }}")
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := mfConfig.Parameters["temperature"]; got != float32(0.7) {
|
||||||
|
t.Fatalf("temperature = %#v, want %v", got, float32(0.7))
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := mfConfig.Parameters["stop"]; got == nil || len(got.([]string)) != 2 {
|
||||||
|
t.Fatalf("unexpected stop params: %#v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestModelfileConfig_Empty(t *testing.T) {
|
func TestModelfileConfig_Empty(t *testing.T) {
|
||||||
config := &ModelfileConfig{}
|
config := &ModelfileConfig{}
|
||||||
|
|
||||||
@@ -120,6 +160,9 @@ func TestCreateOptions(t *testing.T) {
|
|||||||
License: "MIT",
|
License: "MIT",
|
||||||
Parser: "qwen3-thinking",
|
Parser: "qwen3-thinking",
|
||||||
Renderer: "qwen3",
|
Renderer: "qwen3",
|
||||||
|
Parameters: map[string]any{
|
||||||
|
"temperature": float32(0.7),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -144,6 +187,9 @@ func TestCreateOptions(t *testing.T) {
|
|||||||
if opts.Modelfile.Renderer != "qwen3" {
|
if opts.Modelfile.Renderer != "qwen3" {
|
||||||
t.Errorf("Modelfile.Renderer = %q, want %q", opts.Modelfile.Renderer, "qwen3")
|
t.Errorf("Modelfile.Renderer = %q, want %q", opts.Modelfile.Renderer, "qwen3")
|
||||||
}
|
}
|
||||||
|
if opts.Modelfile.Parameters["temperature"] != float32(0.7) {
|
||||||
|
t.Errorf("Modelfile.Parameters[temperature] = %v, want %v", opts.Modelfile.Parameters["temperature"], float32(0.7))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResolveParserName(t *testing.T) {
|
func TestResolveParserName(t *testing.T) {
|
||||||
@@ -252,3 +298,44 @@ func TestQuantizeSupported(t *testing.T) {
|
|||||||
// We can't easily test both cases, so just verify it returns something
|
// We can't easily test both cases, so just verify it returns something
|
||||||
_ = supported
|
_ = supported
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateModelfileLayersIncludesParameters(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||||
|
|
||||||
|
layers, err := createModelfileLayers(&ModelfileConfig{
|
||||||
|
Parameters: map[string]any{
|
||||||
|
"temperature": float32(0.7),
|
||||||
|
"stop": []string{"USER:", "ASSISTANT:"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(layers) != 1 {
|
||||||
|
t.Fatalf("len(layers) = %d, want 1", len(layers))
|
||||||
|
}
|
||||||
|
|
||||||
|
if layers[0].MediaType != "application/vnd.ollama.image.params" {
|
||||||
|
t.Fatalf("MediaType = %q, want %q", layers[0].MediaType, "application/vnd.ollama.image.params")
|
||||||
|
}
|
||||||
|
|
||||||
|
blobPath, err := manifest.BlobsPath(layers[0].Digest)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(blobPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var got map[string]any
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got["temperature"] != float64(0.7) {
|
||||||
|
t.Fatalf("temperature = %v, want %v", got["temperature"], float64(0.7))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -21,7 +19,7 @@ var quantizeParams = map[string]struct {
|
|||||||
bits int
|
bits int
|
||||||
mode string
|
mode string
|
||||||
}{
|
}{
|
||||||
"int4": {32, 4, "affine"},
|
"int4": {64, 4, "affine"},
|
||||||
"nvfp4": {16, 4, "nvfp4"},
|
"nvfp4": {16, 4, "nvfp4"},
|
||||||
"int8": {64, 8, "affine"},
|
"int8": {64, 8, "affine"},
|
||||||
"mxfp8": {32, 8, "mxfp8"},
|
"mxfp8": {32, 8, "mxfp8"},
|
||||||
@@ -194,9 +192,10 @@ func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
|
|||||||
return blobData, nil
|
return blobData, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// QuantizeSupported returns true if quantization is supported (MLX build)
|
// QuantizeSupported returns true if quantization is supported (MLX library available)
|
||||||
func QuantizeSupported() bool {
|
func QuantizeSupported() bool {
|
||||||
return true
|
mlx.InitMLX()
|
||||||
|
return mlx.IsMLXAvailable()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensureTempDir creates the temp directory for quantization if it doesn't exist
|
// ensureTempDir creates the temp directory for quantization if it doesn't exist
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
//go:build !mlx
|
|
||||||
|
|
||||||
package client
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/create"
|
|
||||||
)
|
|
||||||
|
|
||||||
// quantizeTensor is not available without MLX
|
|
||||||
func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) {
|
|
||||||
return nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
|
|
||||||
}
|
|
||||||
|
|
||||||
// quantizePackedGroup is not available without MLX
|
|
||||||
func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
|
|
||||||
return nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
|
|
||||||
}
|
|
||||||
|
|
||||||
// QuantizeSupported returns false when MLX is not available
|
|
||||||
func QuantizeSupported() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
@@ -288,6 +288,18 @@ func normalizeQuantType(quantize string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isStackedExpertWeight(name string) bool {
|
||||||
|
// Combined/stacked expert tensors may be emitted either as "...proj.weight" (per-expert)
|
||||||
|
// or "...proj" (pre-stacked packed tensor).
|
||||||
|
if strings.HasSuffix(name, ".bias") || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".qbias") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Contains(name, ".mlp.switch_mlp.") ||
|
||||||
|
strings.Contains(name, ".mlp.experts.") ||
|
||||||
|
strings.Contains(name, ".mlp.shared_experts.")
|
||||||
|
}
|
||||||
|
|
||||||
// GetTensorQuantization returns the appropriate quantization type for a tensor.
|
// GetTensorQuantization returns the appropriate quantization type for a tensor.
|
||||||
// Returns "" if the tensor should not be quantized.
|
// Returns "" if the tensor should not be quantized.
|
||||||
// This implements mixed-precision quantization:
|
// This implements mixed-precision quantization:
|
||||||
@@ -296,18 +308,25 @@ func normalizeQuantType(quantize string) string {
|
|||||||
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
|
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
|
||||||
// - Norms, embeddings, biases, routing gates: no quantization
|
// - Norms, embeddings, biases, routing gates: no quantization
|
||||||
func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||||
|
stackedExpert := isStackedExpertWeight(name)
|
||||||
|
|
||||||
// Use basic name-based check first
|
// Use basic name-based check first
|
||||||
if !ShouldQuantize(name, "") {
|
if !stackedExpert && !ShouldQuantize(name, "") {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
|
// Quantize standard linear weights (2D). Also allow stacked expert weights (3D),
|
||||||
if len(shape) != 2 {
|
// e.g. qwen switch_mlp / experts combined tensors.
|
||||||
|
if len(shape) != 2 && !(len(shape) == 3 && stackedExpert) {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skip small tensors (less than 1024 elements) - not worth quantizing
|
// Skip small tensors (less than 1024 elements) - not worth quantizing
|
||||||
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
|
var elems int64 = 1
|
||||||
|
for _, d := range shape {
|
||||||
|
elems *= int64(d)
|
||||||
|
}
|
||||||
|
if elems < 1024 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -315,12 +334,12 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
|||||||
quantNorm := normalizeQuantType(quantize)
|
quantNorm := normalizeQuantType(quantize)
|
||||||
|
|
||||||
// MLX quantization requires last dimension to be divisible by group size
|
// MLX quantization requires last dimension to be divisible by group size
|
||||||
// nvfp4: 16, int4/mxfp8: 32, int8: 64
|
// nvfp4: 16, mxfp8: 32, int4/int8: 64
|
||||||
groupSize := int32(32)
|
groupSize := int32(32)
|
||||||
switch quantNorm {
|
switch quantNorm {
|
||||||
case "nvfp4":
|
case "nvfp4":
|
||||||
groupSize = 16
|
groupSize = 16
|
||||||
case "int8":
|
case "int4", "int8":
|
||||||
groupSize = 64
|
groupSize = 64
|
||||||
}
|
}
|
||||||
if shape[len(shape)-1]%groupSize != 0 {
|
if shape[len(shape)-1]%groupSize != 0 {
|
||||||
|
|||||||
@@ -557,6 +557,10 @@ func TestShouldQuantizeTensor(t *testing.T) {
|
|||||||
// 3D+ tensors should not be quantized
|
// 3D+ tensors should not be quantized
|
||||||
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
|
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
|
||||||
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
|
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
|
||||||
|
{"stacked expert switch_mlp gate_up 3D int8", "model.layers.1.mlp.switch_mlp.gate_up_proj.weight", []int32{64, 22016, 4096}, "int8", true},
|
||||||
|
{"stacked expert experts down_proj 3D int8", "model.layers.1.mlp.experts.down_proj.weight", []int32{64, 4096, 14336}, "int8", true},
|
||||||
|
{"stacked expert combined gate_up 3D int8", "model.language_model.layers.0.mlp.experts.gate_up_proj", []int32{256, 1024, 2048}, "int8", true},
|
||||||
|
{"stacked expert combined down_proj 3D int8", "model.language_model.layers.0.mlp.experts.down_proj", []int32{256, 2048, 512}, "int8", true},
|
||||||
|
|
||||||
// Embeddings should not be quantized regardless of shape
|
// Embeddings should not be quantized regardless of shape
|
||||||
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
|
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
|
||||||
@@ -619,6 +623,44 @@ func TestExpertGroupPrefix(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetTensorQuantization_StackedExpert3D(t *testing.T) {
|
||||||
|
gateUp := GetTensorQuantization(
|
||||||
|
"model.layers.1.mlp.switch_mlp.gate_up_proj.weight",
|
||||||
|
[]int32{64, 22016, 4096},
|
||||||
|
"int4",
|
||||||
|
)
|
||||||
|
if gateUp != "int4" {
|
||||||
|
t.Fatalf("gate_up_proj quantization = %q, want %q", gateUp, "int4")
|
||||||
|
}
|
||||||
|
|
||||||
|
down := GetTensorQuantization(
|
||||||
|
"model.layers.1.mlp.experts.down_proj.weight",
|
||||||
|
[]int32{64, 4096, 14336},
|
||||||
|
"int4",
|
||||||
|
)
|
||||||
|
if down != "int8" {
|
||||||
|
t.Fatalf("down_proj quantization = %q, want %q", down, "int8")
|
||||||
|
}
|
||||||
|
|
||||||
|
combinedGateUp := GetTensorQuantization(
|
||||||
|
"model.language_model.layers.0.mlp.experts.gate_up_proj",
|
||||||
|
[]int32{256, 1024, 2048},
|
||||||
|
"int8",
|
||||||
|
)
|
||||||
|
if combinedGateUp != "int8" {
|
||||||
|
t.Fatalf("combined gate_up_proj quantization = %q, want %q", combinedGateUp, "int8")
|
||||||
|
}
|
||||||
|
|
||||||
|
combinedDown := GetTensorQuantization(
|
||||||
|
"model.language_model.layers.0.mlp.experts.down_proj",
|
||||||
|
[]int32{256, 2048, 512},
|
||||||
|
"int4",
|
||||||
|
)
|
||||||
|
if combinedDown != "int8" {
|
||||||
|
t.Fatalf("combined down_proj quantization = %q, want %q", combinedDown, "int8")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
|
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
|
|
||||||
|
|||||||
2
x/imagegen/cache/cache.go
vendored
2
x/imagegen/cache/cache.go
vendored
@@ -1,5 +1,3 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package cache
|
package cache
|
||||||
|
|
||||||
import "github.com/ollama/ollama/x/imagegen/mlx"
|
import "github.com/ollama/ollama/x/imagegen/mlx"
|
||||||
|
|||||||
2
x/imagegen/cache/step.go
vendored
2
x/imagegen/cache/step.go
vendored
@@ -1,5 +1,3 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package cache
|
package cache
|
||||||
|
|
||||||
import "github.com/ollama/ollama/x/imagegen/mlx"
|
import "github.com/ollama/ollama/x/imagegen/mlx"
|
||||||
|
|||||||
2
x/imagegen/cache/teacache.go
vendored
2
x/imagegen/cache/teacache.go
vendored
@@ -1,5 +1,3 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
// Package cache provides caching mechanisms for diffusion model inference.
|
// Package cache provides caching mechanisms for diffusion model inference.
|
||||||
package cache
|
package cache
|
||||||
|
|
||||||
|
|||||||
@@ -5,22 +5,12 @@ Experimental MLX backend for running models on Apple Silicon and CUDA.
|
|||||||
## Build
|
## Build
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
go build -tags mlx -o engine ./x/imagegen/cmd/engine
|
go build -o engine ./x/imagegen/cmd/engine
|
||||||
```
|
```
|
||||||
|
|
||||||
## Text Generation
|
## Text Generation
|
||||||
|
|
||||||
```bash
|
Text generation models are no longer supported by this engine.
|
||||||
./engine -model /path/to/model -prompt "Hello" -max-tokens 100
|
|
||||||
```
|
|
||||||
|
|
||||||
Options:
|
|
||||||
|
|
||||||
- `-temperature` - sampling temperature (default 0.7)
|
|
||||||
- `-top-p` - nucleus sampling (default 0.9)
|
|
||||||
- `-top-k` - top-k sampling (default 40)
|
|
||||||
|
|
||||||
Supports: Llama, Gemma3, GPT-OSS
|
|
||||||
|
|
||||||
## Image Generation
|
## Image Generation
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -18,9 +16,6 @@ import (
|
|||||||
"github.com/ollama/ollama/x/imagegen"
|
"github.com/ollama/ollama/x/imagegen"
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||||
"github.com/ollama/ollama/x/imagegen/models/flux2"
|
"github.com/ollama/ollama/x/imagegen/models/flux2"
|
||||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/models/llama"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||||
)
|
)
|
||||||
@@ -170,11 +165,11 @@ func main() {
|
|||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load image if provided and model supports it
|
// Load image if provided and model supports it.
|
||||||
var image *mlx.Array
|
var image *mlx.Array
|
||||||
if *imagePath != "" {
|
if *imagePath != "" {
|
||||||
if mm, ok := m.(interface{ ImageSize() int32 }); ok {
|
if mm, ok := m.(interface{ ImageSize() int32 }); ok {
|
||||||
image, err = gemma3.ProcessImage(*imagePath, mm.ImageSize())
|
image, err = imagegen.ProcessImage(*imagePath, mm.ImageSize())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("load image:", err)
|
log.Fatal("load image:", err)
|
||||||
}
|
}
|
||||||
@@ -236,14 +231,8 @@ func load(modelPath string) (Model, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch kind {
|
switch kind {
|
||||||
case "gpt_oss":
|
|
||||||
return gpt_oss.Load(modelPath)
|
|
||||||
case "gemma3":
|
|
||||||
return gemma3.Load(modelPath)
|
|
||||||
case "gemma3_text":
|
|
||||||
return gemma3.LoadText(modelPath)
|
|
||||||
default:
|
default:
|
||||||
return llama.Load(modelPath)
|
return nil, fmt.Errorf("model type %q is not supported by x/imagegen/cmd/engine", kind)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import "github.com/ollama/ollama/x/imagegen/mlx"
|
import "github.com/ollama/ollama/x/imagegen/mlx"
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package imagegen
|
package imagegen
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
//go:build mlx
|
package imagegen
|
||||||
|
|
||||||
package gemma3
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -13,8 +11,8 @@ import (
|
|||||||
"golang.org/x/image/draw"
|
"golang.org/x/image/draw"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProcessImage loads and preprocesses an image for the vision tower
|
// ProcessImage loads and preprocesses an image for multimodal vision towers.
|
||||||
// Returns [1, H, W, C] tensor in NHWC format normalized for SigLIP
|
// Returns [1, H, W, C] tensor in NHWC format normalized for SigLIP.
|
||||||
func ProcessImage(path string, imageSize int32) (*mlx.Array, error) {
|
func ProcessImage(path string, imageSize int32) (*mlx.Array, error) {
|
||||||
f, err := os.Open(path)
|
f, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -30,20 +28,20 @@ func ProcessImage(path string, imageSize int32) (*mlx.Array, error) {
|
|||||||
return ProcessImageData(img, imageSize)
|
return ProcessImageData(img, imageSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProcessImageData preprocesses an image.Image for the vision tower
|
// ProcessImageData preprocesses an image.Image for multimodal vision towers.
|
||||||
func ProcessImageData(img image.Image, imageSize int32) (*mlx.Array, error) {
|
func ProcessImageData(img image.Image, imageSize int32) (*mlx.Array, error) {
|
||||||
// Resize to target size using bilinear interpolation
|
// Resize to target size using bilinear interpolation.
|
||||||
resized := image.NewRGBA(image.Rect(0, 0, int(imageSize), int(imageSize)))
|
resized := image.NewRGBA(image.Rect(0, 0, int(imageSize), int(imageSize)))
|
||||||
draw.BiLinear.Scale(resized, resized.Bounds(), img, img.Bounds(), draw.Over, nil)
|
draw.BiLinear.Scale(resized, resized.Bounds(), img, img.Bounds(), draw.Over, nil)
|
||||||
|
|
||||||
// Convert to float32 array [H, W, C] and normalize
|
// Convert to float32 array [H, W, C] and normalize.
|
||||||
// SigLIP normalization: (pixel / 255.0 - 0.5) / 0.5 = pixel / 127.5 - 1.0
|
// SigLIP normalization: (pixel / 255.0 - 0.5) / 0.5 = pixel / 127.5 - 1.0.
|
||||||
data := make([]float32, imageSize*imageSize*3)
|
data := make([]float32, imageSize*imageSize*3)
|
||||||
idx := 0
|
idx := 0
|
||||||
for y := int32(0); y < imageSize; y++ {
|
for y := int32(0); y < imageSize; y++ {
|
||||||
for x := int32(0); x < imageSize; x++ {
|
for x := int32(0); x < imageSize; x++ {
|
||||||
r, g, b, _ := resized.At(int(x), int(y)).RGBA()
|
r, g, b, _ := resized.At(int(x), int(y)).RGBA()
|
||||||
// RGBA returns 16-bit values, convert to 8-bit
|
// RGBA returns 16-bit values, convert to 8-bit.
|
||||||
data[idx] = float32(r>>8)/127.5 - 1.0
|
data[idx] = float32(r>>8)/127.5 - 1.0
|
||||||
data[idx+1] = float32(g>>8)/127.5 - 1.0
|
data[idx+1] = float32(g>>8)/127.5 - 1.0
|
||||||
data[idx+2] = float32(b>>8)/127.5 - 1.0
|
data[idx+2] = float32(b>>8)/127.5 - 1.0
|
||||||
@@ -51,8 +49,8 @@ func ProcessImageData(img image.Image, imageSize int32) (*mlx.Array, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create MLX array [1, H, W, C] for NHWC layout
|
// Create MLX array [1, H, W, C] for NHWC layout.
|
||||||
arr := mlx.NewArrayFloat32(data, []int32{1, imageSize, imageSize, 3})
|
arr := mlx.NewArrayFloat32(data, []int32{1, imageSize, imageSize, 3})
|
||||||
mlx.Eval(arr) // Materialize to prevent use-after-free
|
mlx.Eval(arr) // Materialize to prevent use-after-free.
|
||||||
return arr, nil
|
return arr, nil
|
||||||
}
|
}
|
||||||
@@ -1,5 +1,3 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package imagegen
|
package imagegen
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,420 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package imagegen
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"log/slog"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen/cache"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/models/glm4_moe_lite"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TextModel is the interface for LLM text generation models.
|
|
||||||
type TextModel interface {
|
|
||||||
Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array
|
|
||||||
NewCache(maxSeqLen int32) []cache.Cache
|
|
||||||
Tokenizer() *tokenizer.Tokenizer
|
|
||||||
VocabSize() int32
|
|
||||||
MaxContextLength() int32
|
|
||||||
NumLayers() int
|
|
||||||
}
|
|
||||||
|
|
||||||
// llmState holds the state for LLM generation
|
|
||||||
type llmState struct {
|
|
||||||
model TextModel
|
|
||||||
}
|
|
||||||
|
|
||||||
var llmMu sync.Mutex
|
|
||||||
|
|
||||||
// Dedicated stream for generation (like mlx-lm's generation_stream)
|
|
||||||
var generationStream *mlx.Stream
|
|
||||||
|
|
||||||
// withStream runs fn with the generation stream as default
|
|
||||||
func withStream(fn func()) {
|
|
||||||
// Lazy initialization of generationStream
|
|
||||||
if generationStream == nil {
|
|
||||||
generationStream = mlx.NewStream()
|
|
||||||
}
|
|
||||||
orig := mlx.GetDefaultStream()
|
|
||||||
mlx.SetDefaultStream(generationStream)
|
|
||||||
fn()
|
|
||||||
mlx.SetDefaultStream(orig)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decoder wraps model + cache for autoregressive generation.
|
|
||||||
// This matches the pattern from cmd/engine/generate.go
|
|
||||||
type Decoder struct {
|
|
||||||
model TextModel
|
|
||||||
caches []cache.Cache
|
|
||||||
vocabSize int32
|
|
||||||
temp float32
|
|
||||||
token *mlx.Array // Current token (kept across iterations)
|
|
||||||
oldCacheState []*mlx.Array // Preallocated slice for old cache state
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewDecoder(m TextModel, temp float32) *Decoder {
|
|
||||||
caches := m.NewCache(0)
|
|
||||||
return &Decoder{
|
|
||||||
model: m,
|
|
||||||
caches: caches,
|
|
||||||
vocabSize: m.VocabSize(),
|
|
||||||
temp: temp,
|
|
||||||
oldCacheState: make([]*mlx.Array, 0, len(caches)*2),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Decoder) prefill(inputIDs []int32) int {
|
|
||||||
processed := 0
|
|
||||||
|
|
||||||
// Track old cache state to free after each chunk
|
|
||||||
var oldCacheState []*mlx.Array
|
|
||||||
|
|
||||||
// Process all-but-1 tokens in chunks, eval cache state for memory management
|
|
||||||
for len(inputIDs) > 1 {
|
|
||||||
chunkSize := min(2048, len(inputIDs)-1)
|
|
||||||
if chunkSize <= 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
chunk := inputIDs[:chunkSize]
|
|
||||||
|
|
||||||
// Save old cache state before forward
|
|
||||||
oldCacheState = oldCacheState[:0]
|
|
||||||
for _, c := range d.caches {
|
|
||||||
oldCacheState = append(oldCacheState, c.State()...)
|
|
||||||
}
|
|
||||||
|
|
||||||
var cacheState []*mlx.Array
|
|
||||||
withStream(func() {
|
|
||||||
x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))})
|
|
||||||
d.model.Forward(x, d.caches)
|
|
||||||
for _, c := range d.caches {
|
|
||||||
cacheState = append(cacheState, c.State()...)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
mlx.Eval(cacheState...)
|
|
||||||
|
|
||||||
// Free old cache state
|
|
||||||
for _, arr := range oldCacheState {
|
|
||||||
if arr != nil {
|
|
||||||
arr.Free()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inputIDs = inputIDs[chunkSize:]
|
|
||||||
processed += chunkSize
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save old cache state before final step
|
|
||||||
oldCacheState = oldCacheState[:0]
|
|
||||||
for _, c := range d.caches {
|
|
||||||
oldCacheState = append(oldCacheState, c.State()...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Final token + sampling
|
|
||||||
withStream(func() {
|
|
||||||
x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))})
|
|
||||||
mlx.Eval(x) // Materialize before any other evals
|
|
||||||
logits := d.model.Forward(x, d.caches)
|
|
||||||
d.token = sample(logits, d.temp, d.vocabSize)
|
|
||||||
})
|
|
||||||
// Keep cache state (token auto-kept by AsyncEval)
|
|
||||||
for _, c := range d.caches {
|
|
||||||
mlx.Keep(c.State()...)
|
|
||||||
}
|
|
||||||
mlx.AsyncEval(d.token)
|
|
||||||
|
|
||||||
// Free old cache state from before final step
|
|
||||||
for _, arr := range oldCacheState {
|
|
||||||
if arr != nil {
|
|
||||||
arr.Free()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mlx.ClearCache()
|
|
||||||
|
|
||||||
return processed + len(inputIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Decoder) step() int32 {
|
|
||||||
prevToken := d.token
|
|
||||||
|
|
||||||
// Save old cache state (reuse preallocated slice)
|
|
||||||
d.oldCacheState = d.oldCacheState[:0]
|
|
||||||
for _, c := range d.caches {
|
|
||||||
d.oldCacheState = append(d.oldCacheState, c.State()...)
|
|
||||||
}
|
|
||||||
|
|
||||||
withStream(func() {
|
|
||||||
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
|
|
||||||
d.token = sample(logits, d.temp, d.vocabSize)
|
|
||||||
})
|
|
||||||
// Keep token and new cache state so they survive cleanup
|
|
||||||
mlx.Keep(d.token)
|
|
||||||
for _, c := range d.caches {
|
|
||||||
mlx.Keep(c.State()...)
|
|
||||||
}
|
|
||||||
mlx.AsyncEval(d.token)
|
|
||||||
|
|
||||||
// Sync on previous token (GPU already working on next step)
|
|
||||||
val := prevToken.ItemInt32()
|
|
||||||
|
|
||||||
// Free old token and old cache state
|
|
||||||
prevToken.Free()
|
|
||||||
for _, arr := range d.oldCacheState {
|
|
||||||
arr.Free()
|
|
||||||
}
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
|
|
||||||
// sample samples from logits using temperature scaling
|
|
||||||
func sample(logits *mlx.Array, temp float32, vocabSize int32) *mlx.Array {
|
|
||||||
// Get last position logits: [1, L, vocab] -> [vocab]
|
|
||||||
shape := logits.Shape()
|
|
||||||
seqLen := shape[1]
|
|
||||||
lastLogits := mlx.Slice(logits, []int32{0, seqLen - 1, 0}, []int32{1, seqLen, vocabSize})
|
|
||||||
lastLogits = mlx.Reshape(lastLogits, vocabSize)
|
|
||||||
|
|
||||||
if temp <= 0 || temp < 0.01 {
|
|
||||||
// Greedy decoding
|
|
||||||
return mlx.Argmax(lastLogits, -1, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply temperature scaling
|
|
||||||
scaled := mlx.DivScalar(lastLogits, temp)
|
|
||||||
return mlx.RandomCategorical(scaled, -1, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadLLMModel loads a safetensors LLM model and its tokenizer from manifest storage.
|
|
||||||
func (s *server) loadLLMModel() error {
|
|
||||||
// Load the manifest to get model information
|
|
||||||
modelManifest, err := manifest.LoadManifest(s.modelName)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to load manifest: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Detect model architecture from config.json
|
|
||||||
configData, err := modelManifest.ReadConfig("config.json")
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to read config.json: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var modelConfig struct {
|
|
||||||
Architectures []string `json:"architectures"`
|
|
||||||
ModelType string `json:"model_type"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(configData, &modelConfig); err != nil {
|
|
||||||
return fmt.Errorf("failed to parse config.json: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
arch := ""
|
|
||||||
if len(modelConfig.Architectures) > 0 {
|
|
||||||
arch = modelConfig.Architectures[0]
|
|
||||||
}
|
|
||||||
if arch == "" {
|
|
||||||
arch = modelConfig.ModelType
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Info("detected LLM architecture", "architecture", arch, "model_type", modelConfig.ModelType)
|
|
||||||
|
|
||||||
// Load the appropriate model based on architecture
|
|
||||||
var model TextModel
|
|
||||||
archLower := strings.ToLower(arch)
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case strings.Contains(archLower, "glm4moelite"):
|
|
||||||
m, err := glm4_moe_lite.LoadFromManifest(modelManifest)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to load glm4-moe-lite model: %w", err)
|
|
||||||
}
|
|
||||||
model = m
|
|
||||||
slog.Info("loaded glm4-moe-lite model", "vocab_size", m.VocabSize(), "layers", m.NumLayers())
|
|
||||||
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("LLM architecture %q is not yet supported. "+
|
|
||||||
"Supported architectures: glm4-moe-lite. "+
|
|
||||||
"Please convert your model to GGUF format or use a supported architecture", arch)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.llmModel = &llmState{
|
|
||||||
model: model,
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleLLMCompletion handles LLM text generation requests.
|
|
||||||
func (s *server) handleLLMCompletion(w http.ResponseWriter, r *http.Request, req Request) {
|
|
||||||
if s.llmModel == nil {
|
|
||||||
http.Error(w, "LLM model not loaded", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Serialize generation requests
|
|
||||||
llmMu.Lock()
|
|
||||||
defer llmMu.Unlock()
|
|
||||||
|
|
||||||
if err := s.llmGenerate(w, r, req); err != nil {
|
|
||||||
slog.Error("LLM generation failed", "error", err)
|
|
||||||
// Don't send error if we've already started streaming
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// llmGenerate runs the generation loop using the Decoder pattern from cmd/engine
|
|
||||||
func (s *server) llmGenerate(w http.ResponseWriter, r *http.Request, req Request) error {
|
|
||||||
state := s.llmModel
|
|
||||||
|
|
||||||
// Set up streaming response
|
|
||||||
w.Header().Set("Content-Type", "application/x-ndjson")
|
|
||||||
w.Header().Set("Transfer-Encoding", "chunked")
|
|
||||||
flusher, ok := w.(http.Flusher)
|
|
||||||
if !ok {
|
|
||||||
return errors.New("streaming not supported")
|
|
||||||
}
|
|
||||||
|
|
||||||
tok := state.model.Tokenizer()
|
|
||||||
|
|
||||||
// The prompt is already formatted by the server using the model's renderer
|
|
||||||
// (see server/prompt.go renderPrompt), so we don't apply FormatPrompt here.
|
|
||||||
prompt := req.Prompt
|
|
||||||
|
|
||||||
// Tokenize the prompt
|
|
||||||
inputIDs := tok.Encode(prompt, true)
|
|
||||||
slog.Debug("tokenized prompt", "num_tokens", len(inputIDs))
|
|
||||||
|
|
||||||
// Generation parameters
|
|
||||||
maxTokens := int(state.model.MaxContextLength())
|
|
||||||
if maxTokens <= 0 {
|
|
||||||
maxTokens = 4096
|
|
||||||
}
|
|
||||||
if req.Options != nil && req.Options.NumPredict > 0 {
|
|
||||||
maxTokens = req.Options.NumPredict
|
|
||||||
}
|
|
||||||
|
|
||||||
temperature := float32(0.7)
|
|
||||||
if req.Options != nil && req.Options.Temperature > 0 {
|
|
||||||
temperature = float32(req.Options.Temperature)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enable MLX compilation for better performance
|
|
||||||
mlx.EnableCompile()
|
|
||||||
|
|
||||||
// Create decoder with fresh caches
|
|
||||||
dec := NewDecoder(state.model, temperature)
|
|
||||||
|
|
||||||
prefillStart := time.Now()
|
|
||||||
prefillTokens := dec.prefill(inputIDs)
|
|
||||||
// Prefill measurement includes time to first token
|
|
||||||
firstToken := dec.step()
|
|
||||||
prefillDuration := time.Since(prefillStart)
|
|
||||||
promptEvalDuration := prefillDuration
|
|
||||||
|
|
||||||
enc := json.NewEncoder(w)
|
|
||||||
ctx := r.Context()
|
|
||||||
generated := 0
|
|
||||||
stopReason := "max_tokens"
|
|
||||||
|
|
||||||
// Handle first token
|
|
||||||
generated++
|
|
||||||
if tok.IsEOS(firstToken) {
|
|
||||||
resp := Response{
|
|
||||||
Done: true,
|
|
||||||
StopReason: fmt.Sprintf("first_token_eos:%d", firstToken),
|
|
||||||
PromptEvalCount: prefillTokens,
|
|
||||||
PromptEvalDuration: int(promptEvalDuration.Nanoseconds()),
|
|
||||||
}
|
|
||||||
enc.Encode(resp)
|
|
||||||
flusher.Flush()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
text := tok.Decode([]int32{firstToken})
|
|
||||||
resp := Response{Content: text}
|
|
||||||
enc.Encode(resp)
|
|
||||||
flusher.Flush()
|
|
||||||
|
|
||||||
genStart := time.Now()
|
|
||||||
|
|
||||||
// Generation loop
|
|
||||||
for n := 1; n < maxTokens; n++ {
|
|
||||||
// Check for cancellation
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
stopReason = fmt.Sprintf("context_cancelled:%d", generated)
|
|
||||||
break
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
if stopReason != "max_tokens" {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
token := dec.step()
|
|
||||||
generated++
|
|
||||||
|
|
||||||
if tok.IsEOS(token) {
|
|
||||||
stopReason = fmt.Sprintf("eos_token:%d", token)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
text := tok.Decode([]int32{token})
|
|
||||||
|
|
||||||
// Check for stop sequences
|
|
||||||
if req.Options != nil && len(req.Options.Stop) > 0 {
|
|
||||||
shouldStop := false
|
|
||||||
var matchedStop string
|
|
||||||
for _, stop := range req.Options.Stop {
|
|
||||||
if strings.Contains(text, stop) {
|
|
||||||
text = strings.Split(text, stop)[0]
|
|
||||||
shouldStop = true
|
|
||||||
matchedStop = stop
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if shouldStop {
|
|
||||||
if text != "" {
|
|
||||||
resp := Response{Content: text}
|
|
||||||
enc.Encode(resp)
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
stopReason = fmt.Sprintf("stop_sequence:%s", matchedStop)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := Response{Content: text}
|
|
||||||
enc.Encode(resp)
|
|
||||||
flusher.Flush()
|
|
||||||
|
|
||||||
// Periodically clear MLX cache
|
|
||||||
if n%256 == 0 {
|
|
||||||
mlx.ClearCache()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clean up
|
|
||||||
mlx.ClearCache()
|
|
||||||
|
|
||||||
// Send final response with stats
|
|
||||||
evalDuration := time.Since(genStart)
|
|
||||||
resp = Response{
|
|
||||||
Done: true,
|
|
||||||
StopReason: fmt.Sprintf("%s:generated=%d", stopReason, generated),
|
|
||||||
PromptEvalCount: prefillTokens,
|
|
||||||
PromptEvalDuration: int(promptEvalDuration.Nanoseconds()),
|
|
||||||
EvalCount: generated,
|
|
||||||
EvalDuration: int(evalDuration.Nanoseconds()),
|
|
||||||
}
|
|
||||||
enc.Encode(resp)
|
|
||||||
flusher.Flush()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,5 +1,3 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package manifest
|
package manifest
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user