mirror of
https://github.com/ollama/ollama.git
synced 2026-04-24 01:35:49 +02:00
Compare commits
43 Commits
pdevine/sa
...
jmorganca/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4fdeb59325 | ||
|
|
61086083eb | ||
|
|
62d1f01ab4 | ||
|
|
10e51c5177 | ||
|
|
3e06bde643 | ||
|
|
6be2de8214 | ||
|
|
ebb1b9ec14 | ||
|
|
d126467d5d | ||
|
|
afb4c62fbf | ||
|
|
e790dc435b | ||
|
|
288077c3a3 | ||
|
|
4425c54eda | ||
|
|
778899a5d2 | ||
|
|
4eab60c1e2 | ||
|
|
1af850e6e3 | ||
|
|
9b0c7cc7b9 | ||
|
|
6928630601 | ||
|
|
9896e3627f | ||
|
|
15732f0ea7 | ||
|
|
562c76d7cc | ||
|
|
122c68c151 | ||
|
|
82848a7806 | ||
|
|
39982a954e | ||
|
|
e9f6ea232f | ||
|
|
110eff01a9 | ||
|
|
799e51d419 | ||
|
|
e8fcb29586 | ||
|
|
97d2f05a6d | ||
|
|
8207e55ec7 | ||
|
|
ad16bffc7d | ||
|
|
c1e3ef4bcc | ||
|
|
a3093cd5e5 | ||
|
|
23d4cad1a2 | ||
|
|
86513cb697 | ||
|
|
3490e9590b | ||
|
|
8da09b1e7e | ||
|
|
a60b9adcce | ||
|
|
a16f96658b | ||
|
|
18ab09b431 | ||
|
|
638faeac54 | ||
|
|
dd5eb6337d | ||
|
|
79917cf80b | ||
|
|
cc90a035a0 |
52
.github/workflows/release.yaml
vendored
52
.github/workflows/release.yaml
vendored
@@ -117,6 +117,25 @@ jobs:
|
||||
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
||||
flags: ''
|
||||
runner_dir: 'vulkan'
|
||||
- os: windows
|
||||
arch: amd64
|
||||
preset: 'MLX CUDA 13'
|
||||
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
|
||||
cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip
|
||||
cuda-components:
|
||||
- '"cudart"'
|
||||
- '"nvcc"'
|
||||
- '"cublas"'
|
||||
- '"cublas_dev"'
|
||||
- '"cufft"'
|
||||
- '"cufft_dev"'
|
||||
- '"nvrtc"'
|
||||
- '"nvrtc_dev"'
|
||||
- '"crt"'
|
||||
- '"nvvm"'
|
||||
- '"nvptxcompiler"'
|
||||
cuda-version: '13.0'
|
||||
flags: ''
|
||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||
environment: release
|
||||
env:
|
||||
@@ -125,8 +144,10 @@ jobs:
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
choco install -y --no-progress ccache ninja
|
||||
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
||||
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan')
|
||||
if (Get-Command ccache -ErrorAction SilentlyContinue) {
|
||||
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
||||
}
|
||||
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan') || startsWith(matrix.preset, 'MLX ')
|
||||
id: cache-install
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
@@ -134,8 +155,9 @@ jobs:
|
||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||
C:\Program Files\AMD\ROCm
|
||||
C:\VulkanSDK
|
||||
key: ${{ matrix.install }}
|
||||
- if: startsWith(matrix.preset, 'CUDA ')
|
||||
C:\Program Files\NVIDIA\CUDNN
|
||||
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'MLX ')
|
||||
name: Install CUDA ${{ matrix.cuda-version }}
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
@@ -179,6 +201,23 @@ jobs:
|
||||
run: |
|
||||
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
- if: startsWith(matrix.preset, 'MLX ')
|
||||
name: Install cuDNN for MLX
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
|
||||
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||
Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip"
|
||||
Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted
|
||||
$cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName
|
||||
New-Item -ItemType Directory -Force -Path $cudnnRoot
|
||||
Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse
|
||||
}
|
||||
|
||||
echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
@@ -186,7 +225,8 @@ jobs:
|
||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||
C:\Program Files\AMD\ROCm
|
||||
C:\VulkanSDK
|
||||
key: ${{ matrix.install }}
|
||||
C:\Program Files\NVIDIA\CUDNN
|
||||
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/cache@v4
|
||||
with:
|
||||
@@ -198,7 +238,7 @@ jobs:
|
||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}"
|
||||
cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}"
|
||||
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
|
||||
cmake --install build --component "${{ startsWith(matrix.preset, 'MLX ') && 'MLX' || startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
|
||||
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
|
||||
env:
|
||||
CMAKE_GENERATOR: Ninja
|
||||
|
||||
64
.github/workflows/test.yaml
vendored
64
.github/workflows/test.yaml
vendored
@@ -37,7 +37,7 @@ jobs:
|
||||
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
|
||||
}
|
||||
|
||||
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
|
||||
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*' '.github/**/*') | tee -a $GITHUB_OUTPUT
|
||||
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
|
||||
|
||||
linux:
|
||||
@@ -51,7 +51,7 @@ jobs:
|
||||
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
|
||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
||||
- preset: ROCm
|
||||
container: rocm/dev-ubuntu-22.04:6.1.2
|
||||
container: rocm/dev-ubuntu-22.04:7.2
|
||||
extra-packages: rocm-libs
|
||||
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm'
|
||||
- preset: Vulkan
|
||||
@@ -60,6 +60,10 @@ jobs:
|
||||
mesa-vulkan-drivers vulkan-tools
|
||||
libvulkan1 libvulkan-dev
|
||||
vulkan-sdk cmake ccache g++ make
|
||||
- preset: 'MLX CUDA 13'
|
||||
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
|
||||
extra-packages: libcudnn9-dev-cuda-13 libopenblas-dev liblapack-dev liblapacke-dev git curl
|
||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=87 -DBLAS_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu -DLAPACK_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu'
|
||||
runs-on: linux
|
||||
container: ${{ matrix.container }}
|
||||
steps:
|
||||
@@ -76,6 +80,10 @@ jobs:
|
||||
$sudo apt-get update
|
||||
fi
|
||||
$sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }}
|
||||
# MLX requires CMake 3.25+, install from official releases
|
||||
if [ "${{ matrix.preset }}" = "MLX CUDA 13" ]; then
|
||||
curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.31.2/cmake-3.31.2-linux-$(uname -m).tar.gz | $sudo tar xz -C /usr/local --strip-components 1
|
||||
fi
|
||||
# Export VULKAN_SDK if provided by LunarG package (defensive)
|
||||
if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then
|
||||
echo "VULKAN_SDK=/usr" >> $GITHUB_ENV
|
||||
@@ -87,8 +95,8 @@ jobs:
|
||||
path: /github/home/.cache/ccache
|
||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
|
||||
- run: |
|
||||
cmake --preset ${{ matrix.preset }} ${{ matrix.flags }}
|
||||
cmake --build --preset ${{ matrix.preset }} --parallel
|
||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
|
||||
cmake --build --preset "${{ matrix.preset }}" --parallel
|
||||
|
||||
windows:
|
||||
needs: [changes]
|
||||
@@ -114,12 +122,31 @@ jobs:
|
||||
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||
- preset: Vulkan
|
||||
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
||||
- preset: 'MLX CUDA 13'
|
||||
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
|
||||
cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip
|
||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
|
||||
cuda-components:
|
||||
- '"cudart"'
|
||||
- '"nvcc"'
|
||||
- '"cublas"'
|
||||
- '"cublas_dev"'
|
||||
- '"cufft"'
|
||||
- '"cufft_dev"'
|
||||
- '"nvrtc"'
|
||||
- '"nvrtc_dev"'
|
||||
- '"crt"'
|
||||
- '"nvvm"'
|
||||
- '"nvptxcompiler"'
|
||||
cuda-version: '13.0'
|
||||
runs-on: windows
|
||||
steps:
|
||||
- run: |
|
||||
choco install -y --no-progress ccache ninja
|
||||
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
||||
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan'
|
||||
if (Get-Command ccache -ErrorAction SilentlyContinue) {
|
||||
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
||||
}
|
||||
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan' || matrix.preset == 'MLX CUDA 13'
|
||||
id: cache-install
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
@@ -127,8 +154,9 @@ jobs:
|
||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||
C:\Program Files\AMD\ROCm
|
||||
C:\VulkanSDK
|
||||
key: ${{ matrix.install }}
|
||||
- if: matrix.preset == 'CUDA'
|
||||
C:\Program Files\NVIDIA\CUDNN
|
||||
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||
- if: matrix.preset == 'CUDA' || matrix.preset == 'MLX CUDA 13'
|
||||
name: Install CUDA ${{ matrix.cuda-version }}
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
@@ -168,6 +196,23 @@ jobs:
|
||||
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
|
||||
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
|
||||
- if: matrix.preset == 'MLX CUDA 13'
|
||||
name: Install cuDNN for MLX
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
|
||||
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||
Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip"
|
||||
Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted
|
||||
$cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName
|
||||
New-Item -ItemType Directory -Force -Path $cudnnRoot
|
||||
Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse
|
||||
}
|
||||
|
||||
echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
@@ -175,7 +220,8 @@ jobs:
|
||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||
C:\Program Files\AMD\ROCm
|
||||
C:\VulkanSDK
|
||||
key: ${{ matrix.install }}
|
||||
C:\Program Files\NVIDIA\CUDNN
|
||||
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/cache@v4
|
||||
with:
|
||||
|
||||
108
CMakeLists.txt
108
CMakeLists.txt
@@ -64,10 +64,15 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR})
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${OLLAMA_BUILD_DIR})
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
|
||||
# Store ggml include paths for use with target_include_directories later.
|
||||
# We avoid global include_directories() to prevent polluting the include path
|
||||
# for other projects like MLX (whose openblas dependency has its own common.h).
|
||||
set(GGML_INCLUDE_DIRS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx
|
||||
)
|
||||
|
||||
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
|
||||
|
||||
@@ -87,6 +92,14 @@ if(NOT CPU_VARIANTS)
|
||||
set(CPU_VARIANTS "ggml-cpu")
|
||||
endif()
|
||||
|
||||
# Apply ggml include directories to ggml targets only (not globally)
|
||||
target_include_directories(ggml-base PRIVATE ${GGML_INCLUDE_DIRS})
|
||||
foreach(variant ${CPU_VARIANTS})
|
||||
if(TARGET ${variant})
|
||||
target_include_directories(${variant} PRIVATE ${GGML_INCLUDE_DIRS})
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
install(TARGETS ggml-base ${CPU_VARIANTS}
|
||||
RUNTIME_DEPENDENCIES
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
@@ -103,6 +116,7 @@ if(CMAKE_CUDA_COMPILER)
|
||||
|
||||
find_package(CUDAToolkit)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
||||
target_include_directories(ggml-cuda PRIVATE ${GGML_INCLUDE_DIRS})
|
||||
install(TARGETS ggml-cuda
|
||||
RUNTIME_DEPENDENCIES
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||
@@ -134,6 +148,7 @@ if(CMAKE_HIP_COMPILER)
|
||||
if(AMDGPU_TARGETS)
|
||||
find_package(hip REQUIRED)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
||||
target_include_directories(ggml-hip PRIVATE ${GGML_INCLUDE_DIRS})
|
||||
|
||||
if (WIN32)
|
||||
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
|
||||
@@ -148,7 +163,7 @@ if(CMAKE_HIP_COMPILER)
|
||||
)
|
||||
install(RUNTIME_DEPENDENCY_SET rocm
|
||||
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
|
||||
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf
|
||||
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register roctx64 rocroller drm drm_amdgpu numa elf
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
POST_EXCLUDE_REGEXES "system32"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
|
||||
@@ -168,6 +183,7 @@ if(NOT APPLE)
|
||||
find_package(Vulkan)
|
||||
if(Vulkan_FOUND)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
||||
target_include_directories(ggml-vulkan PRIVATE ${GGML_INCLUDE_DIRS})
|
||||
install(TARGETS ggml-vulkan
|
||||
RUNTIME_DEPENDENCIES
|
||||
PRE_INCLUDE_REGEXES vulkan
|
||||
@@ -179,7 +195,6 @@ if(NOT APPLE)
|
||||
endif()
|
||||
|
||||
option(MLX_ENGINE "Enable MLX backend" OFF)
|
||||
|
||||
if(MLX_ENGINE)
|
||||
message(STATUS "Setting up MLX (this takes a while...)")
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx)
|
||||
@@ -187,10 +202,36 @@ if(MLX_ENGINE)
|
||||
# Find CUDA toolkit if MLX is built with CUDA support
|
||||
find_package(CUDAToolkit)
|
||||
|
||||
# Build list of directories for runtime dependency resolution
|
||||
set(MLX_RUNTIME_DIRS ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR})
|
||||
# Add cuDNN bin paths for DLLs (Windows MLX CUDA builds)
|
||||
# CUDNN_ROOT_DIR is the standard CMake variable for cuDNN location
|
||||
if(DEFINED ENV{CUDNN_ROOT_DIR})
|
||||
# cuDNN 9.x has versioned subdirectories under bin/ (e.g., bin/13.0/)
|
||||
file(GLOB CUDNN_BIN_SUBDIRS "$ENV{CUDNN_ROOT_DIR}/bin/*")
|
||||
list(APPEND MLX_RUNTIME_DIRS ${CUDNN_BIN_SUBDIRS})
|
||||
endif()
|
||||
# Add build output directory and MLX dependency build directories
|
||||
list(APPEND MLX_RUNTIME_DIRS ${OLLAMA_BUILD_DIR})
|
||||
# OpenBLAS DLL location (pre-built zip extracts into openblas-src/bin/)
|
||||
list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/openblas-src/bin)
|
||||
# NCCL: on Linux, if real NCCL is found, cmake bundles libnccl.so via the
|
||||
# regex below. If NCCL is not found, MLX links a static stub (OBJECT lib)
|
||||
# so there is no runtime dependency. This path covers the stub build dir
|
||||
# for windows so we include the DLL in our dependencies.
|
||||
list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/distributed/nccl/nccl_stub-prefix/src/nccl_stub-build/Release)
|
||||
|
||||
# Base regexes for runtime dependencies (cross-platform)
|
||||
set(MLX_INCLUDE_REGEXES cublas cublasLt cudart cufft nvrtc nvrtc-builtins cudnn nccl openblas gfortran)
|
||||
# On Windows, also include dl.dll (dlfcn-win32 POSIX emulation layer)
|
||||
if(WIN32)
|
||||
list(APPEND MLX_INCLUDE_REGEXES "^dl\\.dll$")
|
||||
endif()
|
||||
|
||||
install(TARGETS mlx mlxc
|
||||
RUNTIME_DEPENDENCIES
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
|
||||
DIRECTORIES ${MLX_RUNTIME_DIRS}
|
||||
PRE_INCLUDE_REGEXES ${MLX_INCLUDE_REGEXES}
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
@@ -205,13 +246,54 @@ if(MLX_ENGINE)
|
||||
COMPONENT MLX)
|
||||
endif()
|
||||
|
||||
# Manually install cudart and cublas since they might not be picked up as direct dependencies
|
||||
# Install CCCL headers for NVRTC JIT compilation at runtime.
|
||||
# MLX's own install rules use the default component so they get skipped by
|
||||
# --component MLX. Headers are installed alongside libmlx in OLLAMA_INSTALL_DIR.
|
||||
# On Linux, MLX's jit_module.cpp resolves CCCL via
|
||||
# current_binary_dir().parent_path() / "include" / "cccl", so we create a
|
||||
# symlink from lib/ollama/include -> ${OLLAMA_RUNNER_DIR}/include
|
||||
# This will need refinement if we add multiple CUDA versions for MLX in the future.
|
||||
if(EXISTS ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda)
|
||||
install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda
|
||||
DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl
|
||||
COMPONENT MLX)
|
||||
install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/nv
|
||||
DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl
|
||||
COMPONENT MLX)
|
||||
if(NOT WIN32 AND NOT APPLE)
|
||||
install(CODE "
|
||||
set(_link \"${CMAKE_INSTALL_PREFIX}/lib/ollama/include\")
|
||||
set(_target \"${OLLAMA_RUNNER_DIR}/include\")
|
||||
if(NOT EXISTS \${_link})
|
||||
execute_process(COMMAND \${CMAKE_COMMAND} -E create_symlink \${_target} \${_link})
|
||||
endif()
|
||||
" COMPONENT MLX)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# On Windows, explicitly install dl.dll (dlfcn-win32 POSIX dlopen emulation)
|
||||
# RUNTIME_DEPENDENCIES auto-excludes it via POST_EXCLUDE_FILES_STRICT because
|
||||
# dlfcn-win32 is a known CMake target with its own install rules (which install
|
||||
# to the wrong destination). We must install it explicitly here.
|
||||
if(WIN32)
|
||||
install(FILES ${OLLAMA_BUILD_DIR}/dl.dll
|
||||
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||
COMPONENT MLX)
|
||||
endif()
|
||||
|
||||
# Manually install CUDA runtime libraries that MLX loads via dlopen
|
||||
# (not detected by RUNTIME_DEPENDENCIES since they aren't link-time deps)
|
||||
if(CUDAToolkit_FOUND)
|
||||
file(GLOB CUDART_LIBS
|
||||
file(GLOB MLX_CUDA_LIBS
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
|
||||
if(CUDART_LIBS)
|
||||
install(FILES ${CUDART_LIBS}
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*"
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcublasLt.so*"
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libnvrtc.so*"
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libnvrtc-builtins.so*"
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcufft.so*"
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcudnn.so*")
|
||||
if(MLX_CUDA_LIBS)
|
||||
install(FILES ${MLX_CUDA_LIBS}
|
||||
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||
COMPONENT MLX)
|
||||
endif()
|
||||
|
||||
@@ -77,6 +77,15 @@
|
||||
"OLLAMA_RUNNER_DIR": "rocm"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "ROCm 7",
|
||||
"inherits": [ "ROCm" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_HIP_FLAGS": "-parallel-jobs=4",
|
||||
"AMDGPU_TARGETS": "gfx942;gfx950;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1151;gfx1200;gfx1201;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-",
|
||||
"OLLAMA_RUNNER_DIR": "rocm"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Vulkan",
|
||||
"inherits": [ "Default" ],
|
||||
@@ -103,6 +112,7 @@
|
||||
"name": "MLX CUDA 13",
|
||||
"inherits": [ "MLX", "CUDA 13" ],
|
||||
"cacheVariables": {
|
||||
"MLX_CUDA_ARCHITECTURES": "86;89;90;90a;100;103;75-virtual;80-virtual;110-virtual;120-virtual;121-virtual",
|
||||
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
|
||||
}
|
||||
}
|
||||
@@ -158,6 +168,11 @@
|
||||
"inherits": [ "ROCm" ],
|
||||
"configurePreset": "ROCm 6"
|
||||
},
|
||||
{
|
||||
"name": "ROCm 7",
|
||||
"inherits": [ "ROCm" ],
|
||||
"configurePreset": "ROCm 7"
|
||||
},
|
||||
{
|
||||
"name": "Vulkan",
|
||||
"targets": [ "ggml-vulkan" ],
|
||||
|
||||
122
Dockerfile
122
Dockerfile
@@ -1,28 +1,23 @@
|
||||
# vim: filetype=dockerfile
|
||||
|
||||
ARG FLAVOR=${TARGETARCH}
|
||||
ARG PARALLEL=8
|
||||
|
||||
ARG ROCMVERSION=6.3.3
|
||||
ARG ROCMVERSION=7.2
|
||||
ARG JETPACK5VERSION=r35.4.1
|
||||
ARG JETPACK6VERSION=r36.4.0
|
||||
ARG CMAKEVERSION=3.31.2
|
||||
ARG NINJAVERSION=1.12.1
|
||||
ARG VULKANVERSION=1.4.321.1
|
||||
|
||||
# Default empty stages for local MLX source overrides.
|
||||
# Override with: docker build --build-context local-mlx=../mlx --build-context local-mlx-c=../mlx-c
|
||||
FROM scratch AS local-mlx
|
||||
FROM scratch AS local-mlx-c
|
||||
|
||||
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
|
||||
RUN dnf install -y yum-utils ccache gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ gcc-toolset-11-binutils \
|
||||
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
||||
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
||||
ARG VULKANVERSION
|
||||
RUN wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
|
||||
&& tar xvf /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
|
||||
&& dnf -y install ninja-build \
|
||||
&& ln -s /usr/bin/python3 /usr/bin/python \
|
||||
&& /${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \
|
||||
&& /${VULKANVERSION}/vulkansdk -j 8 shaderc
|
||||
RUN cp -r /${VULKANVERSION}/x86_64/include/* /usr/local/include/ \
|
||||
&& cp -r /${VULKANVERSION}/x86_64/lib/* /usr/local/lib
|
||||
ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH
|
||||
|
||||
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
|
||||
# install epel-release for ccache
|
||||
@@ -33,100 +28,119 @@ ENV CC=clang CXX=clang++
|
||||
|
||||
FROM base-${TARGETARCH} AS base
|
||||
ARG CMAKEVERSION
|
||||
ARG NINJAVERSION
|
||||
RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||
RUN dnf install -y unzip \
|
||||
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux$([ "$(uname -m)" = "aarch64" ] && echo "-aarch64").zip \
|
||||
&& unzip /tmp/ninja.zip -d /usr/local/bin \
|
||||
&& rm /tmp/ninja.zip
|
||||
ENV CMAKE_GENERATOR=Ninja
|
||||
ENV LDFLAGS=-s
|
||||
|
||||
FROM base AS cpu
|
||||
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
|
||||
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
||||
ARG PARALLEL
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CPU' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CPU' \
|
||||
&& cmake --install build --component CPU --strip --parallel ${PARALLEL}
|
||||
&& cmake --build --preset 'CPU' -- -l $(nproc) \
|
||||
&& cmake --install build --component CPU --strip
|
||||
|
||||
FROM base AS cuda-11
|
||||
ARG CUDA11VERSION=11.8
|
||||
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
||||
ARG PARALLEL
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 11' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
&& cmake --build --preset 'CUDA 11' -- -l $(nproc) \
|
||||
&& cmake --install build --component CUDA --strip
|
||||
|
||||
FROM base AS cuda-12
|
||||
ARG CUDA12VERSION=12.8
|
||||
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
||||
ARG PARALLEL
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 12' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
&& cmake --build --preset 'CUDA 12' -- -l $(nproc) \
|
||||
&& cmake --install build --component CUDA --strip
|
||||
|
||||
|
||||
FROM base AS cuda-13
|
||||
ARG CUDA13VERSION=13.0
|
||||
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
||||
ARG PARALLEL
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 13' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
&& cmake --build --preset 'CUDA 13' -- -l $(nproc) \
|
||||
&& cmake --install build --component CUDA --strip
|
||||
|
||||
|
||||
FROM base AS rocm-6
|
||||
FROM base AS rocm-7
|
||||
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
||||
ARG PARALLEL
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'ROCm 6' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
|
||||
&& cmake --install build --component HIP --strip --parallel ${PARALLEL}
|
||||
cmake --preset 'ROCm 7' \
|
||||
&& cmake --build --preset 'ROCm 7' -- -l $(nproc) \
|
||||
&& cmake --install build --component HIP --strip
|
||||
RUN rm -f dist/lib/ollama/rocm/rocblas/library/*gfx90[06]*
|
||||
|
||||
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
|
||||
ARG CMAKEVERSION
|
||||
RUN apt-get update && apt-get install -y curl ccache \
|
||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||
ARG NINJAVERSION
|
||||
RUN apt-get update && apt-get install -y curl ccache unzip \
|
||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 \
|
||||
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux-aarch64.zip \
|
||||
&& unzip /tmp/ninja.zip -d /usr/local/bin \
|
||||
&& rm /tmp/ninja.zip
|
||||
ENV CMAKE_GENERATOR=Ninja
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
ARG PARALLEL
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'JetPack 5' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 5' \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
&& cmake --build --preset 'JetPack 5' -- -l $(nproc) \
|
||||
&& cmake --install build --component CUDA --strip
|
||||
|
||||
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
|
||||
ARG CMAKEVERSION
|
||||
RUN apt-get update && apt-get install -y curl ccache \
|
||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||
ARG NINJAVERSION
|
||||
RUN apt-get update && apt-get install -y curl ccache unzip \
|
||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 \
|
||||
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux-aarch64.zip \
|
||||
&& unzip /tmp/ninja.zip -d /usr/local/bin \
|
||||
&& rm /tmp/ninja.zip
|
||||
ENV CMAKE_GENERATOR=Ninja
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
ARG PARALLEL
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'JetPack 6' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
&& cmake --build --preset 'JetPack 6' -- -l $(nproc) \
|
||||
&& cmake --install build --component CUDA --strip
|
||||
|
||||
FROM base AS vulkan
|
||||
ARG VULKANVERSION
|
||||
RUN ln -s /usr/bin/python3 /usr/bin/python \
|
||||
&& wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk.tar.xz \
|
||||
&& tar xvf /tmp/vulkansdk.tar.xz -C /tmp \
|
||||
&& /tmp/${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \
|
||||
&& /tmp/${VULKANVERSION}/vulkansdk -j 8 shaderc \
|
||||
&& cp -r /tmp/${VULKANVERSION}/x86_64/include/* /usr/local/include/ \
|
||||
&& cp -r /tmp/${VULKANVERSION}/x86_64/lib/* /usr/local/lib \
|
||||
&& cp -r /tmp/${VULKANVERSION}/x86_64/bin/* /usr/local/bin/ \
|
||||
&& rm -rf /tmp/${VULKANVERSION} /tmp/vulkansdk.tar.xz
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'Vulkan' \
|
||||
&& cmake --build --parallel --preset 'Vulkan' \
|
||||
&& cmake --install build --component Vulkan --strip --parallel 8
|
||||
&& cmake --build --preset 'Vulkan' -- -l $(nproc) \
|
||||
&& cmake --install build --component Vulkan --strip
|
||||
|
||||
FROM base AS mlx
|
||||
ARG CUDA13VERSION=13.0
|
||||
@@ -138,20 +152,27 @@ ENV PATH=/usr/local/cuda-13/bin:$PATH
|
||||
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
|
||||
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
|
||||
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
|
||||
ARG PARALLEL
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
COPY x/imagegen/mlx x/imagegen/mlx
|
||||
COPY go.mod go.sum .
|
||||
COPY MLX_VERSION .
|
||||
COPY MLX_VERSION MLX_CORE_VERSION .
|
||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
RUN go mod download
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
|
||||
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
|
||||
--mount=type=bind,from=local-mlx,target=/tmp/local-mlx \
|
||||
--mount=type=bind,from=local-mlx-c,target=/tmp/local-mlx-c \
|
||||
if [ -f /tmp/local-mlx/CMakeLists.txt ]; then \
|
||||
export OLLAMA_MLX_SOURCE=/tmp/local-mlx; \
|
||||
fi \
|
||||
&& if [ -f /tmp/local-mlx-c/CMakeLists.txt ]; then \
|
||||
export OLLAMA_MLX_C_SOURCE=/tmp/local-mlx-c; \
|
||||
fi \
|
||||
&& cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
||||
&& cmake --build --preset 'MLX CUDA 13' -- -l $(nproc) \
|
||||
&& cmake --install build --component MLX --strip
|
||||
|
||||
FROM base AS build
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
@@ -160,16 +181,14 @@ RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
# Clone mlx-c headers for CGO (version from MLX_VERSION file)
|
||||
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
ENV CGO_CFLAGS="${CGO_CFLAGS} -I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
|
||||
ENV CGO_CFLAGS="${CGO_CFLAGS}"
|
||||
ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}"
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .
|
||||
go build -trimpath -buildmode=pie -o /bin/ollama .
|
||||
|
||||
FROM --platform=linux/amd64 scratch AS amd64
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
@@ -186,10 +205,9 @@ COPY --from=jetpack-5 dist/lib/ollama/ /lib/ollama/
|
||||
COPY --from=jetpack-6 dist/lib/ollama/ /lib/ollama/
|
||||
|
||||
FROM scratch AS rocm
|
||||
COPY --from=rocm-6 dist/lib/ollama /lib/ollama
|
||||
COPY --from=rocm-7 dist/lib/ollama /lib/ollama
|
||||
|
||||
FROM ${FLAVOR} AS archive
|
||||
ARG VULKANVERSION
|
||||
COPY --from=cpu dist/lib/ollama /lib/ollama
|
||||
COPY --from=build /bin/ollama /bin/ollama
|
||||
|
||||
|
||||
1
MLX_CORE_VERSION
Normal file
1
MLX_CORE_VERSION
Normal file
@@ -0,0 +1 @@
|
||||
v0.30.6
|
||||
@@ -1063,7 +1063,7 @@ func DefaultOptions() Options {
|
||||
TopP: 0.9,
|
||||
TypicalP: 1.0,
|
||||
RepeatLastN: 64,
|
||||
RepeatPenalty: 1.1,
|
||||
RepeatPenalty: 1.0,
|
||||
PresencePenalty: 0.0,
|
||||
FrequencyPenalty: 0.0,
|
||||
Seed: -1,
|
||||
|
||||
@@ -214,6 +214,7 @@ export default function Settings() {
|
||||
Agent: false,
|
||||
Tools: false,
|
||||
ContextLength: 0,
|
||||
AutoUpdateEnabled: true,
|
||||
});
|
||||
updateSettingsMutation.mutate(defaultSettings);
|
||||
}
|
||||
|
||||
98
cmd/cmd.go
98
cmd/cmd.go
@@ -41,6 +41,7 @@ import (
|
||||
"github.com/ollama/ollama/cmd/tui"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/readline"
|
||||
@@ -131,6 +132,17 @@ func getModelfileName(cmd *cobra.Command) (string, error) {
|
||||
return absName, nil
|
||||
}
|
||||
|
||||
// isLocalhost returns true if the configured Ollama host is a loopback or unspecified address.
|
||||
func isLocalhost() bool {
|
||||
host := envconfig.Host()
|
||||
h, _, _ := net.SplitHostPort(host.Host)
|
||||
if h == "localhost" {
|
||||
return true
|
||||
}
|
||||
ip := net.ParseIP(h)
|
||||
return ip != nil && (ip.IsLoopback() || ip.IsUnspecified())
|
||||
}
|
||||
|
||||
func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
@@ -145,6 +157,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
// Check for --experimental flag for safetensors model creation
|
||||
experimental, _ := cmd.Flags().GetBool("experimental")
|
||||
if experimental {
|
||||
if !isLocalhost() {
|
||||
return errors.New("remote safetensor model creation not yet supported")
|
||||
}
|
||||
// Get Modelfile content - either from -f flag or default to "FROM ."
|
||||
var reader io.Reader
|
||||
filename, err := getModelfileName(cmd)
|
||||
@@ -168,29 +183,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("failed to parse Modelfile: %w", err)
|
||||
}
|
||||
|
||||
// Extract FROM path and configuration
|
||||
var modelDir string
|
||||
mfConfig := &xcreateclient.ModelfileConfig{}
|
||||
|
||||
for _, cmd := range modelfile.Commands {
|
||||
switch cmd.Name {
|
||||
case "model":
|
||||
modelDir = cmd.Args
|
||||
case "template":
|
||||
mfConfig.Template = cmd.Args
|
||||
case "system":
|
||||
mfConfig.System = cmd.Args
|
||||
case "license":
|
||||
mfConfig.License = cmd.Args
|
||||
case "parser":
|
||||
mfConfig.Parser = cmd.Args
|
||||
case "renderer":
|
||||
mfConfig.Renderer = cmd.Args
|
||||
}
|
||||
}
|
||||
|
||||
if modelDir == "" {
|
||||
modelDir = "."
|
||||
modelDir, mfConfig, err := xcreateclient.ConfigFromModelfile(modelfile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Resolve relative paths based on Modelfile location
|
||||
@@ -214,6 +209,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
if filename == "" {
|
||||
// No Modelfile found - check if current directory is an image gen model
|
||||
if create.IsTensorModelDir(".") {
|
||||
if !isLocalhost() {
|
||||
return errors.New("remote safetensor model creation not yet supported")
|
||||
}
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
||||
ModelName: modelName,
|
||||
@@ -406,12 +404,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||
return err
|
||||
}
|
||||
|
||||
requestedCloud := modelref.HasExplicitCloudSource(opts.Model)
|
||||
|
||||
if info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model}); err != nil {
|
||||
return err
|
||||
} else if info.RemoteHost != "" {
|
||||
} else if info.RemoteHost != "" || requestedCloud {
|
||||
// Cloud model, no need to load/unload
|
||||
|
||||
isCloud := strings.HasPrefix(info.RemoteHost, "https://ollama.com")
|
||||
isCloud := requestedCloud || strings.HasPrefix(info.RemoteHost, "https://ollama.com")
|
||||
|
||||
// Check if user is signed in for ollama.com cloud models
|
||||
if isCloud {
|
||||
@@ -422,10 +422,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||
|
||||
if opts.ShowConnect {
|
||||
p.StopAndClear()
|
||||
remoteModel := info.RemoteModel
|
||||
if remoteModel == "" {
|
||||
remoteModel = opts.Model
|
||||
}
|
||||
if isCloud {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", remoteModel)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", remoteModel, info.RemoteHost)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -497,6 +501,20 @@ func generateEmbedding(cmd *cobra.Command, modelName, input string, keepAlive *a
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(parthsareen): consolidate with TUI signin flow
|
||||
func handleCloudAuthorizationError(err error) bool {
|
||||
var authErr api.AuthorizationError
|
||||
if errors.As(err, &authErr) && authErr.StatusCode == http.StatusUnauthorized {
|
||||
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
|
||||
if authErr.SigninURL != "" {
|
||||
fmt.Printf(ConnectInstructions, authErr.SigninURL)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
interactive := true
|
||||
|
||||
@@ -585,17 +603,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
opts.WordWrap = !nowrap
|
||||
|
||||
useImagegen := false
|
||||
if cmd.Flags().Lookup("imagegen") != nil {
|
||||
useImagegen, err = cmd.Flags().GetBool("imagegen")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if useImagegen {
|
||||
opts.Options["use_imagegen_runner"] = true
|
||||
}
|
||||
|
||||
// Fill out the rest of the options based on information about the
|
||||
// model.
|
||||
client, err := api.ClientFromEnvironment()
|
||||
@@ -604,12 +611,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
requestedCloud := modelref.HasExplicitCloudSource(name)
|
||||
|
||||
info, err := func() (*api.ShowResponse, error) {
|
||||
showReq := &api.ShowRequest{Name: name}
|
||||
info, err := client.Show(cmd.Context(), showReq)
|
||||
var se api.StatusError
|
||||
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
|
||||
if requestedCloud {
|
||||
return nil, err
|
||||
}
|
||||
if err := PullHandler(cmd, []string{name}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -618,6 +629,9 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return info, err
|
||||
}()
|
||||
if err != nil {
|
||||
if handleCloudAuthorizationError(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -712,7 +726,13 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
return generateInteractive(cmd, opts)
|
||||
}
|
||||
return generate(cmd, opts)
|
||||
if err := generate(cmd, opts); err != nil {
|
||||
if handleCloudAuthorizationError(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func SigninHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
209
cmd/cmd_test.go
209
cmd/cmd_test.go
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
@@ -705,6 +706,139 @@ func TestRunEmbeddingModelNoInput(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunHandler_CloudAuthErrorOnShow_PrintsSigninMessage(t *testing.T) {
|
||||
var generateCalled bool
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
if err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "unauthorized",
|
||||
"signin_url": "https://ollama.com/signin",
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||
generateCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
t.Cleanup(mockServer.Close)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(t.Context())
|
||||
cmd.Flags().String("keepalive", "", "")
|
||||
cmd.Flags().Bool("truncate", false, "")
|
||||
cmd.Flags().Int("dimensions", 0, "")
|
||||
cmd.Flags().Bool("verbose", false, "")
|
||||
cmd.Flags().Bool("insecure", false, "")
|
||||
cmd.Flags().Bool("nowordwrap", false, "")
|
||||
cmd.Flags().String("format", "", "")
|
||||
cmd.Flags().String("think", "", "")
|
||||
cmd.Flags().Bool("hidethinking", false, "")
|
||||
|
||||
oldStdout := os.Stdout
|
||||
readOut, writeOut, _ := os.Pipe()
|
||||
os.Stdout = writeOut
|
||||
t.Cleanup(func() { os.Stdout = oldStdout })
|
||||
|
||||
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||
|
||||
_ = writeOut.Close()
|
||||
var out bytes.Buffer
|
||||
_, _ = io.Copy(&out, readOut)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("RunHandler returned error: %v", err)
|
||||
}
|
||||
|
||||
if generateCalled {
|
||||
t.Fatal("expected run to stop before /api/generate after unauthorized /api/show")
|
||||
}
|
||||
|
||||
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
|
||||
t.Fatalf("expected sign-in guidance message, got %q", out.String())
|
||||
}
|
||||
|
||||
if !strings.Contains(out.String(), "https://ollama.com/signin") {
|
||||
t.Fatalf("expected signin_url in output, got %q", out.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunHandler_CloudAuthErrorOnGenerate_PrintsSigninMessage(t *testing.T) {
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||
Capabilities: []model.Capability{model.CapabilityCompletion},
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
if err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "unauthorized",
|
||||
"signin_url": "https://ollama.com/signin",
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
t.Cleanup(mockServer.Close)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(t.Context())
|
||||
cmd.Flags().String("keepalive", "", "")
|
||||
cmd.Flags().Bool("truncate", false, "")
|
||||
cmd.Flags().Int("dimensions", 0, "")
|
||||
cmd.Flags().Bool("verbose", false, "")
|
||||
cmd.Flags().Bool("insecure", false, "")
|
||||
cmd.Flags().Bool("nowordwrap", false, "")
|
||||
cmd.Flags().String("format", "", "")
|
||||
cmd.Flags().String("think", "", "")
|
||||
cmd.Flags().Bool("hidethinking", false, "")
|
||||
|
||||
oldStdout := os.Stdout
|
||||
readOut, writeOut, _ := os.Pipe()
|
||||
os.Stdout = writeOut
|
||||
t.Cleanup(func() { os.Stdout = oldStdout })
|
||||
|
||||
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||
|
||||
_ = writeOut.Close()
|
||||
var out bytes.Buffer
|
||||
_, _ = io.Copy(&out, readOut)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("RunHandler returned error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
|
||||
t.Fatalf("expected sign-in guidance message, got %q", out.String())
|
||||
}
|
||||
|
||||
if !strings.Contains(out.String(), "https://ollama.com/signin") {
|
||||
t.Fatalf("expected signin_url in output, got %q", out.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelfileName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -1664,20 +1798,26 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
|
||||
func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
remoteHost string
|
||||
remoteModel string
|
||||
whoamiStatus int
|
||||
whoamiResp any
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "ollama.com cloud model - user signed in",
|
||||
model: "test-cloud-model",
|
||||
remoteHost: "https://ollama.com",
|
||||
remoteModel: "test-model",
|
||||
whoamiStatus: http.StatusOK,
|
||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||
},
|
||||
{
|
||||
name: "ollama.com cloud model - user not signed in",
|
||||
model: "test-cloud-model",
|
||||
remoteHost: "https://ollama.com",
|
||||
remoteModel: "test-model",
|
||||
whoamiStatus: http.StatusUnauthorized,
|
||||
whoamiResp: map[string]string{
|
||||
"error": "unauthorized",
|
||||
@@ -1687,7 +1827,33 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "non-ollama.com remote - no auth check",
|
||||
model: "test-cloud-model",
|
||||
remoteHost: "https://other-remote.com",
|
||||
remoteModel: "test-model",
|
||||
whoamiStatus: http.StatusUnauthorized, // should not be called
|
||||
whoamiResp: nil,
|
||||
},
|
||||
{
|
||||
name: "explicit :cloud model - auth check without remote metadata",
|
||||
model: "kimi-k2.5:cloud",
|
||||
remoteHost: "",
|
||||
remoteModel: "",
|
||||
whoamiStatus: http.StatusOK,
|
||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||
},
|
||||
{
|
||||
name: "explicit -cloud model - auth check without remote metadata",
|
||||
model: "kimi-k2.5:latest-cloud",
|
||||
remoteHost: "",
|
||||
remoteModel: "",
|
||||
whoamiStatus: http.StatusOK,
|
||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||
},
|
||||
{
|
||||
name: "dash cloud-like name without explicit source does not require auth",
|
||||
model: "test-cloud-model",
|
||||
remoteHost: "",
|
||||
remoteModel: "",
|
||||
whoamiStatus: http.StatusUnauthorized, // should not be called
|
||||
whoamiResp: nil,
|
||||
},
|
||||
@@ -1702,7 +1868,7 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||
RemoteHost: tt.remoteHost,
|
||||
RemoteModel: "test-model",
|
||||
RemoteModel: tt.remoteModel,
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
@@ -1715,6 +1881,8 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
case "/api/generate":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
@@ -1727,13 +1895,13 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
cmd.SetContext(t.Context())
|
||||
|
||||
opts := &runOptions{
|
||||
Model: "test-cloud-model",
|
||||
Model: tt.model,
|
||||
ShowConnect: false,
|
||||
}
|
||||
|
||||
err := loadOrUnloadModel(cmd, opts)
|
||||
|
||||
if strings.HasPrefix(tt.remoteHost, "https://ollama.com") {
|
||||
if strings.HasPrefix(tt.remoteHost, "https://ollama.com") || modelref.HasExplicitCloudSource(tt.model) {
|
||||
if !whoamiCalled {
|
||||
t.Error("expected whoami to be called for ollama.com cloud model")
|
||||
}
|
||||
@@ -1760,3 +1928,38 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsLocalhost(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
expected bool
|
||||
}{
|
||||
{"default empty", "", true},
|
||||
{"localhost no port", "localhost", true},
|
||||
{"localhost with port", "localhost:11435", true},
|
||||
{"127.0.0.1 no port", "127.0.0.1", true},
|
||||
{"127.0.0.1 with port", "127.0.0.1:11434", true},
|
||||
{"0.0.0.0 no port", "0.0.0.0", true},
|
||||
{"0.0.0.0 with port", "0.0.0.0:11434", true},
|
||||
{"::1 no port", "::1", true},
|
||||
{"[::1] with port", "[::1]:11434", true},
|
||||
{"loopback with scheme", "http://localhost:11434", true},
|
||||
{"remote hostname", "example.com", false},
|
||||
{"remote hostname with port", "example.com:11434", false},
|
||||
{"remote IP", "192.168.1.1", false},
|
||||
{"remote IP with port", "192.168.1.1:11434", false},
|
||||
{"remote with scheme", "http://example.com:11434", false},
|
||||
{"https remote", "https://example.com:443", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", tt.host)
|
||||
got := isLocalhost()
|
||||
if got != tt.expected {
|
||||
t.Errorf("isLocalhost() with OLLAMA_HOST=%q = %v, want %v", tt.host, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -107,15 +107,12 @@ func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAli
|
||||
}
|
||||
|
||||
if !force && aliases["primary"] != "" {
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
||||
if isCloudModel(ctx, client, aliases["fast"]) {
|
||||
return aliases, false, nil
|
||||
}
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
if isCloudModelName(aliases["primary"]) {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
return aliases, false, nil
|
||||
}
|
||||
delete(aliases, "fast")
|
||||
return aliases, false, nil
|
||||
}
|
||||
|
||||
items, existingModels, cloudModels, client, err := listModels(ctx)
|
||||
@@ -139,10 +136,8 @@ func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAli
|
||||
aliases["primary"] = primary
|
||||
}
|
||||
|
||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
||||
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
if isCloudModelName(aliases["primary"]) {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
@@ -233,6 +233,9 @@ func ModelExists(ctx context.Context, name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
if isCloudModelName(name) {
|
||||
return true
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return false
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"path/filepath"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
@@ -125,13 +124,12 @@ func (d *Droid) Edit(models []string) error {
|
||||
}
|
||||
|
||||
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
|
||||
var newModels []any
|
||||
var defaultModelID string
|
||||
for i, model := range models {
|
||||
maxOutput := 64000
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if isCloudModelName(model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
maxOutput = l.Output
|
||||
}
|
||||
|
||||
@@ -1276,25 +1276,17 @@ func TestDroidEdit_LocalModelDefaultMaxOutput(t *testing.T) {
|
||||
|
||||
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
|
||||
// Verify that every cloud model in cloudModelLimits has a valid output
|
||||
// value that would be used for maxOutputTokens when isCloudModel returns true.
|
||||
// :cloud suffix stripping must also work since that's how users specify them.
|
||||
// value that would be used for maxOutputTokens when the selected model uses
|
||||
// the explicit :cloud source tag.
|
||||
for name, expected := range cloudModelLimits {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
l, ok := lookupCloudModelLimit(name)
|
||||
if !ok {
|
||||
t.Fatalf("lookupCloudModelLimit(%q) returned false", name)
|
||||
}
|
||||
if l.Output != expected.Output {
|
||||
t.Errorf("output = %d, want %d", l.Output, expected.Output)
|
||||
}
|
||||
// Also verify :cloud suffix lookup
|
||||
cloudName := name + ":cloud"
|
||||
l2, ok := lookupCloudModelLimit(cloudName)
|
||||
l, ok := lookupCloudModelLimit(cloudName)
|
||||
if !ok {
|
||||
t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName)
|
||||
}
|
||||
if l2.Output != expected.Output {
|
||||
t.Errorf(":cloud output = %d, want %d", l2.Output, expected.Output)
|
||||
if l.Output != expected.Output {
|
||||
t.Errorf("output = %d, want %d", l.Output, expected.Output)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
@@ -81,6 +82,7 @@ var cloudModelLimits = map[string]cloudModelLimit{
|
||||
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
|
||||
"glm-4.6": {Context: 202_752, Output: 131_072},
|
||||
"glm-4.7": {Context: 202_752, Output: 131_072},
|
||||
"glm-5": {Context: 202_752, Output: 131_072},
|
||||
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
|
||||
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
|
||||
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
|
||||
@@ -90,6 +92,7 @@ var cloudModelLimits = map[string]cloudModelLimit{
|
||||
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
|
||||
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
|
||||
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
|
||||
"qwen3.5": {Context: 262_144, Output: 32_768},
|
||||
}
|
||||
|
||||
// recommendedVRAM maps local recommended models to their approximate VRAM requirement.
|
||||
@@ -324,12 +327,7 @@ func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (stri
|
||||
|
||||
// If the selected model isn't installed, pull it first
|
||||
if !existingModels[selected] {
|
||||
if cloudModels[selected] {
|
||||
// Cloud models only pull a small manifest; no confirmation needed
|
||||
if err := pullModel(ctx, client, selected); err != nil {
|
||||
return "", fmt.Errorf("failed to pull %s: %w", selected, err)
|
||||
}
|
||||
} else {
|
||||
if !isCloudModelName(selected) {
|
||||
msg := fmt.Sprintf("Download %s?", selected)
|
||||
if ok, err := confirmPrompt(msg); err != nil {
|
||||
return "", err
|
||||
@@ -524,7 +522,7 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
|
||||
|
||||
var toPull []string
|
||||
for _, m := range selected {
|
||||
if !existingModels[m] {
|
||||
if !existingModels[m] && !isCloudModelName(m) {
|
||||
toPull = append(toPull, m)
|
||||
}
|
||||
}
|
||||
@@ -550,12 +548,28 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// TODO(parthsareen): consolidate pull logic from call sites
|
||||
func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error {
|
||||
if existingModels[model] {
|
||||
if isCloudModelName(model) || existingModels[model] {
|
||||
return nil
|
||||
}
|
||||
msg := fmt.Sprintf("Download %s?", model)
|
||||
if ok, err := confirmPrompt(msg); err != nil {
|
||||
return confirmAndPull(ctx, client, model)
|
||||
}
|
||||
|
||||
// TODO(parthsareen): pull this out to tui package
|
||||
// ShowOrPull checks if a model exists via client.Show and offers to pull it if not found.
|
||||
func ShowOrPull(ctx context.Context, client *api.Client, model string) error {
|
||||
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
|
||||
return nil
|
||||
}
|
||||
if isCloudModelName(model) {
|
||||
return nil
|
||||
}
|
||||
return confirmAndPull(ctx, client, model)
|
||||
}
|
||||
|
||||
func confirmAndPull(ctx context.Context, client *api.Client, model string) error {
|
||||
if ok, err := confirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
|
||||
return err
|
||||
} else if !ok {
|
||||
return errCancelled
|
||||
@@ -567,26 +581,6 @@ func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[st
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(parthsareen): pull this out to tui package
|
||||
// ShowOrPull checks if a model exists via client.Show and offers to pull it if not found.
|
||||
func ShowOrPull(ctx context.Context, client *api.Client, model string) error {
|
||||
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
|
||||
return nil
|
||||
}
|
||||
// Cloud models only pull a small manifest; skip the download confirmation
|
||||
// TODO(parthsareen): consolidate with cloud config changes
|
||||
if strings.HasSuffix(model, "cloud") {
|
||||
return pullModel(ctx, client, model)
|
||||
}
|
||||
if ok, err := confirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
|
||||
return err
|
||||
} else if !ok {
|
||||
return errCancelled
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
return pullModel(ctx, client, model)
|
||||
}
|
||||
|
||||
func listModels(ctx context.Context) ([]ModelItem, map[string]bool, map[string]bool, *api.Client, error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
@@ -731,10 +725,8 @@ func syncAliases(ctx context.Context, client *api.Client, ac AliasConfigurer, na
|
||||
}
|
||||
aliases["primary"] = model
|
||||
|
||||
if isCloudModel(ctx, client, model) {
|
||||
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
|
||||
aliases["fast"] = model
|
||||
}
|
||||
if isCloudModelName(model) {
|
||||
aliases["fast"] = model
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
@@ -1020,7 +1012,7 @@ Examples:
|
||||
existingAliases = aliases
|
||||
|
||||
// Ensure cloud models are authenticated
|
||||
if isCloudModel(cmd.Context(), client, model) {
|
||||
if isCloudModelName(model) {
|
||||
if err := ensureAuth(cmd.Context(), client, map[string]bool{model: true}, []string{model}); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1209,7 +1201,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
||||
// When user has no models, preserve recommended order.
|
||||
notInstalled := make(map[string]bool)
|
||||
for i := range items {
|
||||
if !existingModels[items[i].Name] {
|
||||
if !existingModels[items[i].Name] && !cloudModels[items[i].Name] {
|
||||
notInstalled[items[i].Name] = true
|
||||
var parts []string
|
||||
if items[i].Description != "" {
|
||||
@@ -1303,7 +1295,8 @@ func IsCloudModelDisabled(ctx context.Context, name string) bool {
|
||||
}
|
||||
|
||||
func isCloudModelName(name string) bool {
|
||||
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
|
||||
// TODO(drifkin): Replace this wrapper with inlining once things stabilize a bit
|
||||
return modelref.HasExplicitCloudSource(name)
|
||||
}
|
||||
|
||||
func filterCloudModels(existing []modelInfo) []modelInfo {
|
||||
|
||||
@@ -426,8 +426,14 @@ func TestBuildModelList_NoExistingModels(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
if !strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("item %q should have description ending with '(not downloaded)', got %q", item.Name, item.Description)
|
||||
if strings.HasSuffix(item.Name, ":cloud") {
|
||||
if strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
} else {
|
||||
if !strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("item %q should have description ending with '(not downloaded)', got %q", item.Name, item.Description)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -492,10 +498,14 @@ func TestBuildModelList_ExistingRecommendedMarked(t *testing.T) {
|
||||
if strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("installed recommended %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
case "minimax-m2.5:cloud", "kimi-k2.5:cloud", "qwen3:8b":
|
||||
case "qwen3:8b":
|
||||
if !strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("non-installed recommended %q should have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
case "minimax-m2.5:cloud", "kimi-k2.5:cloud":
|
||||
if strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -536,7 +546,13 @@ func TestBuildModelList_HasRecommendedCloudModel_OnlyNonInstalledAtBottom(t *tes
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
if !slices.Contains([]string{"kimi-k2.5:cloud", "llama3.2"}, item.Name) {
|
||||
isCloud := strings.HasSuffix(item.Name, ":cloud")
|
||||
isInstalled := slices.Contains([]string{"kimi-k2.5:cloud", "llama3.2"}, item.Name)
|
||||
if isInstalled || isCloud {
|
||||
if strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("installed or cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
} else {
|
||||
if !strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("non-installed %q should have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
@@ -1000,8 +1016,8 @@ func TestShowOrPull_ModelNotFound_ConfirmNo_Cancelled(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPull_CloudModel_SkipsConfirmation(t *testing.T) {
|
||||
// Confirm prompt should NOT be called for cloud models
|
||||
func TestShowOrPull_CloudModel_DoesNotPull(t *testing.T) {
|
||||
// Confirm prompt should NOT be called for explicit cloud models
|
||||
oldHook := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
t.Error("confirm prompt should not be called for cloud models")
|
||||
@@ -1032,8 +1048,115 @@ func TestShowOrPull_CloudModel_SkipsConfirmation(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Errorf("ShowOrPull should succeed for cloud model, got: %v", err)
|
||||
}
|
||||
if !pullCalled {
|
||||
t.Error("expected pull to be called for cloud model without confirmation")
|
||||
if pullCalled {
|
||||
t.Error("expected pull not to be called for cloud model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPull_CloudLegacySuffix_DoesNotPull(t *testing.T) {
|
||||
// Confirm prompt should NOT be called for explicit cloud models
|
||||
oldHook := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
t.Error("confirm prompt should not be called for cloud models")
|
||||
return false, nil
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldHook }()
|
||||
|
||||
var pullCalled bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
case "/api/pull":
|
||||
pullCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"status":"success"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := ShowOrPull(context.Background(), client, "gpt-oss:20b-cloud")
|
||||
if err != nil {
|
||||
t.Errorf("ShowOrPull should succeed for cloud model, got: %v", err)
|
||||
}
|
||||
if pullCalled {
|
||||
t.Error("expected pull not to be called for cloud model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullIfNeeded_CloudModel_DoesNotPull(t *testing.T) {
|
||||
oldHook := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
t.Error("confirm prompt should not be called for cloud models")
|
||||
return false, nil
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldHook }()
|
||||
|
||||
err := pullIfNeeded(context.Background(), nil, map[string]bool{}, "glm-5:cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error for cloud model, got %v", err)
|
||||
}
|
||||
|
||||
err = pullIfNeeded(context.Background(), nil, map[string]bool{}, "gpt-oss:20b-cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error for cloud model with legacy suffix, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectModelsWithSelectors_CloudSelection_DoesNotPull(t *testing.T) {
|
||||
var pullCalled bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"not found"}`)
|
||||
case "/api/tags":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"models":[]}`)
|
||||
case "/api/me":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"name":"test-user"}`)
|
||||
case "/api/pull":
|
||||
pullCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"status":"success"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"not found"}`)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
single := func(title string, items []ModelItem, current string) (string, error) {
|
||||
for _, item := range items {
|
||||
if item.Name == "glm-5:cloud" {
|
||||
return item.Name, nil
|
||||
}
|
||||
}
|
||||
t.Fatalf("expected glm-5:cloud in selector items, got %v", items)
|
||||
return "", nil
|
||||
}
|
||||
|
||||
multi := func(title string, items []ModelItem, preChecked []string) ([]string, error) {
|
||||
return nil, fmt.Errorf("multi selector should not be called")
|
||||
}
|
||||
|
||||
selected, err := selectModelsWithSelectors(context.Background(), "codex", "", single, multi)
|
||||
if err != nil {
|
||||
t.Fatalf("selectModelsWithSelectors returned error: %v", err)
|
||||
}
|
||||
if !slices.Equal(selected, []string{"glm-5:cloud"}) {
|
||||
t.Fatalf("unexpected selected models: %v", selected)
|
||||
}
|
||||
if pullCalled {
|
||||
t.Fatal("expected cloud selection to skip pull")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -502,7 +502,7 @@ func (c *Openclaw) Edit(models []string) error {
|
||||
ollama = make(map[string]any)
|
||||
}
|
||||
|
||||
ollama["baseUrl"] = envconfig.Host().String() + "/v1"
|
||||
ollama["baseUrl"] = envconfig.Host().String()
|
||||
// needed to register provider
|
||||
ollama["apiKey"] = "ollama-local"
|
||||
ollama["api"] = "ollama"
|
||||
|
||||
@@ -589,7 +589,7 @@ const testOpenclawFixture = `{
|
||||
"providers": {
|
||||
"anthropic": {"apiKey": "xxx"},
|
||||
"ollama": {
|
||||
"baseUrl": "http://127.0.0.1:11434/v1",
|
||||
"baseUrl": "http://127.0.0.1:11434",
|
||||
"models": [{"id": "old-model", "customField": "preserved"}]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
)
|
||||
|
||||
// OpenCode implements Runner and Editor for OpenCode integration
|
||||
@@ -26,13 +26,10 @@ type cloudModelLimit struct {
|
||||
}
|
||||
|
||||
// lookupCloudModelLimit returns the token limits for a cloud model.
|
||||
// It tries the exact name first, then strips the ":cloud" suffix.
|
||||
// It normalizes explicit cloud source suffixes before checking the shared limit map.
|
||||
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
||||
if l, ok := cloudModelLimits[name]; ok {
|
||||
return l, true
|
||||
}
|
||||
base := strings.TrimSuffix(name, ":cloud")
|
||||
if base != name {
|
||||
base, stripped := modelref.StripCloudSourceTag(name)
|
||||
if stripped {
|
||||
if l, ok := cloudModelLimits[base]; ok {
|
||||
return l, true
|
||||
}
|
||||
@@ -122,13 +119,18 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
if !ok {
|
||||
ollama = map[string]any{
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama (local)",
|
||||
"name": "Ollama",
|
||||
"options": map[string]any{
|
||||
"baseURL": envconfig.Host().String() + "/v1",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Migrate legacy provider name
|
||||
if name, _ := ollama["name"].(string); name == "Ollama (local)" {
|
||||
ollama["name"] = "Ollama"
|
||||
}
|
||||
|
||||
models, ok := ollama["models"].(map[string]any)
|
||||
if !ok {
|
||||
models = make(map[string]any)
|
||||
@@ -147,8 +149,6 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
|
||||
for _, model := range modelList {
|
||||
if existing, ok := models[model].(map[string]any); ok {
|
||||
// migrate existing models without _launch marker
|
||||
@@ -158,7 +158,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
||||
}
|
||||
}
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if isCloudModelName(model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
existing["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
@@ -172,7 +172,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
"name": model,
|
||||
"_launch": true,
|
||||
}
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if isCloudModelName(model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
entry["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
|
||||
@@ -3,6 +3,8 @@ package config
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -232,6 +234,44 @@ func TestOpenCodeEdit(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("migrate Ollama (local) provider name", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"name":"Ollama (local)","npm":"@ai-sdk/openai-compatible","options":{"baseURL":"http://localhost:11434/v1"}}}}`), 0o644)
|
||||
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
if ollama["name"] != "Ollama" {
|
||||
t.Errorf("provider name not migrated: got %q, want %q", ollama["name"], "Ollama")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserve custom provider name", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"name":"My Custom Ollama","npm":"@ai-sdk/openai-compatible","options":{"baseURL":"http://localhost:11434/v1"}}}}`), 0o644)
|
||||
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
if ollama["name"] != "My Custom Ollama" {
|
||||
t.Errorf("custom provider name was changed: got %q, want %q", ollama["name"], "My Custom Ollama")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
@@ -619,6 +659,54 @@ func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_BackfillsCloudModelLimitOnExistingEntry(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"glm-5"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"models": {
|
||||
"glm-5:cloud": {
|
||||
"name": "glm-5:cloud",
|
||||
"_launch": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`), 0o644)
|
||||
|
||||
if err := o.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entry := readOpenCodeModel(t, configPath, "glm-5:cloud")
|
||||
limit, ok := entry["limit"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("cloud model limit was not added on re-edit")
|
||||
}
|
||||
if limit["context"] != float64(202_752) {
|
||||
t.Errorf("context = %v, want 202752", limit["context"])
|
||||
}
|
||||
if limit["output"] != float64(131_072) {
|
||||
t.Errorf("output = %v, want 131072", limit["output"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupCloudModelLimit(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -626,13 +714,17 @@ func TestLookupCloudModelLimit(t *testing.T) {
|
||||
wantContext int
|
||||
wantOutput int
|
||||
}{
|
||||
{"glm-4.7", true, 202_752, 131_072},
|
||||
{"glm-4.7", false, 0, 0},
|
||||
{"glm-4.7:cloud", true, 202_752, 131_072},
|
||||
{"kimi-k2.5", true, 262_144, 262_144},
|
||||
{"glm-5:cloud", true, 202_752, 131_072},
|
||||
{"gpt-oss:120b-cloud", true, 131_072, 131_072},
|
||||
{"gpt-oss:20b-cloud", true, 131_072, 131_072},
|
||||
{"kimi-k2.5", false, 0, 0},
|
||||
{"kimi-k2.5:cloud", true, 262_144, 262_144},
|
||||
{"deepseek-v3.2", true, 163_840, 65_536},
|
||||
{"deepseek-v3.2", false, 0, 0},
|
||||
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
|
||||
{"qwen3-coder:480b", true, 262_144, 65_536},
|
||||
{"qwen3-coder:480b", false, 0, 0},
|
||||
{"qwen3-coder:480b:cloud", true, 262_144, 65_536},
|
||||
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
|
||||
{"llama3.2", false, 0, 0},
|
||||
{"unknown-model:cloud", false, 0, 0},
|
||||
|
||||
@@ -107,7 +107,8 @@ func (p *Pi) Edit(models []string) error {
|
||||
|
||||
// Build new models list:
|
||||
// 1. Keep user-managed models (no _launch marker) - untouched
|
||||
// 2. Keep ollama-managed models (_launch marker) that are still selected
|
||||
// 2. Keep ollama-managed models (_launch marker) that are still selected,
|
||||
// except stale cloud entries that should be rebuilt below
|
||||
// 3. Add new ollama-managed models
|
||||
var newModels []any
|
||||
for _, m := range existingModels {
|
||||
@@ -117,7 +118,13 @@ func (p *Pi) Edit(models []string) error {
|
||||
if !isPiOllamaModel(modelObj) {
|
||||
newModels = append(newModels, m)
|
||||
} else if selectedSet[id] {
|
||||
// Ollama-managed and still selected - keep it
|
||||
// Rebuild stale managed cloud entries so createConfig refreshes
|
||||
// the whole entry instead of patching it in place.
|
||||
if !hasContextWindow(modelObj) {
|
||||
if _, ok := lookupCloudModelLimit(id); ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
newModels = append(newModels, m)
|
||||
selectedSet[id] = false
|
||||
}
|
||||
@@ -199,12 +206,28 @@ func isPiOllamaModel(cfg map[string]any) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func hasContextWindow(cfg map[string]any) bool {
|
||||
switch v := cfg["contextWindow"].(type) {
|
||||
case float64:
|
||||
return v > 0
|
||||
case int:
|
||||
return v > 0
|
||||
case int64:
|
||||
return v > 0
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// createConfig builds Pi model config with capability detection
|
||||
func createConfig(ctx context.Context, client *api.Client, modelID string) map[string]any {
|
||||
cfg := map[string]any{
|
||||
"id": modelID,
|
||||
"_launch": true,
|
||||
}
|
||||
if l, ok := lookupCloudModelLimit(modelID); ok {
|
||||
cfg["contextWindow"] = l.Context
|
||||
}
|
||||
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID})
|
||||
if err != nil {
|
||||
@@ -223,7 +246,8 @@ func createConfig(ctx context.Context, client *api.Client, modelID string) map[s
|
||||
cfg["reasoning"] = true
|
||||
}
|
||||
|
||||
// Extract context window from ModelInfo
|
||||
// Extract context window from ModelInfo. For known cloud models, the
|
||||
// pre-filled shared limit remains unless the server provides a positive value.
|
||||
for key, val := range resp.ModelInfo {
|
||||
if strings.HasSuffix(key, ".context_length") {
|
||||
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
|
||||
|
||||
@@ -192,6 +192,48 @@ func TestPiEdit(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rebuilds stale existing managed cloud model", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
existingConfig := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"baseUrl": "http://localhost:11434/v1",
|
||||
"api": "openai-completions",
|
||||
"apiKey": "ollama",
|
||||
"models": [
|
||||
{"id": "glm-5:cloud", "_launch": true, "legacyField": "stale"}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := pi.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
cfg := readConfig()
|
||||
providers := cfg["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelsArray := ollama["models"].([]any)
|
||||
modelEntry := modelsArray[0].(map[string]any)
|
||||
|
||||
if modelEntry["contextWindow"] != float64(202_752) {
|
||||
t.Errorf("contextWindow = %v, want 202752", modelEntry["contextWindow"])
|
||||
}
|
||||
input, ok := modelEntry["input"].([]any)
|
||||
if !ok || len(input) != 1 || input[0] != "text" {
|
||||
t.Errorf("input = %v, want [text]", modelEntry["input"])
|
||||
}
|
||||
if _, ok := modelEntry["legacyField"]; ok {
|
||||
t.Error("legacyField should be removed when stale managed cloud entry is rebuilt")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("replaces old models with new ones", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
@@ -798,6 +840,60 @@ func TestCreateConfig(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to cloud context when show fails", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "kimi-k2.5:cloud")
|
||||
|
||||
if cfg["contextWindow"] != 262_144 {
|
||||
t.Errorf("contextWindow = %v, want 262144", cfg["contextWindow"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to cloud context when model info is empty", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "glm-5:cloud")
|
||||
|
||||
if cfg["contextWindow"] != 202_752 {
|
||||
t.Errorf("contextWindow = %v, want 202752", cfg["contextWindow"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to cloud context for dash cloud suffix", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "gpt-oss:120b-cloud")
|
||||
|
||||
if cfg["contextWindow"] != 131_072 {
|
||||
t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("skips zero context length", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
@@ -147,7 +148,13 @@ type signInCheckMsg struct {
|
||||
type clearStatusMsg struct{}
|
||||
|
||||
func (m *model) modelExists(name string) bool {
|
||||
if m.availableModels == nil || name == "" {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
if modelref.HasExplicitCloudSource(name) {
|
||||
return true
|
||||
}
|
||||
if m.availableModels == nil {
|
||||
return false
|
||||
}
|
||||
if m.availableModels[name] {
|
||||
@@ -209,7 +216,7 @@ func (m *model) openMultiModelModal(integration string) {
|
||||
}
|
||||
|
||||
func isCloudModel(name string) bool {
|
||||
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
|
||||
return modelref.HasExplicitCloudSource(name)
|
||||
}
|
||||
|
||||
func cloudStatusDisabled(client *api.Client) bool {
|
||||
|
||||
@@ -54,6 +54,7 @@ type nemotronHModel struct {
|
||||
NGroups uint32 `json:"n_groups"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
HybridOverridePattern hybridPattern `json:"hybrid_override_pattern"`
|
||||
LayersBlockType []string `json:"layers_block_type"`
|
||||
|
||||
// MoE
|
||||
NumExperts uint32 `json:"num_experts"`
|
||||
@@ -162,8 +163,27 @@ func (n *nemotronHModel) denseIntermediateSize() uint32 {
|
||||
|
||||
func (n *nemotronHModel) layerArrays() (headCountKV []uint32, ffnLengths []uint32, err error) {
|
||||
pattern := strings.TrimSpace(string(n.HybridOverridePattern))
|
||||
|
||||
// Convert layers_block_type array to pattern string if hybrid_override_pattern is not set
|
||||
if pattern == "" && len(n.LayersBlockType) > 0 {
|
||||
var sb strings.Builder
|
||||
for _, blockType := range n.LayersBlockType {
|
||||
switch strings.ToLower(blockType) {
|
||||
case "mamba":
|
||||
sb.WriteRune('M')
|
||||
case "moe":
|
||||
sb.WriteRune('E')
|
||||
case "attention":
|
||||
sb.WriteRune('A')
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("nemotron_h: unsupported block type %q in layers_block_type", blockType)
|
||||
}
|
||||
}
|
||||
pattern = sb.String()
|
||||
}
|
||||
|
||||
if pattern == "" {
|
||||
return nil, nil, fmt.Errorf("nemotron_h: hybrid_override_pattern must be set")
|
||||
return nil, nil, fmt.Errorf("nemotron_h: hybrid_override_pattern or layers_block_type must be set")
|
||||
}
|
||||
|
||||
runes := []rune(pattern)
|
||||
|
||||
@@ -12,7 +12,6 @@ To use Ollama with tools that expect the Anthropic API (like Claude Code), set t
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
|
||||
export ANTHROPIC_API_KEY="" # required but ignored
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
```
|
||||
|
||||
@@ -269,7 +268,7 @@ ollama launch claude --config
|
||||
Set the environment variables and run Claude Code:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Or set the environment variables in your shell profile:
|
||||
@@ -277,7 +276,6 @@ Or set the environment variables in your shell profile:
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=""
|
||||
```
|
||||
|
||||
Then run Claude Code with any Ollama model:
|
||||
|
||||
@@ -6,7 +6,7 @@ Ollama provides compatibility with parts of the [OpenAI API](https://platform.op
|
||||
|
||||
## Usage
|
||||
|
||||
### Simple `v1/chat/completions` example
|
||||
### Simple `/v1/chat/completions` example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
@@ -57,7 +57,7 @@ curl -X POST http://localhost:11434/v1/chat/completions \
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
### Simple `v1/responses` example
|
||||
### Simple `/v1/responses` example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
@@ -103,7 +103,7 @@ curl -X POST http://localhost:11434/v1/responses \
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
### v1/chat/completions with vision example
|
||||
### `/v1/chat/completions` with vision example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
|
||||
@@ -51,6 +51,9 @@ Install prerequisites:
|
||||
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
|
||||
- (Optional) VULKAN GPU support
|
||||
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
|
||||
- (Optional) MLX engine support
|
||||
- [CUDA 13+ SDK](https://developer.nvidia.com/cuda-downloads)
|
||||
- [cuDNN 9+](https://developer.nvidia.com/cudnn)
|
||||
|
||||
Then, configure and build the project:
|
||||
|
||||
@@ -101,6 +104,10 @@ Install prerequisites:
|
||||
- (Optional) VULKAN GPU support
|
||||
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
|
||||
- Or install via package manager: `sudo apt install vulkan-sdk` (Ubuntu/Debian) or `sudo dnf install vulkan-sdk` (Fedora/CentOS)
|
||||
- (Optional) MLX engine support
|
||||
- [CUDA 13+ SDK](https://developer.nvidia.com/cuda-downloads)
|
||||
- [cuDNN 9+](https://developer.nvidia.com/cudnn)
|
||||
- OpenBLAS/LAPACK: `sudo apt install libopenblas-dev liblapack-dev liblapacke-dev` (Ubuntu/Debian)
|
||||
> [!IMPORTANT]
|
||||
> Ensure prerequisites are in `PATH` before running CMake.
|
||||
|
||||
@@ -118,6 +125,67 @@ Lastly, run Ollama:
|
||||
go run . serve
|
||||
```
|
||||
|
||||
## MLX Engine (Optional)
|
||||
|
||||
The MLX engine enables running safetensor based models. It requires building the [MLX](https://github.com/ml-explore/mlx) and [MLX-C](https://github.com/ml-explore/mlx-c) shared libraries separately via CMake. On MacOS, MLX leverages the Metal library to run on the GPU, and on Windows and Linux, runs on NVIDIA GPUs via CUDA v13.
|
||||
|
||||
### macOS (Apple Silicon)
|
||||
|
||||
Requires the Metal toolchain. Install [Xcode](https://developer.apple.com/xcode/) first, then:
|
||||
|
||||
```shell
|
||||
xcodebuild -downloadComponent MetalToolchain
|
||||
```
|
||||
|
||||
Verify it's installed correctly (should print "no input files"):
|
||||
|
||||
```shell
|
||||
xcrun metal
|
||||
```
|
||||
|
||||
Then build:
|
||||
|
||||
```shell
|
||||
cmake -B build --preset MLX
|
||||
cmake --build build --preset MLX --parallel
|
||||
cmake --install build --component MLX
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> Without the Metal toolchain, cmake will silently complete with Metal disabled. Check the cmake output for `Setting MLX_BUILD_METAL=OFF` which indicates the toolchain is missing.
|
||||
|
||||
### Windows / Linux (CUDA)
|
||||
|
||||
Requires CUDA 13+ and [cuDNN](https://developer.nvidia.com/cudnn) 9+.
|
||||
|
||||
```shell
|
||||
cmake -B build --preset "MLX CUDA 13"
|
||||
cmake --build build --target mlx --target mlxc --config Release --parallel
|
||||
cmake --install build --component MLX --strip
|
||||
```
|
||||
|
||||
### Local MLX source overrides
|
||||
|
||||
To build against a local checkout of MLX and/or MLX-C (useful for development), set environment variables before running CMake:
|
||||
|
||||
```shell
|
||||
export OLLAMA_MLX_SOURCE=/path/to/mlx
|
||||
export OLLAMA_MLX_C_SOURCE=/path/to/mlx-c
|
||||
```
|
||||
|
||||
For example, using the helper scripts with local mlx and mlx-c repos:
|
||||
```shell
|
||||
OLLAMA_MLX_SOURCE=../mlx OLLAMA_MLX_C_SOURCE=../mlx-c ./scripts/build_linux.sh
|
||||
|
||||
OLLAMA_MLX_SOURCE=../mlx OLLAMA_MLX_C_SOURCE=../mlx-c ./scripts/build_darwin.sh
|
||||
```
|
||||
|
||||
```powershell
|
||||
$env:OLLAMA_MLX_SOURCE="../mlx"
|
||||
$env:OLLAMA_MLX_C_SOURCE="../mlx-c"
|
||||
./scripts/build_darwin.ps1
|
||||
```
|
||||
|
||||
## Docker
|
||||
|
||||
```shell
|
||||
|
||||
29
docs/gpu.mdx
29
docs/gpu.mdx
@@ -61,11 +61,13 @@ Ollama supports the following AMD GPUs via the ROCm library:
|
||||
|
||||
### Linux Support
|
||||
|
||||
| Family | Cards and accelerators |
|
||||
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` |
|
||||
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `SSG` |
|
||||
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` |
|
||||
| Family | Cards and accelerators |
|
||||
| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| AMD Radeon RX | `9070 XT` `9070 GRE` `9070` `9060 XT` `9060 XT LP` `9060` `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7700` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `5700 XT` `5700` `5600 XT` `5500 XT` |
|
||||
| AMD Radeon AI PRO | `R9700` `R9600D` |
|
||||
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` |
|
||||
| AMD Ryzen AI | `Ryzen AI Max+ 395` `Ryzen AI Max 390` `Ryzen AI Max 385` `Ryzen AI 9 HX 475` `Ryzen AI 9 HX 470` `Ryzen AI 9 465` `Ryzen AI 9 HX 375` `Ryzen AI 9 HX 370` `Ryzen AI 9 365` |
|
||||
| AMD Instinct | `MI350X` `MI300X` `MI300A` `MI250X` `MI250` `MI210` `MI100` |
|
||||
|
||||
### Windows Support
|
||||
|
||||
@@ -97,17 +99,20 @@ This table shows some example GPUs that map to these LLVM targets:
|
||||
| **LLVM Target** | **An Example GPU** |
|
||||
|-----------------|---------------------|
|
||||
| gfx908 | Radeon Instinct MI100 |
|
||||
| gfx90a | Radeon Instinct MI210 |
|
||||
| gfx940 | Radeon Instinct MI300 |
|
||||
| gfx941 | |
|
||||
| gfx942 | |
|
||||
| gfx90a | Radeon Instinct MI210/MI250 |
|
||||
| gfx942 | Radeon Instinct MI300X/MI300A |
|
||||
| gfx950 | Radeon Instinct MI350X |
|
||||
| gfx1010 | Radeon RX 5700 XT |
|
||||
| gfx1012 | Radeon RX 5500 XT |
|
||||
| gfx1030 | Radeon PRO V620 |
|
||||
| gfx1100 | Radeon PRO W7900 |
|
||||
| gfx1101 | Radeon PRO W7700 |
|
||||
| gfx1102 | Radeon RX 7600 |
|
||||
|
||||
AMD is working on enhancing ROCm v6 to broaden support for families of GPUs in a
|
||||
future release which should increase support for more GPUs.
|
||||
| gfx1103 | Radeon 780M |
|
||||
| gfx1150 | Ryzen AI 9 HX 375 |
|
||||
| gfx1151 | Ryzen AI Max+ 395 |
|
||||
| gfx1200 | Radeon RX 9070 |
|
||||
| gfx1201 | Radeon RX 9070 XT |
|
||||
|
||||
Reach out on [Discord](https://discord.gg/ollama) or file an
|
||||
[issue](https://github.com/ollama/ollama/issues) for additional help.
|
||||
|
||||
@@ -101,7 +101,7 @@ nvidia-smi
|
||||
|
||||
### Install AMD ROCm drivers (optional)
|
||||
|
||||
[Download and Install](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/tutorial/quick-start.html) ROCm v6.
|
||||
[Download and Install](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/tutorial/quick-start.html) ROCm v7.
|
||||
|
||||
### Start Ollama
|
||||
|
||||
|
||||
@@ -152,7 +152,9 @@ PARAMETER <parameter> <parametervalue>
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
|
||||
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
|
||||
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.0) | float | repeat_penalty 1.0 |
|
||||
| presence_penalty | Penalizes tokens that have already appeared in the generated text to reduce repetition. (Default: 0.0) | float | presence_penalty 1.5 |
|
||||
| frequency_penalty | Penalizes tokens based on how often they have appeared in the generated text. (Default: 0.0) | float | frequency_penalty 1.0 |
|
||||
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
|
||||
| seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | int | seed 42 |
|
||||
| stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate `stop` parameters in a modelfile. | string | stop "AI assistant:" |
|
||||
|
||||
115
internal/modelref/modelref.go
Normal file
115
internal/modelref/modelref.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package modelref
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ModelSource uint8
|
||||
|
||||
const (
|
||||
ModelSourceUnspecified ModelSource = iota
|
||||
ModelSourceLocal
|
||||
ModelSourceCloud
|
||||
)
|
||||
|
||||
var (
|
||||
ErrConflictingSourceSuffix = errors.New("use either :local or :cloud, not both")
|
||||
ErrModelRequired = errors.New("model is required")
|
||||
)
|
||||
|
||||
type ParsedRef struct {
|
||||
Original string
|
||||
Base string
|
||||
Source ModelSource
|
||||
}
|
||||
|
||||
func ParseRef(raw string) (ParsedRef, error) {
|
||||
var zero ParsedRef
|
||||
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return zero, ErrModelRequired
|
||||
}
|
||||
|
||||
base, source, explicit := parseSourceSuffix(raw)
|
||||
if explicit {
|
||||
if _, _, nested := parseSourceSuffix(base); nested {
|
||||
return zero, fmt.Errorf("%w: %q", ErrConflictingSourceSuffix, raw)
|
||||
}
|
||||
}
|
||||
|
||||
return ParsedRef{
|
||||
Original: raw,
|
||||
Base: base,
|
||||
Source: source,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func HasExplicitCloudSource(raw string) bool {
|
||||
parsedRef, err := ParseRef(raw)
|
||||
return err == nil && parsedRef.Source == ModelSourceCloud
|
||||
}
|
||||
|
||||
func HasExplicitLocalSource(raw string) bool {
|
||||
parsedRef, err := ParseRef(raw)
|
||||
return err == nil && parsedRef.Source == ModelSourceLocal
|
||||
}
|
||||
|
||||
func StripCloudSourceTag(raw string) (string, bool) {
|
||||
parsedRef, err := ParseRef(raw)
|
||||
if err != nil || parsedRef.Source != ModelSourceCloud {
|
||||
return strings.TrimSpace(raw), false
|
||||
}
|
||||
|
||||
return parsedRef.Base, true
|
||||
}
|
||||
|
||||
func NormalizePullName(raw string) (string, bool, error) {
|
||||
parsedRef, err := ParseRef(raw)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
if parsedRef.Source != ModelSourceCloud {
|
||||
return parsedRef.Base, false, nil
|
||||
}
|
||||
|
||||
return toLegacyCloudPullName(parsedRef.Base), true, nil
|
||||
}
|
||||
|
||||
func toLegacyCloudPullName(base string) string {
|
||||
if hasExplicitTag(base) {
|
||||
return base + "-cloud"
|
||||
}
|
||||
|
||||
return base + ":cloud"
|
||||
}
|
||||
|
||||
func hasExplicitTag(name string) bool {
|
||||
lastSlash := strings.LastIndex(name, "/")
|
||||
lastColon := strings.LastIndex(name, ":")
|
||||
return lastColon > lastSlash
|
||||
}
|
||||
|
||||
func parseSourceSuffix(raw string) (string, ModelSource, bool) {
|
||||
idx := strings.LastIndex(raw, ":")
|
||||
if idx >= 0 {
|
||||
suffixRaw := strings.TrimSpace(raw[idx+1:])
|
||||
suffix := strings.ToLower(suffixRaw)
|
||||
|
||||
switch suffix {
|
||||
case "cloud":
|
||||
return raw[:idx], ModelSourceCloud, true
|
||||
case "local":
|
||||
return raw[:idx], ModelSourceLocal, true
|
||||
}
|
||||
|
||||
if !strings.Contains(suffixRaw, "/") && strings.HasSuffix(suffix, "-cloud") {
|
||||
return raw[:idx+1] + suffixRaw[:len(suffixRaw)-len("-cloud")], ModelSourceCloud, true
|
||||
}
|
||||
}
|
||||
|
||||
return raw, ModelSourceUnspecified, false
|
||||
}
|
||||
268
internal/modelref/modelref_test.go
Normal file
268
internal/modelref/modelref_test.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package modelref
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseRef(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantBase string
|
||||
wantSource ModelSource
|
||||
wantErr error
|
||||
wantCloud bool
|
||||
wantLocal bool
|
||||
wantStripped string
|
||||
wantStripOK bool
|
||||
}{
|
||||
{
|
||||
name: "cloud suffix",
|
||||
input: "gpt-oss:20b:cloud",
|
||||
wantBase: "gpt-oss:20b",
|
||||
wantSource: ModelSourceCloud,
|
||||
wantCloud: true,
|
||||
wantStripped: "gpt-oss:20b",
|
||||
wantStripOK: true,
|
||||
},
|
||||
{
|
||||
name: "legacy cloud suffix",
|
||||
input: "gpt-oss:20b-cloud",
|
||||
wantBase: "gpt-oss:20b",
|
||||
wantSource: ModelSourceCloud,
|
||||
wantCloud: true,
|
||||
wantStripped: "gpt-oss:20b",
|
||||
wantStripOK: true,
|
||||
},
|
||||
{
|
||||
name: "local suffix",
|
||||
input: "qwen3:8b:local",
|
||||
wantBase: "qwen3:8b",
|
||||
wantSource: ModelSourceLocal,
|
||||
wantLocal: true,
|
||||
wantStripped: "qwen3:8b:local",
|
||||
},
|
||||
{
|
||||
name: "no source suffix",
|
||||
input: "llama3.2",
|
||||
wantBase: "llama3.2",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantStripped: "llama3.2",
|
||||
},
|
||||
{
|
||||
name: "bare cloud name is not explicit cloud",
|
||||
input: "my-cloud-model",
|
||||
wantBase: "my-cloud-model",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantStripped: "my-cloud-model",
|
||||
},
|
||||
{
|
||||
name: "slash in suffix blocks legacy cloud parsing",
|
||||
input: "foo:bar-cloud/baz",
|
||||
wantBase: "foo:bar-cloud/baz",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantStripped: "foo:bar-cloud/baz",
|
||||
},
|
||||
{
|
||||
name: "conflicting source suffixes",
|
||||
input: "foo:cloud:local",
|
||||
wantErr: ErrConflictingSourceSuffix,
|
||||
wantSource: ModelSourceUnspecified,
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
input: " ",
|
||||
wantErr: ErrModelRequired,
|
||||
wantSource: ModelSourceUnspecified,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ParseRef(tt.input)
|
||||
if tt.wantErr != nil {
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Fatalf("ParseRef(%q) error = %v, want %v", tt.input, err, tt.wantErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("ParseRef(%q) returned error: %v", tt.input, err)
|
||||
}
|
||||
|
||||
if got.Base != tt.wantBase {
|
||||
t.Fatalf("base = %q, want %q", got.Base, tt.wantBase)
|
||||
}
|
||||
|
||||
if got.Source != tt.wantSource {
|
||||
t.Fatalf("source = %v, want %v", got.Source, tt.wantSource)
|
||||
}
|
||||
|
||||
if HasExplicitCloudSource(tt.input) != tt.wantCloud {
|
||||
t.Fatalf("HasExplicitCloudSource(%q) = %v, want %v", tt.input, HasExplicitCloudSource(tt.input), tt.wantCloud)
|
||||
}
|
||||
|
||||
if HasExplicitLocalSource(tt.input) != tt.wantLocal {
|
||||
t.Fatalf("HasExplicitLocalSource(%q) = %v, want %v", tt.input, HasExplicitLocalSource(tt.input), tt.wantLocal)
|
||||
}
|
||||
|
||||
stripped, ok := StripCloudSourceTag(tt.input)
|
||||
if ok != tt.wantStripOK {
|
||||
t.Fatalf("StripCloudSourceTag(%q) ok = %v, want %v", tt.input, ok, tt.wantStripOK)
|
||||
}
|
||||
if stripped != tt.wantStripped {
|
||||
t.Fatalf("StripCloudSourceTag(%q) base = %q, want %q", tt.input, stripped, tt.wantStripped)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizePullName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantName string
|
||||
wantCloud bool
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "explicit local strips source",
|
||||
input: "gpt-oss:20b:local",
|
||||
wantName: "gpt-oss:20b",
|
||||
},
|
||||
{
|
||||
name: "explicit cloud with size maps to legacy dash cloud tag",
|
||||
input: "gpt-oss:20b:cloud",
|
||||
wantName: "gpt-oss:20b-cloud",
|
||||
wantCloud: true,
|
||||
},
|
||||
{
|
||||
name: "legacy cloud with size remains stable",
|
||||
input: "gpt-oss:20b-cloud",
|
||||
wantName: "gpt-oss:20b-cloud",
|
||||
wantCloud: true,
|
||||
},
|
||||
{
|
||||
name: "explicit cloud without tag maps to cloud tag",
|
||||
input: "qwen3:cloud",
|
||||
wantName: "qwen3:cloud",
|
||||
wantCloud: true,
|
||||
},
|
||||
{
|
||||
name: "host port without tag keeps host port and appends cloud tag",
|
||||
input: "localhost:11434/library/foo:cloud",
|
||||
wantName: "localhost:11434/library/foo:cloud",
|
||||
wantCloud: true,
|
||||
},
|
||||
{
|
||||
name: "conflicting source suffixes fail",
|
||||
input: "foo:cloud:local",
|
||||
wantErr: ErrConflictingSourceSuffix,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotName, gotCloud, err := NormalizePullName(tt.input)
|
||||
if tt.wantErr != nil {
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Fatalf("NormalizePullName(%q) error = %v, want %v", tt.input, err, tt.wantErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("NormalizePullName(%q) returned error: %v", tt.input, err)
|
||||
}
|
||||
|
||||
if gotName != tt.wantName {
|
||||
t.Fatalf("normalized name = %q, want %q", gotName, tt.wantName)
|
||||
}
|
||||
if gotCloud != tt.wantCloud {
|
||||
t.Fatalf("cloud = %v, want %v", gotCloud, tt.wantCloud)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSourceSuffix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantBase string
|
||||
wantSource ModelSource
|
||||
wantExplicit bool
|
||||
}{
|
||||
{
|
||||
name: "explicit cloud suffix",
|
||||
input: "gpt-oss:20b:cloud",
|
||||
wantBase: "gpt-oss:20b",
|
||||
wantSource: ModelSourceCloud,
|
||||
wantExplicit: true,
|
||||
},
|
||||
{
|
||||
name: "explicit local suffix",
|
||||
input: "qwen3:8b:local",
|
||||
wantBase: "qwen3:8b",
|
||||
wantSource: ModelSourceLocal,
|
||||
wantExplicit: true,
|
||||
},
|
||||
{
|
||||
name: "legacy cloud suffix on tag",
|
||||
input: "gpt-oss:20b-cloud",
|
||||
wantBase: "gpt-oss:20b",
|
||||
wantSource: ModelSourceCloud,
|
||||
wantExplicit: true,
|
||||
},
|
||||
{
|
||||
name: "legacy cloud suffix does not match model segment",
|
||||
input: "my-cloud-model",
|
||||
wantBase: "my-cloud-model",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantExplicit: false,
|
||||
},
|
||||
{
|
||||
name: "legacy cloud suffix blocked when suffix includes slash",
|
||||
input: "foo:bar-cloud/baz",
|
||||
wantBase: "foo:bar-cloud/baz",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantExplicit: false,
|
||||
},
|
||||
{
|
||||
name: "unknown suffix is not explicit source",
|
||||
input: "gpt-oss:clod",
|
||||
wantBase: "gpt-oss:clod",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantExplicit: false,
|
||||
},
|
||||
{
|
||||
name: "uppercase suffix is accepted",
|
||||
input: "gpt-oss:20b:CLOUD",
|
||||
wantBase: "gpt-oss:20b",
|
||||
wantSource: ModelSourceCloud,
|
||||
wantExplicit: true,
|
||||
},
|
||||
{
|
||||
name: "no suffix",
|
||||
input: "llama3.2",
|
||||
wantBase: "llama3.2",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantExplicit: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotBase, gotSource, gotExplicit := parseSourceSuffix(tt.input)
|
||||
if gotBase != tt.wantBase {
|
||||
t.Fatalf("base = %q, want %q", gotBase, tt.wantBase)
|
||||
}
|
||||
if gotSource != tt.wantSource {
|
||||
t.Fatalf("source = %v, want %v", gotSource, tt.wantSource)
|
||||
}
|
||||
if gotExplicit != tt.wantExplicit {
|
||||
t.Fatalf("explicit = %v, want %v", gotExplicit, tt.wantExplicit)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -74,8 +74,7 @@ type LlamaServer interface {
|
||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||
Close() error
|
||||
VRAMSize() uint64 // Total VRAM across all GPUs
|
||||
TotalSize() uint64
|
||||
MemorySize() (total, vram uint64)
|
||||
VRAMByGPU(id ml.DeviceID) uint64
|
||||
Pid() int
|
||||
GetPort() int
|
||||
@@ -685,8 +684,9 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
|
||||
// Windows CUDA should not use mmap for best performance
|
||||
// Linux with a model larger than free space, mmap leads to thrashing
|
||||
// For CPU loads we want the memory to be allocated, not FS cache
|
||||
totalSize, _ := s.MemorySize()
|
||||
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
|
||||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < s.TotalSize() && s.options.UseMMap == nil) ||
|
||||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < totalSize && s.options.UseMMap == nil) ||
|
||||
(len(gpus) == 0 && s.options.UseMMap == nil) ||
|
||||
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
|
||||
(s.options.UseMMap != nil && !*s.options.UseMMap) {
|
||||
@@ -1848,17 +1848,17 @@ func (s *llamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *llmServer) VRAMSize() uint64 {
|
||||
func (s *llmServer) MemorySize() (total, vram uint64) {
|
||||
if s.mem == nil {
|
||||
return 0
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
var mem uint64
|
||||
|
||||
for _, g := range s.mem.GPUs {
|
||||
mem += g.Size()
|
||||
vram += g.Size()
|
||||
}
|
||||
|
||||
total = s.mem.InputWeights + s.mem.CPU.Size() + vram
|
||||
|
||||
// Some elements are always on CPU. However, if we have allocated all layers
|
||||
// on the GPU then include the CPU components as well, to represent complete offloading.
|
||||
noCPULayers := true
|
||||
@@ -1869,25 +1869,11 @@ func (s *llmServer) VRAMSize() uint64 {
|
||||
}
|
||||
}
|
||||
if noCPULayers {
|
||||
mem += s.mem.InputWeights
|
||||
mem += s.mem.CPU.Graph
|
||||
vram += s.mem.InputWeights
|
||||
vram += s.mem.CPU.Graph
|
||||
}
|
||||
|
||||
return mem
|
||||
}
|
||||
|
||||
func (s *llmServer) TotalSize() uint64 {
|
||||
if s.mem == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
mem := s.mem.InputWeights
|
||||
mem += s.mem.CPU.Size()
|
||||
for _, g := range s.mem.GPUs {
|
||||
mem += g.Size()
|
||||
}
|
||||
|
||||
return mem
|
||||
return total, vram
|
||||
}
|
||||
|
||||
func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
@@ -919,7 +920,7 @@ func hasWebSearchTool(tools []anthropic.Tool) bool {
|
||||
}
|
||||
|
||||
func isCloudModelName(name string) bool {
|
||||
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
|
||||
return modelref.HasExplicitCloudSource(name)
|
||||
}
|
||||
|
||||
// extractQueryFromToolCall extracts the search query from a web_search tool call
|
||||
|
||||
@@ -41,8 +41,8 @@ type GatedDeltaNet struct {
|
||||
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
|
||||
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
|
||||
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
|
||||
SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias
|
||||
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
|
||||
SSMDT ml.Tensor `gguf:"ssm_dt,alt:ssm_dt.bias"` // alpha bias
|
||||
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
|
||||
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
|
||||
SSMOut *nn.Linear `gguf:"ssm_out"`
|
||||
|
||||
@@ -135,6 +135,18 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
||||
default:
|
||||
return nil, errors.New("qwen3next: missing linear attention beta/alpha projections")
|
||||
}
|
||||
if gdn.SSMDT == nil {
|
||||
return nil, errors.New("qwen3next: missing linear attention ssm_dt tensor")
|
||||
}
|
||||
if gdn.SSMA == nil {
|
||||
return nil, errors.New("qwen3next: missing linear attention ssm_a tensor")
|
||||
}
|
||||
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
|
||||
return nil, errors.New("qwen3next: missing linear attention ssm_conv1d tensor")
|
||||
}
|
||||
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
|
||||
return nil, errors.New("qwen3next: missing linear attention ssm_norm/ssm_out projections")
|
||||
}
|
||||
|
||||
// Compute gate: softplus(alpha + dt_bias) * -A
|
||||
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
|
||||
@@ -442,6 +454,10 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
||||
vT := v.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, headVDim, nChunks, numVHeads*nSeqs)
|
||||
stateT := state.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
|
||||
|
||||
// Collect chunk outputs and concatenate at the end.
|
||||
// Avoids SET on buffer-less intermediates under partial offload.
|
||||
chunks := make([]ml.Tensor, nChunks)
|
||||
|
||||
for chunk := range nChunks {
|
||||
qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
vTChunk := vT.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
@@ -463,14 +479,7 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
||||
vAttn := vTNewChunk.Mulmat(ctx, attnChunk)
|
||||
coreAttnOutChunk := attnInter.Add(ctx, vAttn)
|
||||
|
||||
v = v.SetInplace(
|
||||
ctx,
|
||||
coreAttnOutChunk,
|
||||
v.Stride(1),
|
||||
v.Stride(2),
|
||||
v.Stride(3),
|
||||
chunk*v.Stride(2),
|
||||
)
|
||||
chunks[chunk] = coreAttnOutChunk
|
||||
|
||||
// Update state for next chunk
|
||||
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
@@ -483,6 +492,20 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
||||
stateT = stateT.Add(ctx, kgdMulVNew)
|
||||
}
|
||||
|
||||
// Use a balanced concat tree so concat work does not balloon on long prompts.
|
||||
for len(chunks) > 1 {
|
||||
merged := make([]ml.Tensor, 0, (len(chunks)+1)/2)
|
||||
for i := 0; i < len(chunks); i += 2 {
|
||||
if i+1 < len(chunks) {
|
||||
merged = append(merged, chunks[i].Concat(ctx, chunks[i+1], 2))
|
||||
} else {
|
||||
merged = append(merged, chunks[i])
|
||||
}
|
||||
}
|
||||
chunks = merged
|
||||
}
|
||||
v = chunks[0]
|
||||
|
||||
// Final reshape
|
||||
coreAttnOut := v.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
|
||||
|
||||
|
||||
@@ -437,6 +437,46 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func (m *Model) Validate() error {
|
||||
if m.Options == nil {
|
||||
return fmt.Errorf("qwen3next: missing model options")
|
||||
}
|
||||
if len(m.Layers) != len(m.Options.isRecurrent) {
|
||||
return fmt.Errorf("qwen3next: layer config mismatch: have %d layers, %d recurrent flags", len(m.Layers), len(m.Options.isRecurrent))
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
if !m.Options.isRecurrent[i] {
|
||||
continue
|
||||
}
|
||||
|
||||
gdn, ok := layer.Operator.(*GatedDeltaNet)
|
||||
if !ok || gdn == nil {
|
||||
return fmt.Errorf("qwen3next: layer %d expected recurrent operator", i)
|
||||
}
|
||||
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
|
||||
return fmt.Errorf("qwen3next: layer %d missing attn_qkv/attn_gate projections", i)
|
||||
}
|
||||
if gdn.SSMBetaAlpha == nil && (gdn.SSMBeta == nil || gdn.SSMAlpha == nil) {
|
||||
return fmt.Errorf("qwen3next: layer %d missing linear attention beta/alpha projections", i)
|
||||
}
|
||||
if gdn.SSMDT == nil {
|
||||
return fmt.Errorf("qwen3next: layer %d missing ssm_dt tensor", i)
|
||||
}
|
||||
if gdn.SSMA == nil {
|
||||
return fmt.Errorf("qwen3next: layer %d missing ssm_a tensor", i)
|
||||
}
|
||||
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
|
||||
return fmt.Errorf("qwen3next: layer %d missing ssm_conv1d tensor", i)
|
||||
}
|
||||
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
|
||||
return fmt.Errorf("qwen3next: layer %d missing ssm_norm/ssm_out projections", i)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
m.positionCache = nil
|
||||
if len(m.mropeSections) > 0 {
|
||||
@@ -450,6 +490,64 @@ var (
|
||||
_ model.MultimodalProcessor = (*Model)(nil)
|
||||
)
|
||||
|
||||
func defaultVHeadReordered(arch string) bool {
|
||||
return arch == "qwen35" || arch == "qwen35moe"
|
||||
}
|
||||
|
||||
func inferRecurrentLayers(headCountKV []uint64, numLayers int, fullAttentionInterval uint32) ([]bool, error) {
|
||||
isRecurrent := make([]bool, numLayers)
|
||||
|
||||
hasZero := false
|
||||
hasFull := false
|
||||
for i := range numLayers {
|
||||
if i >= len(headCountKV) {
|
||||
continue
|
||||
}
|
||||
|
||||
if headCountKV[i] == 0 {
|
||||
isRecurrent[i] = true
|
||||
hasZero = true
|
||||
} else {
|
||||
hasFull = true
|
||||
}
|
||||
}
|
||||
if hasZero && hasFull {
|
||||
return isRecurrent, nil
|
||||
}
|
||||
if !hasFull {
|
||||
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
|
||||
}
|
||||
|
||||
// Compatibility path: older imports store a scalar KV head count and omit
|
||||
// per-layer recurrent flags. Derive the hybrid layout from the interval.
|
||||
interval := int(fullAttentionInterval)
|
||||
if interval == 0 {
|
||||
interval = min(4, numLayers)
|
||||
}
|
||||
if interval <= 0 {
|
||||
return nil, fmt.Errorf("qwen3next: invalid block_count (%d)", numLayers)
|
||||
}
|
||||
if interval > numLayers {
|
||||
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds block_count (%d)", interval, numLayers)
|
||||
}
|
||||
|
||||
hasZero = false
|
||||
hasFull = false
|
||||
for i := range numLayers {
|
||||
isRecurrent[i] = (i+1)%interval != 0
|
||||
if isRecurrent[i] {
|
||||
hasZero = true
|
||||
} else {
|
||||
hasFull = true
|
||||
}
|
||||
}
|
||||
if !hasZero || !hasFull {
|
||||
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) does not produce a mixed recurrent/full layout", interval)
|
||||
}
|
||||
|
||||
return isRecurrent, nil
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
numLayers := int(c.Uint("block_count"))
|
||||
layers := make([]Layer, numLayers)
|
||||
@@ -460,26 +558,14 @@ func New(c fs.Config) (model.Model, error) {
|
||||
HeadCountKV() []uint64
|
||||
}
|
||||
|
||||
var isRecurrent []bool
|
||||
var headCountKV []uint64
|
||||
if hc, ok := c.(headCounts); ok {
|
||||
headCountKV = hc.HeadCountKV()
|
||||
}
|
||||
|
||||
isRecurrent = make([]bool, numLayers)
|
||||
hasZero := false
|
||||
hasFull := false
|
||||
for i := range numLayers {
|
||||
// If KV head count is 0, it's a recurrent layer
|
||||
if i < len(headCountKV) && headCountKV[i] == 0 {
|
||||
isRecurrent[i] = true
|
||||
hasZero = true
|
||||
} else if i < len(headCountKV) && headCountKV[i] > 0 {
|
||||
hasFull = true
|
||||
}
|
||||
}
|
||||
if !hasZero || !hasFull {
|
||||
return nil, fmt.Errorf("qwen3next: invalid attention.head_count_kv array; expected mix of zero and non-zero values")
|
||||
isRecurrent, err := inferRecurrentLayers(headCountKV, numLayers, c.Uint("full_attention_interval"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Determine if MoE
|
||||
@@ -543,7 +629,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
ssmNGroup: int(c.Uint("ssm.group_count")),
|
||||
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
|
||||
convKernelSize: int(c.Uint("ssm.conv_kernel")),
|
||||
vHeadReordered: c.Bool("ssm.v_head_reordered", false),
|
||||
vHeadReordered: c.Bool("ssm.v_head_reordered", defaultVHeadReordered(c.Architecture())),
|
||||
isRecurrent: isRecurrent,
|
||||
mropeSections: slices.Collect(func(yield func(int) bool) {
|
||||
for _, section := range mropeSections {
|
||||
@@ -555,7 +641,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)),
|
||||
}
|
||||
if opts.numKVHeads == 0 {
|
||||
return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value")
|
||||
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
|
||||
}
|
||||
|
||||
// Calculate cache dimensions
|
||||
|
||||
65
model/models/qwen3next/model_new_test.go
Normal file
65
model/models/qwen3next/model_new_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInferRecurrentLayersMixedKVArray(t *testing.T) {
|
||||
got, err := inferRecurrentLayers([]uint64{0, 2, 0, 2}, 4, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||
}
|
||||
|
||||
want := []bool{true, false, true, false}
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInferRecurrentLayersScalarKVDefaultInterval(t *testing.T) {
|
||||
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2, 2, 2}, 8, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||
}
|
||||
|
||||
want := []bool{true, true, true, false, true, true, true, false}
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInferRecurrentLayersScalarKVConfiguredInterval(t *testing.T) {
|
||||
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2}, 6, 3)
|
||||
if err != nil {
|
||||
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||
}
|
||||
|
||||
want := []bool{true, true, false, true, true, false}
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInferRecurrentLayersAllZeroRejects(t *testing.T) {
|
||||
_, err := inferRecurrentLayers([]uint64{0, 0, 0, 0}, 4, 0)
|
||||
if err == nil {
|
||||
t.Fatal("inferRecurrentLayers() expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "must include at least one non-zero value") {
|
||||
t.Fatalf("unexpected error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultVHeadReordered(t *testing.T) {
|
||||
if !defaultVHeadReordered("qwen35") {
|
||||
t.Fatal("defaultVHeadReordered(qwen35) = false, want true")
|
||||
}
|
||||
if !defaultVHeadReordered("qwen35moe") {
|
||||
t.Fatal("defaultVHeadReordered(qwen35moe) = false, want true")
|
||||
}
|
||||
if defaultVHeadReordered("qwen3next") {
|
||||
t.Fatal("defaultVHeadReordered(qwen3next) = true, want false")
|
||||
}
|
||||
}
|
||||
45
model/models/qwen3next/model_validate_test.go
Normal file
45
model/models/qwen3next/model_validate_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
func TestValidateRecurrentLayerRequiresSSMDT(t *testing.T) {
|
||||
m := &Model{
|
||||
Layers: []Layer{{
|
||||
Operator: &GatedDeltaNet{
|
||||
SSMQKV: &nn.Linear{},
|
||||
SSMQKVGate: &nn.Linear{},
|
||||
SSMBeta: &nn.Linear{},
|
||||
SSMAlpha: &nn.Linear{},
|
||||
},
|
||||
}},
|
||||
Options: &Options{
|
||||
isRecurrent: []bool{true},
|
||||
},
|
||||
}
|
||||
|
||||
err := m.Validate()
|
||||
if err == nil {
|
||||
t.Fatal("Validate() expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "missing ssm_dt") {
|
||||
t.Fatalf("unexpected error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateNonRecurrentSkipsLinearChecks(t *testing.T) {
|
||||
m := &Model{
|
||||
Layers: []Layer{{Operator: &FullAttention{}}},
|
||||
Options: &Options{
|
||||
isRecurrent: []bool{false},
|
||||
},
|
||||
}
|
||||
|
||||
if err := m.Validate(); err != nil {
|
||||
t.Fatalf("Validate() error = %v", err)
|
||||
}
|
||||
}
|
||||
@@ -32,9 +32,10 @@ const (
|
||||
)
|
||||
|
||||
type GLM46Parser struct {
|
||||
state glm46ParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
state glm46ParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
callIndex int
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) HasToolSupport() bool {
|
||||
@@ -48,6 +49,7 @@ func (p *GLM46Parser) HasThinkingSupport() bool {
|
||||
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.callIndex = 0
|
||||
return tools
|
||||
}
|
||||
|
||||
@@ -89,6 +91,8 @@ func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string,
|
||||
slog.Warn("glm-4.6 tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCall.Function.Index = p.callIndex
|
||||
p.callIndex++
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
case glm46EventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
@@ -341,6 +345,47 @@ func escapeGLM46Content(s string) string {
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// repairUnclosedArgValues inserts missing </arg_value> closing tags.
|
||||
// GLM models sometimes omit the closing tag, producing XML like:
|
||||
//
|
||||
// <arg_value>value</tool_call>
|
||||
//
|
||||
// instead of:
|
||||
//
|
||||
// <arg_value>value</arg_value></tool_call>
|
||||
func repairUnclosedArgValues(s string) string {
|
||||
var result strings.Builder
|
||||
for {
|
||||
openIdx := strings.Index(s, "<arg_value>")
|
||||
if openIdx == -1 {
|
||||
result.WriteString(s)
|
||||
break
|
||||
}
|
||||
afterOpen := openIdx + len("<arg_value>")
|
||||
closeIdx := strings.Index(s[afterOpen:], "</arg_value>")
|
||||
nextKeyIdx := strings.Index(s[afterOpen:], "<arg_key>")
|
||||
// Check if properly closed before the next <arg_key> (or no next key)
|
||||
if closeIdx != -1 && (nextKeyIdx == -1 || closeIdx < nextKeyIdx) {
|
||||
end := afterOpen + closeIdx + len("</arg_value>")
|
||||
result.WriteString(s[:end])
|
||||
s = s[end:]
|
||||
continue
|
||||
}
|
||||
// Unclosed — insert </arg_value> before the next <arg_key> or at end
|
||||
if nextKeyIdx != -1 {
|
||||
insertAt := afterOpen + nextKeyIdx
|
||||
result.WriteString(s[:insertAt])
|
||||
result.WriteString("</arg_value>")
|
||||
s = s[insertAt:]
|
||||
} else {
|
||||
result.WriteString(s)
|
||||
result.WriteString("</arg_value>")
|
||||
break
|
||||
}
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
|
||||
func parseGLM46ToolCall(raw glm46EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
||||
// Escape any unescaped entities in text content
|
||||
// We need to escape text between tags, but not the tags themselves
|
||||
@@ -349,10 +394,14 @@ func parseGLM46ToolCall(raw glm46EventRawToolCall, tools []api.Tool) (api.ToolCa
|
||||
// Wrap the content in a root element to make it valid XML
|
||||
xmlString := "<tool_call>" + escaped + "</tool_call>"
|
||||
|
||||
// Parse XML into struct
|
||||
// Parse XML into struct, retrying once with repaired XML if it fails
|
||||
var parsed GLMToolCallXML
|
||||
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
|
||||
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
|
||||
parsed = GLMToolCallXML{}
|
||||
repaired := "<tool_call>" + repairUnclosedArgValues(escaped) + "</tool_call>"
|
||||
if err2 := xml.Unmarshal([]byte(repaired), &parsed); err2 != nil {
|
||||
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Extract and trim function name
|
||||
|
||||
@@ -846,6 +846,47 @@ line3</arg_value>`,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unclosed arg_value at end",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `get-weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Paris`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get-weather",
|
||||
Arguments: args(`{"city": "Paris"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unclosed arg_value before next arg_key",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `get-weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Paris<arg_key>unit</arg_key>
|
||||
<arg_value>celsius</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get-weather",
|
||||
Arguments: args(`{"city": "Paris", "unit": "celsius"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple unclosed arg_values",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `get-weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Paris<arg_key>unit</arg_key>
|
||||
<arg_value>celsius`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get-weather",
|
||||
Arguments: args(`{"city": "Paris", "unit": "celsius"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range cases {
|
||||
@@ -860,3 +901,45 @@ line3</arg_value>`,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairUnclosedArgValues(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "already valid",
|
||||
input: `<arg_key>k</arg_key><arg_value>v</arg_value>`,
|
||||
want: `<arg_key>k</arg_key><arg_value>v</arg_value>`,
|
||||
},
|
||||
{
|
||||
name: "unclosed at end",
|
||||
input: `<arg_key>k</arg_key><arg_value>v`,
|
||||
want: `<arg_key>k</arg_key><arg_value>v</arg_value>`,
|
||||
},
|
||||
{
|
||||
name: "unclosed before next arg_key",
|
||||
input: `<arg_key>a</arg_key><arg_value>1<arg_key>b</arg_key><arg_value>2</arg_value>`,
|
||||
want: `<arg_key>a</arg_key><arg_value>1</arg_value><arg_key>b</arg_key><arg_value>2</arg_value>`,
|
||||
},
|
||||
{
|
||||
name: "no arg_value tags",
|
||||
input: `just plain text`,
|
||||
want: `just plain text`,
|
||||
},
|
||||
{
|
||||
name: "multiple unclosed",
|
||||
input: `<arg_key>a</arg_key><arg_value>1<arg_key>b</arg_key><arg_value>2`,
|
||||
want: `<arg_key>a</arg_key><arg_value>1</arg_value><arg_key>b</arg_key><arg_value>2</arg_value>`,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := repairUnclosedArgValues(tc.input)
|
||||
if got != tc.want {
|
||||
t.Errorf("got %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ type GLM47Parser struct {
|
||||
|
||||
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.callIndex = 0
|
||||
// When thinking is enabled (nil or true), the prompt ends with <think>,
|
||||
// so model output starts directly with thinking content (no opening tag).
|
||||
if thinkValue == nil || thinkValue.Bool() {
|
||||
|
||||
@@ -97,3 +97,91 @@ func TestGLM47ParserToolCallEscaping(t *testing.T) {
|
||||
t.Fatalf("expected %#v, got %#v", expected, toolCall)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM47ParserToolCallIndexing(t *testing.T) {
|
||||
parser := GLM47Parser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
input := `plan</think>
|
||||
<tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>
|
||||
<tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>
|
||||
<tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>`
|
||||
|
||||
_, _, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
want := []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||
}
|
||||
if len(calls) != len(want) {
|
||||
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||
}
|
||||
for i := range want {
|
||||
if !toolCallEqual(calls[i], want[i]) {
|
||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM47ParserToolCallIndexingStreaming(t *testing.T) {
|
||||
parser := GLM47Parser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
var all []api.ToolCall
|
||||
|
||||
_, _, calls, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call><tool_call>second<arg_key>b</arg_key>", false)
|
||||
if err != nil {
|
||||
t.Fatalf("step 1 parse failed: %v", err)
|
||||
}
|
||||
all = append(all, calls...)
|
||||
|
||||
_, _, calls, err = parser.Add("<arg_value>2</arg_value></tool_call><tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("step 2 parse failed: %v", err)
|
||||
}
|
||||
all = append(all, calls...)
|
||||
|
||||
want := []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||
}
|
||||
if len(all) != len(want) {
|
||||
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||
}
|
||||
for i := range want {
|
||||
if !toolCallEqual(all[i], want[i]) {
|
||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM47ParserToolCallIndexResetOnInit(t *testing.T) {
|
||||
parser := GLM47Parser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
_, _, _, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("first parse failed: %v", err)
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, nil)
|
||||
_, _, calls, err := parser.Add("plan</think><tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("second parse failed: %v", err)
|
||||
}
|
||||
|
||||
want := api.ToolCall{
|
||||
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||
}
|
||||
if !toolCallEqual(calls[0], want) {
|
||||
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func ParserForName(name string) Parser {
|
||||
case "qwen3-thinking":
|
||||
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
case "qwen3.5":
|
||||
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
p = &Qwen35Parser{}
|
||||
case "qwen3-coder":
|
||||
p = &Qwen3CoderParser{}
|
||||
case "qwen3-vl-instruct":
|
||||
|
||||
@@ -38,6 +38,7 @@ type Qwen3Parser struct {
|
||||
state qwen3ParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
callIndex int
|
||||
hasThinkingSupport bool
|
||||
defaultThinking bool
|
||||
maybeThinkingOpenAtBOL bool
|
||||
@@ -54,6 +55,7 @@ func (p *Qwen3Parser) HasThinkingSupport() bool {
|
||||
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.buffer.Reset()
|
||||
p.callIndex = 0
|
||||
|
||||
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
||||
if thinkValue == nil {
|
||||
@@ -106,6 +108,8 @@ func (p *Qwen3Parser) Add(s string, done bool) (content string, thinking string,
|
||||
slog.Warn("qwen3 tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCall.Function.Index = p.callIndex
|
||||
p.callIndex++
|
||||
calls = append(calls, toolCall)
|
||||
case qwen3EventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
|
||||
238
model/parsers/qwen35.go
Normal file
238
model/parsers/qwen35.go
Normal file
@@ -0,0 +1,238 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type qwen35ParserState int
|
||||
|
||||
const (
|
||||
qwen35ParserStateCollectingThinking qwen35ParserState = iota
|
||||
qwen35ParserStateThinkingDoneEatingWhitespace
|
||||
qwen35ParserStateCollectingContent
|
||||
)
|
||||
|
||||
const (
|
||||
qwen35ThinkingOpenTag = "<think>"
|
||||
qwen35ThinkingCloseTag = "</think>"
|
||||
)
|
||||
|
||||
// Qwen35Parser handles qwen3.5 reasoning extraction and delegates post-thinking
|
||||
// content (including XML tool calls) to Qwen3CoderParser.
|
||||
type Qwen35Parser struct {
|
||||
toolParser Qwen3CoderParser
|
||||
|
||||
state qwen35ParserState
|
||||
buffer strings.Builder
|
||||
// Some checkpoints may emit an explicit leading <think> even when the
|
||||
// prompt already opened thinking. Strip at most one such tag.
|
||||
allowLeadingThinkOpenTag bool
|
||||
}
|
||||
|
||||
func (p *Qwen35Parser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *Qwen35Parser) HasThinkingSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *Qwen35Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.buffer.Reset()
|
||||
p.toolParser = Qwen3CoderParser{}
|
||||
p.toolParser.Init(tools, nil, nil)
|
||||
|
||||
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
||||
if thinkValue == nil {
|
||||
thinkingEnabled = true
|
||||
}
|
||||
|
||||
assistantPrefill := lastMessage != nil && lastMessage.Role == "assistant" && lastMessage.Content != ""
|
||||
if thinkingEnabled && !assistantPrefill {
|
||||
p.state = qwen35ParserStateCollectingThinking
|
||||
p.allowLeadingThinkOpenTag = true
|
||||
} else {
|
||||
p.state = qwen35ParserStateCollectingContent
|
||||
p.allowLeadingThinkOpenTag = false
|
||||
}
|
||||
|
||||
return tools
|
||||
}
|
||||
|
||||
type qwen35Event interface {
|
||||
isQwen35Event()
|
||||
}
|
||||
|
||||
type qwen35EventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (qwen35EventContent) isQwen35Event() {}
|
||||
|
||||
type qwen35EventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (qwen35EventThinkingContent) isQwen35Event() {}
|
||||
|
||||
func (p *Qwen35Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case qwen35EventContent:
|
||||
parsedContent, _, parsedCalls, err := p.toolParser.Add(event.content, done)
|
||||
if err != nil {
|
||||
slog.Warn("qwen3.5 tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
contentSb.WriteString(parsedContent)
|
||||
calls = append(calls, parsedCalls...)
|
||||
case qwen35EventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), calls, nil
|
||||
}
|
||||
|
||||
func (p *Qwen35Parser) parseEvents() []qwen35Event {
|
||||
var all []qwen35Event
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []qwen35Event
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(all) > 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "qwen3.5 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
func (p *Qwen35Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
|
||||
return splitAtTag(&p.buffer, tag, trimAfter)
|
||||
}
|
||||
|
||||
func (p *Qwen35Parser) eatLeadingWhitespaceAndTransitionTo(nextState qwen35ParserState) ([]qwen35Event, bool) {
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
if trimmed == "" {
|
||||
return nil, false
|
||||
}
|
||||
p.state = nextState
|
||||
p.buffer.WriteString(trimmed)
|
||||
return nil, true
|
||||
}
|
||||
|
||||
// maybeConsumeLeadingThinkOpenTag handles a single optional leading <think> tag.
|
||||
// Returns (handled, shouldContinueParsingNow).
|
||||
func (p *Qwen35Parser) maybeConsumeLeadingThinkOpenTag(acc string) (bool, bool) {
|
||||
if !p.allowLeadingThinkOpenTag {
|
||||
return false, false
|
||||
}
|
||||
|
||||
trimmed := strings.TrimLeftFunc(acc, unicode.IsSpace)
|
||||
if strings.HasPrefix(trimmed, qwen35ThinkingOpenTag) {
|
||||
after := strings.TrimPrefix(trimmed, qwen35ThinkingOpenTag)
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
if after == "" {
|
||||
return true, false
|
||||
}
|
||||
p.allowLeadingThinkOpenTag = false
|
||||
return true, true
|
||||
}
|
||||
|
||||
if strings.HasPrefix(qwen35ThinkingOpenTag, trimmed) {
|
||||
return true, false
|
||||
}
|
||||
|
||||
p.allowLeadingThinkOpenTag = false
|
||||
return false, false
|
||||
}
|
||||
|
||||
func (p *Qwen35Parser) eat() ([]qwen35Event, bool) {
|
||||
var events []qwen35Event
|
||||
|
||||
switch p.state {
|
||||
case qwen35ParserStateCollectingThinking:
|
||||
acc := p.buffer.String()
|
||||
|
||||
if handled, continueNow := p.maybeConsumeLeadingThinkOpenTag(acc); handled {
|
||||
return events, continueNow
|
||||
}
|
||||
|
||||
if strings.Contains(acc, qwen35ThinkingCloseTag) {
|
||||
thinking, remaining := p.splitAtTag(qwen35ThinkingCloseTag, true)
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, qwen35EventThinkingContent{content: thinking})
|
||||
}
|
||||
if remaining == "" {
|
||||
p.state = qwen35ParserStateThinkingDoneEatingWhitespace
|
||||
} else {
|
||||
p.state = qwen35ParserStateCollectingContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(acc, qwen35ThinkingCloseTag); overlapLen > 0 {
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, qwen35EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
whitespaceLen := trailingWhitespaceLen(acc)
|
||||
ambiguousStart := len(acc) - whitespaceLen
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, qwen35EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
|
||||
case qwen35ParserStateThinkingDoneEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(qwen35ParserStateCollectingContent)
|
||||
|
||||
case qwen35ParserStateCollectingContent:
|
||||
if p.buffer.Len() == 0 {
|
||||
return events, false
|
||||
}
|
||||
|
||||
content := p.buffer.String()
|
||||
p.buffer.Reset()
|
||||
if len(content) > 0 {
|
||||
events = append(events, qwen35EventContent{content: content})
|
||||
}
|
||||
return events, false
|
||||
|
||||
default:
|
||||
slog.Warn("qwen3.5 parser entered unknown state; resetting to content mode", "state", p.state)
|
||||
p.state = qwen35ParserStateCollectingContent
|
||||
return events, false
|
||||
}
|
||||
}
|
||||
382
model/parsers/qwen35_test.go
Normal file
382
model/parsers/qwen35_test.go
Normal file
@@ -0,0 +1,382 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestQwen35ParserXMLToolCall(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
t.Fatal("expected qwen3.5 parser")
|
||||
}
|
||||
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: func() *api.ToolPropertiesMap {
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
props.Set("days", api.ToolProperty{Type: api.PropertyType{"integer"}})
|
||||
return props
|
||||
}(),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
parser.Init(tools, nil, &api.ThinkValue{Value: false})
|
||||
input := "<tool_call><function=get_weather><parameter=location>\nSan Francisco\n</parameter><parameter=days>\n3\n</parameter></function></tool_call>"
|
||||
content, thinking, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
t.Fatalf("expected empty content, got %q", content)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
|
||||
if calls[0].Function.Name != "get_weather" {
|
||||
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
|
||||
}
|
||||
|
||||
location, ok := calls[0].Function.Arguments.Get("location")
|
||||
if !ok || location != "San Francisco" {
|
||||
t.Fatalf("expected location %q, got %v", "San Francisco", location)
|
||||
}
|
||||
|
||||
days, ok := calls[0].Function.Arguments.Get("days")
|
||||
if !ok || days != 3 {
|
||||
t.Fatalf("expected days %d, got %v", 3, days)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ParserThinkingWithExplicitOpeningTag(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
t.Fatal("expected qwen3.5 parser")
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
content, thinking, calls, err := parser.Add("<think>\nLet me think...</think>Answer.", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "Let me think..." {
|
||||
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
|
||||
}
|
||||
if content != "Answer." {
|
||||
t.Fatalf("expected content %q, got %q", "Answer.", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ParserAssistantPrefillStartsInContent(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
t.Fatal("expected qwen3.5 parser")
|
||||
}
|
||||
|
||||
last := &api.Message{Role: "assistant", Content: "Prefilled response start"}
|
||||
parser.Init(nil, last, nil)
|
||||
|
||||
content, thinking, calls, err := parser.Add(" and continued", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected no thinking for assistant prefill continuation, got %q", thinking)
|
||||
}
|
||||
if content != " and continued" {
|
||||
t.Fatalf("expected content %q, got %q", " and continued", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ParserToolCallEmittedInThinkingIsNotParsed(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
t.Fatal("expected qwen3.5 parser")
|
||||
}
|
||||
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: func() *api.ToolPropertiesMap {
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
return props
|
||||
}(),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
parser.Init(tools, nil, &api.ThinkValue{Value: true})
|
||||
input := `Need weather lookup<tool_call><function=get_weather><parameter=location>
|
||||
SF
|
||||
</parameter></function></tool_call>`
|
||||
content, thinking, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
t.Fatalf("expected empty content, got %q", content)
|
||||
}
|
||||
expectedThinking := `Need weather lookup<tool_call><function=get_weather><parameter=location>
|
||||
SF
|
||||
</parameter></function></tool_call>`
|
||||
if thinking != expectedThinking {
|
||||
t.Fatalf("expected thinking %q, got %q", expectedThinking, thinking)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls before </think>, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ParserToolCallAfterThinkingCloseIsParsed(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
t.Fatal("expected qwen3.5 parser")
|
||||
}
|
||||
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: func() *api.ToolPropertiesMap {
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
return props
|
||||
}(),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
parser.Init(tools, nil, &api.ThinkValue{Value: true})
|
||||
input := `Need weather lookup</think><tool_call><function=get_weather><parameter=location>
|
||||
SF
|
||||
</parameter></function></tool_call>`
|
||||
content, thinking, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
t.Fatalf("expected empty content, got %q", content)
|
||||
}
|
||||
if thinking != "Need weather lookup" {
|
||||
t.Fatalf("expected thinking %q, got %q", "Need weather lookup", thinking)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call after </think>, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "get_weather" {
|
||||
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
|
||||
}
|
||||
|
||||
location, ok := calls[0].Function.Arguments.Get("location")
|
||||
if !ok || location != "SF" {
|
||||
t.Fatalf("expected location %q, got %v", "SF", location)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ParserThinkingDisabledPassesContentThrough(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
t.Fatal("expected qwen3.5 parser")
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
content, thinking, calls, err := parser.Add("Plain answer without think close tag.", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||
}
|
||||
if content != "Plain answer without think close tag." {
|
||||
t.Fatalf("expected content %q, got %q", "Plain answer without think close tag.", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ParserThinkingDisabledWithCloseTagTreatsAsContent(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
t.Fatal("expected qwen3.5 parser")
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
content, thinking, calls, err := parser.Add("</think>Some content after spurious tag.", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||
}
|
||||
if content != "</think>Some content after spurious tag." {
|
||||
t.Fatalf("expected content %q, got %q", "</think>Some content after spurious tag.", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ParserLeadingThinkCloseProducesContent(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
t.Fatal("expected qwen3.5 parser")
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
content, thinking, calls, err := parser.Add("</think>The final answer.", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||
}
|
||||
if content != "The final answer." {
|
||||
t.Fatalf("expected content %q, got %q", "The final answer.", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ParserStreamingSplitThinkCloseTag(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
t.Fatal("expected qwen3.5 parser")
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
content, thinking, calls, err := parser.Add("Reasoning text</thi", false)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed on first chunk: %v", err)
|
||||
}
|
||||
if thinking != "Reasoning text" {
|
||||
t.Fatalf("expected thinking %q, got %q", "Reasoning text", thinking)
|
||||
}
|
||||
if content != "" {
|
||||
t.Fatalf("expected empty content, got %q", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
|
||||
content, thinking, calls, err = parser.Add("nk>The final answer.", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed on second chunk: %v", err)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected no additional thinking on second chunk, got %q", thinking)
|
||||
}
|
||||
if content != "The final answer." {
|
||||
t.Fatalf("expected content %q, got %q", "The final answer.", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ParserStreamingEatsWhitespaceAfterThinkClose(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
t.Fatal("expected qwen3.5 parser")
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
content, thinking, calls, err := parser.Add("Reasoning</think>", false)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed on first chunk: %v", err)
|
||||
}
|
||||
if thinking != "Reasoning" {
|
||||
t.Fatalf("expected thinking %q, got %q", "Reasoning", thinking)
|
||||
}
|
||||
if content != "" {
|
||||
t.Fatalf("expected empty content, got %q", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
|
||||
content, thinking, calls, err = parser.Add("\n \t", false)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed on whitespace chunk: %v", err)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected no thinking on whitespace chunk, got %q", thinking)
|
||||
}
|
||||
if content != "" {
|
||||
t.Fatalf("expected whitespace after </think> to be eaten, got content %q", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
|
||||
content, thinking, calls, err = parser.Add("The final answer.", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed on content chunk: %v", err)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected no additional thinking, got %q", thinking)
|
||||
}
|
||||
if content != "The final answer." {
|
||||
t.Fatalf("expected content %q, got %q", "The final answer.", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ParserThinkingTruncatedWithoutCloseTag(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
t.Fatal("expected qwen3.5 parser")
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
content, thinking, calls, err := parser.Add("Reasoning that never closes", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "Reasoning that never closes" {
|
||||
t.Fatalf("expected thinking %q, got %q", "Reasoning that never closes", thinking)
|
||||
}
|
||||
if content != "" {
|
||||
t.Fatalf("expected empty content, got %q", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
@@ -230,3 +230,89 @@ func TestQwen35ParserRespectsNoThink(t *testing.T) {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserToolCallIndexing(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
input := `<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>
|
||||
<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>
|
||||
<tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`
|
||||
_, _, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
want := []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||
}
|
||||
if len(calls) != len(want) {
|
||||
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||
}
|
||||
for i := range want {
|
||||
if !toolCallEqual(calls[i], want[i]) {
|
||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserToolCallIndexingStreaming(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
var all []api.ToolCall
|
||||
|
||||
_, _, calls, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call><tool_call>{"name":"second","arguments":{"b":"2"}`, false)
|
||||
if err != nil {
|
||||
t.Fatalf("step 1 parse failed: %v", err)
|
||||
}
|
||||
all = append(all, calls...)
|
||||
|
||||
_, _, calls, err = parser.Add(`}</tool_call><tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`, true)
|
||||
if err != nil {
|
||||
t.Fatalf("step 2 parse failed: %v", err)
|
||||
}
|
||||
all = append(all, calls...)
|
||||
|
||||
want := []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||
}
|
||||
if len(all) != len(want) {
|
||||
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||
}
|
||||
for i := range want {
|
||||
if !toolCallEqual(all[i], want[i]) {
|
||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserToolCallIndexResetOnInit(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
_, _, _, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>`, true)
|
||||
if err != nil {
|
||||
t.Fatalf("first parse failed: %v", err)
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
_, _, calls, err := parser.Add(`<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>`, true)
|
||||
if err != nil {
|
||||
t.Fatalf("second parse failed: %v", err)
|
||||
}
|
||||
|
||||
want := api.ToolCall{
|
||||
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||
}
|
||||
if !toolCallEqual(calls[0], want) {
|
||||
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,9 +29,10 @@ const (
|
||||
)
|
||||
|
||||
type Qwen3CoderParser struct {
|
||||
state qwenParserState
|
||||
acc strings.Builder
|
||||
tools []api.Tool
|
||||
state qwenParserState
|
||||
acc strings.Builder
|
||||
tools []api.Tool
|
||||
callIndex int
|
||||
}
|
||||
|
||||
func (p *Qwen3CoderParser) HasToolSupport() bool {
|
||||
@@ -44,6 +45,7 @@ func (p *Qwen3CoderParser) HasThinkingSupport() bool {
|
||||
|
||||
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.callIndex = 0
|
||||
return tools // Qwen doesn't modify tools
|
||||
}
|
||||
|
||||
@@ -62,6 +64,8 @@ func (p *Qwen3CoderParser) Add(s string, done bool) (content string, thinking st
|
||||
slog.Warn("qwen tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCall.Function.Index = p.callIndex
|
||||
p.callIndex++
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
case qwenEventContent:
|
||||
// TODO(drifkin): if the same turn contains multiple interleaved content
|
||||
|
||||
@@ -1035,6 +1035,92 @@ func TestQwenToolCallValueParsing(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3CoderParserToolCallIndexing(t *testing.T) {
|
||||
parser := Qwen3CoderParser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
input := `<tool_call><function=first><parameter=a>1</parameter></function></tool_call>
|
||||
<tool_call><function=second><parameter=b>2</parameter></function></tool_call>
|
||||
<tool_call><function=third><parameter=c>3</parameter></function></tool_call>`
|
||||
_, _, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
want := []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
|
||||
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
|
||||
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
|
||||
}
|
||||
if len(calls) != len(want) {
|
||||
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||
}
|
||||
for i := range want {
|
||||
if !toolCallEqual(calls[i], want[i]) {
|
||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3CoderParserToolCallIndexingStreaming(t *testing.T) {
|
||||
parser := Qwen3CoderParser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
var all []api.ToolCall
|
||||
|
||||
_, _, calls, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call><tool_call><function=second>", false)
|
||||
if err != nil {
|
||||
t.Fatalf("step 1 parse failed: %v", err)
|
||||
}
|
||||
all = append(all, calls...)
|
||||
|
||||
_, _, calls, err = parser.Add("<parameter=b>2</parameter></function></tool_call><tool_call><function=third><parameter=c>3</parameter></function></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("step 2 parse failed: %v", err)
|
||||
}
|
||||
all = append(all, calls...)
|
||||
|
||||
want := []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
|
||||
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
|
||||
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
|
||||
}
|
||||
if len(all) != len(want) {
|
||||
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||
}
|
||||
for i := range want {
|
||||
if !toolCallEqual(all[i], want[i]) {
|
||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3CoderParserToolCallIndexResetOnInit(t *testing.T) {
|
||||
parser := Qwen3CoderParser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
_, _, _, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("first parse failed: %v", err)
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, nil)
|
||||
_, _, calls, err := parser.Add("<tool_call><function=second><parameter=b>2</parameter></function></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("second parse failed: %v", err)
|
||||
}
|
||||
|
||||
want := api.ToolCall{
|
||||
Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 0},
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||
}
|
||||
if !toolCallEqual(calls[0], want) {
|
||||
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenXMLTransform(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
|
||||
@@ -8,7 +8,21 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type GlmOcrRenderer struct{}
|
||||
type GlmOcrRenderer struct {
|
||||
useImgTags bool
|
||||
}
|
||||
|
||||
func (r *GlmOcrRenderer) renderContent(message api.Message, imageOffset int) (string, int) {
|
||||
var sb strings.Builder
|
||||
for range message.Images {
|
||||
if r.useImgTags {
|
||||
sb.WriteString(fmt.Sprintf("[img-%d]", imageOffset))
|
||||
imageOffset++
|
||||
}
|
||||
}
|
||||
sb.WriteString(message.Content)
|
||||
return sb.String(), imageOffset
|
||||
}
|
||||
|
||||
func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
@@ -38,11 +52,14 @@ func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkV
|
||||
thinkingExplicitlySet = true
|
||||
}
|
||||
|
||||
imageOffset := 0
|
||||
for i, message := range messages {
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|user|>\n")
|
||||
sb.WriteString(message.Content)
|
||||
content, nextOffset := r.renderContent(message, imageOffset)
|
||||
imageOffset = nextOffset
|
||||
sb.WriteString(content)
|
||||
if thinkingExplicitlySet && !enableThinking && !strings.HasSuffix(message.Content, "/nothink") {
|
||||
sb.WriteString("/nothink")
|
||||
}
|
||||
|
||||
99
model/renderers/glmocr_test.go
Normal file
99
model/renderers/glmocr_test.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestGlmOcrRenderer_Images(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
renderer *GlmOcrRenderer
|
||||
messages []api.Message
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "use_img_tags_single_image",
|
||||
renderer: &GlmOcrRenderer{useImgTags: true},
|
||||
messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Describe this image.",
|
||||
Images: []api.ImageData{api.ImageData("img1")},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>\n[img-0]Describe this image.<|assistant|>\n",
|
||||
},
|
||||
{
|
||||
name: "use_img_tags_multiple_images",
|
||||
renderer: &GlmOcrRenderer{useImgTags: true},
|
||||
messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Describe these images.",
|
||||
Images: []api.ImageData{api.ImageData("img1"), api.ImageData("img2")},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>\n[img-0][img-1]Describe these images.<|assistant|>\n",
|
||||
},
|
||||
{
|
||||
name: "multi_turn_increments_image_offset",
|
||||
renderer: &GlmOcrRenderer{useImgTags: true},
|
||||
messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "First image",
|
||||
Images: []api.ImageData{api.ImageData("img1")},
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Processed.",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Second image",
|
||||
Images: []api.ImageData{api.ImageData("img2")},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>\n[img-0]First image<|assistant|>\n<think></think>\nProcessed.\n<|user|>\n[img-1]Second image<|assistant|>\n",
|
||||
},
|
||||
{
|
||||
name: "default_no_img_tags",
|
||||
renderer: &GlmOcrRenderer{},
|
||||
messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "No image tags expected.",
|
||||
Images: []api.ImageData{api.ImageData("img1")},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>\nNo image tags expected.<|assistant|>\n",
|
||||
},
|
||||
{
|
||||
name: "no_images_content_unchanged",
|
||||
renderer: &GlmOcrRenderer{useImgTags: true},
|
||||
messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Text only message.",
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>\nText only message.<|assistant|>\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.renderer.Render(tt.messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Render() error = %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tt.expected, got); diff != "" {
|
||||
t.Fatalf("Render() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
194
model/renderers/qwen35.go
Normal file
194
model/renderers/qwen35.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const (
|
||||
qwen35ThinkOpenTag = "<think>"
|
||||
qwen35ThinkCloseTag = "</think>"
|
||||
qwen35ToolPostamble = `
|
||||
</tools>
|
||||
|
||||
If you choose to call a function ONLY reply in the following format with NO suffix:
|
||||
|
||||
<tool_call>
|
||||
<function=example_function_name>
|
||||
<parameter=example_parameter_1>
|
||||
value_1
|
||||
</parameter>
|
||||
<parameter=example_parameter_2>
|
||||
This is the value for the second parameter
|
||||
that can span
|
||||
multiple lines
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
|
||||
<IMPORTANT>
|
||||
Reminder:
|
||||
- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags
|
||||
- Required parameters MUST be specified
|
||||
- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
|
||||
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
|
||||
</IMPORTANT>`
|
||||
)
|
||||
|
||||
type Qwen35Renderer struct {
|
||||
isThinking bool
|
||||
|
||||
emitEmptyThinkOnNoThink bool
|
||||
useImgTags bool
|
||||
}
|
||||
|
||||
func (r *Qwen35Renderer) renderContent(content api.Message, imageOffset int) (string, int) {
|
||||
// This assumes all images are at the front of the message - same assumption as ollama/ollama/runner.go
|
||||
var subSb strings.Builder
|
||||
for range content.Images {
|
||||
if r.useImgTags {
|
||||
subSb.WriteString(fmt.Sprintf("[img-%d]", imageOffset))
|
||||
imageOffset++
|
||||
} else {
|
||||
subSb.WriteString("<|vision_start|><|image_pad|><|vision_end|>")
|
||||
}
|
||||
}
|
||||
// TODO: support videos
|
||||
|
||||
subSb.WriteString(content.Content)
|
||||
return subSb.String(), imageOffset
|
||||
}
|
||||
|
||||
func splitQwen35ReasoningContent(content, messageThinking string, isThinking bool) (reasoning string, remaining string) {
|
||||
if isThinking && messageThinking != "" {
|
||||
return strings.TrimSpace(messageThinking), content
|
||||
}
|
||||
|
||||
if idx := strings.Index(content, qwen35ThinkCloseTag); idx != -1 {
|
||||
before := content[:idx]
|
||||
if open := strings.LastIndex(before, qwen35ThinkOpenTag); open != -1 {
|
||||
reasoning = before[open+len(qwen35ThinkOpenTag):]
|
||||
} else {
|
||||
reasoning = before
|
||||
}
|
||||
content = strings.TrimLeft(content[idx+len(qwen35ThinkCloseTag):], "\n")
|
||||
}
|
||||
|
||||
return strings.TrimSpace(reasoning), content
|
||||
}
|
||||
|
||||
func (r *Qwen35Renderer) Render(messages []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
isThinking := r.isThinking
|
||||
if think != nil {
|
||||
isThinking = think.Bool()
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString(imStartTag + "system\n")
|
||||
sb.WriteString("# Tools\n\nYou have access to the following functions:\n\n<tools>")
|
||||
for _, tool := range tools {
|
||||
sb.WriteString("\n")
|
||||
if b, err := marshalWithSpaces(tool); err == nil {
|
||||
sb.Write(b)
|
||||
}
|
||||
}
|
||||
sb.WriteString(qwen35ToolPostamble)
|
||||
if len(messages) > 0 && messages[0].Role == "system" {
|
||||
systemContent, _ := r.renderContent(messages[0], 0)
|
||||
systemContent = strings.TrimSpace(systemContent)
|
||||
if systemContent != "" {
|
||||
sb.WriteString("\n\n")
|
||||
sb.WriteString(systemContent)
|
||||
}
|
||||
}
|
||||
sb.WriteString(imEndTag + "\n")
|
||||
} else if len(messages) > 0 && messages[0].Role == "system" {
|
||||
systemContent, _ := r.renderContent(messages[0], 0)
|
||||
sb.WriteString(imStartTag + "system\n" + strings.TrimSpace(systemContent) + imEndTag + "\n")
|
||||
}
|
||||
|
||||
multiStepTool := true
|
||||
lastQueryIndex := len(messages) - 1 // so this is the last user message
|
||||
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
message := messages[i]
|
||||
if multiStepTool && message.Role == "user" {
|
||||
content, _ := r.renderContent(message, 0)
|
||||
content = strings.TrimSpace(content)
|
||||
if !(strings.HasPrefix(content, "<tool_response>") && strings.HasSuffix(content, "</tool_response>")) {
|
||||
multiStepTool = false
|
||||
lastQueryIndex = i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
imageOffset := 0
|
||||
for i, message := range messages {
|
||||
content, nextImageOffset := r.renderContent(message, imageOffset)
|
||||
imageOffset = nextImageOffset
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
lastMessage := i == len(messages)-1
|
||||
prefill := lastMessage && message.Role == "assistant"
|
||||
|
||||
if message.Role == "user" || (message.Role == "system" && i != 0) {
|
||||
sb.WriteString(imStartTag + message.Role + "\n" + content + imEndTag + "\n")
|
||||
} else if message.Role == "assistant" {
|
||||
contentReasoning, content := splitQwen35ReasoningContent(content, message.Thinking, isThinking)
|
||||
|
||||
if isThinking && i > lastQueryIndex {
|
||||
sb.WriteString(imStartTag + message.Role + "\n<think>\n" + contentReasoning + "\n</think>\n\n" + content)
|
||||
} else {
|
||||
sb.WriteString(imStartTag + message.Role + "\n" + content)
|
||||
}
|
||||
|
||||
if len(message.ToolCalls) > 0 {
|
||||
for j, toolCall := range message.ToolCalls {
|
||||
if j == 0 {
|
||||
if strings.TrimSpace(content) != "" {
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
} else {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
sb.WriteString("<tool_call>\n<function=" + toolCall.Function.Name + ">\n")
|
||||
for name, value := range toolCall.Function.Arguments.All() {
|
||||
sb.WriteString("<parameter=" + name + ">\n")
|
||||
sb.WriteString(formatToolCallArgument(value))
|
||||
sb.WriteString("\n</parameter>\n")
|
||||
}
|
||||
sb.WriteString("</function>\n</tool_call>")
|
||||
}
|
||||
}
|
||||
|
||||
if !prefill {
|
||||
sb.WriteString(imEndTag + "\n")
|
||||
}
|
||||
} else if message.Role == "tool" {
|
||||
if i == 0 || messages[i-1].Role != "tool" {
|
||||
sb.WriteString(imStartTag + "user")
|
||||
}
|
||||
sb.WriteString("\n<tool_response>\n" + content + "\n</tool_response>")
|
||||
if i == len(messages)-1 || messages[i+1].Role != "tool" {
|
||||
sb.WriteString(imEndTag + "\n")
|
||||
}
|
||||
}
|
||||
|
||||
// prefill at the end
|
||||
if lastMessage && !prefill {
|
||||
sb.WriteString(imStartTag + "assistant\n")
|
||||
if isThinking {
|
||||
sb.WriteString("<think>\n")
|
||||
} else if r.emitEmptyThinkOnNoThink {
|
||||
sb.WriteString("<think>\n\n</think>\n\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
389
model/renderers/qwen35_test.go
Normal file
389
model/renderers/qwen35_test.go
Normal file
@@ -0,0 +1,389 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestQwen35RendererUsesXMLToolCallingFormat(t *testing.T) {
|
||||
renderer := &Qwen35Renderer{isThinking: true}
|
||||
msgs := []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "What's the weather in Paris?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "I'll check.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgsOrdered([]orderedArg{
|
||||
{Key: "location", Value: "Paris"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "22C"},
|
||||
{Role: "user", Content: "Thanks"},
|
||||
}
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsOrdered([]orderedProp{
|
||||
{
|
||||
Key: "location",
|
||||
Value: api.ToolProperty{
|
||||
Type: api.PropertyType{"string"},
|
||||
},
|
||||
},
|
||||
}),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got, err := renderer.Render(msgs, tools, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("render failed: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(got, "<tools>") {
|
||||
t.Fatalf("expected tools section in prompt, got:\n%s", got)
|
||||
}
|
||||
if !strings.Contains(got, "<function=example_function_name>") {
|
||||
t.Fatalf("expected xml-style tool call instructions, got:\n%s", got)
|
||||
}
|
||||
|
||||
wantToolCall := "<tool_call>\n<function=get_weather>\n<parameter=location>\nParis\n</parameter>\n</function>\n</tool_call>"
|
||||
if !strings.Contains(got, wantToolCall) {
|
||||
t.Fatalf("expected xml tool call payload, got:\n%s", got)
|
||||
}
|
||||
|
||||
toolsIdx := strings.Index(got, "# Tools")
|
||||
systemIdx := strings.Index(got, "You are a helpful assistant.")
|
||||
if toolsIdx == -1 || systemIdx == -1 || systemIdx < toolsIdx {
|
||||
t.Fatalf("expected system prompt appended after tool instructions, got:\n%s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35RendererNoThinkPrefill(t *testing.T) {
|
||||
renderer := &Qwen35Renderer{isThinking: true, emitEmptyThinkOnNoThink: true}
|
||||
msgs := []api.Message{
|
||||
{Role: "user", Content: "hello"},
|
||||
}
|
||||
|
||||
got, err := renderer.Render(msgs, nil, &api.ThinkValue{Value: false})
|
||||
if err != nil {
|
||||
t.Fatalf("render failed: %v", err)
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(got, "<|im_start|>assistant\n<think>\n\n</think>\n\n") {
|
||||
t.Fatalf("expected explicit no-think prefill, got:\n%s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35RendererBackToBackToolCallsAndResponses(t *testing.T) {
|
||||
renderer := &Qwen35Renderer{isThinking: true}
|
||||
|
||||
msgs := []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Run add and multiply."},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "I'll run both now.",
|
||||
Thinking: "Need to call add and multiply.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "add",
|
||||
Arguments: testArgsOrdered([]orderedArg{
|
||||
{Key: "a", Value: 2},
|
||||
{Key: "b", Value: 3},
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "multiply",
|
||||
Arguments: testArgsOrdered([]orderedArg{
|
||||
{Key: "x", Value: 4},
|
||||
{Key: "y", Value: 5},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "5"},
|
||||
{Role: "tool", Content: "20"},
|
||||
{Role: "user", Content: "Summarize the results."},
|
||||
}
|
||||
|
||||
got, err := renderer.Render(msgs, qwen35MathTools(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("render failed: %v", err)
|
||||
}
|
||||
|
||||
if strings.Contains(got, "Need to call add and multiply.") {
|
||||
t.Fatalf("did not expect historical reasoning block in this sequence, got:\n%s", got)
|
||||
}
|
||||
|
||||
wantToolCalls := `<tool_call>
|
||||
<function=add>
|
||||
<parameter=a>
|
||||
2
|
||||
</parameter>
|
||||
<parameter=b>
|
||||
3
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
<function=multiply>
|
||||
<parameter=x>
|
||||
4
|
||||
</parameter>
|
||||
<parameter=y>
|
||||
5
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>`
|
||||
if !strings.Contains(got, wantToolCalls) {
|
||||
t.Fatalf("expected back-to-back tool calls, got:\n%s", got)
|
||||
}
|
||||
|
||||
wantToolResponses := `<|im_start|>user
|
||||
<tool_response>
|
||||
5
|
||||
</tool_response>
|
||||
<tool_response>
|
||||
20
|
||||
</tool_response><|im_end|>`
|
||||
if !strings.Contains(got, wantToolResponses) {
|
||||
t.Fatalf("expected grouped back-to-back tool responses, got:\n%s", got)
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(got, "<|im_start|>assistant\n<think>\n") {
|
||||
t.Fatalf("expected assistant thinking prefill at end, got:\n%s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35RendererInterleavedThinkingAndTools(t *testing.T) {
|
||||
renderer := &Qwen35Renderer{isThinking: true}
|
||||
|
||||
msgs := []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Plan a picnic in Paris."},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Checking weather first.",
|
||||
Thinking: "Need weather before giving advice.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgsOrdered([]orderedArg{
|
||||
{Key: "location", Value: "Paris"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "22C"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Checking UV too.",
|
||||
Thinking: "Need UV index for sunscreen advice.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_uv",
|
||||
Arguments: testArgsOrdered([]orderedArg{
|
||||
{Key: "location", Value: "Paris"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "5"},
|
||||
}
|
||||
|
||||
got, err := renderer.Render(msgs, qwen35WeatherUVTools(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("render failed: %v", err)
|
||||
}
|
||||
|
||||
wantFirstTurn := `<|im_start|>assistant
|
||||
<think>
|
||||
Need weather before giving advice.
|
||||
</think>
|
||||
|
||||
Checking weather first.
|
||||
|
||||
<tool_call>
|
||||
<function=get_weather>
|
||||
<parameter=location>
|
||||
Paris
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call><|im_end|>`
|
||||
if !strings.Contains(got, wantFirstTurn) {
|
||||
t.Fatalf("expected first assistant thinking/tool sequence, got:\n%s", got)
|
||||
}
|
||||
|
||||
wantSecondTurn := `<|im_start|>assistant
|
||||
<think>
|
||||
Need UV index for sunscreen advice.
|
||||
</think>
|
||||
|
||||
Checking UV too.
|
||||
|
||||
<tool_call>
|
||||
<function=get_uv>
|
||||
<parameter=location>
|
||||
Paris
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call><|im_end|>`
|
||||
if !strings.Contains(got, wantSecondTurn) {
|
||||
t.Fatalf("expected second assistant thinking/tool sequence, got:\n%s", got)
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(got, "<|im_start|>assistant\n<think>\n") {
|
||||
t.Fatalf("expected assistant thinking prefill at end, got:\n%s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35RendererAssistantPrefillWithThinking(t *testing.T) {
|
||||
renderer := &Qwen35Renderer{isThinking: true}
|
||||
msgs := []api.Message{
|
||||
{Role: "user", Content: "Write two words."},
|
||||
{
|
||||
Role: "assistant",
|
||||
Thinking: "Keep it short.",
|
||||
Content: "Hello world",
|
||||
},
|
||||
}
|
||||
|
||||
got, err := renderer.Render(msgs, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("render failed: %v", err)
|
||||
}
|
||||
|
||||
want := `<|im_start|>user
|
||||
Write two words.<|im_end|>
|
||||
<|im_start|>assistant
|
||||
<think>
|
||||
Keep it short.
|
||||
</think>
|
||||
|
||||
Hello world`
|
||||
if got != want {
|
||||
t.Fatalf("unexpected prefill output\n--- got ---\n%s\n--- want ---\n%s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func qwen35MathTools() []api.Tool {
|
||||
return []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "add",
|
||||
Description: "Add two numbers",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsOrdered([]orderedProp{
|
||||
{
|
||||
Key: "a",
|
||||
Value: api.ToolProperty{
|
||||
Type: api.PropertyType{"integer"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Key: "b",
|
||||
Value: api.ToolProperty{
|
||||
Type: api.PropertyType{"integer"},
|
||||
},
|
||||
},
|
||||
}),
|
||||
Required: []string{"a", "b"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "multiply",
|
||||
Description: "Multiply two numbers",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsOrdered([]orderedProp{
|
||||
{
|
||||
Key: "x",
|
||||
Value: api.ToolProperty{
|
||||
Type: api.PropertyType{"integer"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Key: "y",
|
||||
Value: api.ToolProperty{
|
||||
Type: api.PropertyType{"integer"},
|
||||
},
|
||||
},
|
||||
}),
|
||||
Required: []string{"x", "y"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func qwen35WeatherUVTools() []api.Tool {
|
||||
return []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather for a location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsOrdered([]orderedProp{
|
||||
{
|
||||
Key: "location",
|
||||
Value: api.ToolProperty{
|
||||
Type: api.PropertyType{"string"},
|
||||
},
|
||||
},
|
||||
}),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_uv",
|
||||
Description: "Get UV index for a location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsOrdered([]orderedProp{
|
||||
{
|
||||
Key: "location",
|
||||
Value: api.ToolProperty{
|
||||
Type: api.PropertyType{"string"},
|
||||
},
|
||||
},
|
||||
}),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -57,7 +57,7 @@ func rendererForName(name string) Renderer {
|
||||
renderer := &Qwen3VLRenderer{isThinking: true, useImgTags: RenderImgTags}
|
||||
return renderer
|
||||
case "qwen3.5":
|
||||
renderer := &Qwen3VLRenderer{isThinking: true, emitEmptyThinkOnNoThink: true, useImgTags: RenderImgTags}
|
||||
renderer := &Qwen35Renderer{isThinking: true, emitEmptyThinkOnNoThink: true, useImgTags: RenderImgTags}
|
||||
return renderer
|
||||
case "cogito":
|
||||
renderer := &CogitoRenderer{isThinking: true}
|
||||
@@ -86,7 +86,7 @@ func rendererForName(name string) Renderer {
|
||||
case "glm-4.7":
|
||||
return &GLM47Renderer{}
|
||||
case "glm-ocr":
|
||||
return &GlmOcrRenderer{}
|
||||
return &GlmOcrRenderer{useImgTags: RenderImgTags}
|
||||
case "lfm2":
|
||||
return &LFM2Renderer{IsThinking: false, useImgTags: RenderImgTags}
|
||||
case "lfm2-thinking":
|
||||
|
||||
@@ -181,6 +181,9 @@ func fileDigestMap(path string) (map[string]string, error) {
|
||||
}
|
||||
|
||||
if !filepath.IsLocal(rel) {
|
||||
if strings.Contains(rel, ".cache") {
|
||||
return nil, fmt.Errorf("insecure path: %s\n\nUse --local-dir <dir> when downloading model to disable caching", rel)
|
||||
}
|
||||
return nil, fmt.Errorf("insecure path: %s", rel)
|
||||
}
|
||||
|
||||
|
||||
@@ -562,6 +562,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
if errors.As(err, &reprocess) {
|
||||
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||
seq.sampler.Reset()
|
||||
// Skip this sequence but continue processing the rest
|
||||
nextBatch.seqs[seqIdx] = nil // clear this sequence for this batch
|
||||
err = nil
|
||||
@@ -692,6 +693,12 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
// (unless we take down the whole runner).
|
||||
if len(seq.pendingInputs) > 0 {
|
||||
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
||||
for _, inp := range seq.pendingInputs {
|
||||
if len(inp.Multimodal) != 0 {
|
||||
continue
|
||||
}
|
||||
seq.sampler.Accept(inp.Token)
|
||||
}
|
||||
seq.pendingInputs = []*input.Input{}
|
||||
}
|
||||
|
||||
@@ -892,6 +899,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
req.Options.TopK,
|
||||
req.Options.TopP,
|
||||
req.Options.MinP,
|
||||
req.Options.RepeatPenalty,
|
||||
req.Options.PresencePenalty,
|
||||
req.Options.FrequencyPenalty,
|
||||
req.Options.Seed,
|
||||
grammar,
|
||||
)
|
||||
@@ -938,6 +948,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
seq.sampler.Reset()
|
||||
for _, inp := range seq.cache.Inputs {
|
||||
if len(inp.Multimodal) != 0 {
|
||||
continue
|
||||
}
|
||||
seq.sampler.Accept(inp.Token)
|
||||
}
|
||||
|
||||
s.seqs[i] = seq
|
||||
s.cond.Signal()
|
||||
found = true
|
||||
|
||||
@@ -16,24 +16,49 @@ type token struct {
|
||||
value float32 // The raw logit or probability from the model
|
||||
}
|
||||
|
||||
const DefaultPenaltyLookback = 64
|
||||
|
||||
type Sampler struct {
|
||||
rng *rand.Rand
|
||||
topK int
|
||||
topP float32
|
||||
minP float32
|
||||
temperature float32
|
||||
repeat float32
|
||||
presence float32
|
||||
frequency float32
|
||||
history []int32
|
||||
grammar *GrammarSampler
|
||||
}
|
||||
|
||||
func (s *Sampler) Reset() {
|
||||
s.history = s.history[:0]
|
||||
}
|
||||
|
||||
func (s *Sampler) Accept(token int32) {
|
||||
s.history = append(s.history, token)
|
||||
if len(s.history) > DefaultPenaltyLookback {
|
||||
copy(s.history, s.history[len(s.history)-DefaultPenaltyLookback:])
|
||||
s.history = s.history[:DefaultPenaltyLookback]
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||
if len(logits) == 0 {
|
||||
return -1, errors.New("sample: no logits provided to sample")
|
||||
}
|
||||
|
||||
counts := tokenCounts(s.history, len(logits))
|
||||
|
||||
tokens := make([]token, len(logits))
|
||||
for i := range logits {
|
||||
value := logits[i]
|
||||
if count := counts[int32(i)]; count > 0 {
|
||||
value = applyPenalty(value, count, s.repeat, s.presence, s.frequency)
|
||||
}
|
||||
|
||||
tokens[i].id = int32(i)
|
||||
tokens[i].value = logits[i]
|
||||
tokens[i].value = value
|
||||
}
|
||||
|
||||
t, err := s.sample(tokens)
|
||||
@@ -55,8 +80,12 @@ func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||
// we need to reset them before applying the grammar and
|
||||
// sampling again
|
||||
for i := range logits {
|
||||
value := logits[i]
|
||||
if count := counts[int32(i)]; count > 0 {
|
||||
value = applyPenalty(value, count, s.repeat, s.presence, s.frequency)
|
||||
}
|
||||
tokens[i].id = int32(i)
|
||||
tokens[i].value = logits[i]
|
||||
tokens[i].value = value
|
||||
}
|
||||
s.grammar.Apply(tokens)
|
||||
t, err = s.sample(tokens)
|
||||
@@ -127,7 +156,7 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
||||
}
|
||||
|
||||
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *GrammarSampler) Sampler {
|
||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, repeatPenalty float32, presencePenalty float32, frequencyPenalty float32, seed int, grammar *GrammarSampler) Sampler {
|
||||
var rng *rand.Rand
|
||||
if seed != -1 {
|
||||
// PCG requires two parameters: sequence and stream
|
||||
@@ -154,12 +183,19 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
|
||||
minP = 1.0
|
||||
}
|
||||
|
||||
if repeatPenalty <= 0 {
|
||||
repeatPenalty = 1.0
|
||||
}
|
||||
|
||||
return Sampler{
|
||||
rng: rng,
|
||||
topK: topK,
|
||||
topP: topP,
|
||||
minP: minP,
|
||||
temperature: temperature,
|
||||
repeat: repeatPenalty,
|
||||
presence: presencePenalty,
|
||||
frequency: frequencyPenalty,
|
||||
grammar: grammar,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
||||
logits[i] = float32(rand.Float64()*10 - 5)
|
||||
}
|
||||
|
||||
sampler := NewSampler(0.8, 0, 0, 0, 42, nil)
|
||||
sampler := NewSampler(0.8, 0, 0, 0, 1, 0, 0, 42, nil)
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
sampler.Sample(logits)
|
||||
@@ -49,7 +49,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
||||
|
||||
for _, tc := range configs {
|
||||
b.Run("Config"+tc.name, func(b *testing.B) {
|
||||
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil)
|
||||
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, 1, 0, 0, tc.seed, nil)
|
||||
sampler.Sample(logits)
|
||||
|
||||
b.ResetTimer()
|
||||
@@ -62,7 +62,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
||||
|
||||
// Test with combined transforms separately - topK influences performance greatly
|
||||
b.Run("TransformCombined", func(b *testing.B) {
|
||||
sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil)
|
||||
sampler := NewSampler(0.8, 50, 0.9, 0.05, 1, 0, 0, 42, nil)
|
||||
b.ResetTimer()
|
||||
|
||||
for b.Loop() {
|
||||
@@ -81,7 +81,7 @@ func BenchmarkGreedySampler(b *testing.B) {
|
||||
logits[i] = float32(rand.Float64()*10 - 5)
|
||||
}
|
||||
|
||||
sampler := NewSampler(0, -1, 0, 0, -1, nil)
|
||||
sampler := NewSampler(0, -1, 0, 0, 1, 0, 0, -1, nil)
|
||||
b.ResetTimer()
|
||||
|
||||
for b.Loop() {
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
|
||||
func TestWeighted(t *testing.T) {
|
||||
logits := []float32{-10, 3, -10, -10}
|
||||
sampler := NewSampler(0, 0, 0, 0, 0, nil)
|
||||
sampler := NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil)
|
||||
got, err := sampler.Sample(logits)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@@ -25,7 +25,7 @@ func TestWeighted(t *testing.T) {
|
||||
}
|
||||
|
||||
logits = []float32{-100, -10, 0, 10}
|
||||
sampler = NewSampler(0, 0, 0, 0, 0, nil)
|
||||
sampler = NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil)
|
||||
got, err = sampler.Sample(logits)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@@ -39,7 +39,7 @@ func TestWeighted(t *testing.T) {
|
||||
// Test very high p
|
||||
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
|
||||
// Use extremely small topP to filter out all tokens
|
||||
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
|
||||
sampler = NewSampler(1.0, 0, 1e-10, 0, 1, 0, 0, 0, nil)
|
||||
got, err = sampler.Sample(logits)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@@ -52,7 +52,7 @@ func TestWeighted(t *testing.T) {
|
||||
}
|
||||
|
||||
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
|
||||
sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
|
||||
sampler = NewSampler(1, 0, 0.95, 0.05, 1, 0, 0, 0, nil)
|
||||
got, err = sampler.Sample(logits)
|
||||
if err == nil {
|
||||
t.Errorf("expected error, got %d", got)
|
||||
@@ -151,8 +151,8 @@ func TestGrammar(t *testing.T) {
|
||||
|
||||
func BenchmarkSample(b *testing.B) {
|
||||
samplers := map[string]Sampler{
|
||||
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
|
||||
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
|
||||
"Greedy": NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
|
||||
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, 1, 0, 0, -1, nil),
|
||||
}
|
||||
|
||||
// Generate random logits for benchmarking
|
||||
|
||||
@@ -25,6 +25,48 @@ func (h *tokenHeap) Pop() any {
|
||||
return x
|
||||
}
|
||||
|
||||
func tokenCounts(history []int32, vocabSize int) map[int32]int {
|
||||
if len(history) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
start := 0
|
||||
if len(history) > DefaultPenaltyLookback {
|
||||
start = len(history) - DefaultPenaltyLookback
|
||||
}
|
||||
|
||||
counts := make(map[int32]int, len(history)-start)
|
||||
for _, token := range history[start:] {
|
||||
if token < 0 || int(token) >= vocabSize {
|
||||
continue
|
||||
}
|
||||
counts[token]++
|
||||
}
|
||||
|
||||
return counts
|
||||
}
|
||||
|
||||
func applyPenalty(logit float32, count int, repeatPenalty float32, presencePenalty float32, frequencyPenalty float32) float32 {
|
||||
if repeatPenalty != 1.0 {
|
||||
// Preserve ordering for negative logits when applying repeat penalty.
|
||||
if logit < 0 {
|
||||
logit *= repeatPenalty
|
||||
} else {
|
||||
logit /= repeatPenalty
|
||||
}
|
||||
}
|
||||
|
||||
if frequencyPenalty != 0 {
|
||||
logit -= float32(count) * frequencyPenalty
|
||||
}
|
||||
|
||||
if presencePenalty != 0 {
|
||||
logit -= presencePenalty
|
||||
}
|
||||
|
||||
return logit
|
||||
}
|
||||
|
||||
// temperature applies scaling to the logits
|
||||
func temperature(ts []token, temp float32) {
|
||||
// Ensure temperature clipping near 0 to avoid numerical instability
|
||||
|
||||
@@ -295,6 +295,86 @@ func TestMinP(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenCounts(t *testing.T) {
|
||||
history := make([]int32, 70)
|
||||
history[0] = 7
|
||||
history[69] = 7
|
||||
|
||||
counts := tokenCounts(history, 8)
|
||||
if got := counts[7]; got != 1 {
|
||||
t.Fatalf("lookback mismatch: got %d want %d", got, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyPenalty(t *testing.T) {
|
||||
logit := applyPenalty(5.0, 3, 1.0, 1.5, 0.5)
|
||||
if math.Abs(float64(logit-2.0)) > 1e-6 {
|
||||
t.Fatalf("unexpected penalty result: got %f want %f", logit, 2.0)
|
||||
}
|
||||
|
||||
logit = applyPenalty(4.0, 1, 2.0, 0, 0)
|
||||
if math.Abs(float64(logit-2.0)) > 1e-6 {
|
||||
t.Fatalf("unexpected repeat penalty result for positive logits: got %f want %f", logit, 2.0)
|
||||
}
|
||||
|
||||
logit = applyPenalty(-4.0, 1, 2.0, 0, 0)
|
||||
if math.Abs(float64(logit-(-8.0))) > 1e-6 {
|
||||
t.Fatalf("unexpected repeat penalty result for negative logits: got %f want %f", logit, -8.0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSamplerPresencePenalty(t *testing.T) {
|
||||
logits := []float32{0.0, 5.0, 0.0}
|
||||
|
||||
baseline := NewSampler(0, 0, 1, 0, 1, 0, 0, -1, nil)
|
||||
baseline.Accept(1)
|
||||
got, err := baseline.Sample(logits)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != 1 {
|
||||
t.Fatalf("unexpected baseline token: got %d want %d", got, 1)
|
||||
}
|
||||
|
||||
presence := NewSampler(0, 0, 1, 0, 1, 6, 0, -1, nil)
|
||||
presence.Accept(1)
|
||||
got, err = presence.Sample(logits)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got == 1 {
|
||||
t.Fatalf("presence penalty did not change repeated token selection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSamplerFrequencyPenalty(t *testing.T) {
|
||||
logits := []float32{0.0, 5.0, 4.0}
|
||||
|
||||
baseline := NewSampler(0, 0, 1, 0, 1, 0, 0, -1, nil)
|
||||
baseline.Accept(1)
|
||||
baseline.Accept(1)
|
||||
baseline.Accept(1)
|
||||
got, err := baseline.Sample(logits)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != 1 {
|
||||
t.Fatalf("unexpected baseline token: got %d want %d", got, 1)
|
||||
}
|
||||
|
||||
frequency := NewSampler(0, 0, 1, 0, 1, 0, 1.0, -1, nil)
|
||||
frequency.Accept(1)
|
||||
frequency.Accept(1)
|
||||
frequency.Accept(1)
|
||||
got, err = frequency.Sample(logits)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != 2 {
|
||||
t.Fatalf("frequency penalty did not demote repeated token as expected: got %d want %d", got, 2)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTransforms(b *testing.B) {
|
||||
// Generate random logits
|
||||
tokens := make([]token, 1<<16)
|
||||
|
||||
@@ -59,7 +59,7 @@ _build_darwin() {
|
||||
cmake --install $BUILD_DIR --component CPU
|
||||
cmake --install $BUILD_DIR --component MLX
|
||||
# Override CGO flags to point to the amd64 build directory
|
||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||
MLX_CGO_CFLAGS="-O3 -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-ldl -lc++ -framework Accelerate -mmacosx-version-min=14.0"
|
||||
else
|
||||
BUILD_DIR=build
|
||||
@@ -70,10 +70,10 @@ _build_darwin() {
|
||||
cmake --build --preset MLX --parallel
|
||||
cmake --install $BUILD_DIR --component MLX
|
||||
# Use default CGO flags from mlx.go for arm64
|
||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||
MLX_CGO_CFLAGS="-O3 -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
|
||||
fi
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX .
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -o $INSTALL_PREFIX .
|
||||
# Copy MLX libraries to same directory as executable for dlopen
|
||||
cp $INSTALL_PREFIX/lib/ollama/libmlxc.dylib $INSTALL_PREFIX/
|
||||
cp $INSTALL_PREFIX/lib/ollama/libmlx.dylib $INSTALL_PREFIX/
|
||||
|
||||
@@ -4,7 +4,10 @@
|
||||
#
|
||||
# gcloud auth application-default login
|
||||
|
||||
$ErrorActionPreference = "Stop"
|
||||
# Use "Continue" so that stderr output from native commands (e.g. CGo warnings)
|
||||
# is not promoted to a terminating exception by the try/catch block.
|
||||
# All native commands already check $LASTEXITCODE explicitly.
|
||||
$ErrorActionPreference = "Continue"
|
||||
|
||||
mkdir -Force -path .\dist | Out-Null
|
||||
|
||||
@@ -16,13 +19,13 @@ function checkEnv {
|
||||
if ($null -ne $arch) {
|
||||
$script:ARCH = ($arch.ToString().ToLower()).Replace("x64", "amd64")
|
||||
} else {
|
||||
write-host "WARNING: old powershell detected, assuming amd64 architecture - set `$env:ARCH to override"
|
||||
Write-Output "WARNING: old powershell detected, assuming amd64 architecture - set `$env:ARCH to override"
|
||||
$script:ARCH="amd64"
|
||||
}
|
||||
}
|
||||
$script:TARGET_ARCH=$script:ARCH
|
||||
Write-host "Building for ${script:TARGET_ARCH}"
|
||||
write-host "Locating required tools and paths"
|
||||
Write-Output "Locating required tools and paths"
|
||||
$script:SRC_DIR=$PWD
|
||||
|
||||
# Locate CUDA versions
|
||||
@@ -37,16 +40,17 @@ function checkEnv {
|
||||
$script:CUDA_DIRS=($cudaList | sort-object -Descending)
|
||||
}
|
||||
if ($script:CUDA_DIRS.length -gt 0) {
|
||||
write-host "Available CUDA Versions: $script:CUDA_DIRS"
|
||||
Write-Output "Available CUDA Versions: $script:CUDA_DIRS"
|
||||
} else {
|
||||
write-host "No CUDA versions detected"
|
||||
Write-Output "No CUDA versions detected"
|
||||
}
|
||||
|
||||
# Locate ROCm version
|
||||
if ($null -ne $env:HIP_PATH) {
|
||||
# Locate ROCm v6
|
||||
$rocmDir=(get-item "C:\Program Files\AMD\ROCm\6.*" -ea 'silentlycontinue' | sort-object -Descending | select-object -First 1)
|
||||
if ($null -ne $rocmDir) {
|
||||
$script:HIP_PATH=$rocmDir.FullName
|
||||
} elseif ($null -ne $env:HIP_PATH -and $env:HIP_PATH -match '[/\\]6\.') {
|
||||
$script:HIP_PATH=$env:HIP_PATH
|
||||
} else {
|
||||
$script:HIP_PATH=(get-item "C:\Program Files\AMD\ROCm\*\bin\" -ea 'silentlycontinue' | sort-object -Descending)
|
||||
}
|
||||
|
||||
$inoSetup=(get-item "C:\Program Files*\Inno Setup*\")
|
||||
@@ -78,7 +82,7 @@ function checkEnv {
|
||||
} else {
|
||||
$script:PKG_VERSION="0.0.0"
|
||||
}
|
||||
write-host "Building Ollama $script:VERSION with package version $script:PKG_VERSION"
|
||||
Write-Output "Building Ollama $script:VERSION with package version $script:PKG_VERSION"
|
||||
|
||||
# Note: Windows Kits 10 signtool crashes with GCP's plugin
|
||||
if ($null -eq $env:SIGN_TOOL) {
|
||||
@@ -87,12 +91,32 @@ function checkEnv {
|
||||
${script:SignTool}=${env:SIGN_TOOL}
|
||||
}
|
||||
if ("${env:KEY_CONTAINER}") {
|
||||
${script:OLLAMA_CERT}=$(resolve-path "${script:SRC_DIR}\ollama_inc.crt")
|
||||
Write-host "Code signing enabled"
|
||||
if (Test-Path "${script:SRC_DIR}\ollama_inc.crt") {
|
||||
${script:OLLAMA_CERT}=$(resolve-path "${script:SRC_DIR}\ollama_inc.crt")
|
||||
Write-host "Code signing enabled"
|
||||
} else {
|
||||
Write-Output "WARNING: KEY_CONTAINER is set but ollama_inc.crt not found at ${script:SRC_DIR}\ollama_inc.crt - code signing disabled"
|
||||
}
|
||||
} else {
|
||||
write-host "Code signing disabled - please set KEY_CONTAINERS to sign and copy ollama_inc.crt to the top of the source tree"
|
||||
Write-Output "Code signing disabled - please set KEY_CONTAINERS to sign and copy ollama_inc.crt to the top of the source tree"
|
||||
}
|
||||
$script:JOBS=([Environment]::ProcessorCount)
|
||||
if ($env:OLLAMA_BUILD_PARALLEL) {
|
||||
$script:JOBS=[int]$env:OLLAMA_BUILD_PARALLEL
|
||||
} else {
|
||||
# Use physical core count rather than logical processors (hyperthreads)
|
||||
# to avoid saturating the system during builds
|
||||
try {
|
||||
$cores = (Get-CimInstance Win32_Processor | Measure-Object -Property NumberOfCores -Sum).Sum
|
||||
} catch {
|
||||
$cores = 0
|
||||
}
|
||||
if ($cores -gt 0) {
|
||||
$script:JOBS = $cores
|
||||
} else {
|
||||
$script:JOBS = [Environment]::ProcessorCount
|
||||
}
|
||||
}
|
||||
Write-Output "Build parallelism: $script:JOBS (set OLLAMA_BUILD_PARALLEL to override)"
|
||||
}
|
||||
|
||||
|
||||
@@ -127,7 +151,7 @@ function cuda11 {
|
||||
}
|
||||
}
|
||||
}
|
||||
write-host "Building CUDA v$cudaMajorVer backend libraries $cuda"
|
||||
Write-Output "Building CUDA v$cudaMajorVer backend libraries $cuda"
|
||||
$env:CUDAToolkit_ROOT=$cuda
|
||||
& cmake -B build\cuda_v$cudaMajorVer --preset "CUDA $cudaMajorVer" -T cuda="$cuda" -DCMAKE_CUDA_COMPILER="$cuda\bin\nvcc.exe" -G "Visual Studio 16 2019" --install-prefix "$script:DIST_DIR"
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
@@ -136,12 +160,12 @@ function cuda11 {
|
||||
& cmake --install build\cuda_v$cudaMajorVer --component "CUDA" --strip
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
} else {
|
||||
write-host "CUDA v$cudaMajorVer not detected, skipping"
|
||||
Write-Output "CUDA v$cudaMajorVer not detected, skipping"
|
||||
}
|
||||
} else {
|
||||
write-host "not arch we wanted"
|
||||
Write-Output "not arch we wanted"
|
||||
}
|
||||
write-host "done"
|
||||
Write-Output "done"
|
||||
}
|
||||
|
||||
function cudaCommon {
|
||||
@@ -159,7 +183,7 @@ function cudaCommon {
|
||||
}
|
||||
}
|
||||
}
|
||||
write-host "Building CUDA v$cudaMajorVer backend libraries $cuda"
|
||||
Write-Output "Building CUDA v$cudaMajorVer backend libraries $cuda"
|
||||
$env:CUDAToolkit_ROOT=$cuda
|
||||
& cmake -B build\cuda_v$cudaMajorVer --preset "CUDA $cudaMajorVer" -T cuda="$cuda" --install-prefix "$script:DIST_DIR"
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
@@ -168,7 +192,7 @@ function cudaCommon {
|
||||
& cmake --install build\cuda_v$cudaMajorVer --component "CUDA" --strip
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
} else {
|
||||
write-host "CUDA v$cudaMajorVer not detected, skipping"
|
||||
Write-Output "CUDA v$cudaMajorVer not detected, skipping"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -181,11 +205,11 @@ function cuda13 {
|
||||
cudaCommon("13")
|
||||
}
|
||||
|
||||
function rocm {
|
||||
function rocm6 {
|
||||
mkdir -Force -path "${script:DIST_DIR}\" | Out-Null
|
||||
if ($script:ARCH -ne "arm64") {
|
||||
if ($script:HIP_PATH) {
|
||||
write-host "Building ROCm backend libraries $script:HIP_PATH"
|
||||
Write-Output "Building ROCm backend libraries $script:HIP_PATH"
|
||||
if (-Not (get-command -ErrorAction silent ninja)) {
|
||||
$NINJA_DIR=(gci -path (Get-CimInstance MSFT_VSInstance -Namespace root/cimv2/vs)[0].InstallLocation -r -fi ninja.exe).Directory.FullName
|
||||
$env:PATH="$NINJA_DIR;$env:PATH"
|
||||
@@ -193,9 +217,11 @@ function rocm {
|
||||
$env:HIPCXX="${script:HIP_PATH}\bin\clang++.exe"
|
||||
$env:HIP_PLATFORM="amd"
|
||||
$env:CMAKE_PREFIX_PATH="${script:HIP_PATH}"
|
||||
# Set CC/CXX via environment instead of -D flags to avoid triggering
|
||||
# spurious compiler-change reconfigures that reset CMAKE_INSTALL_PREFIX
|
||||
$env:CC="${script:HIP_PATH}\bin\clang.exe"
|
||||
$env:CXX="${script:HIP_PATH}\bin\clang++.exe"
|
||||
& cmake -B build\rocm --preset "ROCm 6" -G Ninja `
|
||||
-DCMAKE_C_COMPILER=clang `
|
||||
-DCMAKE_CXX_COMPILER=clang++ `
|
||||
-DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" `
|
||||
-DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" `
|
||||
--install-prefix $script:DIST_DIR
|
||||
@@ -203,20 +229,22 @@ function rocm {
|
||||
$env:HIPCXX=""
|
||||
$env:HIP_PLATFORM=""
|
||||
$env:CMAKE_PREFIX_PATH=""
|
||||
$env:CC=""
|
||||
$env:CXX=""
|
||||
& cmake --build build\rocm --target ggml-hip --config Release --parallel $script:JOBS
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --install build\rocm --component "HIP" --strip
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
Remove-Item -Path $script:DIST_DIR\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
|
||||
} else {
|
||||
write-host "ROCm not detected, skipping"
|
||||
Write-Output "ROCm not detected, skipping"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function vulkan {
|
||||
if ($env:VULKAN_SDK) {
|
||||
write-host "Building Vulkan backend libraries"
|
||||
Write-Output "Building Vulkan backend libraries"
|
||||
& cmake -B build\vulkan --preset Vulkan --install-prefix $script:DIST_DIR
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --build build\vulkan --target ggml-vulkan --config Release --parallel $script:JOBS
|
||||
@@ -224,33 +252,91 @@ function vulkan {
|
||||
& cmake --install build\vulkan --component Vulkan --strip
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
} else {
|
||||
write-host "Vulkan not detected, skipping"
|
||||
Write-Output "Vulkan not detected, skipping"
|
||||
}
|
||||
}
|
||||
|
||||
function mlxCuda13 {
|
||||
mkdir -Force -path "${script:DIST_DIR}\" | Out-Null
|
||||
$cudaMajorVer="13"
|
||||
if ($script:ARCH -ne "arm64") {
|
||||
if ("$script:CUDA_DIRS".Contains("v$cudaMajorVer")) {
|
||||
foreach ($d in $Script:CUDA_DIRS){
|
||||
if ($d.FullName.Contains("v$cudaMajorVer")) {
|
||||
if (test-path -literalpath (join-path -path $d -childpath "nvcc.exe" ) ) {
|
||||
$cuda=($d.FullName|split-path -parent)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Check for cuDNN - required for MLX CUDA backend
|
||||
# Supports two layouts:
|
||||
# 1. CI/zip extract: CUDNN\include\cudnn.h, lib\x64\, bin\x64\
|
||||
# 2. Official installer: CUDNN\v*\include\{cuda-ver}\cudnn.h, lib\{cuda-ver}\x64\, bin\{cuda-ver}\
|
||||
if ($env:CUDNN_INCLUDE_PATH -and $env:CUDNN_LIBRARY_PATH) {
|
||||
Write-Output "Using cuDNN from environment: $env:CUDNN_INCLUDE_PATH"
|
||||
} elseif (Test-Path "C:\Program Files\NVIDIA\CUDNN\include\cudnn.h") {
|
||||
# CI/zip layout (flat)
|
||||
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
|
||||
$env:CUDNN_ROOT_DIR = $cudnnRoot
|
||||
$env:CUDNN_INCLUDE_PATH = "$cudnnRoot\include"
|
||||
$env:CUDNN_LIBRARY_PATH = "$cudnnRoot\lib\x64"
|
||||
Write-Output "Found cuDNN at $cudnnRoot (flat layout)"
|
||||
} else {
|
||||
# Official installer layout (versioned)
|
||||
$cudnnRoot = $null
|
||||
$resolved = Resolve-Path -Path "C:\Program Files\NVIDIA\CUDNN\v*" -ErrorAction SilentlyContinue | Sort-Object -Descending | Select-Object -First 1
|
||||
if ($resolved -and (Test-Path "$($resolved.Path)\include\$cudaMajorVer.0\cudnn.h")) {
|
||||
$cudnnRoot = $resolved.Path
|
||||
$env:CUDNN_ROOT_DIR = $cudnnRoot
|
||||
$env:CUDNN_INCLUDE_PATH = "$cudnnRoot\include\$cudaMajorVer.0"
|
||||
$env:CUDNN_LIBRARY_PATH = "$cudnnRoot\lib\$cudaMajorVer.0\x64"
|
||||
Write-Output "Found cuDNN at $cudnnRoot (official installer, CUDA $cudaMajorVer.0)"
|
||||
} else {
|
||||
Write-Output "cuDNN not found - set CUDNN_INCLUDE_PATH and CUDNN_LIBRARY_PATH environment variables"
|
||||
Write-Output "Skipping MLX build"
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
Write-Output "Building MLX CUDA v$cudaMajorVer backend libraries $cuda"
|
||||
$env:CUDAToolkit_ROOT=$cuda
|
||||
& cmake -B build\mlx_cuda_v$cudaMajorVer --preset "MLX CUDA $cudaMajorVer" -T cuda="$cuda" --install-prefix "$script:DIST_DIR"
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --build build\mlx_cuda_v$cudaMajorVer --target mlx --target mlxc --config Release --parallel $script:JOBS -- /nodeReuse:false
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --install build\mlx_cuda_v$cudaMajorVer --component "MLX" --strip
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
} else {
|
||||
Write-Output "CUDA v$cudaMajorVer not detected, skipping MLX build"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function ollama {
|
||||
mkdir -Force -path "${script:DIST_DIR}\" | Out-Null
|
||||
write-host "Building ollama CLI"
|
||||
Write-Output "Building ollama CLI"
|
||||
& go build -trimpath -ldflags "-s -w -X=github.com/ollama/ollama/version.Version=$script:VERSION -X=github.com/ollama/ollama/server.mode=release" .
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
cp .\ollama.exe "${script:DIST_DIR}\"
|
||||
}
|
||||
|
||||
function app {
|
||||
write-host "Building Ollama App $script:VERSION with package version $script:PKG_VERSION"
|
||||
Write-Output "Building Ollama App $script:VERSION with package version $script:PKG_VERSION"
|
||||
|
||||
if (!(Get-Command npm -ErrorAction SilentlyContinue)) {
|
||||
write-host "npm is not installed. Please install Node.js and npm first:"
|
||||
write-host " Visit: https://nodejs.org/"
|
||||
Write-Output "npm is not installed. Please install Node.js and npm first:"
|
||||
Write-Output " Visit: https://nodejs.org/"
|
||||
exit 1
|
||||
}
|
||||
|
||||
if (!(Get-Command tsc -ErrorAction SilentlyContinue)) {
|
||||
write-host "Installing TypeScript compiler..."
|
||||
Write-Output "Installing TypeScript compiler..."
|
||||
npm install -g typescript
|
||||
}
|
||||
if (!(Get-Command tscriptify -ErrorAction SilentlyContinue)) {
|
||||
write-host "Installing tscriptify..."
|
||||
Write-Output "Installing tscriptify..."
|
||||
go install github.com/tkrajina/typescriptify-golang-structs/tscriptify@latest
|
||||
}
|
||||
if (!(Get-Command tscriptify -ErrorAction SilentlyContinue)) {
|
||||
@@ -260,32 +346,32 @@ function app {
|
||||
Push-Location app/ui/app
|
||||
npm install
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
write-host "ERROR: npm install failed with exit code $LASTEXITCODE"
|
||||
Write-Output "ERROR: npm install failed with exit code $LASTEXITCODE"
|
||||
exit $LASTEXITCODE
|
||||
}
|
||||
|
||||
write-host "Building React application..."
|
||||
Write-Output "Building React application..."
|
||||
npm run build
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
write-host "ERROR: npm run build failed with exit code $LASTEXITCODE"
|
||||
Write-Output "ERROR: npm run build failed with exit code $LASTEXITCODE"
|
||||
exit $LASTEXITCODE
|
||||
}
|
||||
|
||||
# Check if dist directory exists and has content
|
||||
if (!(Test-Path "dist")) {
|
||||
write-host "ERROR: dist directory was not created by npm run build"
|
||||
Write-Output "ERROR: dist directory was not created by npm run build"
|
||||
exit 1
|
||||
}
|
||||
|
||||
$distFiles = Get-ChildItem "dist" -Recurse
|
||||
if ($distFiles.Count -eq 0) {
|
||||
write-host "ERROR: dist directory is empty after npm run build"
|
||||
Write-Output "ERROR: dist directory is empty after npm run build"
|
||||
exit 1
|
||||
}
|
||||
|
||||
Pop-Location
|
||||
|
||||
write-host "Running go generate"
|
||||
Write-Output "Running go generate"
|
||||
& go generate ./...
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& go build -trimpath -ldflags "-s -w -H windowsgui -X=github.com/ollama/ollama/app/version.Version=$script:VERSION" -o .\dist\windows-ollama-app-${script:ARCH}.exe ./app/cmd/app/
|
||||
@@ -293,42 +379,42 @@ function app {
|
||||
}
|
||||
|
||||
function deps {
|
||||
write-host "Download MSVC Redistributables"
|
||||
Write-Output "Download MSVC Redistributables"
|
||||
mkdir -Force -path "${script:SRC_DIR}\dist\\windows-arm64" | Out-Null
|
||||
mkdir -Force -path "${script:SRC_DIR}\dist\\windows-amd64" | Out-Null
|
||||
invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.arm64.exe" -OutFile "${script:SRC_DIR}\dist\windows-arm64\vc_redist.arm64.exe"
|
||||
invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.x64.exe" -OutFile "${script:SRC_DIR}\dist\windows-amd64\vc_redist.x64.exe"
|
||||
write-host "Done."
|
||||
invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.arm64.exe" -OutFile "${script:SRC_DIR}\dist\windows-arm64\vc_redist.arm64.exe" -ErrorAction Stop
|
||||
invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.x64.exe" -OutFile "${script:SRC_DIR}\dist\windows-amd64\vc_redist.x64.exe" -ErrorAction Stop
|
||||
Write-Output "Done."
|
||||
}
|
||||
|
||||
function sign {
|
||||
# Copy install.ps1 to dist for release packaging
|
||||
write-host "Copying install.ps1 to dist"
|
||||
Copy-Item -Path "${script:SRC_DIR}\scripts\install.ps1" -Destination "${script:SRC_DIR}\dist\install.ps1"
|
||||
Write-Output "Copying install.ps1 to dist"
|
||||
Copy-Item -Path "${script:SRC_DIR}\scripts\install.ps1" -Destination "${script:SRC_DIR}\dist\install.ps1" -ErrorAction Stop
|
||||
|
||||
if ("${env:KEY_CONTAINER}") {
|
||||
write-host "Signing Ollama executables, scripts and libraries"
|
||||
Write-Output "Signing Ollama executables, scripts and libraries"
|
||||
& "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
|
||||
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} `
|
||||
$(get-childitem -path "${script:SRC_DIR}\dist\windows-*" -r -include @('*.exe', '*.dll'))
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
|
||||
write-host "Signing install.ps1"
|
||||
Write-Output "Signing install.ps1"
|
||||
& "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
|
||||
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} `
|
||||
"${script:SRC_DIR}\dist\install.ps1"
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
} else {
|
||||
write-host "Signing not enabled"
|
||||
Write-Output "Signing not enabled"
|
||||
}
|
||||
}
|
||||
|
||||
function installer {
|
||||
if ($null -eq ${script:INNO_SETUP_DIR}) {
|
||||
write-host "ERROR: missing Inno Setup installation directory - install from https://jrsoftware.org/isdl.php"
|
||||
Write-Output "ERROR: missing Inno Setup installation directory - install from https://jrsoftware.org/isdl.php"
|
||||
exit 1
|
||||
}
|
||||
write-host "Building Ollama Installer"
|
||||
Write-Output "Building Ollama Installer"
|
||||
cd "${script:SRC_DIR}\app"
|
||||
$env:PKG_VERSION=$script:PKG_VERSION
|
||||
if ("${env:KEY_CONTAINER}") {
|
||||
@@ -342,24 +428,24 @@ function installer {
|
||||
function zip {
|
||||
if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64") {
|
||||
if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm") {
|
||||
write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64-rocm.zip"
|
||||
Write-Output "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64-rocm.zip"
|
||||
# Temporarily adjust paths so we can retain the same directory structure
|
||||
Remove-Item -ea 0 -r "${script:SRC_DIR}\dist\windows-amd64-rocm"
|
||||
mkdir -Force -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama"
|
||||
Write-Output "Extract this ROCm zip file to the same location where you extracted ollama-windows-amd64.zip" > "${script:SRC_DIR}\dist\windows-amd64-rocm\README.txt"
|
||||
Move-Item -path "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -destination "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama"
|
||||
Move-Item -path "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -destination "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama" -ErrorAction Stop
|
||||
Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-amd64-rocm\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64-rocm.zip" -Force
|
||||
}
|
||||
|
||||
write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64.zip"
|
||||
Write-Output "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64.zip"
|
||||
Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-amd64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64.zip" -Force
|
||||
if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64-rocm") {
|
||||
Move-Item -destination "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama\rocm"
|
||||
Move-Item -destination "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama\rocm" -ErrorAction Stop
|
||||
}
|
||||
}
|
||||
|
||||
if (Test-Path -Path "${script:SRC_DIR}\dist\windows-arm64") {
|
||||
write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-arm64.zip"
|
||||
Write-Output "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-arm64.zip"
|
||||
Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-arm64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-arm64.zip" -Force
|
||||
}
|
||||
}
|
||||
@@ -375,8 +461,9 @@ try {
|
||||
cpu
|
||||
cuda12
|
||||
cuda13
|
||||
rocm
|
||||
rocm6
|
||||
vulkan
|
||||
mlxCuda13
|
||||
ollama
|
||||
app
|
||||
deps
|
||||
@@ -385,13 +472,13 @@ try {
|
||||
zip
|
||||
} else {
|
||||
for ( $i = 0; $i -lt $args.count; $i++ ) {
|
||||
write-host "running build step $($args[$i])"
|
||||
Write-Output "running build step $($args[$i])"
|
||||
& $($args[$i])
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
write-host "Build Failed"
|
||||
write-host $_
|
||||
Write-Error "Build Failed: $($_.Exception.Message)"
|
||||
Write-Error "$($_.ScriptStackTrace)"
|
||||
} finally {
|
||||
set-location $script:SRC_DIR
|
||||
$env:PKG_VERSION=""
|
||||
|
||||
@@ -16,9 +16,16 @@ OLLAMA_COMMON_BUILD_ARGS="--build-arg=VERSION \
|
||||
--build-arg=OLLAMA_FAST_BUILD \
|
||||
--build-arg=CUSTOM_CPU_FLAGS \
|
||||
--build-arg=GPU_RUNNER_CPU_FLAGS \
|
||||
--build-arg=PARALLEL \
|
||||
--build-arg=AMDGPU_TARGETS"
|
||||
|
||||
# Forward local MLX source overrides as Docker build contexts
|
||||
if [ -n "${OLLAMA_MLX_SOURCE:-}" ]; then
|
||||
OLLAMA_COMMON_BUILD_ARGS="$OLLAMA_COMMON_BUILD_ARGS --build-context local-mlx=$(cd "$OLLAMA_MLX_SOURCE" && pwd)"
|
||||
fi
|
||||
if [ -n "${OLLAMA_MLX_C_SOURCE:-}" ]; then
|
||||
OLLAMA_COMMON_BUILD_ARGS="$OLLAMA_COMMON_BUILD_ARGS --build-context local-mlx-c=$(cd "$OLLAMA_MLX_C_SOURCE" && pwd)"
|
||||
fi
|
||||
|
||||
echo "Building Ollama"
|
||||
echo "VERSION=$VERSION"
|
||||
echo "PLATFORM=$PLATFORM"
|
||||
479
server/cloud_proxy.go
Normal file
479
server/cloud_proxy.go
Normal file
@@ -0,0 +1,479 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultCloudProxyBaseURL = "https://ollama.com:443"
|
||||
defaultCloudProxySigningHost = "ollama.com"
|
||||
cloudProxyBaseURLEnv = "OLLAMA_CLOUD_BASE_URL"
|
||||
legacyCloudAnthropicKey = "legacy_cloud_anthropic_web_search"
|
||||
)
|
||||
|
||||
var (
|
||||
cloudProxyBaseURL = defaultCloudProxyBaseURL
|
||||
cloudProxySigningHost = defaultCloudProxySigningHost
|
||||
cloudProxySignRequest = signCloudProxyRequest
|
||||
cloudProxySigninURL = signinURL
|
||||
)
|
||||
|
||||
var hopByHopHeaders = map[string]struct{}{
|
||||
"connection": {},
|
||||
"content-length": {},
|
||||
"proxy-connection": {},
|
||||
"keep-alive": {},
|
||||
"proxy-authenticate": {},
|
||||
"proxy-authorization": {},
|
||||
"te": {},
|
||||
"trailer": {},
|
||||
"transfer-encoding": {},
|
||||
"upgrade": {},
|
||||
}
|
||||
|
||||
func init() {
|
||||
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL(envconfig.Var(cloudProxyBaseURLEnv), mode)
|
||||
if err != nil {
|
||||
slog.Warn("ignoring cloud base URL override", "env", cloudProxyBaseURLEnv, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
cloudProxyBaseURL = baseURL
|
||||
cloudProxySigningHost = signingHost
|
||||
|
||||
if overridden {
|
||||
slog.Info("cloud base URL override enabled", "env", cloudProxyBaseURLEnv, "url", cloudProxyBaseURL, "mode", mode)
|
||||
}
|
||||
}
|
||||
|
||||
func cloudPassthroughMiddleware(disabledOperation string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request.Method != http.MethodPost {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(drifkin): Avoid full-body buffering here for model detection.
|
||||
// A future optimization can parse just enough JSON to read "model" (and
|
||||
// optionally short-circuit cloud-disabled explicit-cloud requests) while
|
||||
// preserving raw passthrough semantics.
|
||||
body, err := readRequestBody(c.Request)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
model, ok := extractModelField(body)
|
||||
if !ok {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
modelRef, err := parseAndValidateModelRef(model)
|
||||
if err != nil || modelRef.Source != modelSourceCloud {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
normalizedBody, err := replaceJSONModelField(body, modelRef.Base)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// TEMP(drifkin): keep Anthropic web search requests on the local middleware
|
||||
// path so WebSearchAnthropicWriter can orchestrate follow-up calls.
|
||||
if c.Request.URL.Path == "/v1/messages" {
|
||||
if hasAnthropicWebSearchTool(body) {
|
||||
c.Set(legacyCloudAnthropicKey, true)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
proxyCloudRequest(c, normalizedBody, disabledOperation)
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
func cloudModelPathPassthroughMiddleware(disabledOperation string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
modelName := strings.TrimSpace(c.Param("model"))
|
||||
if modelName == "" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
modelRef, err := parseAndValidateModelRef(modelName)
|
||||
if err != nil || modelRef.Source != modelSourceCloud {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
proxyPath := "/v1/models/" + modelRef.Base
|
||||
proxyCloudRequestWithPath(c, nil, proxyPath, disabledOperation)
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
func proxyCloudJSONRequest(c *gin.Context, payload any, disabledOperation string) {
|
||||
// TEMP(drifkin): we currently split out this `WithPath` method because we are
|
||||
// mapping `/v1/messages` + web_search to `/api/chat` temporarily. Once we
|
||||
// stop doing this, we can inline this method.
|
||||
proxyCloudJSONRequestWithPath(c, payload, c.Request.URL.Path, disabledOperation)
|
||||
}
|
||||
|
||||
func proxyCloudJSONRequestWithPath(c *gin.Context, payload any, path string, disabledOperation string) {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
proxyCloudRequestWithPath(c, body, path, disabledOperation)
|
||||
}
|
||||
|
||||
func proxyCloudRequest(c *gin.Context, body []byte, disabledOperation string) {
|
||||
proxyCloudRequestWithPath(c, body, c.Request.URL.Path, disabledOperation)
|
||||
}
|
||||
|
||||
func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disabledOperation string) {
|
||||
if disabled, _ := internalcloud.Status(); disabled {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(disabledOperation)})
|
||||
return
|
||||
}
|
||||
|
||||
baseURL, err := url.Parse(cloudProxyBaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
targetURL := baseURL.ResolveReference(&url.URL{
|
||||
Path: path,
|
||||
RawQuery: c.Request.URL.RawQuery,
|
||||
})
|
||||
|
||||
outReq, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL.String(), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
copyProxyRequestHeaders(outReq.Header, c.Request.Header)
|
||||
if outReq.Header.Get("Content-Type") == "" && len(body) > 0 {
|
||||
outReq.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
if err := cloudProxySignRequest(outReq.Context(), outReq); err != nil {
|
||||
slog.Warn("cloud proxy signing failed", "error", err)
|
||||
writeCloudUnauthorized(c)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(drifkin): Add phase-specific proxy timeouts.
|
||||
// Connect/TLS/TTFB should have bounded timeouts, but once streaming starts
|
||||
// we should not enforce a short total timeout for long-lived responses.
|
||||
resp, err := http.DefaultClient.Do(outReq)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
copyProxyResponseHeaders(c.Writer.Header(), resp.Header)
|
||||
c.Status(resp.StatusCode)
|
||||
|
||||
if err := copyProxyResponseBody(c.Writer, resp.Body); err != nil {
|
||||
ctxErr := c.Request.Context().Err()
|
||||
if errors.Is(err, context.Canceled) && errors.Is(ctxErr, context.Canceled) {
|
||||
slog.Debug(
|
||||
"cloud proxy response stream closed by client",
|
||||
"path", c.Request.URL.Path,
|
||||
"status", resp.StatusCode,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Warn(
|
||||
"cloud proxy response copy failed",
|
||||
"path", c.Request.URL.Path,
|
||||
"status", resp.StatusCode,
|
||||
"request_context_canceled", ctxErr != nil,
|
||||
"request_context_err", ctxErr,
|
||||
"error", err,
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func replaceJSONModelField(body []byte, model string) ([]byte, error) {
|
||||
if len(body) == 0 {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
var payload map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modelJSON, err := json.Marshal(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload["model"] = modelJSON
|
||||
|
||||
return json.Marshal(payload)
|
||||
}
|
||||
|
||||
func readRequestBody(r *http.Request) ([]byte, error) {
|
||||
if r.Body == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func extractModelField(body []byte) (string, bool) {
|
||||
if len(body) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
var payload map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
raw, ok := payload["model"]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
var model string
|
||||
if err := json.Unmarshal(raw, &model); err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
model = strings.TrimSpace(model)
|
||||
return model, model != ""
|
||||
}
|
||||
|
||||
func hasAnthropicWebSearchTool(body []byte) bool {
|
||||
if len(body) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Tools []struct {
|
||||
Type string `json:"type"`
|
||||
} `json:"tools"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, tool := range payload.Tools {
|
||||
if strings.HasPrefix(strings.TrimSpace(tool.Type), "web_search") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func writeCloudUnauthorized(c *gin.Context) {
|
||||
signinURL, err := cloudProxySigninURL()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": signinURL})
|
||||
}
|
||||
|
||||
func signCloudProxyRequest(ctx context.Context, req *http.Request) error {
|
||||
if !strings.EqualFold(req.URL.Hostname(), cloudProxySigningHost) {
|
||||
return nil
|
||||
}
|
||||
|
||||
ts := strconv.FormatInt(time.Now().Unix(), 10)
|
||||
challenge := buildCloudSignatureChallenge(req, ts)
|
||||
signature, err := auth.Sign(ctx, []byte(challenge))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", signature)
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildCloudSignatureChallenge(req *http.Request, ts string) string {
|
||||
query := req.URL.Query()
|
||||
query.Set("ts", ts)
|
||||
req.URL.RawQuery = query.Encode()
|
||||
|
||||
return fmt.Sprintf("%s,%s", req.Method, req.URL.RequestURI())
|
||||
}
|
||||
|
||||
func resolveCloudProxyBaseURL(rawOverride string, runMode string) (baseURL string, signingHost string, overridden bool, err error) {
|
||||
baseURL = defaultCloudProxyBaseURL
|
||||
signingHost = defaultCloudProxySigningHost
|
||||
|
||||
rawOverride = strings.TrimSpace(rawOverride)
|
||||
if rawOverride == "" {
|
||||
return baseURL, signingHost, false, nil
|
||||
}
|
||||
|
||||
u, err := url.Parse(rawOverride)
|
||||
if err != nil {
|
||||
return "", "", false, fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
if u.Scheme == "" || u.Host == "" {
|
||||
return "", "", false, fmt.Errorf("invalid URL: scheme and host are required")
|
||||
}
|
||||
if u.User != nil {
|
||||
return "", "", false, fmt.Errorf("invalid URL: userinfo is not allowed")
|
||||
}
|
||||
if u.Path != "" && u.Path != "/" {
|
||||
return "", "", false, fmt.Errorf("invalid URL: path is not allowed")
|
||||
}
|
||||
if u.RawQuery != "" || u.Fragment != "" {
|
||||
return "", "", false, fmt.Errorf("invalid URL: query and fragment are not allowed")
|
||||
}
|
||||
|
||||
host := u.Hostname()
|
||||
if host == "" {
|
||||
return "", "", false, fmt.Errorf("invalid URL: host is required")
|
||||
}
|
||||
|
||||
loopback := isLoopbackHost(host)
|
||||
if runMode == gin.ReleaseMode && !loopback {
|
||||
return "", "", false, fmt.Errorf("non-loopback cloud override is not allowed in release mode")
|
||||
}
|
||||
if !loopback && !strings.EqualFold(u.Scheme, "https") {
|
||||
return "", "", false, fmt.Errorf("non-loopback cloud override must use https")
|
||||
}
|
||||
|
||||
u.Path = ""
|
||||
u.RawPath = ""
|
||||
u.RawQuery = ""
|
||||
u.Fragment = ""
|
||||
|
||||
return u.String(), strings.ToLower(host), true, nil
|
||||
}
|
||||
|
||||
func isLoopbackHost(host string) bool {
|
||||
if strings.EqualFold(host, "localhost") {
|
||||
return true
|
||||
}
|
||||
|
||||
ip := net.ParseIP(host)
|
||||
return ip != nil && ip.IsLoopback()
|
||||
}
|
||||
|
||||
func copyProxyRequestHeaders(dst, src http.Header) {
|
||||
connectionTokens := connectionHeaderTokens(src)
|
||||
for key, values := range src {
|
||||
if isHopByHopHeader(key) || isConnectionTokenHeader(key, connectionTokens) {
|
||||
continue
|
||||
}
|
||||
|
||||
dst.Del(key)
|
||||
for _, value := range values {
|
||||
dst.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func copyProxyResponseHeaders(dst, src http.Header) {
|
||||
connectionTokens := connectionHeaderTokens(src)
|
||||
for key, values := range src {
|
||||
if isHopByHopHeader(key) || isConnectionTokenHeader(key, connectionTokens) {
|
||||
continue
|
||||
}
|
||||
|
||||
dst.Del(key)
|
||||
for _, value := range values {
|
||||
dst.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func copyProxyResponseBody(dst http.ResponseWriter, src io.Reader) error {
|
||||
flusher, canFlush := dst.(http.Flusher)
|
||||
buf := make([]byte, 32*1024)
|
||||
|
||||
for {
|
||||
n, err := src.Read(buf)
|
||||
if n > 0 {
|
||||
if _, writeErr := dst.Write(buf[:n]); writeErr != nil {
|
||||
return writeErr
|
||||
}
|
||||
if canFlush {
|
||||
// TODO(drifkin): Consider conditional flushing so non-streaming
|
||||
// responses don't flush every write and can optimize throughput.
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isHopByHopHeader(name string) bool {
|
||||
_, ok := hopByHopHeaders[strings.ToLower(name)]
|
||||
return ok
|
||||
}
|
||||
|
||||
func connectionHeaderTokens(header http.Header) map[string]struct{} {
|
||||
tokens := map[string]struct{}{}
|
||||
for _, raw := range header.Values("Connection") {
|
||||
for _, token := range strings.Split(raw, ",") {
|
||||
token = strings.TrimSpace(strings.ToLower(token))
|
||||
if token == "" {
|
||||
continue
|
||||
}
|
||||
tokens[token] = struct{}{}
|
||||
}
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func isConnectionTokenHeader(name string, tokens map[string]struct{}) bool {
|
||||
if len(tokens) == 0 {
|
||||
return false
|
||||
}
|
||||
_, ok := tokens[strings.ToLower(name)]
|
||||
return ok
|
||||
}
|
||||
154
server/cloud_proxy_test.go
Normal file
154
server/cloud_proxy_test.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestCopyProxyRequestHeaders_StripsConnectionTokenHeaders(t *testing.T) {
|
||||
src := http.Header{}
|
||||
src.Add("Connection", "keep-alive, X-Trace-Hop, x-alt-hop")
|
||||
src.Add("X-Trace-Hop", "drop-me")
|
||||
src.Add("X-Alt-Hop", "drop-me-too")
|
||||
src.Add("Keep-Alive", "timeout=5")
|
||||
src.Add("X-End-To-End", "keep-me")
|
||||
|
||||
dst := http.Header{}
|
||||
copyProxyRequestHeaders(dst, src)
|
||||
|
||||
if got := dst.Get("Connection"); got != "" {
|
||||
t.Fatalf("expected Connection to be stripped, got %q", got)
|
||||
}
|
||||
if got := dst.Get("Keep-Alive"); got != "" {
|
||||
t.Fatalf("expected Keep-Alive to be stripped, got %q", got)
|
||||
}
|
||||
if got := dst.Get("X-Trace-Hop"); got != "" {
|
||||
t.Fatalf("expected X-Trace-Hop to be stripped via Connection token, got %q", got)
|
||||
}
|
||||
if got := dst.Get("X-Alt-Hop"); got != "" {
|
||||
t.Fatalf("expected X-Alt-Hop to be stripped via Connection token, got %q", got)
|
||||
}
|
||||
if got := dst.Get("X-End-To-End"); got != "keep-me" {
|
||||
t.Fatalf("expected X-End-To-End to be forwarded, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyProxyResponseHeaders_StripsConnectionTokenHeaders(t *testing.T) {
|
||||
src := http.Header{}
|
||||
src.Add("Connection", "X-Upstream-Hop")
|
||||
src.Add("X-Upstream-Hop", "drop-me")
|
||||
src.Add("Content-Type", "application/json")
|
||||
src.Add("X-Server-Trace", "keep-me")
|
||||
|
||||
dst := http.Header{}
|
||||
copyProxyResponseHeaders(dst, src)
|
||||
|
||||
if got := dst.Get("Connection"); got != "" {
|
||||
t.Fatalf("expected Connection to be stripped, got %q", got)
|
||||
}
|
||||
if got := dst.Get("X-Upstream-Hop"); got != "" {
|
||||
t.Fatalf("expected X-Upstream-Hop to be stripped via Connection token, got %q", got)
|
||||
}
|
||||
if got := dst.Get("Content-Type"); got != "application/json" {
|
||||
t.Fatalf("expected Content-Type to be forwarded, got %q", got)
|
||||
}
|
||||
if got := dst.Get("X-Server-Trace"); got != "keep-me" {
|
||||
t.Fatalf("expected X-Server-Trace to be forwarded, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCloudProxyBaseURL_Default(t *testing.T) {
|
||||
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("", gin.ReleaseMode)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if overridden {
|
||||
t.Fatal("expected override=false for empty input")
|
||||
}
|
||||
if baseURL != defaultCloudProxyBaseURL {
|
||||
t.Fatalf("expected default base URL %q, got %q", defaultCloudProxyBaseURL, baseURL)
|
||||
}
|
||||
if signingHost != defaultCloudProxySigningHost {
|
||||
t.Fatalf("expected default signing host %q, got %q", defaultCloudProxySigningHost, signingHost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCloudProxyBaseURL_ReleaseAllowsLoopback(t *testing.T) {
|
||||
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("http://localhost:8080", gin.ReleaseMode)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !overridden {
|
||||
t.Fatal("expected override=true")
|
||||
}
|
||||
if baseURL != "http://localhost:8080" {
|
||||
t.Fatalf("unexpected base URL: %q", baseURL)
|
||||
}
|
||||
if signingHost != "localhost" {
|
||||
t.Fatalf("unexpected signing host: %q", signingHost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCloudProxyBaseURL_ReleaseRejectsNonLoopback(t *testing.T) {
|
||||
_, _, _, err := resolveCloudProxyBaseURL("https://example.com", gin.ReleaseMode)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-loopback override in release mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCloudProxyBaseURL_DevAllowsNonLoopbackHTTPS(t *testing.T) {
|
||||
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("https://example.com:8443", gin.DebugMode)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !overridden {
|
||||
t.Fatal("expected override=true")
|
||||
}
|
||||
if baseURL != "https://example.com:8443" {
|
||||
t.Fatalf("unexpected base URL: %q", baseURL)
|
||||
}
|
||||
if signingHost != "example.com" {
|
||||
t.Fatalf("unexpected signing host: %q", signingHost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCloudProxyBaseURL_DevRejectsNonLoopbackHTTP(t *testing.T) {
|
||||
_, _, _, err := resolveCloudProxyBaseURL("http://example.com", gin.DebugMode)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-loopback http override in dev mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCloudSignatureChallengeIncludesExistingQuery(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&foo=bar", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
got := buildCloudSignatureChallenge(req, "123")
|
||||
want := "POST,/v1/messages?beta=true&foo=bar&ts=123"
|
||||
if got != want {
|
||||
t.Fatalf("challenge mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if req.URL.RawQuery != "beta=true&foo=bar&ts=123" {
|
||||
t.Fatalf("unexpected signed query: %q", req.URL.RawQuery)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCloudSignatureChallengeOverwritesExistingTimestamp(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&ts=999", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
got := buildCloudSignatureChallenge(req, "123")
|
||||
want := "POST,/v1/messages?beta=true&ts=123"
|
||||
if got != want {
|
||||
t.Fatalf("challenge mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if req.URL.RawQuery != "beta=true&ts=123" {
|
||||
t.Fatalf("unexpected signed query: %q", req.URL.RawQuery)
|
||||
}
|
||||
}
|
||||
@@ -65,11 +65,22 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
config.Parser = r.Parser
|
||||
config.Requires = r.Requires
|
||||
|
||||
for v := range r.Files {
|
||||
for v, digest := range r.Files {
|
||||
if !fs.ValidPath(v) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
|
||||
return
|
||||
}
|
||||
if digest == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": manifest.ErrInvalidDigestFormat.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for _, digest := range r.Adapters {
|
||||
if digest == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": manifest.ErrInvalidDigestFormat.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
name := model.ParseName(cmp.Or(r.Model, r.Name))
|
||||
@@ -99,19 +110,26 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
|
||||
if r.From != "" {
|
||||
slog.Debug("create model from model name", "from", r.From)
|
||||
fromName := model.ParseName(r.From)
|
||||
if !fromName.IsValid() {
|
||||
fromRef, err := parseAndValidateModelRef(r.From)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": errtypes.InvalidModelNameErrMsg, "status": http.StatusBadRequest}
|
||||
return
|
||||
}
|
||||
if r.RemoteHost != "" {
|
||||
ru, err := remoteURL(r.RemoteHost)
|
||||
|
||||
fromName := fromRef.Name
|
||||
remoteHost := r.RemoteHost
|
||||
if fromRef.Source == modelSourceCloud && remoteHost == "" {
|
||||
remoteHost = cloudProxyBaseURL
|
||||
}
|
||||
|
||||
if remoteHost != "" {
|
||||
ru, err := remoteURL(remoteHost)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": "bad remote", "status": http.StatusBadRequest}
|
||||
return
|
||||
}
|
||||
|
||||
config.RemoteModel = r.From
|
||||
config.RemoteModel = fromRef.Base
|
||||
config.RemoteHost = ru
|
||||
remote = true
|
||||
} else {
|
||||
|
||||
@@ -71,6 +71,10 @@ type Model struct {
|
||||
Template *template.Template
|
||||
}
|
||||
|
||||
func (m *Model) IsMLX() bool {
|
||||
return m.Config.ModelFormat == "safetensors"
|
||||
}
|
||||
|
||||
// Capabilities returns the capabilities that the model supports
|
||||
func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities := []model.Capability{}
|
||||
|
||||
81
server/model_resolver.go
Normal file
81
server/model_resolver.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
type modelSource = modelref.ModelSource
|
||||
|
||||
const (
|
||||
modelSourceUnspecified modelSource = modelref.ModelSourceUnspecified
|
||||
modelSourceLocal modelSource = modelref.ModelSourceLocal
|
||||
modelSourceCloud modelSource = modelref.ModelSourceCloud
|
||||
)
|
||||
|
||||
var (
|
||||
errConflictingModelSource = modelref.ErrConflictingSourceSuffix
|
||||
errModelRequired = modelref.ErrModelRequired
|
||||
)
|
||||
|
||||
type parsedModelRef struct {
|
||||
// Original is the caller-provided model string before source parsing.
|
||||
// Example: "gpt-oss:20b:cloud".
|
||||
Original string
|
||||
// Base is the model string after source suffix normalization.
|
||||
// Example: "gpt-oss:20b:cloud" -> "gpt-oss:20b".
|
||||
Base string
|
||||
// Name is Base parsed as a fully-qualified model.Name with defaults applied.
|
||||
// Example: "registry.ollama.ai/library/gpt-oss:20b".
|
||||
Name model.Name
|
||||
// Source captures explicit source intent from the original input.
|
||||
// Example: "gpt-oss:20b:cloud" -> modelSourceCloud.
|
||||
Source modelSource
|
||||
}
|
||||
|
||||
func parseAndValidateModelRef(raw string) (parsedModelRef, error) {
|
||||
var zero parsedModelRef
|
||||
|
||||
parsed, err := modelref.ParseRef(raw)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
|
||||
name := model.ParseName(parsed.Base)
|
||||
if !name.IsValid() {
|
||||
return zero, model.Unqualified(name)
|
||||
}
|
||||
|
||||
return parsedModelRef{
|
||||
Original: parsed.Original,
|
||||
Base: parsed.Base,
|
||||
Name: name,
|
||||
Source: parsed.Source,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseNormalizePullModelRef(raw string) (parsedModelRef, error) {
|
||||
var zero parsedModelRef
|
||||
|
||||
parsedRef, err := modelref.ParseRef(raw)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
|
||||
normalizedName, _, err := modelref.NormalizePullName(raw)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
|
||||
name := model.ParseName(normalizedName)
|
||||
if !name.IsValid() {
|
||||
return zero, model.Unqualified(name)
|
||||
}
|
||||
|
||||
return parsedModelRef{
|
||||
Original: parsedRef.Original,
|
||||
Base: normalizedName,
|
||||
Name: name,
|
||||
Source: parsedRef.Source,
|
||||
}, nil
|
||||
}
|
||||
170
server/model_resolver_test.go
Normal file
170
server/model_resolver_test.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseModelSelector(t *testing.T) {
|
||||
t.Run("cloud suffix", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("gpt-oss:20b:cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceCloud {
|
||||
t.Fatalf("expected source cloud, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Base != "gpt-oss:20b" {
|
||||
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
|
||||
}
|
||||
|
||||
if got.Name.String() != "registry.ollama.ai/library/gpt-oss:20b" {
|
||||
t.Fatalf("unexpected resolved name: %q", got.Name.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("legacy cloud suffix", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("gpt-oss:20b-cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceCloud {
|
||||
t.Fatalf("expected source cloud, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Base != "gpt-oss:20b" {
|
||||
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("bare dash cloud name is not explicit cloud", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("my-cloud-model")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceUnspecified {
|
||||
t.Fatalf("expected source unspecified, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Base != "my-cloud-model" {
|
||||
t.Fatalf("expected base my-cloud-model, got %q", got.Base)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("local suffix", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("qwen3:8b:local")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceLocal {
|
||||
t.Fatalf("expected source local, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Base != "qwen3:8b" {
|
||||
t.Fatalf("expected base qwen3:8b, got %q", got.Base)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("conflicting source suffixes fail", func(t *testing.T) {
|
||||
_, err := parseAndValidateModelRef("foo:cloud:local")
|
||||
if !errors.Is(err, errConflictingModelSource) {
|
||||
t.Fatalf("expected errConflictingModelSource, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unspecified source", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("llama3")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceUnspecified {
|
||||
t.Fatalf("expected source unspecified, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Name.Tag != "latest" {
|
||||
t.Fatalf("expected default latest tag, got %q", got.Name.Tag)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unknown suffix is treated as tag", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("gpt-oss:clod")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceUnspecified {
|
||||
t.Fatalf("expected source unspecified, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Name.Tag != "clod" {
|
||||
t.Fatalf("expected tag clod, got %q", got.Name.Tag)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty model fails", func(t *testing.T) {
|
||||
_, err := parseAndValidateModelRef("")
|
||||
if !errors.Is(err, errModelRequired) {
|
||||
t.Fatalf("expected errModelRequired, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid model fails", func(t *testing.T) {
|
||||
_, err := parseAndValidateModelRef("::cloud")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid model")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unqualified") {
|
||||
t.Fatalf("expected unqualified model error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestParsePullModelRef(t *testing.T) {
|
||||
t.Run("explicit local is normalized", func(t *testing.T) {
|
||||
got, err := parseNormalizePullModelRef("gpt-oss:20b:local")
|
||||
if err != nil {
|
||||
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceLocal {
|
||||
t.Fatalf("expected source local, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Base != "gpt-oss:20b" {
|
||||
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("explicit cloud with size maps to legacy cloud suffix", func(t *testing.T) {
|
||||
got, err := parseNormalizePullModelRef("gpt-oss:20b:cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
|
||||
}
|
||||
if got.Base != "gpt-oss:20b-cloud" {
|
||||
t.Fatalf("expected base gpt-oss:20b-cloud, got %q", got.Base)
|
||||
}
|
||||
if got.Name.String() != "registry.ollama.ai/library/gpt-oss:20b-cloud" {
|
||||
t.Fatalf("unexpected resolved name: %q", got.Name.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("explicit cloud without size maps to cloud tag", func(t *testing.T) {
|
||||
got, err := parseNormalizePullModelRef("qwen3:cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
|
||||
}
|
||||
if got.Base != "qwen3:cloud" {
|
||||
t.Fatalf("expected base qwen3:cloud, got %q", got.Base)
|
||||
}
|
||||
if got.Name.String() != "registry.ollama.ai/library/qwen3:cloud" {
|
||||
t.Fatalf("unexpected resolved name: %q", got.Name.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -30,42 +30,44 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
lastMsgIdx := len(msgs) - 1
|
||||
currMsgIdx := 0
|
||||
|
||||
// Start with all messages and remove from the front until it fits in context
|
||||
for i := 0; i <= lastMsgIdx; i++ {
|
||||
// Collect system messages from the portion we're about to skip
|
||||
system = make([]api.Message, 0)
|
||||
for j := range i {
|
||||
if msgs[j].Role == "system" {
|
||||
system = append(system, msgs[j])
|
||||
if truncate {
|
||||
// Start with all messages and remove from the front until it fits in context
|
||||
for i := 0; i <= lastMsgIdx; i++ {
|
||||
// Collect system messages from the portion we're about to skip
|
||||
system = make([]api.Message, 0)
|
||||
for j := range i {
|
||||
if msgs[j].Role == "system" {
|
||||
system = append(system, msgs[j])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
s, err := tokenize(ctx, p)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
ctxLen := len(s)
|
||||
if m.ProjectorPaths != nil {
|
||||
for _, msg := range msgs[i:] {
|
||||
ctxLen += imageNumTokens * len(msg.Images)
|
||||
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if !truncate || ctxLen <= opts.NumCtx {
|
||||
currMsgIdx = i
|
||||
break
|
||||
}
|
||||
s, err := tokenize(ctx, p)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
// Must always include at least the last message
|
||||
if i == lastMsgIdx {
|
||||
currMsgIdx = lastMsgIdx
|
||||
break
|
||||
ctxLen := len(s)
|
||||
if m.ProjectorPaths != nil {
|
||||
for _, msg := range msgs[i:] {
|
||||
ctxLen += imageNumTokens * len(msg.Images)
|
||||
}
|
||||
}
|
||||
|
||||
if ctxLen <= opts.NumCtx {
|
||||
currMsgIdx = i
|
||||
break
|
||||
}
|
||||
|
||||
// Must always include at least the last message
|
||||
if i == lastMsgIdx {
|
||||
currMsgIdx = lastMsgIdx
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -366,3 +367,33 @@ func TestChatPromptRendererDoesNotRewriteMessageContent(t *testing.T) {
|
||||
t.Fatal("prompt is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatPromptGLMOcrRendererAddsImageTags(t *testing.T) {
|
||||
msgs := []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "extract text",
|
||||
Images: []api.ImageData{[]byte("img-1"), []byte("img-2")},
|
||||
},
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Config: model.ConfigV2{Renderer: "glm-ocr"},
|
||||
ProjectorPaths: []string{"vision"},
|
||||
}
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: 8192}}
|
||||
think := false
|
||||
|
||||
prompt, images, err := chatPrompt(t.Context(), &m, mockRunner{}.Tokenize, &opts, msgs, nil, &api.ThinkValue{Value: think}, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got, want := len(images), 2; got != want {
|
||||
t.Fatalf("len(images) = %d, want %d", got, want)
|
||||
}
|
||||
|
||||
if !strings.Contains(prompt, "<|user|>\n[img-0][img-1]extract text") {
|
||||
t.Fatalf("prompt missing glm-ocr image tags, got: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
202
server/routes.go
202
server/routes.go
@@ -62,8 +62,21 @@ const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
|
||||
const (
|
||||
cloudErrRemoteInferenceUnavailable = "remote model is unavailable"
|
||||
cloudErrRemoteModelDetailsUnavailable = "remote model details are unavailable"
|
||||
cloudErrWebSearchUnavailable = "web search is unavailable"
|
||||
cloudErrWebFetchUnavailable = "web fetch is unavailable"
|
||||
)
|
||||
|
||||
func writeModelRefParseError(c *gin.Context, err error, fallbackStatus int, fallbackMessage string) {
|
||||
switch {
|
||||
case errors.Is(err, errConflictingModelSource):
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
case errors.Is(err, model.ErrUnqualifiedName):
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
||||
default:
|
||||
c.JSON(fallbackStatus, gin.H{"error": fallbackMessage})
|
||||
}
|
||||
}
|
||||
|
||||
func shouldUseHarmony(model *Model) bool {
|
||||
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
|
||||
// heuristic to check whether the template expects to be parsed via harmony:
|
||||
@@ -150,7 +163,7 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
|
||||
}
|
||||
|
||||
useImagegen, _ := requestOpts["use_imagegen_runner"].(bool)
|
||||
// Deprecated runner override option; ignore if present.
|
||||
delete(requestOpts, "use_imagegen_runner")
|
||||
|
||||
opts, err := s.modelOptions(model, requestOpts)
|
||||
@@ -158,7 +171,7 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive, useImagegen)
|
||||
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-runnerCh:
|
||||
@@ -196,14 +209,22 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
// Ideally this is "invalid model name" but we're keeping with
|
||||
// what the API currently returns until we can change it.
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
modelRef, err := parseAndValidateModelRef(req.Model)
|
||||
if err != nil {
|
||||
writeModelRefParseError(c, err, http.StatusNotFound, fmt.Sprintf("model '%s' not found", req.Model))
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceCloud {
|
||||
// TODO(drifkin): evaluate an `/api/*` passthrough for cloud where the
|
||||
// original body (modulo model name normalization) is sent to cloud.
|
||||
req.Model = modelRef.Base
|
||||
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
name := modelRef.Name
|
||||
|
||||
resolvedName, _, err := s.resolveAlias(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -237,6 +258,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceLocal && m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
return
|
||||
}
|
||||
|
||||
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||
if disabled, _ := internalcloud.Status(); disabled {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable)})
|
||||
@@ -370,12 +396,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Validate Think value: string values currently only allowed for harmony/gptoss models
|
||||
if req.Think != nil && req.Think.IsString() && m.Config.Parser != "harmony" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())})
|
||||
return
|
||||
}
|
||||
|
||||
caps := []model.Capability{model.CapabilityCompletion}
|
||||
if req.Suffix != "" {
|
||||
caps = append(caps, model.CapabilityInsert)
|
||||
@@ -484,7 +504,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
// the real chat handler, but doing this as a stopgap to get renderer
|
||||
// support for generate
|
||||
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
|
||||
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
|
||||
genTruncate := (req.Truncate == nil || *req.Truncate) && !m.IsMLX()
|
||||
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, genTruncate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -675,6 +696,18 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
modelRef, err := parseAndValidateModelRef(req.Model)
|
||||
if err != nil {
|
||||
writeModelRefParseError(c, err, http.StatusNotFound, fmt.Sprintf("model '%s' not found", req.Model))
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceCloud {
|
||||
req.Model = modelRef.Base
|
||||
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
var input []string
|
||||
|
||||
switch i := req.Input.(type) {
|
||||
@@ -697,7 +730,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
name, err := getExistingName(model.ParseName(req.Model))
|
||||
name, err := getExistingName(modelRef.Name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
return
|
||||
@@ -844,12 +877,20 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
modelRef, err := parseAndValidateModelRef(req.Model)
|
||||
if err != nil {
|
||||
writeModelRefParseError(c, err, http.StatusBadRequest, "model is required")
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceCloud {
|
||||
req.Model = modelRef.Base
|
||||
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
name := modelRef.Name
|
||||
|
||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
@@ -891,12 +932,19 @@ func (s *Server) PullHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(cmp.Or(req.Model, req.Name))
|
||||
if !name.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
||||
// TEMP(drifkin): we're temporarily allowing to continue pulling cloud model
|
||||
// stub-files until we integrate cloud models into `/api/tags` (in which case
|
||||
// this roundabout way of "adding" cloud models won't be needed anymore). So
|
||||
// right here normalize any `:cloud` models into the legacy-style suffixes
|
||||
// `:<tag>-cloud` and `:cloud`
|
||||
modelRef, err := parseNormalizePullModelRef(cmp.Or(req.Model, req.Name))
|
||||
if err != nil {
|
||||
writeModelRefParseError(c, err, http.StatusBadRequest, errtypes.InvalidModelNameErrMsg)
|
||||
return
|
||||
}
|
||||
|
||||
name := modelRef.Name
|
||||
|
||||
name, err = getExistingName(name)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
@@ -1023,13 +1071,20 @@ func (s *Server) DeleteHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
n := model.ParseName(cmp.Or(r.Model, r.Name))
|
||||
if !n.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
|
||||
modelRef, err := parseNormalizePullModelRef(cmp.Or(r.Model, r.Name))
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, errConflictingModelSource):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
case errors.Is(err, model.ErrUnqualifiedName):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
|
||||
default:
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
n, err := getExistingName(n)
|
||||
n, err := getExistingName(modelRef.Name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))})
|
||||
return
|
||||
@@ -1078,6 +1133,20 @@ func (s *Server) ShowHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
modelRef, err := parseAndValidateModelRef(req.Model)
|
||||
if err != nil {
|
||||
writeModelRefParseError(c, err, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceCloud {
|
||||
req.Model = modelRef.Base
|
||||
proxyCloudJSONRequest(c, req, cloudErrRemoteModelDetailsUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
req.Model = modelRef.Base
|
||||
|
||||
resp, err := GetModelInfo(req)
|
||||
if err != nil {
|
||||
var statusErr api.StatusError
|
||||
@@ -1094,6 +1163,11 @@ func (s *Server) ShowHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceLocal && resp.RemoteHost != "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", modelRef.Original)})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
@@ -1621,6 +1695,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.GET("/api/experimental/aliases", s.ListAliasesHandler)
|
||||
r.POST("/api/experimental/aliases", s.CreateAliasHandler)
|
||||
r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler)
|
||||
r.POST("/api/experimental/web_search", s.WebSearchExperimentalHandler)
|
||||
r.POST("/api/experimental/web_fetch", s.WebFetchExperimentalHandler)
|
||||
|
||||
// Inference
|
||||
r.GET("/api/ps", s.PsHandler)
|
||||
@@ -1630,18 +1706,20 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||
|
||||
// Inference (OpenAI compatibility)
|
||||
r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||
// TODO(cloud-stage-a): apply Modelfile overlay deltas for local models with cloud
|
||||
// parents on v1 request families while preserving this explicit :cloud passthrough.
|
||||
r.POST("/v1/chat/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ChatMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.CompletionsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/embeddings", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
r.GET("/v1/models/:model", cloudModelPathPassthroughMiddleware(cloudErrRemoteModelDetailsUnavailable), middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
// OpenAI-compatible image generation endpoints
|
||||
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/images/edits", middleware.ImageEditsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/images/generations", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/images/edits", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageEditsMiddleware(), s.GenerateHandler)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/messages", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
|
||||
if rc != nil {
|
||||
// wrap old with new
|
||||
@@ -1863,6 +1941,29 @@ func (s *Server) StatusHandler(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) WebSearchExperimentalHandler(c *gin.Context) {
|
||||
s.webExperimentalProxyHandler(c, "/api/web_search", cloudErrWebSearchUnavailable)
|
||||
}
|
||||
|
||||
func (s *Server) WebFetchExperimentalHandler(c *gin.Context) {
|
||||
s.webExperimentalProxyHandler(c, "/api/web_fetch", cloudErrWebFetchUnavailable)
|
||||
}
|
||||
|
||||
func (s *Server) webExperimentalProxyHandler(c *gin.Context, proxyPath, disabledOperation string) {
|
||||
body, err := readRequestBody(c.Request)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if len(bytes.TrimSpace(body)) == 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
}
|
||||
|
||||
proxyCloudRequestWithPath(c, body, proxyPath, disabledOperation)
|
||||
}
|
||||
|
||||
func (s *Server) WhoamiHandler(c *gin.Context) {
|
||||
// todo allow other hosts
|
||||
u, err := url.Parse("https://ollama.com")
|
||||
@@ -1951,6 +2052,9 @@ func (s *Server) PsHandler(c *gin.Context) {
|
||||
}
|
||||
if v.llama != nil {
|
||||
mr.ContextLength = v.llama.ContextLength()
|
||||
total, vram := v.llama.MemorySize()
|
||||
mr.Size = int64(total)
|
||||
mr.SizeVRAM = int64(vram)
|
||||
}
|
||||
// The scheduler waits to set expiresAt, so if a model is loading it's
|
||||
// possible that it will be set to the unix epoch. For those cases, just
|
||||
@@ -1997,12 +2101,24 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
modelRef, err := parseAndValidateModelRef(req.Model)
|
||||
if err != nil {
|
||||
writeModelRefParseError(c, err, http.StatusBadRequest, "model is required")
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceCloud {
|
||||
req.Model = modelRef.Base
|
||||
if c.GetBool(legacyCloudAnthropicKey) {
|
||||
proxyCloudJSONRequestWithPath(c, req, "/api/chat", cloudErrRemoteInferenceUnavailable)
|
||||
return
|
||||
}
|
||||
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
name := modelRef.Name
|
||||
|
||||
resolvedName, _, err := s.resolveAlias(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -2034,6 +2150,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceLocal && m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
return
|
||||
}
|
||||
|
||||
// expire the runner
|
||||
if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
|
||||
s.sched.expireRunner(m)
|
||||
@@ -2213,6 +2334,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
truncate := req.Truncate == nil || *req.Truncate
|
||||
if m.IsMLX() {
|
||||
truncate = false
|
||||
}
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
||||
if err != nil {
|
||||
slog.Error("chat prompt error", "error", err)
|
||||
@@ -2233,12 +2357,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Validate Think value: string values currently only allowed for harmony/gptoss models
|
||||
if req.Think != nil && req.Think.IsString() && m.Config.Parser != "harmony" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())})
|
||||
return
|
||||
}
|
||||
|
||||
var thinkingState *thinking.Parser
|
||||
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||
if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -144,6 +144,37 @@ func TestCreateFromBin(t *testing.T) {
|
||||
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
|
||||
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
||||
})
|
||||
|
||||
t.Run("empty file digest", func(t *testing.T) {
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: "my-gguf-model",
|
||||
Files: map[string]string{"0.gguf": ""},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "invalid digest format") {
|
||||
t.Errorf("expected invalid digest format error, got:\n%s", w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty adapter digest", func(t *testing.T) {
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: "my-gguf-model",
|
||||
Files: map[string]string{"0.gguf": digest},
|
||||
Adapters: map[string]string{"adapter.gguf": ""},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "invalid digest format") {
|
||||
t.Errorf("expected invalid digest format error, got:\n%s", w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateFromModel(t *testing.T) {
|
||||
@@ -763,6 +794,43 @@ func TestCreateAndShowRemoteModel(t *testing.T) {
|
||||
fmt.Printf("resp = %#v\n", resp)
|
||||
}
|
||||
|
||||
func TestCreateFromCloudSourceSuffix(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var s Server
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-cloud-from-suffix",
|
||||
From: "gpt-oss:20b:cloud",
|
||||
Info: map[string]any{
|
||||
"capabilities": []string{"completion"},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.ShowHandler, api.ShowRequest{Model: "test-cloud-from-suffix"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.ShowResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.RemoteHost != "https://ollama.com:443" {
|
||||
t.Fatalf("expected remote host https://ollama.com:443, got %q", resp.RemoteHost)
|
||||
}
|
||||
|
||||
if resp.RemoteModel != "gpt-oss:20b" {
|
||||
t.Fatalf("expected remote model gpt-oss:20b, got %q", resp.RemoteModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateLicenses(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -111,3 +111,32 @@ func TestDeleteDuplicateLayers(t *testing.T) {
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
|
||||
}
|
||||
|
||||
func TestDeleteCloudSourceNormalizesToLegacyName(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
|
||||
var s Server
|
||||
|
||||
_, digest := createBinFile(t, nil, nil)
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: "gpt-oss:20b-cloud",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "gpt-oss", "20b-cloud"),
|
||||
})
|
||||
|
||||
w = createRequest(t, s.DeleteHandler, api.DeleteRequest{Name: "gpt-oss:20b:cloud"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d (%s)", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
|
||||
}
|
||||
|
||||
335
server/routes_web_experimental_test.go
Normal file
335
server/routes_web_experimental_test.go
Normal file
@@ -0,0 +1,335 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
)
|
||||
|
||||
type webExperimentalUpstreamCapture struct {
|
||||
path string
|
||||
body string
|
||||
header http.Header
|
||||
}
|
||||
|
||||
func newWebExperimentalUpstream(t *testing.T, responseBody string) (*httptest.Server, *webExperimentalUpstreamCapture) {
|
||||
t.Helper()
|
||||
|
||||
capture := &webExperimentalUpstreamCapture{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
payload, _ := io.ReadAll(r.Body)
|
||||
capture.path = r.URL.Path
|
||||
capture.body = string(payload)
|
||||
capture.header = r.Header.Clone()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(responseBody))
|
||||
}))
|
||||
|
||||
return srv, capture
|
||||
}
|
||||
|
||||
func TestExperimentalWebEndpointsPassthrough(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
localPath string
|
||||
upstreamPath string
|
||||
requestBody string
|
||||
responseBody string
|
||||
assertBody string
|
||||
}{
|
||||
{
|
||||
name: "web_search",
|
||||
localPath: "/api/experimental/web_search",
|
||||
upstreamPath: "/api/web_search",
|
||||
requestBody: `{"query":"what is ollama?","max_results":3}`,
|
||||
responseBody: `{"results":[{"title":"Ollama","url":"https://ollama.com","content":"Cloud models are now available"}]}`,
|
||||
assertBody: `"query":"what is ollama?"`,
|
||||
},
|
||||
{
|
||||
name: "web_fetch",
|
||||
localPath: "/api/experimental/web_fetch",
|
||||
upstreamPath: "/api/web_fetch",
|
||||
requestBody: `{"url":"https://ollama.com"}`,
|
||||
responseBody: `{"title":"Ollama","content":"Cloud models are now available","links":["https://ollama.com/"]}`,
|
||||
assertBody: `"url":"https://ollama.com"`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
upstream, capture := newWebExperimentalUpstream(t, tt.responseBody)
|
||||
defer upstream.Close()
|
||||
|
||||
original := cloudProxyBaseURL
|
||||
cloudProxyBaseURL = upstream.URL
|
||||
t.Cleanup(func() { cloudProxyBaseURL = original })
|
||||
|
||||
s := &Server{}
|
||||
router, err := s.GenerateRoutes(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
local := httptest.NewServer(router)
|
||||
defer local.Close()
|
||||
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+tt.localPath, bytes.NewBufferString(tt.requestBody))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer should-forward")
|
||||
req.Header.Set("X-Test-Header", "web-experimental")
|
||||
|
||||
resp, err := local.Client().Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body))
|
||||
}
|
||||
if capture.path != tt.upstreamPath {
|
||||
t.Fatalf("expected upstream path %q, got %q", tt.upstreamPath, capture.path)
|
||||
}
|
||||
if !bytes.Contains([]byte(capture.body), []byte(tt.assertBody)) {
|
||||
t.Fatalf("expected upstream body to contain %q, got %q", tt.assertBody, capture.body)
|
||||
}
|
||||
if got := capture.header.Get("Authorization"); got != "Bearer should-forward" {
|
||||
t.Fatalf("expected forwarded Authorization header, got %q", got)
|
||||
}
|
||||
if got := capture.header.Get("X-Test-Header"); got != "web-experimental" {
|
||||
t.Fatalf("expected forwarded X-Test-Header=web-experimental, got %q", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExperimentalWebEndpointsMissingBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
|
||||
s := &Server{}
|
||||
router, err := s.GenerateRoutes(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
local := httptest.NewServer(router)
|
||||
defer local.Close()
|
||||
|
||||
tests := []string{
|
||||
"/api/experimental/web_search",
|
||||
"/api/experimental/web_fetch",
|
||||
}
|
||||
|
||||
for _, path := range tests {
|
||||
t.Run(path, func(t *testing.T) {
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+path, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := local.Client().Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400, got %d (%s)", resp.StatusCode, string(body))
|
||||
}
|
||||
if string(body) != `{"error":"missing request body"}` {
|
||||
t.Fatalf("unexpected response body: %s", string(body))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExperimentalWebEndpointsCloudDisabled(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
t.Setenv("OLLAMA_NO_CLOUD", "1")
|
||||
|
||||
s := &Server{}
|
||||
router, err := s.GenerateRoutes(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
local := httptest.NewServer(router)
|
||||
defer local.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
request string
|
||||
operation string
|
||||
}{
|
||||
{
|
||||
name: "web_search",
|
||||
path: "/api/experimental/web_search",
|
||||
request: `{"query":"latest ollama release"}`,
|
||||
operation: cloudErrWebSearchUnavailable,
|
||||
},
|
||||
{
|
||||
name: "web_fetch",
|
||||
path: "/api/experimental/web_fetch",
|
||||
request: `{"url":"https://ollama.com"}`,
|
||||
operation: cloudErrWebFetchUnavailable,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+tt.path, bytes.NewBufferString(tt.request))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := local.Client().Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Fatalf("expected status 403, got %d (%s)", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var got map[string]string
|
||||
if err := json.Unmarshal(body, &got); err != nil {
|
||||
t.Fatalf("expected json error body, got: %q", string(body))
|
||||
}
|
||||
if got["error"] != internalcloud.DisabledError(tt.operation) {
|
||||
t.Fatalf("unexpected error message: %q", got["error"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExperimentalWebEndpointSigningFailureReturnsUnauthorized(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
|
||||
origSignRequest := cloudProxySignRequest
|
||||
origSigninURL := cloudProxySigninURL
|
||||
cloudProxySignRequest = func(context.Context, *http.Request) error {
|
||||
return errors.New("ssh: no key found")
|
||||
}
|
||||
cloudProxySigninURL = func() (string, error) {
|
||||
return "https://ollama.com/signin/example", nil
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
cloudProxySignRequest = origSignRequest
|
||||
cloudProxySigninURL = origSigninURL
|
||||
})
|
||||
|
||||
s := &Server{}
|
||||
router, err := s.GenerateRoutes(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
local := httptest.NewServer(router)
|
||||
defer local.Close()
|
||||
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/experimental/web_search", bytes.NewBufferString(`{"query":"hello"}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := local.Client().Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Fatalf("expected status 401, got %d (%s)", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(body, &got); err != nil {
|
||||
t.Fatalf("expected json error body, got: %q", string(body))
|
||||
}
|
||||
if got["error"] != "unauthorized" {
|
||||
t.Fatalf("unexpected error message: %v", got["error"])
|
||||
}
|
||||
if got["signin_url"] != "https://ollama.com/signin/example" {
|
||||
t.Fatalf("unexpected signin_url: %v", got["signin_url"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestExperimentalWebEndpointSigningFailureWithoutSigninURL(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
|
||||
origSignRequest := cloudProxySignRequest
|
||||
origSigninURL := cloudProxySigninURL
|
||||
cloudProxySignRequest = func(context.Context, *http.Request) error {
|
||||
return errors.New("ssh: no key found")
|
||||
}
|
||||
cloudProxySigninURL = func() (string, error) {
|
||||
return "", errors.New("key missing")
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
cloudProxySignRequest = origSignRequest
|
||||
cloudProxySigninURL = origSigninURL
|
||||
})
|
||||
|
||||
s := &Server{}
|
||||
router, err := s.GenerateRoutes(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
local := httptest.NewServer(router)
|
||||
defer local.Close()
|
||||
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/experimental/web_fetch", bytes.NewBufferString(`{"url":"https://ollama.com"}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := local.Client().Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Fatalf("expected status 401, got %d (%s)", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(body, &got); err != nil {
|
||||
t.Fatalf("expected json error body, got: %q", string(body))
|
||||
}
|
||||
if got["error"] != "unauthorized" {
|
||||
t.Fatalf("unexpected error message: %v", got["error"])
|
||||
}
|
||||
if _, ok := got["signin_url"]; ok {
|
||||
t.Fatalf("did not expect signin_url when helper fails, got %v", got["signin_url"])
|
||||
}
|
||||
}
|
||||
@@ -33,7 +33,6 @@ type LlmRequest struct {
|
||||
successCh chan *runnerRef
|
||||
errCh chan error
|
||||
schedAttempts uint
|
||||
useImagegen bool
|
||||
}
|
||||
|
||||
type Scheduler struct {
|
||||
@@ -106,7 +105,7 @@ func schedulerModelKey(m *Model) string {
|
||||
}
|
||||
|
||||
// context must be canceled to decrement ref count and release the runner
|
||||
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration, useImagegen bool) (chan *runnerRef, chan error) {
|
||||
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
|
||||
if opts.NumCtx < 4 {
|
||||
opts.NumCtx = 4
|
||||
}
|
||||
@@ -123,7 +122,6 @@ func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, ses
|
||||
sessionDuration: sessionDuration,
|
||||
successCh: make(chan *runnerRef, 1),
|
||||
errCh: make(chan error, 1),
|
||||
useImagegen: useImagegen,
|
||||
}
|
||||
|
||||
key := schedulerModelKey(req.model)
|
||||
@@ -231,7 +229,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
}
|
||||
|
||||
// Check for experimental safetensors LLM models
|
||||
if pending.model.Config.ModelFormat == "safetensors" {
|
||||
if pending.model.IsMLX() {
|
||||
if slices.Contains(pending.model.Config.Capabilities, "completion") {
|
||||
// LLM model with safetensors format - use MLX runner
|
||||
if s.loadMLX(pending) {
|
||||
@@ -536,6 +534,7 @@ iGPUScan:
|
||||
}
|
||||
}
|
||||
|
||||
totalSize, vramSize := llama.MemorySize()
|
||||
runner := &runnerRef{
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
@@ -545,8 +544,8 @@ iGPUScan:
|
||||
sessionDuration: sessionDuration,
|
||||
gpus: gpuIDs,
|
||||
discreteGPUs: discreteGPUs,
|
||||
vramSize: llama.VRAMSize(),
|
||||
totalSize: llama.TotalSize(),
|
||||
totalSize: totalSize,
|
||||
vramSize: vramSize,
|
||||
loading: true,
|
||||
pid: llama.Pid(),
|
||||
}
|
||||
@@ -592,20 +591,15 @@ iGPUScan:
|
||||
return false
|
||||
}
|
||||
|
||||
// loadMLX loads an experimental safetensors model using the unified MLX runner.
|
||||
// This supports both LLM (completion) and image generation models.
|
||||
// loadMLX loads an experimental safetensors model using MLX runners.
|
||||
// Image models use x/imagegen; LLM models use x/mlxrunner.
|
||||
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
modelName := req.model.ShortName
|
||||
var server llm.LlamaServer
|
||||
var err error
|
||||
|
||||
isImagegen := false
|
||||
if slices.Contains(req.model.Config.Capabilities, "image") {
|
||||
server, err = imagegen.NewServer(modelName, imagegen.ModeImageGen)
|
||||
isImagegen = true
|
||||
} else if req.useImagegen {
|
||||
server, err = imagegen.NewServer(modelName, imagegen.ModeLLM)
|
||||
isImagegen = true
|
||||
server, err = imagegen.NewServer(modelName)
|
||||
} else {
|
||||
server, err = mlxrunner.NewClient(modelName)
|
||||
}
|
||||
@@ -619,6 +613,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
sessionDuration = req.sessionDuration.Duration
|
||||
}
|
||||
|
||||
totalSize, vramSize := server.MemorySize()
|
||||
runner := &runnerRef{
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
@@ -626,10 +621,10 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
llama: server,
|
||||
Options: &req.opts,
|
||||
loading: false,
|
||||
isImagegen: isImagegen,
|
||||
isImagegen: slices.Contains(req.model.Config.Capabilities, "image"),
|
||||
sessionDuration: sessionDuration,
|
||||
totalSize: server.TotalSize(),
|
||||
vramSize: server.VRAMSize(),
|
||||
totalSize: totalSize,
|
||||
vramSize: vramSize,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
@@ -735,8 +730,8 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
||||
runner.refMu.Lock()
|
||||
defer runner.refMu.Unlock()
|
||||
|
||||
// Check if runner type (imagegen vs mlxrunner) matches what's requested
|
||||
wantImagegen := req.useImagegen || slices.Contains(req.model.Config.Capabilities, "image")
|
||||
// Check if runner type (imagegen vs mlxrunner) matches what's requested.
|
||||
wantImagegen := slices.Contains(req.model.Config.Capabilities, "image")
|
||||
if runner.isImagegen != wantImagegen {
|
||||
return true
|
||||
}
|
||||
@@ -762,7 +757,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
||||
defer cancel()
|
||||
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
|
||||
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
|
||||
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
|
||||
(!runner.model.IsMLX() && !reflect.DeepEqual(optsExisting, optsNew)) || // have the runner options changed?
|
||||
runner.llama.Ping(ctx) != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -408,10 +408,10 @@ func TestSchedGetRunner(t *testing.T) {
|
||||
s.getSystemInfoFn = getSystemInfoFn
|
||||
s.newServerFn = a.newServer
|
||||
slog.Info("a")
|
||||
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration, false)
|
||||
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration)
|
||||
require.Len(t, s.pendingReqCh, 1)
|
||||
slog.Info("b")
|
||||
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration, false)
|
||||
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration)
|
||||
require.Len(t, s.pendingReqCh, 1)
|
||||
require.Empty(t, successCh1b)
|
||||
require.Len(t, errCh1b, 1)
|
||||
@@ -435,7 +435,7 @@ func TestSchedGetRunner(t *testing.T) {
|
||||
|
||||
c.req.model.ModelPath = "bad path"
|
||||
slog.Info("c")
|
||||
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration, false)
|
||||
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration)
|
||||
// Starts in pending channel, then should be quickly processed to return an error
|
||||
time.Sleep(50 * time.Millisecond) // Long enough for the "a" model to expire and unload
|
||||
require.Empty(t, successCh1c)
|
||||
@@ -470,7 +470,7 @@ func TestSchedGetRunnerUsesDigestKeyWhenModelPathEmpty(t *testing.T) {
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
reqModel := &Model{Name: "safetensors-b", Digest: "sha-b"}
|
||||
successCh, errCh := s.GetRunner(ctx, reqModel, opts, nil, false)
|
||||
successCh, errCh := s.GetRunner(ctx, reqModel, opts, nil)
|
||||
|
||||
require.Empty(t, successCh)
|
||||
require.Empty(t, errCh)
|
||||
@@ -499,7 +499,7 @@ func TestSchedGetRunnerReusesSameDigestWhenModelPathEmpty(t *testing.T) {
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
reqCtx, cancelReq := context.WithCancel(ctx)
|
||||
successCh, errCh := s.GetRunner(reqCtx, &Model{Name: "safetensors-a-copy", Digest: "sha-a"}, opts, nil, false)
|
||||
successCh, errCh := s.GetRunner(reqCtx, &Model{Name: "safetensors-a-copy", Digest: "sha-a"}, opts, nil)
|
||||
cancelReq()
|
||||
|
||||
select {
|
||||
@@ -574,7 +574,7 @@ func TestSchedPrematureExpired(t *testing.T) {
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getSystemInfoFn = getSystemInfoFn
|
||||
s.newServerFn = scenario1a.newServer
|
||||
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration, false)
|
||||
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
|
||||
require.Len(t, s.pendingReqCh, 1)
|
||||
s.Run(ctx)
|
||||
select {
|
||||
@@ -861,8 +861,7 @@ func (s *mockLlm) Close() error {
|
||||
s.closeCalled = true
|
||||
return s.closeResp
|
||||
}
|
||||
func (s *mockLlm) VRAMSize() uint64 { return s.vramSize }
|
||||
func (s *mockLlm) TotalSize() uint64 { return s.totalSize }
|
||||
func (s *mockLlm) MemorySize() (uint64, uint64) { return s.totalSize, s.vramSize }
|
||||
func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] }
|
||||
func (s *mockLlm) Pid() int { return -1 }
|
||||
func (s *mockLlm) GetPort() int { return -1 }
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/readline"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@@ -43,7 +44,7 @@ const (
|
||||
// isLocalModel checks if the model is running locally (not a cloud model).
|
||||
// TODO: Improve local/cloud model identification - could check model metadata
|
||||
func isLocalModel(modelName string) bool {
|
||||
return !strings.HasSuffix(modelName, "-cloud")
|
||||
return !modelref.HasExplicitCloudSource(modelName)
|
||||
}
|
||||
|
||||
// isLocalServer checks if connecting to a local Ollama server.
|
||||
|
||||
@@ -22,12 +22,22 @@ func TestIsLocalModel(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "cloud model",
|
||||
modelName: "gpt-4-cloud",
|
||||
modelName: "gpt-oss:latest-cloud",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "cloud model with :cloud suffix",
|
||||
modelName: "gpt-oss:cloud",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "cloud model with version",
|
||||
modelName: "claude-3-cloud",
|
||||
modelName: "gpt-oss:20b-cloud",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "cloud model with version and :cloud suffix",
|
||||
modelName: "gpt-oss:20b:cloud",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
@@ -134,7 +144,7 @@ func TestTruncateToolOutput(t *testing.T) {
|
||||
{
|
||||
name: "long output cloud model - uses 10k limit",
|
||||
output: string(localLimitOutput), // 20k chars, under 10k token limit
|
||||
modelName: "gpt-4-cloud",
|
||||
modelName: "gpt-oss:latest-cloud",
|
||||
host: "",
|
||||
shouldTrim: false,
|
||||
expectedLimit: defaultTokenLimit,
|
||||
@@ -142,7 +152,7 @@ func TestTruncateToolOutput(t *testing.T) {
|
||||
{
|
||||
name: "very long output cloud model - trimmed at 10k",
|
||||
output: string(defaultLimitOutput),
|
||||
modelName: "gpt-4-cloud",
|
||||
modelName: "gpt-oss:latest-cloud",
|
||||
host: "",
|
||||
shouldTrim: true,
|
||||
expectedLimit: defaultTokenLimit,
|
||||
|
||||
@@ -13,9 +13,12 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/create"
|
||||
@@ -27,11 +30,79 @@ const MinOllamaVersion = "0.14.0"
|
||||
|
||||
// ModelfileConfig holds configuration extracted from a Modelfile.
|
||||
type ModelfileConfig struct {
|
||||
Template string
|
||||
System string
|
||||
License string
|
||||
Parser string
|
||||
Renderer string
|
||||
Template string
|
||||
System string
|
||||
License string
|
||||
Parser string
|
||||
Renderer string
|
||||
Parameters map[string]any
|
||||
}
|
||||
|
||||
var ignoredModelfileParameters = []string{
|
||||
"penalize_newline",
|
||||
"low_vram",
|
||||
"f16_kv",
|
||||
"logits_all",
|
||||
"vocab_only",
|
||||
"use_mlock",
|
||||
"mirostat",
|
||||
"mirostat_tau",
|
||||
"mirostat_eta",
|
||||
}
|
||||
|
||||
// ConfigFromModelfile extracts the model directory and x/create-specific
|
||||
// Modelfile configuration from a parsed Modelfile.
|
||||
func ConfigFromModelfile(modelfile *parser.Modelfile) (string, *ModelfileConfig, error) {
|
||||
var modelDir string
|
||||
mfConfig := &ModelfileConfig{}
|
||||
|
||||
for _, cmd := range modelfile.Commands {
|
||||
switch cmd.Name {
|
||||
case "model":
|
||||
modelDir = cmd.Args
|
||||
case "template":
|
||||
mfConfig.Template = cmd.Args
|
||||
case "system":
|
||||
mfConfig.System = cmd.Args
|
||||
case "license":
|
||||
mfConfig.License = cmd.Args
|
||||
case "parser":
|
||||
mfConfig.Parser = cmd.Args
|
||||
case "renderer":
|
||||
mfConfig.Renderer = cmd.Args
|
||||
case "adapter", "message", "requires":
|
||||
continue
|
||||
default:
|
||||
if slices.Contains(ignoredModelfileParameters, cmd.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
ps, err := api.FormatParams(map[string][]string{cmd.Name: {cmd.Args}})
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if mfConfig.Parameters == nil {
|
||||
mfConfig.Parameters = make(map[string]any)
|
||||
}
|
||||
|
||||
for k, v := range ps {
|
||||
if ks, ok := mfConfig.Parameters[k].([]string); ok {
|
||||
mfConfig.Parameters[k] = append(ks, v.([]string)...)
|
||||
} else if vs, ok := v.([]string); ok {
|
||||
mfConfig.Parameters[k] = vs
|
||||
} else {
|
||||
mfConfig.Parameters[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if modelDir == "" {
|
||||
modelDir = "."
|
||||
}
|
||||
|
||||
return modelDir, mfConfig, nil
|
||||
}
|
||||
|
||||
// CreateOptions holds all options for model creation.
|
||||
@@ -39,7 +110,7 @@ type CreateOptions struct {
|
||||
ModelName string
|
||||
ModelDir string
|
||||
Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization
|
||||
Modelfile *ModelfileConfig // template/system/license/parser/renderer from Modelfile
|
||||
Modelfile *ModelfileConfig // template/system/license/parser/renderer/parameters from Modelfile
|
||||
}
|
||||
|
||||
// CreateModel imports a model from a local directory.
|
||||
@@ -351,6 +422,19 @@ func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
if len(mf.Parameters) > 0 {
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(mf.Parameters); err != nil {
|
||||
return nil, fmt.Errorf("failed to encode parameters: %w", err)
|
||||
}
|
||||
|
||||
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create params layer: %w", err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/parser"
|
||||
)
|
||||
|
||||
func TestModelfileConfig(t *testing.T) {
|
||||
@@ -31,6 +37,40 @@ func TestModelfileConfig(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigFromModelfile(t *testing.T) {
|
||||
modelfile, err := parser.ParseFile(strings.NewReader(`
|
||||
FROM ./model
|
||||
TEMPLATE {{ .Prompt }}
|
||||
PARAMETER temperature 0.7
|
||||
PARAMETER stop USER:
|
||||
PARAMETER stop ASSISTANT:
|
||||
`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
modelDir, mfConfig, err := ConfigFromModelfile(modelfile)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if modelDir != "./model" {
|
||||
t.Fatalf("modelDir = %q, want %q", modelDir, "./model")
|
||||
}
|
||||
|
||||
if mfConfig.Template != "{{ .Prompt }}" {
|
||||
t.Fatalf("Template = %q, want %q", mfConfig.Template, "{{ .Prompt }}")
|
||||
}
|
||||
|
||||
if got := mfConfig.Parameters["temperature"]; got != float32(0.7) {
|
||||
t.Fatalf("temperature = %#v, want %v", got, float32(0.7))
|
||||
}
|
||||
|
||||
if got := mfConfig.Parameters["stop"]; got == nil || len(got.([]string)) != 2 {
|
||||
t.Fatalf("unexpected stop params: %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelfileConfig_Empty(t *testing.T) {
|
||||
config := &ModelfileConfig{}
|
||||
|
||||
@@ -120,6 +160,9 @@ func TestCreateOptions(t *testing.T) {
|
||||
License: "MIT",
|
||||
Parser: "qwen3-thinking",
|
||||
Renderer: "qwen3",
|
||||
Parameters: map[string]any{
|
||||
"temperature": float32(0.7),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -144,6 +187,9 @@ func TestCreateOptions(t *testing.T) {
|
||||
if opts.Modelfile.Renderer != "qwen3" {
|
||||
t.Errorf("Modelfile.Renderer = %q, want %q", opts.Modelfile.Renderer, "qwen3")
|
||||
}
|
||||
if opts.Modelfile.Parameters["temperature"] != float32(0.7) {
|
||||
t.Errorf("Modelfile.Parameters[temperature] = %v, want %v", opts.Modelfile.Parameters["temperature"], float32(0.7))
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveParserName(t *testing.T) {
|
||||
@@ -252,3 +298,44 @@ func TestQuantizeSupported(t *testing.T) {
|
||||
// We can't easily test both cases, so just verify it returns something
|
||||
_ = supported
|
||||
}
|
||||
|
||||
func TestCreateModelfileLayersIncludesParameters(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
|
||||
layers, err := createModelfileLayers(&ModelfileConfig{
|
||||
Parameters: map[string]any{
|
||||
"temperature": float32(0.7),
|
||||
"stop": []string{"USER:", "ASSISTANT:"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(layers) != 1 {
|
||||
t.Fatalf("len(layers) = %d, want 1", len(layers))
|
||||
}
|
||||
|
||||
if layers[0].MediaType != "application/vnd.ollama.image.params" {
|
||||
t.Fatalf("MediaType = %q, want %q", layers[0].MediaType, "application/vnd.ollama.image.params")
|
||||
}
|
||||
|
||||
blobPath, err := manifest.BlobsPath(layers[0].Digest)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(blobPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got["temperature"] != float64(0.7) {
|
||||
t.Fatalf("temperature = %v, want %v", got["temperature"], float64(0.7))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
@@ -21,7 +19,7 @@ var quantizeParams = map[string]struct {
|
||||
bits int
|
||||
mode string
|
||||
}{
|
||||
"int4": {32, 4, "affine"},
|
||||
"int4": {64, 4, "affine"},
|
||||
"nvfp4": {16, 4, "nvfp4"},
|
||||
"int8": {64, 8, "affine"},
|
||||
"mxfp8": {32, 8, "mxfp8"},
|
||||
@@ -194,9 +192,10 @@ func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
|
||||
return blobData, nil
|
||||
}
|
||||
|
||||
// QuantizeSupported returns true if quantization is supported (MLX build)
|
||||
// QuantizeSupported returns true if quantization is supported (MLX library available)
|
||||
func QuantizeSupported() bool {
|
||||
return true
|
||||
mlx.InitMLX()
|
||||
return mlx.IsMLXAvailable()
|
||||
}
|
||||
|
||||
// ensureTempDir creates the temp directory for quantization if it doesn't exist
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
//go:build !mlx
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/ollama/ollama/x/create"
|
||||
)
|
||||
|
||||
// quantizeTensor is not available without MLX
|
||||
func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) {
|
||||
return nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
|
||||
}
|
||||
|
||||
// quantizePackedGroup is not available without MLX
|
||||
func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
|
||||
return nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
|
||||
}
|
||||
|
||||
// QuantizeSupported returns false when MLX is not available
|
||||
func QuantizeSupported() bool {
|
||||
return false
|
||||
}
|
||||
@@ -288,6 +288,18 @@ func normalizeQuantType(quantize string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func isStackedExpertWeight(name string) bool {
|
||||
// Combined/stacked expert tensors may be emitted either as "...proj.weight" (per-expert)
|
||||
// or "...proj" (pre-stacked packed tensor).
|
||||
if strings.HasSuffix(name, ".bias") || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".qbias") {
|
||||
return false
|
||||
}
|
||||
|
||||
return strings.Contains(name, ".mlp.switch_mlp.") ||
|
||||
strings.Contains(name, ".mlp.experts.") ||
|
||||
strings.Contains(name, ".mlp.shared_experts.")
|
||||
}
|
||||
|
||||
// GetTensorQuantization returns the appropriate quantization type for a tensor.
|
||||
// Returns "" if the tensor should not be quantized.
|
||||
// This implements mixed-precision quantization:
|
||||
@@ -296,18 +308,25 @@ func normalizeQuantType(quantize string) string {
|
||||
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
|
||||
// - Norms, embeddings, biases, routing gates: no quantization
|
||||
func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||
stackedExpert := isStackedExpertWeight(name)
|
||||
|
||||
// Use basic name-based check first
|
||||
if !ShouldQuantize(name, "") {
|
||||
if !stackedExpert && !ShouldQuantize(name, "") {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
|
||||
if len(shape) != 2 {
|
||||
// Quantize standard linear weights (2D). Also allow stacked expert weights (3D),
|
||||
// e.g. qwen switch_mlp / experts combined tensors.
|
||||
if len(shape) != 2 && !(len(shape) == 3 && stackedExpert) {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Skip small tensors (less than 1024 elements) - not worth quantizing
|
||||
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
|
||||
var elems int64 = 1
|
||||
for _, d := range shape {
|
||||
elems *= int64(d)
|
||||
}
|
||||
if elems < 1024 {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -315,12 +334,12 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||
quantNorm := normalizeQuantType(quantize)
|
||||
|
||||
// MLX quantization requires last dimension to be divisible by group size
|
||||
// nvfp4: 16, int4/mxfp8: 32, int8: 64
|
||||
// nvfp4: 16, mxfp8: 32, int4/int8: 64
|
||||
groupSize := int32(32)
|
||||
switch quantNorm {
|
||||
case "nvfp4":
|
||||
groupSize = 16
|
||||
case "int8":
|
||||
case "int4", "int8":
|
||||
groupSize = 64
|
||||
}
|
||||
if shape[len(shape)-1]%groupSize != 0 {
|
||||
|
||||
@@ -557,6 +557,10 @@ func TestShouldQuantizeTensor(t *testing.T) {
|
||||
// 3D+ tensors should not be quantized
|
||||
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
|
||||
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
|
||||
{"stacked expert switch_mlp gate_up 3D int8", "model.layers.1.mlp.switch_mlp.gate_up_proj.weight", []int32{64, 22016, 4096}, "int8", true},
|
||||
{"stacked expert experts down_proj 3D int8", "model.layers.1.mlp.experts.down_proj.weight", []int32{64, 4096, 14336}, "int8", true},
|
||||
{"stacked expert combined gate_up 3D int8", "model.language_model.layers.0.mlp.experts.gate_up_proj", []int32{256, 1024, 2048}, "int8", true},
|
||||
{"stacked expert combined down_proj 3D int8", "model.language_model.layers.0.mlp.experts.down_proj", []int32{256, 2048, 512}, "int8", true},
|
||||
|
||||
// Embeddings should not be quantized regardless of shape
|
||||
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
|
||||
@@ -619,6 +623,44 @@ func TestExpertGroupPrefix(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTensorQuantization_StackedExpert3D(t *testing.T) {
|
||||
gateUp := GetTensorQuantization(
|
||||
"model.layers.1.mlp.switch_mlp.gate_up_proj.weight",
|
||||
[]int32{64, 22016, 4096},
|
||||
"int4",
|
||||
)
|
||||
if gateUp != "int4" {
|
||||
t.Fatalf("gate_up_proj quantization = %q, want %q", gateUp, "int4")
|
||||
}
|
||||
|
||||
down := GetTensorQuantization(
|
||||
"model.layers.1.mlp.experts.down_proj.weight",
|
||||
[]int32{64, 4096, 14336},
|
||||
"int4",
|
||||
)
|
||||
if down != "int8" {
|
||||
t.Fatalf("down_proj quantization = %q, want %q", down, "int8")
|
||||
}
|
||||
|
||||
combinedGateUp := GetTensorQuantization(
|
||||
"model.language_model.layers.0.mlp.experts.gate_up_proj",
|
||||
[]int32{256, 1024, 2048},
|
||||
"int8",
|
||||
)
|
||||
if combinedGateUp != "int8" {
|
||||
t.Fatalf("combined gate_up_proj quantization = %q, want %q", combinedGateUp, "int8")
|
||||
}
|
||||
|
||||
combinedDown := GetTensorQuantization(
|
||||
"model.language_model.layers.0.mlp.experts.down_proj",
|
||||
[]int32{256, 2048, 512},
|
||||
"int4",
|
||||
)
|
||||
if combinedDown != "int8" {
|
||||
t.Fatalf("combined down_proj quantization = %q, want %q", combinedDown, "int8")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
|
||||
2
x/imagegen/cache/cache.go
vendored
2
x/imagegen/cache/cache.go
vendored
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package cache
|
||||
|
||||
import "github.com/ollama/ollama/x/imagegen/mlx"
|
||||
|
||||
2
x/imagegen/cache/step.go
vendored
2
x/imagegen/cache/step.go
vendored
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package cache
|
||||
|
||||
import "github.com/ollama/ollama/x/imagegen/mlx"
|
||||
|
||||
2
x/imagegen/cache/teacache.go
vendored
2
x/imagegen/cache/teacache.go
vendored
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package cache provides caching mechanisms for diffusion model inference.
|
||||
package cache
|
||||
|
||||
|
||||
@@ -5,22 +5,12 @@ Experimental MLX backend for running models on Apple Silicon and CUDA.
|
||||
## Build
|
||||
|
||||
```bash
|
||||
go build -tags mlx -o engine ./x/imagegen/cmd/engine
|
||||
go build -o engine ./x/imagegen/cmd/engine
|
||||
```
|
||||
|
||||
## Text Generation
|
||||
|
||||
```bash
|
||||
./engine -model /path/to/model -prompt "Hello" -max-tokens 100
|
||||
```
|
||||
|
||||
Options:
|
||||
|
||||
- `-temperature` - sampling temperature (default 0.7)
|
||||
- `-top-p` - nucleus sampling (default 0.9)
|
||||
- `-top-k` - top-k sampling (default 40)
|
||||
|
||||
Supports: Llama, Gemma3, GPT-OSS
|
||||
Text generation models are no longer supported by this engine.
|
||||
|
||||
## Image Generation
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
@@ -18,9 +16,6 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/flux2"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||
"github.com/ollama/ollama/x/imagegen/models/llama"
|
||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
@@ -170,11 +165,11 @@ func main() {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Load image if provided and model supports it
|
||||
// Load image if provided and model supports it.
|
||||
var image *mlx.Array
|
||||
if *imagePath != "" {
|
||||
if mm, ok := m.(interface{ ImageSize() int32 }); ok {
|
||||
image, err = gemma3.ProcessImage(*imagePath, mm.ImageSize())
|
||||
image, err = imagegen.ProcessImage(*imagePath, mm.ImageSize())
|
||||
if err != nil {
|
||||
log.Fatal("load image:", err)
|
||||
}
|
||||
@@ -236,14 +231,8 @@ func load(modelPath string) (Model, error) {
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case "gpt_oss":
|
||||
return gpt_oss.Load(modelPath)
|
||||
case "gemma3":
|
||||
return gemma3.Load(modelPath)
|
||||
case "gemma3_text":
|
||||
return gemma3.LoadText(modelPath)
|
||||
default:
|
||||
return llama.Load(modelPath)
|
||||
return nil, fmt.Errorf("model type %q is not supported by x/imagegen/cmd/engine", kind)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package main
|
||||
|
||||
import "github.com/ollama/ollama/x/imagegen/mlx"
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
//go:build mlx
|
||||
|
||||
package gemma3
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -13,8 +11,8 @@ import (
|
||||
"golang.org/x/image/draw"
|
||||
)
|
||||
|
||||
// ProcessImage loads and preprocesses an image for the vision tower
|
||||
// Returns [1, H, W, C] tensor in NHWC format normalized for SigLIP
|
||||
// ProcessImage loads and preprocesses an image for multimodal vision towers.
|
||||
// Returns [1, H, W, C] tensor in NHWC format normalized for SigLIP.
|
||||
func ProcessImage(path string, imageSize int32) (*mlx.Array, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
@@ -30,20 +28,20 @@ func ProcessImage(path string, imageSize int32) (*mlx.Array, error) {
|
||||
return ProcessImageData(img, imageSize)
|
||||
}
|
||||
|
||||
// ProcessImageData preprocesses an image.Image for the vision tower
|
||||
// ProcessImageData preprocesses an image.Image for multimodal vision towers.
|
||||
func ProcessImageData(img image.Image, imageSize int32) (*mlx.Array, error) {
|
||||
// Resize to target size using bilinear interpolation
|
||||
// Resize to target size using bilinear interpolation.
|
||||
resized := image.NewRGBA(image.Rect(0, 0, int(imageSize), int(imageSize)))
|
||||
draw.BiLinear.Scale(resized, resized.Bounds(), img, img.Bounds(), draw.Over, nil)
|
||||
|
||||
// Convert to float32 array [H, W, C] and normalize
|
||||
// SigLIP normalization: (pixel / 255.0 - 0.5) / 0.5 = pixel / 127.5 - 1.0
|
||||
// Convert to float32 array [H, W, C] and normalize.
|
||||
// SigLIP normalization: (pixel / 255.0 - 0.5) / 0.5 = pixel / 127.5 - 1.0.
|
||||
data := make([]float32, imageSize*imageSize*3)
|
||||
idx := 0
|
||||
for y := int32(0); y < imageSize; y++ {
|
||||
for x := int32(0); x < imageSize; x++ {
|
||||
r, g, b, _ := resized.At(int(x), int(y)).RGBA()
|
||||
// RGBA returns 16-bit values, convert to 8-bit
|
||||
// RGBA returns 16-bit values, convert to 8-bit.
|
||||
data[idx] = float32(r>>8)/127.5 - 1.0
|
||||
data[idx+1] = float32(g>>8)/127.5 - 1.0
|
||||
data[idx+2] = float32(b>>8)/127.5 - 1.0
|
||||
@@ -51,8 +49,8 @@ func ProcessImageData(img image.Image, imageSize int32) (*mlx.Array, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Create MLX array [1, H, W, C] for NHWC layout
|
||||
// Create MLX array [1, H, W, C] for NHWC layout.
|
||||
arr := mlx.NewArrayFloat32(data, []int32{1, imageSize, imageSize, 3})
|
||||
mlx.Eval(arr) // Materialize to prevent use-after-free
|
||||
mlx.Eval(arr) // Materialize to prevent use-after-free.
|
||||
return arr, nil
|
||||
}
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,420 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/glm4_moe_lite"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// TextModel is the interface for LLM text generation models.
|
||||
type TextModel interface {
|
||||
Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array
|
||||
NewCache(maxSeqLen int32) []cache.Cache
|
||||
Tokenizer() *tokenizer.Tokenizer
|
||||
VocabSize() int32
|
||||
MaxContextLength() int32
|
||||
NumLayers() int
|
||||
}
|
||||
|
||||
// llmState holds the state for LLM generation
|
||||
type llmState struct {
|
||||
model TextModel
|
||||
}
|
||||
|
||||
var llmMu sync.Mutex
|
||||
|
||||
// Dedicated stream for generation (like mlx-lm's generation_stream)
|
||||
var generationStream *mlx.Stream
|
||||
|
||||
// withStream runs fn with the generation stream as default
|
||||
func withStream(fn func()) {
|
||||
// Lazy initialization of generationStream
|
||||
if generationStream == nil {
|
||||
generationStream = mlx.NewStream()
|
||||
}
|
||||
orig := mlx.GetDefaultStream()
|
||||
mlx.SetDefaultStream(generationStream)
|
||||
fn()
|
||||
mlx.SetDefaultStream(orig)
|
||||
}
|
||||
|
||||
// Decoder wraps model + cache for autoregressive generation.
|
||||
// This matches the pattern from cmd/engine/generate.go
|
||||
type Decoder struct {
|
||||
model TextModel
|
||||
caches []cache.Cache
|
||||
vocabSize int32
|
||||
temp float32
|
||||
token *mlx.Array // Current token (kept across iterations)
|
||||
oldCacheState []*mlx.Array // Preallocated slice for old cache state
|
||||
}
|
||||
|
||||
func NewDecoder(m TextModel, temp float32) *Decoder {
|
||||
caches := m.NewCache(0)
|
||||
return &Decoder{
|
||||
model: m,
|
||||
caches: caches,
|
||||
vocabSize: m.VocabSize(),
|
||||
temp: temp,
|
||||
oldCacheState: make([]*mlx.Array, 0, len(caches)*2),
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Decoder) prefill(inputIDs []int32) int {
|
||||
processed := 0
|
||||
|
||||
// Track old cache state to free after each chunk
|
||||
var oldCacheState []*mlx.Array
|
||||
|
||||
// Process all-but-1 tokens in chunks, eval cache state for memory management
|
||||
for len(inputIDs) > 1 {
|
||||
chunkSize := min(2048, len(inputIDs)-1)
|
||||
if chunkSize <= 0 {
|
||||
break
|
||||
}
|
||||
chunk := inputIDs[:chunkSize]
|
||||
|
||||
// Save old cache state before forward
|
||||
oldCacheState = oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
oldCacheState = append(oldCacheState, c.State()...)
|
||||
}
|
||||
|
||||
var cacheState []*mlx.Array
|
||||
withStream(func() {
|
||||
x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))})
|
||||
d.model.Forward(x, d.caches)
|
||||
for _, c := range d.caches {
|
||||
cacheState = append(cacheState, c.State()...)
|
||||
}
|
||||
})
|
||||
mlx.Eval(cacheState...)
|
||||
|
||||
// Free old cache state
|
||||
for _, arr := range oldCacheState {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
}
|
||||
}
|
||||
|
||||
inputIDs = inputIDs[chunkSize:]
|
||||
processed += chunkSize
|
||||
}
|
||||
|
||||
// Save old cache state before final step
|
||||
oldCacheState = oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
oldCacheState = append(oldCacheState, c.State()...)
|
||||
}
|
||||
|
||||
// Final token + sampling
|
||||
withStream(func() {
|
||||
x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))})
|
||||
mlx.Eval(x) // Materialize before any other evals
|
||||
logits := d.model.Forward(x, d.caches)
|
||||
d.token = sample(logits, d.temp, d.vocabSize)
|
||||
})
|
||||
// Keep cache state (token auto-kept by AsyncEval)
|
||||
for _, c := range d.caches {
|
||||
mlx.Keep(c.State()...)
|
||||
}
|
||||
mlx.AsyncEval(d.token)
|
||||
|
||||
// Free old cache state from before final step
|
||||
for _, arr := range oldCacheState {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
}
|
||||
}
|
||||
|
||||
mlx.ClearCache()
|
||||
|
||||
return processed + len(inputIDs)
|
||||
}
|
||||
|
||||
func (d *Decoder) step() int32 {
|
||||
prevToken := d.token
|
||||
|
||||
// Save old cache state (reuse preallocated slice)
|
||||
d.oldCacheState = d.oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
d.oldCacheState = append(d.oldCacheState, c.State()...)
|
||||
}
|
||||
|
||||
withStream(func() {
|
||||
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
|
||||
d.token = sample(logits, d.temp, d.vocabSize)
|
||||
})
|
||||
// Keep token and new cache state so they survive cleanup
|
||||
mlx.Keep(d.token)
|
||||
for _, c := range d.caches {
|
||||
mlx.Keep(c.State()...)
|
||||
}
|
||||
mlx.AsyncEval(d.token)
|
||||
|
||||
// Sync on previous token (GPU already working on next step)
|
||||
val := prevToken.ItemInt32()
|
||||
|
||||
// Free old token and old cache state
|
||||
prevToken.Free()
|
||||
for _, arr := range d.oldCacheState {
|
||||
arr.Free()
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// sample samples from logits using temperature scaling
|
||||
func sample(logits *mlx.Array, temp float32, vocabSize int32) *mlx.Array {
|
||||
// Get last position logits: [1, L, vocab] -> [vocab]
|
||||
shape := logits.Shape()
|
||||
seqLen := shape[1]
|
||||
lastLogits := mlx.Slice(logits, []int32{0, seqLen - 1, 0}, []int32{1, seqLen, vocabSize})
|
||||
lastLogits = mlx.Reshape(lastLogits, vocabSize)
|
||||
|
||||
if temp <= 0 || temp < 0.01 {
|
||||
// Greedy decoding
|
||||
return mlx.Argmax(lastLogits, -1, false)
|
||||
}
|
||||
|
||||
// Apply temperature scaling
|
||||
scaled := mlx.DivScalar(lastLogits, temp)
|
||||
return mlx.RandomCategorical(scaled, -1, 1)
|
||||
}
|
||||
|
||||
// loadLLMModel loads a safetensors LLM model and its tokenizer from manifest storage.
|
||||
func (s *server) loadLLMModel() error {
|
||||
// Load the manifest to get model information
|
||||
modelManifest, err := manifest.LoadManifest(s.modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Detect model architecture from config.json
|
||||
configData, err := modelManifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read config.json: %w", err)
|
||||
}
|
||||
|
||||
var modelConfig struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(configData, &modelConfig); err != nil {
|
||||
return fmt.Errorf("failed to parse config.json: %w", err)
|
||||
}
|
||||
|
||||
arch := ""
|
||||
if len(modelConfig.Architectures) > 0 {
|
||||
arch = modelConfig.Architectures[0]
|
||||
}
|
||||
if arch == "" {
|
||||
arch = modelConfig.ModelType
|
||||
}
|
||||
|
||||
slog.Info("detected LLM architecture", "architecture", arch, "model_type", modelConfig.ModelType)
|
||||
|
||||
// Load the appropriate model based on architecture
|
||||
var model TextModel
|
||||
archLower := strings.ToLower(arch)
|
||||
|
||||
switch {
|
||||
case strings.Contains(archLower, "glm4moelite"):
|
||||
m, err := glm4_moe_lite.LoadFromManifest(modelManifest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load glm4-moe-lite model: %w", err)
|
||||
}
|
||||
model = m
|
||||
slog.Info("loaded glm4-moe-lite model", "vocab_size", m.VocabSize(), "layers", m.NumLayers())
|
||||
|
||||
default:
|
||||
return fmt.Errorf("LLM architecture %q is not yet supported. "+
|
||||
"Supported architectures: glm4-moe-lite. "+
|
||||
"Please convert your model to GGUF format or use a supported architecture", arch)
|
||||
}
|
||||
|
||||
s.llmModel = &llmState{
|
||||
model: model,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleLLMCompletion handles LLM text generation requests.
|
||||
func (s *server) handleLLMCompletion(w http.ResponseWriter, r *http.Request, req Request) {
|
||||
if s.llmModel == nil {
|
||||
http.Error(w, "LLM model not loaded", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Serialize generation requests
|
||||
llmMu.Lock()
|
||||
defer llmMu.Unlock()
|
||||
|
||||
if err := s.llmGenerate(w, r, req); err != nil {
|
||||
slog.Error("LLM generation failed", "error", err)
|
||||
// Don't send error if we've already started streaming
|
||||
}
|
||||
}
|
||||
|
||||
// llmGenerate runs the generation loop using the Decoder pattern from cmd/engine
|
||||
func (s *server) llmGenerate(w http.ResponseWriter, r *http.Request, req Request) error {
|
||||
state := s.llmModel
|
||||
|
||||
// Set up streaming response
|
||||
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
return errors.New("streaming not supported")
|
||||
}
|
||||
|
||||
tok := state.model.Tokenizer()
|
||||
|
||||
// The prompt is already formatted by the server using the model's renderer
|
||||
// (see server/prompt.go renderPrompt), so we don't apply FormatPrompt here.
|
||||
prompt := req.Prompt
|
||||
|
||||
// Tokenize the prompt
|
||||
inputIDs := tok.Encode(prompt, true)
|
||||
slog.Debug("tokenized prompt", "num_tokens", len(inputIDs))
|
||||
|
||||
// Generation parameters
|
||||
maxTokens := int(state.model.MaxContextLength())
|
||||
if maxTokens <= 0 {
|
||||
maxTokens = 4096
|
||||
}
|
||||
if req.Options != nil && req.Options.NumPredict > 0 {
|
||||
maxTokens = req.Options.NumPredict
|
||||
}
|
||||
|
||||
temperature := float32(0.7)
|
||||
if req.Options != nil && req.Options.Temperature > 0 {
|
||||
temperature = float32(req.Options.Temperature)
|
||||
}
|
||||
|
||||
// Enable MLX compilation for better performance
|
||||
mlx.EnableCompile()
|
||||
|
||||
// Create decoder with fresh caches
|
||||
dec := NewDecoder(state.model, temperature)
|
||||
|
||||
prefillStart := time.Now()
|
||||
prefillTokens := dec.prefill(inputIDs)
|
||||
// Prefill measurement includes time to first token
|
||||
firstToken := dec.step()
|
||||
prefillDuration := time.Since(prefillStart)
|
||||
promptEvalDuration := prefillDuration
|
||||
|
||||
enc := json.NewEncoder(w)
|
||||
ctx := r.Context()
|
||||
generated := 0
|
||||
stopReason := "max_tokens"
|
||||
|
||||
// Handle first token
|
||||
generated++
|
||||
if tok.IsEOS(firstToken) {
|
||||
resp := Response{
|
||||
Done: true,
|
||||
StopReason: fmt.Sprintf("first_token_eos:%d", firstToken),
|
||||
PromptEvalCount: prefillTokens,
|
||||
PromptEvalDuration: int(promptEvalDuration.Nanoseconds()),
|
||||
}
|
||||
enc.Encode(resp)
|
||||
flusher.Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
text := tok.Decode([]int32{firstToken})
|
||||
resp := Response{Content: text}
|
||||
enc.Encode(resp)
|
||||
flusher.Flush()
|
||||
|
||||
genStart := time.Now()
|
||||
|
||||
// Generation loop
|
||||
for n := 1; n < maxTokens; n++ {
|
||||
// Check for cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
stopReason = fmt.Sprintf("context_cancelled:%d", generated)
|
||||
break
|
||||
default:
|
||||
}
|
||||
if stopReason != "max_tokens" {
|
||||
break
|
||||
}
|
||||
|
||||
token := dec.step()
|
||||
generated++
|
||||
|
||||
if tok.IsEOS(token) {
|
||||
stopReason = fmt.Sprintf("eos_token:%d", token)
|
||||
break
|
||||
}
|
||||
|
||||
text := tok.Decode([]int32{token})
|
||||
|
||||
// Check for stop sequences
|
||||
if req.Options != nil && len(req.Options.Stop) > 0 {
|
||||
shouldStop := false
|
||||
var matchedStop string
|
||||
for _, stop := range req.Options.Stop {
|
||||
if strings.Contains(text, stop) {
|
||||
text = strings.Split(text, stop)[0]
|
||||
shouldStop = true
|
||||
matchedStop = stop
|
||||
break
|
||||
}
|
||||
}
|
||||
if shouldStop {
|
||||
if text != "" {
|
||||
resp := Response{Content: text}
|
||||
enc.Encode(resp)
|
||||
flusher.Flush()
|
||||
}
|
||||
stopReason = fmt.Sprintf("stop_sequence:%s", matchedStop)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
resp := Response{Content: text}
|
||||
enc.Encode(resp)
|
||||
flusher.Flush()
|
||||
|
||||
// Periodically clear MLX cache
|
||||
if n%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up
|
||||
mlx.ClearCache()
|
||||
|
||||
// Send final response with stats
|
||||
evalDuration := time.Since(genStart)
|
||||
resp = Response{
|
||||
Done: true,
|
||||
StopReason: fmt.Sprintf("%s:generated=%d", stopReason, generated),
|
||||
PromptEvalCount: prefillTokens,
|
||||
PromptEvalDuration: int(promptEvalDuration.Nanoseconds()),
|
||||
EvalCount: generated,
|
||||
EvalDuration: int(evalDuration.Nanoseconds()),
|
||||
}
|
||||
enc.Encode(resp)
|
||||
flusher.Flush()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package manifest
|
||||
|
||||
import (
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user