Compare commits

..

1 Commits

Author SHA1 Message Date
Patrick Devine
857cffd22a bugfix: fix crash bug in token cache logic
This change fixes a problem in the token cache logic to avoid panics caused by empty token arrays
by ensuring at least one token remains on full cache hits in the relevant function. The happens
if there is an exact match in the cache on subsequent generations.
2026-02-26 18:35:44 -08:00
492 changed files with 13668 additions and 62902 deletions

View File

@@ -27,7 +27,7 @@ jobs:
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
darwin-build: darwin-build:
runs-on: macos-26-xlarge runs-on: macos-14-xlarge
environment: release environment: release
needs: setup-environment needs: setup-environment
env: env:
@@ -117,25 +117,6 @@ jobs:
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
flags: '' flags: ''
runner_dir: 'vulkan' runner_dir: 'vulkan'
- os: windows
arch: amd64
preset: 'MLX CUDA 13'
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip
cuda-components:
- '"cudart"'
- '"nvcc"'
- '"cublas"'
- '"cublas_dev"'
- '"cufft"'
- '"cufft_dev"'
- '"nvrtc"'
- '"nvrtc_dev"'
- '"crt"'
- '"nvvm"'
- '"nvptxcompiler"'
cuda-version: '13.0'
flags: ''
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }} runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release environment: release
env: env:
@@ -144,10 +125,8 @@ jobs:
- name: Install system dependencies - name: Install system dependencies
run: | run: |
choco install -y --no-progress ccache ninja choco install -y --no-progress ccache ninja
if (Get-Command ccache -ErrorAction SilentlyContinue) { ccache -o cache_dir=${{ github.workspace }}\.ccache
ccache -o cache_dir=${{ github.workspace }}\.ccache - if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan')
}
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan') || startsWith(matrix.preset, 'MLX ')
id: cache-install id: cache-install
uses: actions/cache/restore@v4 uses: actions/cache/restore@v4
with: with:
@@ -155,9 +134,8 @@ jobs:
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm C:\Program Files\AMD\ROCm
C:\VulkanSDK C:\VulkanSDK
C:\Program Files\NVIDIA\CUDNN key: ${{ matrix.install }}
key: ${{ matrix.install }}-${{ matrix.cudnn-install }} - if: startsWith(matrix.preset, 'CUDA ')
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'MLX ')
name: Install CUDA ${{ matrix.cuda-version }} name: Install CUDA ${{ matrix.cuda-version }}
run: | run: |
$ErrorActionPreference = "Stop" $ErrorActionPreference = "Stop"
@@ -201,23 +179,6 @@ jobs:
run: | run: |
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: startsWith(matrix.preset, 'MLX ')
name: Install cuDNN for MLX
run: |
$ErrorActionPreference = "Stop"
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip"
Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted
$cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName
New-Item -ItemType Directory -Force -Path $cudnnRoot
Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse
}
echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }} - if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
uses: actions/cache/save@v4 uses: actions/cache/save@v4
with: with:
@@ -225,8 +186,7 @@ jobs:
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm C:\Program Files\AMD\ROCm
C:\VulkanSDK C:\VulkanSDK
C:\Program Files\NVIDIA\CUDNN key: ${{ matrix.install }}
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/cache@v4 - uses: actions/cache@v4
with: with:
@@ -238,7 +198,7 @@ jobs:
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo' Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}" cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}"
cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}" cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}"
cmake --install build --component "${{ startsWith(matrix.preset, 'MLX ') && 'MLX' || startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
env: env:
CMAKE_GENERATOR: Ninja CMAKE_GENERATOR: Ninja
@@ -424,7 +384,6 @@ jobs:
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/include*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;; lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;; lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;; lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
@@ -584,19 +543,11 @@ jobs:
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg dist/*.ps1 dist/*.sh ; do for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg dist/*.ps1 dist/*.sh ; do
echo "Uploading $payload" echo "Uploading $payload"
gh release upload ${GITHUB_REF_NAME} $payload --clobber & gh release upload ${GITHUB_REF_NAME} $payload --clobber &
pids+=($!) pids[$!]=$!
sleep 1 sleep 1
done done
echo "Waiting for uploads to complete" echo "Waiting for uploads to complete"
failed=0 for pid in "${pids[*]}"; do
for pid in "${pids[@]}"; do wait $pid
if ! wait $pid; then
echo "::error::Upload failed (pid $pid)"
failed=1
fi
done done
if [ $failed -ne 0 ]; then
echo "One or more uploads failed"
exit 1
fi
echo "done" echo "done"

View File

@@ -37,7 +37,7 @@ jobs:
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))" | xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
} }
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*' '.github/**/*') | tee -a $GITHUB_OUTPUT echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
linux: linux:
@@ -51,7 +51,7 @@ jobs:
container: nvidia/cuda:13.0.0-devel-ubuntu22.04 container: nvidia/cuda:13.0.0-devel-ubuntu22.04
flags: '-DCMAKE_CUDA_ARCHITECTURES=87' flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
- preset: ROCm - preset: ROCm
container: rocm/dev-ubuntu-22.04:7.2.1 container: rocm/dev-ubuntu-22.04:6.1.2
extra-packages: rocm-libs extra-packages: rocm-libs
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm' flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm'
- preset: Vulkan - preset: Vulkan
@@ -60,11 +60,6 @@ jobs:
mesa-vulkan-drivers vulkan-tools mesa-vulkan-drivers vulkan-tools
libvulkan1 libvulkan-dev libvulkan1 libvulkan-dev
vulkan-sdk cmake ccache g++ make vulkan-sdk cmake ccache g++ make
- preset: 'MLX CUDA 13'
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
extra-packages: libcudnn9-dev-cuda-13 libopenblas-dev liblapack-dev liblapacke-dev git curl
flags: '-DCMAKE_CUDA_ARCHITECTURES=87 -DBLAS_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu -DLAPACK_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu'
install-go: true
runs-on: linux runs-on: linux
container: ${{ matrix.container }} container: ${{ matrix.container }}
steps: steps:
@@ -81,29 +76,19 @@ jobs:
$sudo apt-get update $sudo apt-get update
fi fi
$sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }} $sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }}
# MLX requires CMake 3.25+, install from official releases
if [ "${{ matrix.preset }}" = "MLX CUDA 13" ]; then
curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.31.2/cmake-3.31.2-linux-$(uname -m).tar.gz | $sudo tar xz -C /usr/local --strip-components 1
fi
# Export VULKAN_SDK if provided by LunarG package (defensive) # Export VULKAN_SDK if provided by LunarG package (defensive)
if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then
echo "VULKAN_SDK=/usr" >> $GITHUB_ENV echo "VULKAN_SDK=/usr" >> $GITHUB_ENV
fi fi
env: env:
DEBIAN_FRONTEND: noninteractive DEBIAN_FRONTEND: noninteractive
- if: matrix.install-go
name: Install Go
run: |
GO_VERSION=$(awk '/^go / { print $2 }' go.mod)
curl -fsSL "https://golang.org/dl/go${GO_VERSION}.linux-$(dpkg --print-architecture).tar.gz" | tar xz -C /usr/local
echo "/usr/local/go/bin" >> $GITHUB_PATH
- uses: actions/cache@v4 - uses: actions/cache@v4
with: with:
path: /github/home/.cache/ccache path: /github/home/.cache/ccache
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }} key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
- run: | - run: |
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} cmake --preset ${{ matrix.preset }} ${{ matrix.flags }}
cmake --build --preset "${{ matrix.preset }}" --parallel cmake --build --preset ${{ matrix.preset }} --parallel
windows: windows:
needs: [changes] needs: [changes]
@@ -129,31 +114,12 @@ jobs:
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
- preset: Vulkan - preset: Vulkan
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
- preset: 'MLX CUDA 13'
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
cuda-components:
- '"cudart"'
- '"nvcc"'
- '"cublas"'
- '"cublas_dev"'
- '"cufft"'
- '"cufft_dev"'
- '"nvrtc"'
- '"nvrtc_dev"'
- '"crt"'
- '"nvvm"'
- '"nvptxcompiler"'
cuda-version: '13.0'
runs-on: windows runs-on: windows
steps: steps:
- run: | - run: |
choco install -y --no-progress ccache ninja choco install -y --no-progress ccache ninja
if (Get-Command ccache -ErrorAction SilentlyContinue) { ccache -o cache_dir=${{ github.workspace }}\.ccache
ccache -o cache_dir=${{ github.workspace }}\.ccache - if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan'
}
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan' || matrix.preset == 'MLX CUDA 13'
id: cache-install id: cache-install
uses: actions/cache/restore@v4 uses: actions/cache/restore@v4
with: with:
@@ -161,9 +127,8 @@ jobs:
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm C:\Program Files\AMD\ROCm
C:\VulkanSDK C:\VulkanSDK
C:\Program Files\NVIDIA\CUDNN key: ${{ matrix.install }}
key: ${{ matrix.install }}-${{ matrix.cudnn-install }} - if: matrix.preset == 'CUDA'
- if: matrix.preset == 'CUDA' || matrix.preset == 'MLX CUDA 13'
name: Install CUDA ${{ matrix.cuda-version }} name: Install CUDA ${{ matrix.cuda-version }}
run: | run: |
$ErrorActionPreference = "Stop" $ErrorActionPreference = "Stop"
@@ -203,23 +168,6 @@ jobs:
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path $vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
- if: matrix.preset == 'MLX CUDA 13'
name: Install cuDNN for MLX
run: |
$ErrorActionPreference = "Stop"
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip"
Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted
$cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName
New-Item -ItemType Directory -Force -Path $cudnnRoot
Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse
}
echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }} - if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
uses: actions/cache/save@v4 uses: actions/cache/save@v4
with: with:
@@ -227,8 +175,7 @@ jobs:
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm C:\Program Files\AMD\ROCm
C:\VulkanSDK C:\VulkanSDK
C:\Program Files\NVIDIA\CUDNN key: ${{ matrix.install }}
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/cache@v4 - uses: actions/cache@v4
with: with:

1
.gitignore vendored
View File

@@ -15,4 +15,3 @@ __debug_bin*
llama/build llama/build
llama/vendor llama/vendor
/ollama /ollama
integration/testdata/models/

View File

@@ -64,15 +64,10 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR})
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR}) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${OLLAMA_BUILD_DIR}) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${OLLAMA_BUILD_DIR})
# Store ggml include paths for use with target_include_directories later. include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
# We avoid global include_directories() to prevent polluting the include path include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include)
# for other projects like MLX (whose openblas dependency has its own common.h). include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
set(GGML_INCLUDE_DIRS include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx
)
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0) add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
@@ -92,14 +87,6 @@ if(NOT CPU_VARIANTS)
set(CPU_VARIANTS "ggml-cpu") set(CPU_VARIANTS "ggml-cpu")
endif() endif()
# Apply ggml include directories to ggml targets only (not globally)
target_include_directories(ggml-base PRIVATE ${GGML_INCLUDE_DIRS})
foreach(variant ${CPU_VARIANTS})
if(TARGET ${variant})
target_include_directories(${variant} PRIVATE ${GGML_INCLUDE_DIRS})
endif()
endforeach()
install(TARGETS ggml-base ${CPU_VARIANTS} install(TARGETS ggml-base ${CPU_VARIANTS}
RUNTIME_DEPENDENCIES RUNTIME_DEPENDENCIES
PRE_EXCLUDE_REGEXES ".*" PRE_EXCLUDE_REGEXES ".*"
@@ -116,7 +103,6 @@ if(CMAKE_CUDA_COMPILER)
find_package(CUDAToolkit) find_package(CUDAToolkit)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
target_include_directories(ggml-cuda PRIVATE ${GGML_INCLUDE_DIRS})
install(TARGETS ggml-cuda install(TARGETS ggml-cuda
RUNTIME_DEPENDENCIES RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR} DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
@@ -148,7 +134,6 @@ if(CMAKE_HIP_COMPILER)
if(AMDGPU_TARGETS) if(AMDGPU_TARGETS)
find_package(hip REQUIRED) find_package(hip REQUIRED)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
target_include_directories(ggml-hip PRIVATE ${GGML_INCLUDE_DIRS})
if (WIN32) if (WIN32)
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY) target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
@@ -163,7 +148,7 @@ if(CMAKE_HIP_COMPILER)
) )
install(RUNTIME_DEPENDENCY_SET rocm install(RUNTIME_DEPENDENCY_SET rocm
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR} DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register roctx64 rocroller drm drm_amdgpu numa elf PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf
PRE_EXCLUDE_REGEXES ".*" PRE_EXCLUDE_REGEXES ".*"
POST_EXCLUDE_REGEXES "system32" POST_EXCLUDE_REGEXES "system32"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
@@ -183,7 +168,6 @@ if(NOT APPLE)
find_package(Vulkan) find_package(Vulkan)
if(Vulkan_FOUND) if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
target_include_directories(ggml-vulkan PRIVATE ${GGML_INCLUDE_DIRS})
install(TARGETS ggml-vulkan install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES RUNTIME_DEPENDENCIES
PRE_INCLUDE_REGEXES vulkan PRE_INCLUDE_REGEXES vulkan
@@ -195,6 +179,7 @@ if(NOT APPLE)
endif() endif()
option(MLX_ENGINE "Enable MLX backend" OFF) option(MLX_ENGINE "Enable MLX backend" OFF)
if(MLX_ENGINE) if(MLX_ENGINE)
message(STATUS "Setting up MLX (this takes a while...)") message(STATUS "Setting up MLX (this takes a while...)")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx)
@@ -202,36 +187,10 @@ if(MLX_ENGINE)
# Find CUDA toolkit if MLX is built with CUDA support # Find CUDA toolkit if MLX is built with CUDA support
find_package(CUDAToolkit) find_package(CUDAToolkit)
# Build list of directories for runtime dependency resolution
set(MLX_RUNTIME_DIRS ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR})
# Add cuDNN bin paths for DLLs (Windows MLX CUDA builds)
# CUDNN_ROOT_DIR is the standard CMake variable for cuDNN location
if(DEFINED ENV{CUDNN_ROOT_DIR})
# cuDNN 9.x has versioned subdirectories under bin/ (e.g., bin/13.0/)
file(GLOB CUDNN_BIN_SUBDIRS "$ENV{CUDNN_ROOT_DIR}/bin/*")
list(APPEND MLX_RUNTIME_DIRS ${CUDNN_BIN_SUBDIRS})
endif()
# Add build output directory and MLX dependency build directories
list(APPEND MLX_RUNTIME_DIRS ${OLLAMA_BUILD_DIR})
# OpenBLAS DLL location (pre-built zip extracts into openblas-src/bin/)
list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/openblas-src/bin)
# NCCL: on Linux, if real NCCL is found, cmake bundles libnccl.so via the
# regex below. If NCCL is not found, MLX links a static stub (OBJECT lib)
# so there is no runtime dependency. This path covers the stub build dir
# for windows so we include the DLL in our dependencies.
list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/distributed/nccl/nccl_stub-prefix/src/nccl_stub-build/Release)
# Base regexes for runtime dependencies (cross-platform)
set(MLX_INCLUDE_REGEXES cublas cublasLt cudart cufft nvrtc nvrtc-builtins cudnn nccl openblas gfortran)
# On Windows, also include dl.dll (dlfcn-win32 POSIX emulation layer)
if(WIN32)
list(APPEND MLX_INCLUDE_REGEXES "^dl\\.dll$")
endif()
install(TARGETS mlx mlxc install(TARGETS mlx mlxc
RUNTIME_DEPENDENCIES RUNTIME_DEPENDENCIES
DIRECTORIES ${MLX_RUNTIME_DIRS} DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES ${MLX_INCLUDE_REGEXES} PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
PRE_EXCLUDE_REGEXES ".*" PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
@@ -246,117 +205,13 @@ if(MLX_ENGINE)
COMPONENT MLX) COMPONENT MLX)
endif() endif()
# Install headers for NVRTC JIT compilation at runtime. # Manually install cudart and cublas since they might not be picked up as direct dependencies
# MLX's own install rules use the default component so they get skipped by
# --component MLX. Headers are installed alongside libmlx in OLLAMA_INSTALL_DIR.
#
# Layout:
# ${OLLAMA_INSTALL_DIR}/include/cccl/{cuda,nv}/ — CCCL headers
# ${OLLAMA_INSTALL_DIR}/include/*.h — CUDA toolkit headers
#
# MLX's jit_module.cpp resolves CCCL via
# current_binary_dir()[.parent_path()] / "include" / "cccl"
# On Linux, MLX's jit_module.cpp resolves CCCL via
# current_binary_dir().parent_path() / "include" / "cccl", so we create a
# symlink from lib/ollama/include -> ${OLLAMA_RUNNER_DIR}/include
# This will need refinement if we add multiple CUDA versions for MLX in the future.
# CUDA runtime headers are found via CUDA_PATH env var (set by mlxrunner).
if(EXISTS ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda)
install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda
DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl
COMPONENT MLX)
install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/nv
DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl
COMPONENT MLX)
if(NOT WIN32 AND NOT APPLE)
install(CODE "
set(_link \"${CMAKE_INSTALL_PREFIX}/lib/ollama/include\")
set(_target \"${OLLAMA_RUNNER_DIR}/include\")
if(NOT EXISTS \${_link})
execute_process(COMMAND \${CMAKE_COMMAND} -E create_symlink \${_target} \${_link})
endif()
" COMPONENT MLX)
endif()
endif()
# Install minimal CUDA toolkit headers needed by MLX JIT kernels.
# These are the transitive closure of includes from mlx/backend/cuda/device/*.cuh.
# The Go mlxrunner sets CUDA_PATH to OLLAMA_INSTALL_DIR so MLX finds them at
# $CUDA_PATH/include/*.h via NVRTC --include-path.
if(CUDAToolkit_FOUND) if(CUDAToolkit_FOUND)
# CUDAToolkit_INCLUDE_DIRS may be a semicolon-separated list file(GLOB CUDART_LIBS
# (e.g. ".../include;.../include/cccl"). Find the entry that
# contains the CUDA runtime headers we need.
set(_cuda_inc "")
foreach(_dir ${CUDAToolkit_INCLUDE_DIRS})
if(EXISTS "${_dir}/cuda_runtime_api.h")
set(_cuda_inc "${_dir}")
break()
endif()
endforeach()
if(NOT _cuda_inc)
message(WARNING "Could not find cuda_runtime_api.h in CUDAToolkit_INCLUDE_DIRS: ${CUDAToolkit_INCLUDE_DIRS}")
else()
set(_dst "${OLLAMA_INSTALL_DIR}/include")
set(_MLX_JIT_CUDA_HEADERS
builtin_types.h
cooperative_groups.h
cuda_bf16.h
cuda_bf16.hpp
cuda_device_runtime_api.h
cuda_fp16.h
cuda_fp16.hpp
cuda_fp8.h
cuda_fp8.hpp
cuda_runtime_api.h
device_types.h
driver_types.h
math_constants.h
surface_types.h
texture_types.h
vector_functions.h
vector_functions.hpp
vector_types.h
)
foreach(_hdr ${_MLX_JIT_CUDA_HEADERS})
install(FILES "${_cuda_inc}/${_hdr}"
DESTINATION ${_dst}
COMPONENT MLX)
endforeach()
# Subdirectory headers
install(DIRECTORY "${_cuda_inc}/cooperative_groups"
DESTINATION ${_dst}
COMPONENT MLX
FILES_MATCHING PATTERN "*.h")
install(FILES "${_cuda_inc}/crt/host_defines.h"
DESTINATION "${_dst}/crt"
COMPONENT MLX)
endif()
endif()
# On Windows, explicitly install dl.dll (dlfcn-win32 POSIX dlopen emulation)
# RUNTIME_DEPENDENCIES auto-excludes it via POST_EXCLUDE_FILES_STRICT because
# dlfcn-win32 is a known CMake target with its own install rules (which install
# to the wrong destination). We must install it explicitly here.
if(WIN32)
install(FILES ${OLLAMA_BUILD_DIR}/dl.dll
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()
# Manually install CUDA runtime libraries that MLX loads via dlopen
# (not detected by RUNTIME_DEPENDENCIES since they aren't link-time deps)
if(CUDAToolkit_FOUND)
file(GLOB MLX_CUDA_LIBS
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*" "${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*" "${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
"${CUDAToolkit_LIBRARY_DIR}/libcublasLt.so*" if(CUDART_LIBS)
"${CUDAToolkit_LIBRARY_DIR}/libnvrtc.so*" install(FILES ${CUDART_LIBS}
"${CUDAToolkit_LIBRARY_DIR}/libnvrtc-builtins.so*"
"${CUDAToolkit_LIBRARY_DIR}/libcufft.so*"
"${CUDAToolkit_LIBRARY_DIR}/libcudnn.so*")
if(MLX_CUDA_LIBS)
install(FILES ${MLX_CUDA_LIBS}
DESTINATION ${OLLAMA_INSTALL_DIR} DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX) COMPONENT MLX)
endif() endif()

View File

@@ -77,15 +77,6 @@
"OLLAMA_RUNNER_DIR": "rocm" "OLLAMA_RUNNER_DIR": "rocm"
} }
}, },
{
"name": "ROCm 7",
"inherits": [ "ROCm" ],
"cacheVariables": {
"CMAKE_HIP_FLAGS": "-parallel-jobs=4",
"AMDGPU_TARGETS": "gfx942;gfx950;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1151;gfx1200;gfx1201;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-",
"OLLAMA_RUNNER_DIR": "rocm"
}
},
{ {
"name": "Vulkan", "name": "Vulkan",
"inherits": [ "Default" ], "inherits": [ "Default" ],
@@ -112,7 +103,6 @@
"name": "MLX CUDA 13", "name": "MLX CUDA 13",
"inherits": [ "MLX", "CUDA 13" ], "inherits": [ "MLX", "CUDA 13" ],
"cacheVariables": { "cacheVariables": {
"MLX_CUDA_ARCHITECTURES": "86;89;90;90a;100;103;75-virtual;80-virtual;110-virtual;120-virtual;121-virtual",
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13" "OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
} }
} }
@@ -168,11 +158,6 @@
"inherits": [ "ROCm" ], "inherits": [ "ROCm" ],
"configurePreset": "ROCm 6" "configurePreset": "ROCm 6"
}, },
{
"name": "ROCm 7",
"inherits": [ "ROCm" ],
"configurePreset": "ROCm 7"
},
{ {
"name": "Vulkan", "name": "Vulkan",
"targets": [ "ggml-vulkan" ], "targets": [ "ggml-vulkan" ],

View File

@@ -1,23 +1,28 @@
# vim: filetype=dockerfile # vim: filetype=dockerfile
ARG FLAVOR=${TARGETARCH} ARG FLAVOR=${TARGETARCH}
ARG PARALLEL=8
ARG ROCMVERSION=7.2.1 ARG ROCMVERSION=6.3.3
ARG JETPACK5VERSION=r35.4.1 ARG JETPACK5VERSION=r35.4.1
ARG JETPACK6VERSION=r36.4.0 ARG JETPACK6VERSION=r36.4.0
ARG CMAKEVERSION=3.31.2 ARG CMAKEVERSION=3.31.2
ARG NINJAVERSION=1.12.1
ARG VULKANVERSION=1.4.321.1 ARG VULKANVERSION=1.4.321.1
# Default empty stages for local MLX source overrides.
# Override with: docker build --build-context local-mlx=../mlx --build-context local-mlx-c=../mlx-c
FROM scratch AS local-mlx
FROM scratch AS local-mlx-c
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64 FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
RUN dnf install -y yum-utils ccache gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ gcc-toolset-11-binutils \ RUN dnf install -y yum-utils ccache gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ gcc-toolset-11-binutils \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
ARG VULKANVERSION
RUN wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
&& tar xvf /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
&& dnf -y install ninja-build \
&& ln -s /usr/bin/python3 /usr/bin/python \
&& /${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \
&& /${VULKANVERSION}/vulkansdk -j 8 shaderc
RUN cp -r /${VULKANVERSION}/x86_64/include/* /usr/local/include/ \
&& cp -r /${VULKANVERSION}/x86_64/lib/* /usr/local/lib
ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH
FROM --platform=linux/arm64 almalinux:8 AS base-arm64 FROM --platform=linux/arm64 almalinux:8 AS base-arm64
# install epel-release for ccache # install epel-release for ccache
@@ -28,119 +33,100 @@ ENV CC=clang CXX=clang++
FROM base-${TARGETARCH} AS base FROM base-${TARGETARCH} AS base
ARG CMAKEVERSION ARG CMAKEVERSION
ARG NINJAVERSION
RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
RUN dnf install -y unzip \
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux$([ "$(uname -m)" = "aarch64" ] && echo "-aarch64").zip \
&& unzip /tmp/ninja.zip -d /usr/local/bin \
&& rm /tmp/ninja.zip
ENV CMAKE_GENERATOR=Ninja
ENV LDFLAGS=-s ENV LDFLAGS=-s
FROM base AS cpu FROM base AS cpu
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json . COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CPU' \ cmake --preset 'CPU' \
&& cmake --build --preset 'CPU' -- -l $(nproc) \ && cmake --build --parallel ${PARALLEL} --preset 'CPU' \
&& cmake --install build --component CPU --strip && cmake --install build --component CPU --strip --parallel ${PARALLEL}
FROM base AS cuda-11 FROM base AS cuda-11
ARG CUDA11VERSION=11.8 ARG CUDA11VERSION=11.8
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-} RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
ENV PATH=/usr/local/cuda-11/bin:$PATH ENV PATH=/usr/local/cuda-11/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json . COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 11' \ cmake --preset 'CUDA 11' \
&& cmake --build --preset 'CUDA 11' -- -l $(nproc) \ && cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
&& cmake --install build --component CUDA --strip && cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM base AS cuda-12 FROM base AS cuda-12
ARG CUDA12VERSION=12.8 ARG CUDA12VERSION=12.8
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-} RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
ENV PATH=/usr/local/cuda-12/bin:$PATH ENV PATH=/usr/local/cuda-12/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json . COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 12' \ cmake --preset 'CUDA 12' \
&& cmake --build --preset 'CUDA 12' -- -l $(nproc) \ && cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
&& cmake --install build --component CUDA --strip && cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM base AS cuda-13 FROM base AS cuda-13
ARG CUDA13VERSION=13.0 ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
ENV PATH=/usr/local/cuda-13/bin:$PATH ENV PATH=/usr/local/cuda-13/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json . COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 13' \ cmake --preset 'CUDA 13' \
&& cmake --build --preset 'CUDA 13' -- -l $(nproc) \ && cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
&& cmake --install build --component CUDA --strip && cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM base AS rocm-7 FROM base AS rocm-6
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json . COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'ROCm 7' \ cmake --preset 'ROCm 6' \
&& cmake --build --preset 'ROCm 7' -- -l $(nproc) \ && cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
&& cmake --install build --component HIP --strip && cmake --install build --component HIP --strip --parallel ${PARALLEL}
RUN rm -f dist/lib/ollama/rocm/rocblas/library/*gfx90[06]* RUN rm -f dist/lib/ollama/rocm/rocblas/library/*gfx90[06]*
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5 FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
ARG CMAKEVERSION ARG CMAKEVERSION
ARG NINJAVERSION RUN apt-get update && apt-get install -y curl ccache \
RUN apt-get update && apt-get install -y curl ccache unzip \ && curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 \
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux-aarch64.zip \
&& unzip /tmp/ninja.zip -d /usr/local/bin \
&& rm /tmp/ninja.zip
ENV CMAKE_GENERATOR=Ninja
COPY CMakeLists.txt CMakePresets.json . COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'JetPack 5' \ cmake --preset 'JetPack 5' \
&& cmake --build --preset 'JetPack 5' -- -l $(nproc) \ && cmake --build --parallel ${PARALLEL} --preset 'JetPack 5' \
&& cmake --install build --component CUDA --strip && cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6 FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
ARG CMAKEVERSION ARG CMAKEVERSION
ARG NINJAVERSION RUN apt-get update && apt-get install -y curl ccache \
RUN apt-get update && apt-get install -y curl ccache unzip \ && curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 \
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux-aarch64.zip \
&& unzip /tmp/ninja.zip -d /usr/local/bin \
&& rm /tmp/ninja.zip
ENV CMAKE_GENERATOR=Ninja
COPY CMakeLists.txt CMakePresets.json . COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'JetPack 6' \ cmake --preset 'JetPack 6' \
&& cmake --build --preset 'JetPack 6' -- -l $(nproc) \ && cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \
&& cmake --install build --component CUDA --strip && cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM base AS vulkan FROM base AS vulkan
ARG VULKANVERSION
RUN ln -s /usr/bin/python3 /usr/bin/python \
&& wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk.tar.xz \
&& tar xvf /tmp/vulkansdk.tar.xz -C /tmp \
&& /tmp/${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \
&& /tmp/${VULKANVERSION}/vulkansdk -j 8 shaderc \
&& cp -r /tmp/${VULKANVERSION}/x86_64/include/* /usr/local/include/ \
&& cp -r /tmp/${VULKANVERSION}/x86_64/lib/* /usr/local/lib \
&& cp -r /tmp/${VULKANVERSION}/x86_64/bin/* /usr/local/bin/ \
&& rm -rf /tmp/${VULKANVERSION} /tmp/vulkansdk.tar.xz
COPY CMakeLists.txt CMakePresets.json . COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'Vulkan' \ cmake --preset 'Vulkan' \
&& cmake --build --preset 'Vulkan' -- -l $(nproc) \ && cmake --build --parallel --preset 'Vulkan' \
&& cmake --install build --component Vulkan --strip && cmake --install build --component Vulkan --strip --parallel 8
FROM base AS mlx FROM base AS mlx
ARG CUDA13VERSION=13.0 ARG CUDA13VERSION=13.0
@@ -152,27 +138,20 @@ ENV PATH=/usr/local/cuda-13/bin:$PATH
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs" ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
ARG PARALLEL
WORKDIR /go/src/github.com/ollama/ollama WORKDIR /go/src/github.com/ollama/ollama
COPY CMakeLists.txt CMakePresets.json . COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/imagegen/mlx x/imagegen/mlx COPY x/imagegen/mlx x/imagegen/mlx
COPY go.mod go.sum . COPY go.mod go.sum .
COPY MLX_VERSION MLX_C_VERSION . COPY MLX_VERSION .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download RUN go mod download
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
--mount=type=bind,from=local-mlx,target=/tmp/local-mlx \ cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
--mount=type=bind,from=local-mlx-c,target=/tmp/local-mlx-c \ && cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
if [ -f /tmp/local-mlx/CMakeLists.txt ]; then \ && cmake --install build --component MLX --strip --parallel ${PARALLEL}
export OLLAMA_MLX_SOURCE=/tmp/local-mlx; \
fi \
&& if [ -f /tmp/local-mlx-c/CMakeLists.txt ]; then \
export OLLAMA_MLX_C_SOURCE=/tmp/local-mlx-c; \
fi \
&& cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
&& cmake --build --preset 'MLX CUDA 13' -- -l $(nproc) \
&& cmake --install build --component MLX --strip
FROM base AS build FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama WORKDIR /go/src/github.com/ollama/ollama
@@ -181,14 +160,16 @@ RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-
ENV PATH=/usr/local/go/bin:$PATH ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download RUN go mod download
COPY . . COPY . .
# Clone mlx-c headers for CGO (version from MLX_VERSION file)
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
ARG GOFLAGS="'-ldflags=-w -s'" ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1 ENV CGO_ENABLED=1
ARG CGO_CFLAGS ARG CGO_CFLAGS
ARG CGO_CXXFLAGS ARG CGO_CXXFLAGS
ENV CGO_CFLAGS="${CGO_CFLAGS}" ENV CGO_CFLAGS="${CGO_CFLAGS} -I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}" ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}"
RUN --mount=type=cache,target=/root/.cache/go-build \ RUN --mount=type=cache,target=/root/.cache/go-build \
go build -trimpath -buildmode=pie -o /bin/ollama . go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .
FROM --platform=linux/amd64 scratch AS amd64 FROM --platform=linux/amd64 scratch AS amd64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/ # COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
@@ -205,9 +186,10 @@ COPY --from=jetpack-5 dist/lib/ollama/ /lib/ollama/
COPY --from=jetpack-6 dist/lib/ollama/ /lib/ollama/ COPY --from=jetpack-6 dist/lib/ollama/ /lib/ollama/
FROM scratch AS rocm FROM scratch AS rocm
COPY --from=rocm-7 dist/lib/ollama /lib/ollama COPY --from=rocm-6 dist/lib/ollama /lib/ollama
FROM ${FLAVOR} AS archive FROM ${FLAVOR} AS archive
ARG VULKANVERSION
COPY --from=cpu dist/lib/ollama /lib/ollama COPY --from=cpu dist/lib/ollama /lib/ollama
COPY --from=build /bin/ollama /bin/ollama COPY --from=build /bin/ollama /bin/ollama

View File

@@ -1 +0,0 @@
0726ca922fc902c4c61ef9c27d94132be418e945

View File

@@ -1 +1 @@
38ad257088fb2193ad47e527cf6534a689f30943 v0.5.0

View File

@@ -55,7 +55,7 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
ollama ollama
``` ```
You'll be prompted to run a model or connect Ollama to your existing agents or applications such as `Claude Code`, `OpenClaw`, `OpenCode` , `Codex`, `Copilot`, and more. You'll be prompted to run a model or connect Ollama to your existing agents or applications such as `claude`, `codex`, `openclaw` and more.
### Coding ### Coding
@@ -65,7 +65,7 @@ To launch a specific integration:
ollama launch claude ollama launch claude
``` ```
Supported integrations include [Claude Code](https://docs.ollama.com/integrations/claude-code), [Codex](https://docs.ollama.com/integrations/codex), [Copilot CLI](https://docs.ollama.com/integrations/copilot-cli), [Droid](https://docs.ollama.com/integrations/droid), and [OpenCode](https://docs.ollama.com/integrations/opencode). Supported integrations include [Claude Code](https://docs.ollama.com/integrations/claude-code), [Codex](https://docs.ollama.com/integrations/codex), [Droid](https://docs.ollama.com/integrations/droid), and [OpenCode](https://docs.ollama.com/integrations/opencode).
### AI assistant ### AI assistant

View File

@@ -68,7 +68,7 @@ type MessagesRequest struct {
Model string `json:"model"` Model string `json:"model"`
MaxTokens int `json:"max_tokens"` MaxTokens int `json:"max_tokens"`
Messages []MessageParam `json:"messages"` Messages []MessageParam `json:"messages"`
System any `json:"system,omitempty"` // string or []map[string]any (JSON-decoded ContentBlock) System any `json:"system,omitempty"` // string or []ContentBlock
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"` TopP *float64 `json:"top_p,omitempty"`
@@ -82,27 +82,8 @@ type MessagesRequest struct {
// MessageParam represents a message in the request // MessageParam represents a message in the request
type MessageParam struct { type MessageParam struct {
Role string `json:"role"` // "user" or "assistant" Role string `json:"role"` // "user" or "assistant"
Content []ContentBlock `json:"content"` // always []ContentBlock; plain strings are normalized on unmarshal Content any `json:"content"` // string or []ContentBlock
}
func (m *MessageParam) UnmarshalJSON(data []byte) error {
var raw struct {
Role string `json:"role"`
Content json.RawMessage `json:"content"`
}
if err := json.Unmarshal(data, &raw); err != nil {
return err
}
m.Role = raw.Role
var s string
if err := json.Unmarshal(raw.Content, &s); err == nil {
m.Content = []ContentBlock{{Type: "text", Text: &s}}
return nil
}
return json.Unmarshal(raw.Content, &m.Content)
} }
// ContentBlock represents a content block in a message. // ContentBlock represents a content block in a message.
@@ -121,9 +102,9 @@ type ContentBlock struct {
Source *ImageSource `json:"source,omitempty"` Source *ImageSource `json:"source,omitempty"`
// For tool_use and server_tool_use blocks // For tool_use and server_tool_use blocks
ID string `json:"id,omitempty"` ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
Input api.ToolCallFunctionArguments `json:"input,omitzero"` Input any `json:"input,omitempty"`
// For tool_result and web_search_tool_result blocks // For tool_result and web_search_tool_result blocks
ToolUseID string `json:"tool_use_id,omitempty"` ToolUseID string `json:"tool_use_id,omitempty"`
@@ -396,144 +377,177 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
var messages []api.Message var messages []api.Message
role := strings.ToLower(msg.Role) role := strings.ToLower(msg.Role)
var textContent strings.Builder switch content := msg.Content.(type) {
var images []api.ImageData case string:
var toolCalls []api.ToolCall messages = append(messages, api.Message{Role: role, Content: content})
var thinking string
var toolResults []api.Message
textBlocks := 0
imageBlocks := 0
toolUseBlocks := 0
toolResultBlocks := 0
serverToolUseBlocks := 0
webSearchToolResultBlocks := 0
thinkingBlocks := 0
unknownBlocks := 0
for _, block := range msg.Content { case []any:
switch block.Type { var textContent strings.Builder
case "text": var images []api.ImageData
textBlocks++ var toolCalls []api.ToolCall
if block.Text != nil { var thinking string
textContent.WriteString(*block.Text) var toolResults []api.Message
textBlocks := 0
imageBlocks := 0
toolUseBlocks := 0
toolResultBlocks := 0
serverToolUseBlocks := 0
webSearchToolResultBlocks := 0
thinkingBlocks := 0
unknownBlocks := 0
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
logutil.Trace("anthropic: invalid content block format", "role", role)
return nil, errors.New("invalid content block format")
} }
case "image": blockType, _ := blockMap["type"].(string)
imageBlocks++
if block.Source == nil {
logutil.Trace("anthropic: invalid image source", "role", role)
return nil, errors.New("invalid image source")
}
if block.Source.Type == "base64" { switch blockType {
decoded, err := base64.StdEncoding.DecodeString(block.Source.Data) case "text":
if err != nil { textBlocks++
logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err) if text, ok := blockMap["text"].(string); ok {
return nil, fmt.Errorf("invalid base64 image data: %w", err) textContent.WriteString(text)
} }
images = append(images, decoded)
} else {
logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", block.Source.Type)
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", block.Source.Type)
}
case "tool_use": case "image":
toolUseBlocks++ imageBlocks++
if block.ID == "" { source, ok := blockMap["source"].(map[string]any)
logutil.Trace("anthropic: tool_use block missing id", "role", role) if !ok {
return nil, errors.New("tool_use block missing required 'id' field") logutil.Trace("anthropic: invalid image source", "role", role)
} return nil, errors.New("invalid image source")
if block.Name == "" { }
logutil.Trace("anthropic: tool_use block missing name", "role", role)
return nil, errors.New("tool_use block missing required 'name' field")
}
toolCalls = append(toolCalls, api.ToolCall{
ID: block.ID,
Function: api.ToolCallFunction{
Name: block.Name,
Arguments: block.Input,
},
})
case "tool_result": sourceType, _ := source["type"].(string)
toolResultBlocks++ if sourceType == "base64" {
var resultContent string data, _ := source["data"].(string)
decoded, err := base64.StdEncoding.DecodeString(data)
if err != nil {
logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err)
return nil, fmt.Errorf("invalid base64 image data: %w", err)
}
images = append(images, decoded)
} else {
logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", sourceType)
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
}
// URL images would need to be fetched - skip for now
switch c := block.Content.(type) { case "tool_use":
case string: toolUseBlocks++
resultContent = c id, ok := blockMap["id"].(string)
case []any: if !ok {
for _, cb := range c { logutil.Trace("anthropic: tool_use block missing id", "role", role)
if cbMap, ok := cb.(map[string]any); ok { return nil, errors.New("tool_use block missing required 'id' field")
if cbMap["type"] == "text" { }
if text, ok := cbMap["text"].(string); ok { name, ok := blockMap["name"].(string)
resultContent += text if !ok {
logutil.Trace("anthropic: tool_use block missing name", "role", role)
return nil, errors.New("tool_use block missing required 'name' field")
}
tc := api.ToolCall{
ID: id,
Function: api.ToolCallFunction{
Name: name,
},
}
if input, ok := blockMap["input"].(map[string]any); ok {
tc.Function.Arguments = mapToArgs(input)
}
toolCalls = append(toolCalls, tc)
case "tool_result":
toolResultBlocks++
toolUseID, _ := blockMap["tool_use_id"].(string)
var resultContent string
switch c := blockMap["content"].(type) {
case string:
resultContent = c
case []any:
for _, cb := range c {
if cbMap, ok := cb.(map[string]any); ok {
if cbMap["type"] == "text" {
if text, ok := cbMap["text"].(string); ok {
resultContent += text
}
} }
} }
} }
} }
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: resultContent,
ToolCallID: toolUseID,
})
case "thinking":
thinkingBlocks++
if t, ok := blockMap["thinking"].(string); ok {
thinking = t
}
case "server_tool_use":
serverToolUseBlocks++
id, _ := blockMap["id"].(string)
name, _ := blockMap["name"].(string)
tc := api.ToolCall{
ID: id,
Function: api.ToolCallFunction{
Name: name,
},
}
if input, ok := blockMap["input"].(map[string]any); ok {
tc.Function.Arguments = mapToArgs(input)
}
toolCalls = append(toolCalls, tc)
case "web_search_tool_result":
webSearchToolResultBlocks++
toolUseID, _ := blockMap["tool_use_id"].(string)
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: formatWebSearchToolResultContent(blockMap["content"]),
ToolCallID: toolUseID,
})
default:
unknownBlocks++
} }
}
toolResults = append(toolResults, api.Message{ if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
Role: "tool", m := api.Message{
Content: resultContent, Role: role,
ToolCallID: block.ToolUseID, Content: textContent.String(),
}) Images: images,
ToolCalls: toolCalls,
case "thinking": Thinking: thinking,
thinkingBlocks++
if block.Thinking != nil {
thinking = *block.Thinking
} }
messages = append(messages, m)
case "server_tool_use":
serverToolUseBlocks++
toolCalls = append(toolCalls, api.ToolCall{
ID: block.ID,
Function: api.ToolCallFunction{
Name: block.Name,
Arguments: block.Input,
},
})
case "web_search_tool_result":
webSearchToolResultBlocks++
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: formatWebSearchToolResultContent(block.Content),
ToolCallID: block.ToolUseID,
})
default:
unknownBlocks++
} }
}
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" { // Add tool results as separate messages
m := api.Message{ messages = append(messages, toolResults...)
Role: role, logutil.Trace("anthropic: converted block message",
Content: textContent.String(), "role", role,
Images: images, "blocks", len(content),
ToolCalls: toolCalls, "text", textBlocks,
Thinking: thinking, "image", imageBlocks,
} "tool_use", toolUseBlocks,
messages = append(messages, m) "tool_result", toolResultBlocks,
} "server_tool_use", serverToolUseBlocks,
"web_search_result", webSearchToolResultBlocks,
"thinking", thinkingBlocks,
"unknown", unknownBlocks,
"messages", TraceAPIMessages(messages),
)
// Add tool results as separate messages default:
messages = append(messages, toolResults...) return nil, fmt.Errorf("invalid message content type: %T", content)
logutil.Trace("anthropic: converted block message", }
"role", role,
"blocks", len(msg.Content),
"text", textBlocks,
"image", imageBlocks,
"tool_use", toolUseBlocks,
"tool_result", toolResultBlocks,
"server_tool_use", serverToolUseBlocks,
"web_search_result", webSearchToolResultBlocks,
"thinking", thinkingBlocks,
"unknown", unknownBlocks,
"messages", TraceAPIMessages(messages),
)
return messages, nil return messages, nil
} }
@@ -838,19 +852,6 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
continue continue
} }
// Close thinking block if still open (thinking → tool_use without text in between)
if c.thinkingStarted && !c.thinkingDone {
c.thinkingDone = true
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
}
if c.textStarted { if c.textStarted {
events = append(events, StreamEvent{ events = append(events, StreamEvent{
Event: "content_block_stop", Event: "content_block_stop",
@@ -868,6 +869,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID) slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID)
continue continue
} }
events = append(events, StreamEvent{ events = append(events, StreamEvent{
Event: "content_block_start", Event: "content_block_start",
Data: ContentBlockStartEvent{ Data: ContentBlockStartEvent{
@@ -877,7 +879,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
Type: "tool_use", Type: "tool_use",
ID: tc.ID, ID: tc.ID,
Name: tc.Function.Name, Name: tc.Function.Name,
Input: api.NewToolCallFunctionArguments(), Input: map[string]any{},
}, },
}, },
}) })
@@ -974,6 +976,15 @@ func ptr(s string) *string {
return &s return &s
} }
// mapToArgs converts a map to ToolCallFunctionArguments
func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
// CountTokensRequest represents an Anthropic count_tokens request // CountTokensRequest represents an Anthropic count_tokens request
type CountTokensRequest struct { type CountTokensRequest struct {
Model string `json:"model"` Model string `json:"model"`
@@ -1006,13 +1017,17 @@ func estimateTokens(req CountTokensRequest) int {
var totalLen int var totalLen int
// Count system prompt // Count system prompt
totalLen += countAnyContent(req.System) if req.System != nil {
totalLen += countAnyContent(req.System)
}
// Count messages
for _, msg := range req.Messages { for _, msg := range req.Messages {
// Count role (always present) // Count role (always present)
totalLen += len(msg.Role) totalLen += len(msg.Role)
// Count content // Count content
totalLen += countAnyContent(msg.Content) contentLen := countAnyContent(msg.Content)
totalLen += contentLen
} }
for _, tool := range req.Tools { for _, tool := range req.Tools {
@@ -1035,25 +1050,12 @@ func countAnyContent(content any) int {
switch c := content.(type) { switch c := content.(type) {
case string: case string:
return len(c) return len(c)
case []ContentBlock: case []any:
total := 0 total := 0
for _, block := range c { for _, block := range c {
total += countContentBlock(block) total += countContentBlock(block)
} }
return total return total
case []any:
total := 0
for _, item := range c {
data, err := json.Marshal(item)
if err != nil {
continue
}
var block ContentBlock
if err := json.Unmarshal(data, &block); err == nil {
total += countContentBlock(block)
}
}
return total
default: default:
if data, err := json.Marshal(content); err == nil { if data, err := json.Marshal(content); err == nil {
return len(data) return len(data)
@@ -1062,19 +1064,38 @@ func countAnyContent(content any) int {
} }
} }
func countContentBlock(block ContentBlock) int { func countContentBlock(block any) int {
blockMap, ok := block.(map[string]any)
if !ok {
if s, ok := block.(string); ok {
return len(s)
}
return 0
}
total := 0 total := 0
if block.Text != nil { blockType, _ := blockMap["type"].(string)
total += len(*block.Text)
if text, ok := blockMap["text"].(string); ok {
total += len(text)
} }
if block.Thinking != nil {
total += len(*block.Thinking) if thinking, ok := blockMap["thinking"].(string); ok {
total += len(thinking)
} }
if block.Type == "tool_use" || block.Type == "tool_result" {
if data, err := json.Marshal(block); err == nil { if blockType == "tool_use" {
if data, err := json.Marshal(blockMap); err == nil {
total += len(data) total += len(data)
} }
} }
if blockType == "tool_result" {
if data, err := json.Marshal(blockMap); err == nil {
total += len(data)
}
}
return total return total
} }

View File

@@ -15,16 +15,11 @@ const (
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
) )
// textContent is a convenience for constructing []ContentBlock with a single text block in tests. // testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
func textContent(s string) []ContentBlock { func testArgs(m map[string]any) api.ToolCallFunctionArguments {
return []ContentBlock{{Type: "text", Text: &s}}
}
// makeArgs creates ToolCallFunctionArguments from key-value pairs (convenience function for tests)
func makeArgs(kvs ...any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments() args := api.NewToolCallFunctionArguments()
for i := 0; i < len(kvs)-1; i += 2 { for k, v := range m {
args.Set(kvs[i].(string), kvs[i+1]) args.Set(k, v)
} }
return args return args
} }
@@ -34,7 +29,7 @@ func TestFromMessagesRequest_Basic(t *testing.T) {
Model: "test-model", Model: "test-model",
MaxTokens: 1024, MaxTokens: 1024,
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: textContent("Hello")}, {Role: "user", Content: "Hello"},
}, },
} }
@@ -66,7 +61,7 @@ func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
MaxTokens: 1024, MaxTokens: 1024,
System: "You are a helpful assistant.", System: "You are a helpful assistant.",
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: textContent("Hello")}, {Role: "user", Content: "Hello"},
}, },
} }
@@ -93,7 +88,7 @@ func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
map[string]any{"type": "text", "text": " Be concise."}, map[string]any{"type": "text", "text": " Be concise."},
}, },
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: textContent("Hello")}, {Role: "user", Content: "Hello"},
}, },
} }
@@ -118,7 +113,7 @@ func TestFromMessagesRequest_WithOptions(t *testing.T) {
req := MessagesRequest{ req := MessagesRequest{
Model: "test-model", Model: "test-model",
MaxTokens: 2048, MaxTokens: 2048,
Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}}, Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Temperature: &temp, Temperature: &temp,
TopP: &topP, TopP: &topP,
TopK: &topK, TopK: &topK,
@@ -153,14 +148,14 @@ func TestFromMessagesRequest_WithImage(t *testing.T) {
Messages: []MessageParam{ Messages: []MessageParam{
{ {
Role: "user", Role: "user",
Content: []ContentBlock{ Content: []any{
{Type: "text", Text: ptr("What's in this image?")}, map[string]any{"type": "text", "text": "What's in this image?"},
{ map[string]any{
Type: "image", "type": "image",
Source: &ImageSource{ "source": map[string]any{
Type: "base64", "type": "base64",
MediaType: "image/png", "media_type": "image/png",
Data: testImage, "data": testImage,
}, },
}, },
}, },
@@ -195,15 +190,15 @@ func TestFromMessagesRequest_WithToolUse(t *testing.T) {
Model: "test-model", Model: "test-model",
MaxTokens: 1024, MaxTokens: 1024,
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: textContent("What's the weather in Paris?")}, {Role: "user", Content: "What's the weather in Paris?"},
{ {
Role: "assistant", Role: "assistant",
Content: []ContentBlock{ Content: []any{
{ map[string]any{
Type: "tool_use", "type": "tool_use",
ID: "call_123", "id": "call_123",
Name: "get_weather", "name": "get_weather",
Input: makeArgs("location", "Paris"), "input": map[string]any{"location": "Paris"},
}, },
}, },
}, },
@@ -239,11 +234,11 @@ func TestFromMessagesRequest_WithToolResult(t *testing.T) {
Messages: []MessageParam{ Messages: []MessageParam{
{ {
Role: "user", Role: "user",
Content: []ContentBlock{ Content: []any{
{ map[string]any{
Type: "tool_result", "type": "tool_result",
ToolUseID: "call_123", "tool_use_id": "call_123",
Content: "The weather in Paris is sunny, 22°C", "content": "The weather in Paris is sunny, 22°C",
}, },
}, },
}, },
@@ -275,7 +270,7 @@ func TestFromMessagesRequest_WithTools(t *testing.T) {
req := MessagesRequest{ req := MessagesRequest{
Model: "test-model", Model: "test-model",
MaxTokens: 1024, MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}}, Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{ Tools: []Tool{
{ {
Name: "get_weather", Name: "get_weather",
@@ -310,7 +305,7 @@ func TestFromMessagesRequest_DropsCustomWebSearchWhenBuiltinPresent(t *testing.T
req := MessagesRequest{ req := MessagesRequest{
Model: "test-model", Model: "test-model",
MaxTokens: 1024, MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}}, Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{ Tools: []Tool{
{ {
Type: "web_search_20250305", Type: "web_search_20250305",
@@ -351,7 +346,7 @@ func TestFromMessagesRequest_KeepsCustomWebSearchWhenBuiltinAbsent(t *testing.T)
req := MessagesRequest{ req := MessagesRequest{
Model: "test-model", Model: "test-model",
MaxTokens: 1024, MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}}, Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{ Tools: []Tool{
{ {
Type: "custom", Type: "custom",
@@ -382,7 +377,7 @@ func TestFromMessagesRequest_WithThinking(t *testing.T) {
req := MessagesRequest{ req := MessagesRequest{
Model: "test-model", Model: "test-model",
MaxTokens: 1024, MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}}, Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000}, Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
} }
@@ -404,13 +399,13 @@ func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
Model: "test-model", Model: "test-model",
MaxTokens: 1024, MaxTokens: 1024,
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: textContent("Hello")}, {Role: "user", Content: "Hello"},
{ {
Role: "assistant", Role: "assistant",
Content: []ContentBlock{ Content: []any{
{ map[string]any{
Type: "thinking", "type": "thinking",
Thinking: ptr("Let me think about this..."), "thinking": "Let me think about this...",
}, },
}, },
}, },
@@ -439,10 +434,10 @@ func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
Messages: []MessageParam{ Messages: []MessageParam{
{ {
Role: "assistant", Role: "assistant",
Content: []ContentBlock{ Content: []any{
{ map[string]any{
Type: "tool_use", "type": "tool_use",
Name: "get_weather", "name": "get_weather",
}, },
}, },
}, },
@@ -465,10 +460,10 @@ func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
Messages: []MessageParam{ Messages: []MessageParam{
{ {
Role: "assistant", Role: "assistant",
Content: []ContentBlock{ Content: []any{
{ map[string]any{
Type: "tool_use", "type": "tool_use",
ID: "call_123", "id": "call_123",
}, },
}, },
}, },
@@ -488,7 +483,7 @@ func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
req := MessagesRequest{ req := MessagesRequest{
Model: "test-model", Model: "test-model",
MaxTokens: 1024, MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}}, Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{ Tools: []Tool{
{ {
Name: "bad_tool", Name: "bad_tool",
@@ -553,7 +548,7 @@ func TestToMessagesResponse_WithToolCalls(t *testing.T) {
ID: "call_123", ID: "call_123",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: makeArgs("location", "Paris"), Arguments: testArgs(map[string]any{"location": "Paris"}),
}, },
}, },
}, },
@@ -765,7 +760,7 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
ID: "call_123", ID: "call_123",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: makeArgs("location", "Paris"), Arguments: testArgs(map[string]any{"location": "Paris"}),
}, },
}, },
}, },
@@ -804,107 +799,6 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
} }
} }
// TestStreamConverter_ThinkingDirectlyFollowedByToolCall verifies that when a
// model emits a thinking block followed directly by a tool_use block (with no
// text block in between), the streaming converter correctly closes the thinking
// block and increments the content index before opening the tool_use block.
// Previously, the converter reused contentIndex=0 for the tool_use block,
// which caused "Content block not found" errors in clients. See #14816.
func TestStreamConverter_ThinkingDirectlyFollowedByToolCall(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model", 0)
// First chunk: thinking content (no text)
resp1 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Thinking: "I should call the tool.",
},
}
events1 := conv.Process(resp1)
// Should have: message_start, content_block_start(thinking), content_block_delta(thinking)
if len(events1) < 3 {
t.Fatalf("expected at least 3 events for thinking chunk, got %d", len(events1))
}
if events1[0].Event != "message_start" {
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
}
thinkingStart, ok := events1[1].Data.(ContentBlockStartEvent)
if !ok || thinkingStart.ContentBlock.Type != "thinking" {
t.Errorf("expected content_block_start(thinking) as second event, got %+v", events1[1])
}
if thinkingStart.Index != 0 {
t.Errorf("expected thinking block at index 0, got %d", thinkingStart.Index)
}
// Second chunk: tool call (no text between thinking and tool)
resp2 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_abc",
Function: api.ToolCallFunction{
Name: "ask_user",
Arguments: makeArgs("question", "cats or dogs?"),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
}
events2 := conv.Process(resp2)
// Expect: content_block_stop(index=0), content_block_start(tool_use, index=1),
// content_block_delta(input_json_delta, index=1), content_block_stop(index=1),
// message_delta, message_stop
var thinkingStop, toolStart, toolDelta, toolStop *StreamEvent
for i := range events2 {
e := &events2[i]
switch e.Event {
case "content_block_stop":
if stop, ok := e.Data.(ContentBlockStopEvent); ok {
if stop.Index == 0 && thinkingStop == nil {
thinkingStop = e
} else if stop.Index == 1 {
toolStop = e
}
}
case "content_block_start":
if start, ok := e.Data.(ContentBlockStartEvent); ok && start.ContentBlock.Type == "tool_use" {
toolStart = e
}
case "content_block_delta":
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok && delta.Delta.Type == "input_json_delta" {
toolDelta = e
}
}
}
if thinkingStop == nil {
t.Error("expected content_block_stop for thinking block (index 0)")
}
if toolStart == nil {
t.Fatal("expected content_block_start for tool_use block")
}
if start, ok := toolStart.Data.(ContentBlockStartEvent); !ok || start.Index != 1 {
t.Errorf("expected tool_use block at index 1, got %+v", toolStart.Data)
}
if toolDelta == nil {
t.Fatal("expected input_json_delta event for tool call")
}
if delta, ok := toolDelta.Data.(ContentBlockDeltaEvent); !ok || delta.Index != 1 {
t.Errorf("expected tool delta at index 1, got %+v", toolDelta.Data)
}
if toolStop == nil {
t.Error("expected content_block_stop for tool_use block (index 1)")
}
}
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) { func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
// Test that unmarshalable arguments (like channels) are handled gracefully // Test that unmarshalable arguments (like channels) are handled gracefully
// and don't cause a panic or corrupt stream // and don't cause a panic or corrupt stream
@@ -970,7 +864,7 @@ func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
ID: "call_good", ID: "call_good",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "good_function", Name: "good_function",
Arguments: makeArgs("location", "Paris"), Arguments: testArgs(map[string]any{"location": "Paris"}),
}, },
}, },
{ {
@@ -1072,57 +966,6 @@ func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
} }
} }
func TestContentBlockJSON_NonToolBlocksDoNotIncludeInput(t *testing.T) {
tests := []struct {
name string
block ContentBlock
}{
{
name: "text block",
block: ContentBlock{
Type: "text",
Text: ptr("hello"),
},
},
{
name: "thinking block",
block: ContentBlock{
Type: "thinking",
Thinking: ptr("let me think"),
},
},
{
name: "image block",
block: ContentBlock{
Type: "image",
Source: &ImageSource{
Type: "base64",
MediaType: "image/png",
Data: testImage,
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.block)
if err != nil {
t.Fatalf("failed to marshal: %v", err)
}
var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if _, ok := result["input"]; ok {
t.Fatalf("unexpected input field in non-tool block JSON: %s", string(data))
}
})
}
}
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) { func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
t.Run("text block start includes empty text", func(t *testing.T) { t.Run("text block start includes empty text", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model", 0) conv := NewStreamConverter("msg_123", "test-model", 0)
@@ -1143,9 +986,7 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
// Marshal and verify the text field is present // Marshal and verify the text field is present
data, _ := json.Marshal(start) data, _ := json.Marshal(start)
var result map[string]any var result map[string]any
if err := json.Unmarshal(data, &result); err != nil { json.Unmarshal(data, &result)
t.Fatalf("failed to unmarshal content_block_start JSON: %v", err)
}
cb := result["content_block"].(map[string]any) cb := result["content_block"].(map[string]any)
if _, ok := cb["text"]; !ok { if _, ok := cb["text"]; !ok {
t.Error("content_block_start for text should include 'text' field") t.Error("content_block_start for text should include 'text' field")
@@ -1192,71 +1033,13 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
t.Error("expected thinking content_block_start event") t.Error("expected thinking content_block_start event")
} }
}) })
t.Run("tool_use block start includes empty input object", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model", 0)
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: makeArgs("location", "Paris"),
},
},
},
},
}
events := conv.Process(resp)
var foundToolStart bool
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
foundToolStart = true
if start.ContentBlock.Input.Len() != 0 {
t.Errorf("expected empty input object, got len=%d", start.ContentBlock.Input.Len())
}
data, _ := json.Marshal(start)
var result map[string]any
json.Unmarshal(data, &result)
cb := result["content_block"].(map[string]any)
input, ok := cb["input"]
if !ok {
t.Error("content_block_start for tool_use should include 'input' field")
continue
}
inputMap, ok := input.(map[string]any)
if !ok {
t.Errorf("input field should be an object, got %T", input)
continue
}
if len(inputMap) != 0 {
t.Errorf("expected empty input object in content_block_start, got %v", inputMap)
}
}
}
}
}
if !foundToolStart {
t.Error("expected tool_use content_block_start event")
}
})
} }
func TestEstimateTokens_SimpleMessage(t *testing.T) { func TestEstimateTokens_SimpleMessage(t *testing.T) {
req := CountTokensRequest{ req := CountTokensRequest{
Model: "test-model", Model: "test-model",
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: textContent("Hello, world!")}, {Role: "user", Content: "Hello, world!"},
}, },
} }
@@ -1277,7 +1060,7 @@ func TestEstimateTokens_WithSystemPrompt(t *testing.T) {
Model: "test-model", Model: "test-model",
System: "You are a helpful assistant.", System: "You are a helpful assistant.",
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: textContent("Hello")}, {Role: "user", Content: "Hello"},
}, },
} }
@@ -1293,7 +1076,7 @@ func TestEstimateTokens_WithTools(t *testing.T) {
req := CountTokensRequest{ req := CountTokensRequest{
Model: "test-model", Model: "test-model",
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: textContent("What's the weather?")}, {Role: "user", Content: "What's the weather?"},
}, },
Tools: []Tool{ Tools: []Tool{
{ {
@@ -1316,17 +1099,17 @@ func TestEstimateTokens_WithThinking(t *testing.T) {
req := CountTokensRequest{ req := CountTokensRequest{
Model: "test-model", Model: "test-model",
Messages: []MessageParam{ Messages: []MessageParam{
{Role: "user", Content: textContent("Hello")}, {Role: "user", Content: "Hello"},
{ {
Role: "assistant", Role: "assistant",
Content: []ContentBlock{ Content: []any{
{ map[string]any{
Type: "thinking", "type": "thinking",
Thinking: ptr("Let me think about this carefully..."), "thinking": "Let me think about this carefully...",
}, },
{ map[string]any{
Type: "text", "type": "text",
Text: ptr("Here is my response."), "text": "Here is my response.",
}, },
}, },
}, },
@@ -1424,12 +1207,12 @@ func TestConvertTool_RegularTool(t *testing.T) {
func TestConvertMessage_ServerToolUse(t *testing.T) { func TestConvertMessage_ServerToolUse(t *testing.T) {
msg := MessageParam{ msg := MessageParam{
Role: "assistant", Role: "assistant",
Content: []ContentBlock{ Content: []any{
{ map[string]any{
Type: "server_tool_use", "type": "server_tool_use",
ID: "srvtoolu_123", "id": "srvtoolu_123",
Name: "web_search", "name": "web_search",
Input: makeArgs("query", "test query"), "input": map[string]any{"query": "test query"},
}, },
}, },
} }
@@ -1460,11 +1243,11 @@ func TestConvertMessage_ServerToolUse(t *testing.T) {
func TestConvertMessage_WebSearchToolResult(t *testing.T) { func TestConvertMessage_WebSearchToolResult(t *testing.T) {
msg := MessageParam{ msg := MessageParam{
Role: "user", Role: "user",
Content: []ContentBlock{ Content: []any{
{ map[string]any{
Type: "web_search_tool_result", "type": "web_search_tool_result",
ToolUseID: "srvtoolu_123", "tool_use_id": "srvtoolu_123",
Content: []any{ "content": []any{
map[string]any{ map[string]any{
"type": "web_search_result", "type": "web_search_result",
"title": "Test Result", "title": "Test Result",
@@ -1501,11 +1284,11 @@ func TestConvertMessage_WebSearchToolResult(t *testing.T) {
func TestConvertMessage_WebSearchToolResultEmptyStillCreatesToolMessage(t *testing.T) { func TestConvertMessage_WebSearchToolResultEmptyStillCreatesToolMessage(t *testing.T) {
msg := MessageParam{ msg := MessageParam{
Role: "user", Role: "user",
Content: []ContentBlock{ Content: []any{
{ map[string]any{
Type: "web_search_tool_result", "type": "web_search_tool_result",
ToolUseID: "srvtoolu_empty", "tool_use_id": "srvtoolu_empty",
Content: []any{}, "content": []any{},
}, },
}, },
} }
@@ -1532,11 +1315,11 @@ func TestConvertMessage_WebSearchToolResultEmptyStillCreatesToolMessage(t *testi
func TestConvertMessage_WebSearchToolResultErrorStillCreatesToolMessage(t *testing.T) { func TestConvertMessage_WebSearchToolResultErrorStillCreatesToolMessage(t *testing.T) {
msg := MessageParam{ msg := MessageParam{
Role: "user", Role: "user",
Content: []ContentBlock{ Content: []any{
{ map[string]any{
Type: "web_search_tool_result", "type": "web_search_tool_result",
ToolUseID: "srvtoolu_error", "tool_use_id": "srvtoolu_error",
Content: map[string]any{ "content": map[string]any{
"type": "web_search_tool_result_error", "type": "web_search_tool_result_error",
"error_code": "max_uses_exceeded", "error_code": "max_uses_exceeded",
}, },

View File

@@ -476,3 +476,25 @@ func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
} }
return &resp, nil return &resp, nil
} }
// AliasRequest is the request body for creating or updating a model alias.
type AliasRequest struct {
Alias string `json:"alias"`
Target string `json:"target"`
PrefixMatching bool `json:"prefix_matching,omitempty"`
}
// SetAliasExperimental creates or updates a model alias via the experimental aliases API.
func (c *Client) SetAliasExperimental(ctx context.Context, req *AliasRequest) error {
return c.do(ctx, http.MethodPost, "/api/experimental/aliases", req, nil)
}
// AliasDeleteRequest is the request body for deleting a model alias.
type AliasDeleteRequest struct {
Alias string `json:"alias"`
}
// DeleteAliasExperimental deletes a model alias via the experimental aliases API.
func (c *Client) DeleteAliasExperimental(ctx context.Context, req *AliasDeleteRequest) error {
return c.do(ctx, http.MethodDelete, "/api/experimental/aliases", req, nil)
}

View File

@@ -436,7 +436,6 @@ type ToolProperty struct {
Description string `json:"description,omitempty"` Description string `json:"description,omitempty"`
Enum []any `json:"enum,omitempty"` Enum []any `json:"enum,omitempty"`
Properties *ToolPropertiesMap `json:"properties,omitempty"` Properties *ToolPropertiesMap `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
} }
// ToTypeScriptType converts a ToolProperty to a TypeScript type string // ToTypeScriptType converts a ToolProperty to a TypeScript type string

View File

@@ -14,7 +14,7 @@ import (
// currentSchemaVersion defines the current database schema version. // currentSchemaVersion defines the current database schema version.
// Increment this when making schema changes that require migrations. // Increment this when making schema changes that require migrations.
const currentSchemaVersion = 16 const currentSchemaVersion = 15
// database wraps the SQLite connection. // database wraps the SQLite connection.
// SQLite handles its own locking for concurrent access: // SQLite handles its own locking for concurrent access:
@@ -82,7 +82,6 @@ func (db *database) init() error {
websearch_enabled BOOLEAN NOT NULL DEFAULT 0, websearch_enabled BOOLEAN NOT NULL DEFAULT 0,
selected_model TEXT NOT NULL DEFAULT '', selected_model TEXT NOT NULL DEFAULT '',
sidebar_open BOOLEAN NOT NULL DEFAULT 0, sidebar_open BOOLEAN NOT NULL DEFAULT 0,
last_home_view TEXT NOT NULL DEFAULT 'launch',
think_enabled BOOLEAN NOT NULL DEFAULT 0, think_enabled BOOLEAN NOT NULL DEFAULT 0,
think_level TEXT NOT NULL DEFAULT '', think_level TEXT NOT NULL DEFAULT '',
cloud_setting_migrated BOOLEAN NOT NULL DEFAULT 0, cloud_setting_migrated BOOLEAN NOT NULL DEFAULT 0,
@@ -265,12 +264,6 @@ func (db *database) migrate() error {
return fmt.Errorf("migrate v14 to v15: %w", err) return fmt.Errorf("migrate v14 to v15: %w", err)
} }
version = 15 version = 15
case 15:
// add last_home_view column to settings table
if err := db.migrateV15ToV16(); err != nil {
return fmt.Errorf("migrate v15 to v16: %w", err)
}
version = 16
default: default:
// If we have a version we don't recognize, just set it to current // If we have a version we don't recognize, just set it to current
// This might happen during development // This might happen during development
@@ -525,21 +518,6 @@ func (db *database) migrateV14ToV15() error {
return nil return nil
} }
// migrateV15ToV16 adds the last_home_view column to the settings table
func (db *database) migrateV15ToV16() error {
_, err := db.conn.Exec(`ALTER TABLE settings ADD COLUMN last_home_view TEXT NOT NULL DEFAULT 'launch'`)
if err != nil && !duplicateColumnError(err) {
return fmt.Errorf("add last_home_view column: %w", err)
}
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 16`)
if err != nil {
return fmt.Errorf("update schema version: %w", err)
}
return nil
}
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug // cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
func (db *database) cleanupOrphanedData() error { func (db *database) cleanupOrphanedData() error {
_, err := db.conn.Exec(` _, err := db.conn.Exec(`
@@ -1188,9 +1166,9 @@ func (db *database) getSettings() (Settings, error) {
var s Settings var s Settings
err := db.conn.QueryRow(` err := db.conn.QueryRow(`
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, turbo_enabled, websearch_enabled, selected_model, sidebar_open, last_home_view, think_enabled, think_level, auto_update_enabled SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level, auto_update_enabled
FROM settings FROM settings
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.LastHomeView, &s.ThinkEnabled, &s.ThinkLevel, &s.AutoUpdateEnabled) `).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel, &s.AutoUpdateEnabled)
if err != nil { if err != nil {
return Settings{}, fmt.Errorf("get settings: %w", err) return Settings{}, fmt.Errorf("get settings: %w", err)
} }
@@ -1199,26 +1177,10 @@ func (db *database) getSettings() (Settings, error) {
} }
func (db *database) setSettings(s Settings) error { func (db *database) setSettings(s Settings) error {
lastHomeView := strings.ToLower(strings.TrimSpace(s.LastHomeView))
validLaunchView := map[string]struct{}{
"launch": {},
"openclaw": {},
"claude": {},
"codex": {},
"opencode": {},
"droid": {},
"pi": {},
}
if lastHomeView != "chat" {
if _, ok := validLaunchView[lastHomeView]; !ok {
lastHomeView = "launch"
}
}
_, err := db.conn.Exec(` _, err := db.conn.Exec(`
UPDATE settings UPDATE settings
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, last_home_view = ?, think_enabled = ?, think_level = ?, auto_update_enabled = ? SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?, auto_update_enabled = ?
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, lastHomeView, s.ThinkEnabled, s.ThinkLevel, s.AutoUpdateEnabled) `, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel, s.AutoUpdateEnabled)
if err != nil { if err != nil {
return fmt.Errorf("set settings: %w", err) return fmt.Errorf("set settings: %w", err)
} }

View File

@@ -135,45 +135,6 @@ func TestMigrationV13ToV14ContextLength(t *testing.T) {
} }
} }
func TestMigrationV15ToV16LastHomeViewDefaultsToLaunch(t *testing.T) {
tmpDir := t.TempDir()
dbPath := filepath.Join(tmpDir, "test.db")
db, err := newDatabase(dbPath)
if err != nil {
t.Fatalf("failed to create database: %v", err)
}
defer db.Close()
if _, err := db.conn.Exec(`
ALTER TABLE settings DROP COLUMN last_home_view;
UPDATE settings SET schema_version = 15;
`); err != nil {
t.Fatalf("failed to seed v15 settings row: %v", err)
}
if err := db.migrate(); err != nil {
t.Fatalf("migration from v15 to v16 failed: %v", err)
}
var lastHomeView string
if err := db.conn.QueryRow("SELECT last_home_view FROM settings").Scan(&lastHomeView); err != nil {
t.Fatalf("failed to read last_home_view: %v", err)
}
if lastHomeView != "launch" {
t.Fatalf("expected last_home_view to default to launch after migration, got %q", lastHomeView)
}
version, err := db.getSchemaVersion()
if err != nil {
t.Fatalf("failed to get schema version: %v", err)
}
if version != currentSchemaVersion {
t.Fatalf("expected schema version %d, got %d", currentSchemaVersion, version)
}
}
func TestChatDeletionWithCascade(t *testing.T) { func TestChatDeletionWithCascade(t *testing.T) {
t.Run("chat deletion cascades to related messages", func(t *testing.T) { t.Run("chat deletion cascades to related messages", func(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()

View File

@@ -167,9 +167,6 @@ type Settings struct {
// SidebarOpen indicates if the chat sidebar is open // SidebarOpen indicates if the chat sidebar is open
SidebarOpen bool SidebarOpen bool
// LastHomeView stores the preferred home route target ("chat" or integration name)
LastHomeView string
// AutoUpdateEnabled indicates if automatic updates should be downloaded // AutoUpdateEnabled indicates if automatic updates should be downloaded
AutoUpdateEnabled bool AutoUpdateEnabled bool
} }
@@ -392,10 +389,6 @@ func (s *Store) Settings() (Settings, error) {
} }
} }
if settings.LastHomeView == "" {
settings.LastHomeView = "launch"
}
return settings, nil return settings, nil
} }

View File

@@ -81,32 +81,6 @@ func TestStore(t *testing.T) {
} }
}) })
t.Run("settings default home view is launch", func(t *testing.T) {
loaded, err := s.Settings()
if err != nil {
t.Fatal(err)
}
if loaded.LastHomeView != "launch" {
t.Fatalf("expected default LastHomeView to be launch, got %q", loaded.LastHomeView)
}
})
t.Run("settings empty home view falls back to launch", func(t *testing.T) {
if err := s.SetSettings(Settings{LastHomeView: ""}); err != nil {
t.Fatal(err)
}
loaded, err := s.Settings()
if err != nil {
t.Fatal(err)
}
if loaded.LastHomeView != "launch" {
t.Fatalf("expected empty LastHomeView to fall back to launch, got %q", loaded.LastHomeView)
}
})
t.Run("window size", func(t *testing.T) { t.Run("window size", func(t *testing.T) {
if err := s.SetWindowSize(1024, 768); err != nil { if err := s.SetWindowSize(1024, 768); err != nil {
t.Fatal(err) t.Fatal(err)

View File

@@ -414,7 +414,6 @@ export class Settings {
ThinkLevel: string; ThinkLevel: string;
SelectedModel: string; SelectedModel: string;
SidebarOpen: boolean; SidebarOpen: boolean;
LastHomeView: string;
AutoUpdateEnabled: boolean; AutoUpdateEnabled: boolean;
constructor(source: any = {}) { constructor(source: any = {}) {
@@ -433,7 +432,6 @@ export class Settings {
this.ThinkLevel = source["ThinkLevel"]; this.ThinkLevel = source["ThinkLevel"];
this.SelectedModel = source["SelectedModel"]; this.SelectedModel = source["SelectedModel"];
this.SidebarOpen = source["SidebarOpen"]; this.SidebarOpen = source["SidebarOpen"];
this.LastHomeView = source["LastHomeView"];
this.AutoUpdateEnabled = source["AutoUpdateEnabled"]; this.AutoUpdateEnabled = source["AutoUpdateEnabled"];
} }
} }
@@ -552,12 +550,14 @@ export class Error {
} }
} }
export class ModelUpstreamResponse { export class ModelUpstreamResponse {
stale: boolean; digest?: string;
pushTime: number;
error?: string; error?: string;
constructor(source: any = {}) { constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source); if ('string' === typeof source) source = JSON.parse(source);
this.stale = source["stale"]; this.digest = source["digest"];
this.pushTime = source["pushTime"];
this.error = source["error"]; this.error = source["error"];
} }
} }

View File

@@ -1,7 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!-- Generated by Pixelmator Pro 3.6.17 -->
<svg width="1200" height="1200" viewBox="0 0 1200 1200" xmlns="http://www.w3.org/2000/svg">
<g id="g314">
<path id="path147" fill="#d97757" stroke="none" d="M 233.959793 800.214905 L 468.644287 668.536987 L 472.590637 657.100647 L 468.644287 650.738403 L 457.208069 650.738403 L 417.986633 648.322144 L 283.892639 644.69812 L 167.597321 639.865845 L 54.926208 633.825623 L 26.577238 627.785339 L 3.3e-05 592.751709 L 2.73832 575.27533 L 26.577238 559.248352 L 60.724873 562.228149 L 136.187973 567.382629 L 249.422867 575.194763 L 331.570496 580.026978 L 453.261841 592.671082 L 472.590637 592.671082 L 475.328857 584.859009 L 468.724915 580.026978 L 463.570557 575.194763 L 346.389313 495.785217 L 219.543671 411.865906 L 153.100723 363.543762 L 117.181267 339.060425 L 99.060455 316.107361 L 91.248367 266.01355 L 123.865784 230.093994 L 167.677887 233.073853 L 178.872513 236.053772 L 223.248367 270.201477 L 318.040283 343.570496 L 441.825592 434.738342 L 459.946411 449.798706 L 467.194672 444.64447 L 468.080597 441.020203 L 459.946411 427.409485 L 392.617493 305.718323 L 320.778564 181.932983 L 288.80542 130.630859 L 280.348999 99.865845 C 277.369171 87.221436 275.194641 76.590698 275.194641 63.624268 L 312.322174 13.20813 L 332.8591 6.604126 L 382.389313 13.20813 L 403.248352 31.328979 L 434.013519 101.71814 L 483.865753 212.537048 L 561.181274 363.221497 L 583.812134 407.919434 L 595.892639 449.315491 L 600.40271 461.959839 L 608.214783 461.959839 L 608.214783 454.711609 L 614.577271 369.825623 L 626.335632 265.61084 L 637.771851 131.516846 L 641.718201 93.745117 L 660.402832 48.483276 L 697.530334 24.000122 L 726.52356 37.852417 L 750.362549 72 L 747.060486 94.067139 L 732.886047 186.201416 L 705.100708 330.52356 L 686.979919 427.167847 L 697.530334 427.167847 L 709.61084 415.087341 L 758.496704 350.174561 L 840.644348 247.490051 L 876.885925 206.738342 L 919.167847 161.71814 L 946.308838 140.29541 L 997.61084 140.29541 L 1035.38269 196.429626 L 1018.469849 254.416199 L 965.637634 321.422852 L 921.825562 378.201538 L 859.006714 462.765259 L 819.785278 530.41626 L 823.409424 535.812073 L 832.75177 534.92627 L 974.657776 504.724915 L 1051.328979 490.872559 L 1142.818848 475.167786 L 1184.214844 494.496582 L 1188.724854 514.147644 L 1172.456421 554.335693 L 1074.604126 578.496765 L 959.838989 601.449829 L 788.939636 641.879272 L 786.845764 643.409485 L 789.261841 646.389343 L 866.255127 653.637634 L 899.194702 655.409424 L 979.812134 655.409424 L 1129.932861 666.604187 L 1169.154419 692.537109 L 1192.671265 724.268677 L 1188.724854 748.429688 L 1128.322144 779.194641 L 1046.818848 759.865845 L 856.590759 714.604126 L 791.355774 698.335754 L 782.335693 698.335754 L 782.335693 703.731567 L 836.69812 756.885986 L 936.322205 846.845581 L 1061.073975 962.81897 L 1067.436279 991.490112 L 1051.409424 1014.120911 L 1034.496704 1011.704712 L 924.885986 929.234924 L 882.604126 892.107544 L 786.845764 811.48999 L 780.483276 811.48999 L 780.483276 819.946289 L 802.550415 852.241699 L 919.087341 1027.409424 L 925.127625 1081.127686 L 916.671204 1098.604126 L 886.469849 1109.154419 L 853.288696 1103.114136 L 785.073914 1007.355835 L 714.684631 899.516785 L 657.906067 802.872498 L 650.979858 806.81897 L 617.476624 1167.704834 L 601.771851 1186.147705 L 565.530212 1200 L 535.328857 1177.046997 L 519.302124 1139.919556 L 535.328857 1066.550537 L 554.657776 970.792053 L 570.362488 894.68457 L 584.536926 800.134277 L 592.993347 768.724976 L 592.429626 766.630859 L 585.503479 767.516968 L 514.22821 865.369263 L 405.825531 1011.865906 L 320.053711 1103.677979 L 299.516815 1111.812256 L 263.919525 1093.369263 L 267.221497 1060.429688 L 287.114136 1031.114136 L 405.825531 880.107361 L 477.422913 786.52356 L 523.651062 732.483276 L 523.328918 724.671265 L 520.590698 724.671265 L 205.288605 929.395935 L 149.154434 936.644409 L 124.993355 914.01355 L 127.973183 876.885986 L 139.409409 864.80542 L 234.201385 799.570435 L 233.879227 799.8927 Z"/>
</g>
</svg>

Before

Width:  |  Height:  |  Size: 4.0 KiB

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 320 320"><path fill="#fff" d="m297.06 130.97c7.26-21.79 4.76-45.66-6.85-65.48-17.46-30.4-52.56-46.04-86.84-38.68-15.25-17.18-37.16-26.95-60.13-26.81-35.04-.08-66.13 22.48-76.91 55.82-22.51 4.61-41.94 18.7-53.31 38.67-17.59 30.32-13.58 68.54 9.92 94.54-7.26 21.79-4.76 45.66 6.85 65.48 17.46 30.4 52.56 46.04 86.84 38.68 15.24 17.18 37.16 26.95 60.13 26.8 35.06.09 66.16-22.49 76.94-55.86 22.51-4.61 41.94-18.7 53.31-38.67 17.57-30.32 13.55-68.51-9.94-94.51zm-120.28 168.11c-14.03.02-27.62-4.89-38.39-13.88.49-.26 1.34-.73 1.89-1.07l63.72-36.8c3.26-1.85 5.26-5.32 5.24-9.07v-89.83l26.93 15.55c.29.14.48.42.52.74v74.39c-.04 33.08-26.83 59.9-59.91 59.97zm-128.84-55.03c-7.03-12.14-9.56-26.37-7.15-40.18.47.28 1.3.79 1.89 1.13l63.72 36.8c3.23 1.89 7.23 1.89 10.47 0l77.79-44.92v31.1c.02.32-.13.63-.38.83l-64.41 37.19c-28.69 16.52-65.33 6.7-81.92-21.95zm-16.77-139.09c7-12.16 18.05-21.46 31.21-26.29 0 .55-.03 1.52-.03 2.2v73.61c-.02 3.74 1.98 7.21 5.23 9.06l77.79 44.91-26.93 15.55c-.27.18-.61.21-.91.08l-64.42-37.22c-28.63-16.58-38.45-53.21-21.95-81.89zm221.26 51.49-77.79-44.92 26.93-15.54c.27-.18.61-.21.91-.08l64.42 37.19c28.68 16.57 38.51 53.26 21.94 81.94-7.01 12.14-18.05 21.44-31.2 26.28v-75.81c.03-3.74-1.96-7.2-5.2-9.06zm26.8-40.34c-.47-.29-1.3-.79-1.89-1.13l-63.72-36.8c-3.23-1.89-7.23-1.89-10.47 0l-77.79 44.92v-31.1c-.02-.32.13-.63.38-.83l64.41-37.16c28.69-16.55 65.37-6.7 81.91 22 6.99 12.12 9.52 26.31 7.15 40.1zm-168.51 55.43-26.94-15.55c-.29-.14-.48-.42-.52-.74v-74.39c.02-33.12 26.89-59.96 60.01-59.94 14.01 0 27.57 4.92 38.34 13.88-.49.26-1.33.73-1.89 1.07l-63.72 36.8c-3.26 1.85-5.26 5.31-5.24 9.06l-.04 89.79zm14.63-31.54 34.65-20.01 34.65 20v40.01l-34.65 20-34.65-20z"/></svg>

Before

Width:  |  Height:  |  Size: 1.7 KiB

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 320 320"><path d="m297.06 130.97c7.26-21.79 4.76-45.66-6.85-65.48-17.46-30.4-52.56-46.04-86.84-38.68-15.25-17.18-37.16-26.95-60.13-26.81-35.04-.08-66.13 22.48-76.91 55.82-22.51 4.61-41.94 18.7-53.31 38.67-17.59 30.32-13.58 68.54 9.92 94.54-7.26 21.79-4.76 45.66 6.85 65.48 17.46 30.4 52.56 46.04 86.84 38.68 15.24 17.18 37.16 26.95 60.13 26.8 35.06.09 66.16-22.49 76.94-55.86 22.51-4.61 41.94-18.7 53.31-38.67 17.57-30.32 13.55-68.51-9.94-94.51zm-120.28 168.11c-14.03.02-27.62-4.89-38.39-13.88.49-.26 1.34-.73 1.89-1.07l63.72-36.8c3.26-1.85 5.26-5.32 5.24-9.07v-89.83l26.93 15.55c.29.14.48.42.52.74v74.39c-.04 33.08-26.83 59.9-59.91 59.97zm-128.84-55.03c-7.03-12.14-9.56-26.37-7.15-40.18.47.28 1.3.79 1.89 1.13l63.72 36.8c3.23 1.89 7.23 1.89 10.47 0l77.79-44.92v31.1c.02.32-.13.63-.38.83l-64.41 37.19c-28.69 16.52-65.33 6.7-81.92-21.95zm-16.77-139.09c7-12.16 18.05-21.46 31.21-26.29 0 .55-.03 1.52-.03 2.2v73.61c-.02 3.74 1.98 7.21 5.23 9.06l77.79 44.91-26.93 15.55c-.27.18-.61.21-.91.08l-64.42-37.22c-28.63-16.58-38.45-53.21-21.95-81.89zm221.26 51.49-77.79-44.92 26.93-15.54c.27-.18.61-.21.91-.08l64.42 37.19c28.68 16.57 38.51 53.26 21.94 81.94-7.01 12.14-18.05 21.44-31.2 26.28v-75.81c.03-3.74-1.96-7.2-5.2-9.06zm26.8-40.34c-.47-.29-1.3-.79-1.89-1.13l-63.72-36.8c-3.23-1.89-7.23-1.89-10.47 0l-77.79 44.92v-31.1c-.02-.32.13-.63.38-.83l64.41-37.16c28.69-16.55 65.37-6.7 81.91 22 6.99 12.12 9.52 26.31 7.15 40.1zm-168.51 55.43-26.94-15.55c-.29-.14-.48-.42-.52-.74v-74.39c.02-33.12 26.89-59.96 60.01-59.94 14.01 0 27.57 4.92 38.34 13.88-.49.26-1.33.73-1.89 1.07l-63.72 36.8c-3.26 1.85-5.26 5.31-5.24 9.06l-.04 89.79zm14.63-31.54 34.65-20.01 34.65 20v40.01l-34.65 20-34.65-20z"/></svg>

Before

Width:  |  Height:  |  Size: 1.7 KiB

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 6.2 KiB

View File

@@ -1,242 +0,0 @@
<svg version="1.2" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 500 500" width="500" height="500">
<style>
.s0 { fill: #f6f4f4 }
.s1 { fill: #0b0303 }
.s2 { fill: #ef0011 }
.s3 { fill: #f3e2e2 }
.s4 { fill: #f00212 }
.s5 { fill: #ba000d }
.s6 { fill: #faf1f1 }
.s7 { fill: #0b0100 }
.s8 { fill: #fbedee }
.s9 { fill: #faeaea }
.s10 { fill: #ab797d }
.s11 { fill: #f8eaea }
.s12 { fill: #902021 }
.s13 { fill: #f9eeee }
.s14 { fill: #f6ecec }
.s15 { fill: #080201 }
.s16 { fill: #150100 }
.s17 { fill: #f2e7e7 }
.s18 { fill: #fbe7e8 }
.s19 { fill: #060101 }
.s20 { fill: #f5e7e7 }
.s21 { fill: #fa999e }
.s22 { fill: #c46064 }
.s23 { fill: #180300 }
.s24 { fill: #f6dcdd }
.s25 { fill: #f2e6e6 }
.s26 { fill: #110200 }
.s27 { fill: #eb0011 }
.s28 { fill: #e20010 }
.s29 { fill: #ea0011 }
.s30 { fill: #760007 }
.s31 { fill: #f00514 }
.s32 { fill: #fcebeb }
.s33 { fill: #ecd6d6 }
.s34 { fill: #f5e3e3 }
.s35 { fill: #f5e4e4 }
.s36 { fill: #faf6f6 }
.s37 { fill: #e50010 }
.s38 { fill: #d5000f }
.s39 { fill: #f2e2e3 }
.s40 { fill: #ef1018 }
.s41 { fill: #f4e8e9 }
.s42 { fill: #ef0513 }
.s43 { fill: #f5e5e5 }
.s44 { fill: #f00413 }
.s45 { fill: #f4e9ea }
.s46 { fill: #ed0011 }
.s47 { fill: #e80011 }
.s48 { fill: #e60613 }
.s49 { fill: #f0d6d6 }
.s50 { fill: #fca9ac }
.s51 { fill: #9c000c }
.s52 { fill: #73393b }
</style>
<g>
<path fill-rule="evenodd" class="s0" d="m166.5 52.5q3.5 0 7 0 2.75 2.99 1.5 7-21.27 45.61-20.5 96 39.99 2.76 72 26.5 7.87 6.86 13.5 15.5 42.88-56.39 103.5-92.5 47.35-25.46 101-25 14.52 0.38 23.5 11.5 3.19 7.74 2 16-1.81 7.18-4.5 14-1 0-1 1-5.04 6.05-9 13-1 0-1 1 0 0.5 0 1-12.42 12.15-28.5 19-6.02 36.27-41.5 45-0.83 2.75 0 5 19.02-12.85 41.5-9 10.85-8.09 23.5-13 15.01-6.37 31-2.5 14.09 7.43 14 23.5-2.83 23.25-15.5 43-6.42 9.92-14 19-10.04 8.8-19.5 18-72.02 48.88-156.5 27-19.63 9.6-41.5 10.5-4.59 1.27-9 3 2 1 4 2 20.09-1.11 35 12 25.46 6.95 37.5 30.5 1.26 5.69-1 11-3.38 3.79-7.5 6.5 5.74 10.07 1.5 20.5-7.55 7.47-17.5 3.5-11.01-5.34-22.5-9.5-18.26 10-38.5 13-15.5 0-31 0-26.62-4.54-51-17-4.17 1.33-8 3.5-7.23 5.87-15 11-8.62 2.58-13.5-4.5-1.82 2.32-4.5 3.5-6.06 2.24-12 3.5-7.5 0-15 0-27.42-2.56-50-18.5-18-17.25-23-41.5 0-11.5 0-23 4.12-22.7 25-33 6.95-16.67 22-26.5-20.39-20.8-14.5-49.5 7.01-26.98 28.5-44.5 7.56-5.27 15-10.5-13.09-30.88-7.5-64 3.16-15.57 14.5-26.5 6.85-2.48 8 4.5-6.59 39.53 11 75.5 7.99-0.49 16-2 2.42-34.57 14.5-67.5 8.51-22.23 27.5-36z"/>
</g>
<g>
<path fill-rule="evenodd" class="s1" d="m113.5 401.5q0.48-5.1-1-10-0.91 0.19-1 1-2.46 1.74-5 3.5 5.65 9.54-5 13-32.21 5.55-61-10-32.89-23.11-29.5-63.5 2.96-22.67 23.5-32 7.99-19.75 27-29.5-27.65-23.7-15.5-58.5 7.33-16.82 20.5-29.5 10.79-8.14 22-15.5-16.49-37.08-5.5-76 3.19-6.13 7.5-11.5 1.48-0.89 2 1-5.69 41.09 12.5 78.5 1 1 2 2 9.97-3.24 20.5-4 2 0 4 0 0-7.5 0-15 0.99-42.22 24.5-77 6.12-7.12 14-12-4.65 13.43-10 27-11.93 37.6-9.5 77 49.38 0.7 83.5 36 2.75 4.5 5.5 9 38.99-52.24 93-88.5 45.84-29.03 100-32.5 15.69-1.56 29 6.5 5.68 7.29 3.5 16.5-10.38 33.62-43.5 45-4.39 37.33-41 45-0.79 8.63-6 15.5 1.91 1.83 4.5 2.5 22.27-17.25 50.5-14.5 12.93-9.41 28-15 36.22-8.28 31.5 28.5-15.19 51.69-62.5 77.5-65.92 35.87-138 15.5-19.67 10.42-42 10.5-8.39 2.88-17 5 3.58 6.08 10 9 20.92-1.14 36 13 22.67 5.23 34.5 25.5 3.33 7.13-3.5 11.5-3.88 1.8-8 3 7.36 8.45 6.5 19.5-4.43 5.66-11.5 3.5-12.84-5.67-26-10.5-39.4 21.02-83 10.5-18.85-5.78-36.5-14.5-13.65 4.14-23.5 14.5-9.51 3.74-11-6.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s2" d="m153.5 173.5q24.62 1.46 46 13.5 12.11 8.1 17.5 21.5 0.74 2.45 0.5 5 0.09 0.81 1 1 1.48-4.9 1-10 5.04 10.48 1.5 22-9.81 27.86-35.5 42.5-26.17 14.97-56 19.5-2.77-0.4-2 1 2.86 1.27 6 1 25.64 1.53 48.5-10 0.34 10.08 2 20 1.08 5.76 5 10 1 1.5 0 3-31.11 20.84-68.5 17.5-23.7-5.7-32.5-28.5-4.39-9.18-3.5-19 15.41 6.23 32 4.5-20.68-6.39-39-18-34.81-27.22-12.5-65.5 11.84-14.83 29-23 4.21 7.66 11.5 12.5 3 1 6 0-26.04-34.62-29-78-0.13-8.46 2-16.5 1 6.5 2 13 3.43 39.53 24.5 73 2.03 2.28 4.5 4 0.5-1.25 1-2.5-1.27-6.54-5-12 0.5-0.75 1-1.5 9.72-3.43 20-4 0.55 10.34 8 17.5 1.94 0.74 4 0.5-17.8-64.6 16.5-122 0.98-1.79 1.5 0-28.21 56.64-13.5 118 1.08 1.43 2.5 0.5 2.21-4.98 2-10.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s3" d="m454.5 97.5q-18.37-2.97-37-1.5-16.14 2.08-32 5.5 32.38-14.09 67-7.5 1.98 1.22 2 3.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s4" d="m454.5 97.5q-1.33 11.18-8.5 20-21.81 26.28-55.5 32-1.11-0.2-2 0.5 2.31 2.82 5.5 4.5 1 2 0 4-9.56 11.3-19.5 20 19.71-8.72 31-27 2.68-0.43 5 1-14.24 30.97-48 36.5-9.93 1.71-20 1.5-6.8-0.48-13 1 5.81 6.92 14 11-10.78 16.03-27 26.5 27.16-7.4 38-33.5 4.34 1.35 9 1-9.08 23.84-33 33.5-18.45 6.41-38 7 22.59 8.92 45-1 12.05-5.52 24-11 9.01-1.79 17 2.5 5.28-4.38 11-8 12.8-6.07 27-5 0 0.5 0 1-19.34 2.69-34 15.5 0.5 0.25 1 0.5 17.79-8.09 36-15 2.71-0.79 5-2 2.5-1 5-2 5.53-4.04 11-8 11.7-4.18 24-6.5 7.78-1.36 15 1.5-2.97 18.45-13.5 34-34.92 49.37-94.5 62.5-59.27 12.45-108-23-15.53-12.52-21.5-31.5-2.47-14.26 4-27-3.15 24.41 14 42-4.92-10.28-7-22-1.97-17.63 7-33 47.28-69.5 125.5-100 15.86-3.42 32-5.5 18.63-1.47 37 1.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s5" d="m86.5 112.5q-1-6.5-2-13 0.7-5.34 3.5-10-1.8 11.32-1.5 23z"/>
</g>
<g>
<path fill-rule="evenodd" class="s6" d="m433.5 97.5q2.22-0.39 4 1-10 13.75-27 14-0.24-2.06 0.5-4 10.3-7.78 22.5-11z"/>
</g>
<g>
<path fill-rule="evenodd" class="s7" d="m407.5 101.5q2.55-0.24 5 0.5-52.87 18.31-84.5 64.5-6.94 7.95-17 11-9.38-2.38-5-11 40.38-48.62 101.5-65z"/>
</g>
<g>
<path fill-rule="evenodd" class="s8" d="m402.5 112.5q3 0 6 0-2.56 8.8-12 7-0.22-1.58 0.5-3 2.72-2.22 5.5-4z"/>
</g>
<g>
</g>
<g>
</g>
<g>
</g>
<g>
</g>
<g>
</g>
<g>
</g>
<g>
</g>
<g>
<path fill-rule="evenodd" class="s9" d="m390.5 149.5q7.77 0.52 15 2-11.29 18.28-31 27 9.94-8.7 19.5-20 1-2 0-4-3.19-1.68-5.5-4.5 0.89-0.7 2-0.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s10" d="m131.5 145.5q0 7.5 0 15-2 0-4 0 1.06-1.36 3-1-0.48-7.29 1-14z"/>
</g>
<g>
<path fill-rule="evenodd" class="s11" d="m219.5 204.5q-1 4.5-2 9 0.24-2.55-0.5-5-5.39-13.4-17.5-21.5-21.38-12.04-46-13.5 0-2 0-4 36.7-0.86 61.5 26 3.06 4.11 4.5 9z"/>
</g>
<g>
<path fill-rule="evenodd" class="s12" d="m329.5 191.5q6.2-1.48 13-1-3.5 1-7 2-2.9-0.97-6-1z"/>
</g>
<g>
<path fill-rule="evenodd" class="s13" d="m329.5 191.5q3.1 0.03 6 1 9.55 1.31 19 3-10.84 26.1-38 33.5 16.22-10.47 27-26.5-8.19-4.08-14-11z"/>
</g>
<g>
<path fill-rule="evenodd" class="s14" d="m479.5 199.5q-7.22-2.86-15-1.5-12.3 2.32-24 6.5 15.6-13.11 36-11.5 3.63 2.26 3 6.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s15" d="m193.5 216.5q-12.01 1.52-22 8-2.83 1.29-5.5 3-4.79-4.57-6.5-11-5.04 2.2-9.5-1-3.47-6.4 3.5-3 4.4 0.05 8-2.5 9.22-9.73 21-16 6.3-3.24 12 1-2.9 1.22-6 1.5 2.61 5.74 4.5 12 0.75 3.97 0.5 8z"/>
</g>
<g>
<path fill-rule="evenodd" class="s16" d="m458.5 200.5q3.04-0.24 6 0.5-18.02 7.05-33 19-1 1-2 0 11.53-14.3 29-19.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s17" d="m178.5 202.5q6.85-0.63 4.5 6-7.6 5.09-6-4 1.08-0.82 1.5-2z"/>
</g>
<g>
<path fill-rule="evenodd" class="s18" d="m469.5 201.5q-2.26 13.65-14.5 22-0.47-2.11 1-4 7.08-8.82 13.5-18z"/>
</g>
<g>
<path fill-rule="evenodd" class="s19" d="m74.5 208.5q8.22-0.2 16 2.5 11.8 4.26 23.5 8.5 5.65-0.63 8-6 2.41 11.83-9.5 13 0.55 3.61 2 7-0.5 1-1 2-4.67-0.94-9.5-1-9.96 0.44-19.5 2.5-5.05-3.55-6.5-9.5-0.75-7.48-0.5-15-6.47 0.15-3-4z"/>
</g>
<g>
<path fill-rule="evenodd" class="s20" d="m429.5 212.5q-2.5 1-5 2-4 0-8 0-14.2-1.07-27 5 15.27-12.44 35-9.5 2.72 1.14 5 2.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s21" d="m219.5 204.5q0.48 5.1-1 10-0.91-0.19-1-1 1-4.5 2-9z"/>
</g>
<g>
<path fill-rule="evenodd" class="s22" d="m416.5 215.5q0-0.5 0-1 4 0 8 0-2.29 1.21-5 2-1.06-1.36-3-1z"/>
</g>
<g>
<path fill-rule="evenodd" class="s23" d="m416.5 215.5q1.94-0.36 3 1-18.21 6.91-36 15-0.5-0.25-1-0.5 14.66-12.81 34-15.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s24" d="m193.5 216.5q4.39 1.3 9 3-0.79 1.04-2 1.5-14.77-0.13-29 3.5 9.99-6.48 22-8z"/>
</g>
<g>
<path fill-rule="evenodd" class="s25" d="m98.5 219.5q6.09-0.98 6 5-3.04 0.24-6-0.5-1.84-2.24 0-4.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s26" d="m176.5 229.5q8.85-1.14 16 4-4.98 1.75-10 0-13.56 14.3-33 19.5-28.06 8.2-55 1 3.32-6.4 10-5.5-0.71 1.47-2 2.5 36.58 4.24 69-14 4.68-2.13 1-5 2.35-0.91 4-2.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s27" d="m231.5 238.5q1.31-0.2 2 1-3.13 28.62 15 51-16.25 6.75-27-7.5-1-1-2 0 14.73 29.34 46 18.5 1.79 0.52 0 1.5-37.63 16.82-50.5-22.5-5.1-26.48 16.5-42z"/>
</g>
<g>
<path fill-rule="evenodd" class="s28" d="m243.5 259.5q5.88 3.62 10.5 9 12.96 18.46 32.5 29.5-31.51-7.75-43-38.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s29" d="m203.5 266.5q1.31-0.2 2 1-2.48 22.08 12 39-6.99 1.35-14 0.5 4.59 4.08 10 7-8.71 0.28-14.5-6.5-16.98-22.76 4.5-41z"/>
</g>
<g>
<path fill-rule="evenodd" class="s27" d="m58.5 284.5q9.6-2.17 14.5 6 5.15 14.18-1 28-11.05-13.14-27.5-17.5 5.15-9.9 14-16.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s30" d="m129.5 288.5q2 1 4 2-3.14 0.27-6-1-0.77-1.4 2-1z"/>
</g>
<g>
<path fill-rule="evenodd" class="s31" d="m56.5 313.5q3.43 5.43 8 10-4.88 0.44-8 4-1.11-0.2-2 0.5 28.91 1.65 38 28.5 0.45 3.16-1 6-11.02-7.01-23-12.5-4.75-3.75-9.5-7.5 1.47 7.42 7 13 8.34 27.18 32 43 0.99 2.41-1.5 3.5-40.25 5.58-66.5-25.5-15.67-22.01-8-48 10.46-23.87 34.5-15z"/>
</g>
<g>
<path fill-rule="evenodd" class="s32" d="m45.5 317.5q4.03-0.25 8 0.5 2.46 4.16-2 6-6.04 2.01-9-3.5 1.26-1.85 3-3z"/>
</g>
<g>
<path fill-rule="evenodd" class="s33" d="m56.5 313.5q4.91 3.14 9.5 7 0.88 2.25-1.5 3-4.57-4.57-8-10z"/>
</g>
<g>
<path fill-rule="evenodd" class="s34" d="m198.5 319.5q-11.1 11.56-27 15.5-15.75 4.88-32 2.5 28.81-3.69 54-18.5 2.65-0.96 5 0.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s4" d="m198.5 319.5q1.44 0.68 2.5 2 2.41 8.23 6 16 1.2 2.64-0.5 5-30.65 21.41-68 18.5-25.16-6.17-32.5-30.5 6.96 4.99 15.5 6.5 8.99 0.75 18 0.5 16.25 2.38 32-2.5 15.9-3.94 27-15.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s35" d="m92.5 356.5q-9.09-26.85-38-28.5 0.89-0.7 2-0.5 25.47-4.89 35.5 19 0.75 4.98 0.5 10z"/>
</g>
<g>
<path fill-rule="evenodd" class="s36" d="m72.5 335.5q3.62-0.38 5 3-4.22 1.83-5-3z"/>
</g>
<g>
<path fill-rule="evenodd" class="s37" d="m223.5 336.5q5.59-0.48 11 1-4.04 4.16-8.5 8-5.99-3.8-2.5-9z"/>
</g>
<g>
<path fill-rule="evenodd" class="s38" d="m90.5 334.5q0.59-1.54 2-0.5 3.94 5.45 9 10 7 6 14 12-6.91-1.7-13-6-6.21-7.72-12-15.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s39" d="m261.5 346.5q-3.54-2.44-8-3.5-6.98-0.75-14-0.5 0.63-1.08 2-1.5 13.82-2.52 26 4-2.63 1.98-6 1.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s40" d="m239.5 342.5q7.02-0.25 14 0.5 4.46 1.06 8 3.5-5.2 2.35-10 5.5-3.88 4.65-9 7.5-9.89-3.09-9.5-13 2.36-3.63 6.5-4z"/>
</g>
<g>
<path fill-rule="evenodd" class="s41" d="m214.5 349.5q-21.43 15.48-48 16 22.82-5.9 43-18.5 3.64-1.12 5 2.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s42" d="m214.5 349.5q5.96 7.2 13.5 13 1 1 0 2-28.58 23.34-65.5 20.5-18.15-4.24-27.5-19.5 1.13 0.94 2.5 1.5 14.7 1.42 29-1.5 26.57-0.52 48-16z"/>
</g>
<g>
<path fill-rule="evenodd" class="s43" d="m302.5 373.5q-14.74-16.73-37-19-4.55 0.25-9 1 25.3-10.24 43.5 11 2.85 2.91 2.5 7z"/>
</g>
<g>
<path fill-rule="evenodd" class="s44" d="m302.5 373.5q0.21 2.44-2 3.5-28.69 7.6-50.5-12.5-0.06-6.71 6.5-9 4.45-0.75 9-1 22.26 2.27 37 19z"/>
</g>
<g>
<path fill-rule="evenodd" class="s45" d="m100.5 356.5q5.42 2.71 11 5.5-13.04 7.54-18.5 21.5-7.57-7.14-10.5-17 5.58 1.54 10 5.5 4.2 0.84 5.5-3.5 1.41-5.99 2.5-12z"/>
</g>
<g>
<path fill-rule="evenodd" class="s8" d="m83.5 394.5q-18.9-10.15-29.5-29-1.54-3.52-2-7 5.79 2.39 10 7 7.82 16.63 21.5 29z"/>
</g>
<g>
<path fill-rule="evenodd" class="s46" d="m232.5 365.5q17.6 6.19 10.5 23-10.6 10.42-25.5 11.5-25.94 3.21-49-9 36.75-1.65 64-25.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s47" d="m113.5 367.5q7.7-0.01 9.5 7-9.69 7.19-18.5 15.5-7.23 5.76-5.5-3.5 3.12-12.84 14.5-19z"/>
</g>
<g>
<path fill-rule="evenodd" class="s29" d="m126.5 380.5q7.88-0.4 12 6.5-8.5 7.25-17 14.5-5.62-12.55 5-21z"/>
</g>
<g>
<path fill-rule="evenodd" class="s48" d="m283.5 385.5q3.22 2.95 7 5.5 2.8 4.03 6 7.5 0.42 2.77-2 4-15.5-9.75-31-19.5-1.79-0.98 0-1.5 9.96 2.49 20 4z"/>
</g>
<g>
<path fill-rule="evenodd" class="s49" d="m283.5 385.5q8.71-1.27 11.5 7 1.22 2.9 1.5 6-3.2-3.47-6-7.5-3.78-2.55-7-5.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s50" d="m83.5 394.5q1.88-0.06 3 1.5-2.25 0.88-3-1.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s51" d="m258.5 392.5q3.51 0.41 0 2.5-2.33 1.93-5 2 2.61-2.28 5-4.5z"/>
</g>
<g>
<path fill-rule="evenodd" class="s52" d="m111.5 392.5q0.09-0.81 1-1 1.48 4.9 1 10-1-4.5-2-9z"/>
</g>
</svg>

Before

Width:  |  Height:  |  Size: 13 KiB

View File

@@ -1,7 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" xmlns:xlink="http://www.w3.org/1999/xlink" width="512" height="512"><svg width="512" height="512" viewBox="0 0 512 512" fill="none" xmlns="http://www.w3.org/2000/svg">
<rect width="512" height="512" fill="#131010"></rect>
<path d="M320 224V352H192V224H320Z" fill="#5A5858"></path>
<path fill-rule="evenodd" clip-rule="evenodd" d="M384 416H128V96H384V416ZM320 160H192V352H320V160Z" fill="white"></path>
</svg><style>@media (prefers-color-scheme: light) { :root { filter: none; } }
@media (prefers-color-scheme: dark) { :root { filter: none; } }
</style></svg>

Before

Width:  |  Height:  |  Size: 612 B

View File

@@ -1,9 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 800 800">
<rect width="800" height="800" rx="160" fill="#fff"/>
<path fill="#000" fill-rule="evenodd" d="
M165.29 165.29 H517.36 V400 H400 V517.36 H282.65 V634.72 H165.29 Z
M282.65 282.65 V400 H400 V282.65 Z
"/>
<path fill="#000" d="M517.36 400 H634.72 V634.72 H517.36 Z"/>
</svg>

Before

Width:  |  Height:  |  Size: 389 B

View File

@@ -1,9 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 800 800">
<rect width="800" height="800" rx="160" fill="#000"/>
<path fill="#fff" fill-rule="evenodd" d="
M165.29 165.29 H517.36 V400 H400 V517.36 H282.65 V634.72 H165.29 Z
M282.65 282.65 V400 H400 V282.65 Z
"/>
<path fill="#fff" d="M517.36 400 H634.72 V634.72 H517.36 Z"/>
</svg>

Before

Width:  |  Height:  |  Size: 389 B

View File

@@ -161,7 +161,7 @@ export async function getModels(query?: string): Promise<Model[]> {
// Add query if it's in the registry and not already in the list // Add query if it's in the registry and not already in the list
if (!exactMatch) { if (!exactMatch) {
const result = await getModelUpstreamInfo(new Model({ model: query })); const result = await getModelUpstreamInfo(new Model({ model: query }));
const existsUpstream = result.exists; const existsUpstream = !!result.digest && !result.error;
if (existsUpstream) { if (existsUpstream) {
filteredModels.push(new Model({ model: query })); filteredModels.push(new Model({ model: query }));
} }
@@ -339,7 +339,7 @@ export async function deleteChat(chatId: string): Promise<void> {
// Get upstream information for model staleness checking // Get upstream information for model staleness checking
export async function getModelUpstreamInfo( export async function getModelUpstreamInfo(
model: Model, model: Model,
): Promise<{ stale: boolean; exists: boolean; error?: string }> { ): Promise<{ digest?: string; pushTime: number; error?: string }> {
try { try {
const response = await fetch(`${API_BASE}/api/v1/model/upstream`, { const response = await fetch(`${API_BASE}/api/v1/model/upstream`, {
method: "POST", method: "POST",
@@ -353,22 +353,22 @@ export async function getModelUpstreamInfo(
if (!response.ok) { if (!response.ok) {
console.warn( console.warn(
`Failed to check upstream for ${model.model}: ${response.status}`, `Failed to check upstream digest for ${model.model}: ${response.status}`,
); );
return { stale: false, exists: false }; return { pushTime: 0 };
} }
const data = await response.json(); const data = await response.json();
if (data.error) { if (data.error) {
console.warn(`Upstream check: ${data.error}`); console.warn(`Upstream digest check: ${data.error}`);
return { stale: false, exists: false, error: data.error }; return { error: data.error, pushTime: 0 };
} }
return { stale: !!data.stale, exists: true }; return { digest: data.digest, pushTime: data.pushTime || 0 };
} catch (error) { } catch (error) {
console.warn(`Error checking model staleness:`, error); console.warn(`Error checking model staleness:`, error);
return { stale: false, exists: false }; return { pushTime: 0 };
} }
} }

View File

@@ -480,15 +480,13 @@ function ChatForm({
return; return;
} }
// Prepare attachments for submission, excluding unsupported images // Prepare attachments for submission
const attachmentsToSend: FileAttachment[] = message.attachments const attachmentsToSend: FileAttachment[] = message.attachments.map(
.filter( (att) => ({
(att) => hasVisionCapability || !isImageFile(att.filename),
)
.map((att) => ({
filename: att.filename, filename: att.filename,
data: att.data || new Uint8Array(0), // Empty data for existing files data: att.data || new Uint8Array(0), // Empty data for existing files
})); }),
);
const useWebSearch = const useWebSearch =
supportsWebSearch && webSearchEnabled && !cloudDisabled; supportsWebSearch && webSearchEnabled && !cloudDisabled;
@@ -738,17 +736,10 @@ function ChatForm({
)} )}
{(message.attachments.length > 0 || message.fileErrors.length > 0) && ( {(message.attachments.length > 0 || message.fileErrors.length > 0) && (
<div className="flex gap-2 overflow-x-auto px-3 pt pb-3 w-full scrollbar-hide"> <div className="flex gap-2 overflow-x-auto px-3 pt pb-3 w-full scrollbar-hide">
{message.attachments.map((attachment, index) => { {message.attachments.map((attachment, index) => (
const isUnsupportedImage =
!hasVisionCapability && isImageFile(attachment.filename);
return (
<div <div
key={attachment.id} key={attachment.id}
className={`group flex items-center gap-2 py-2 px-3 rounded-lg transition-colors flex-shrink-0 ${ className="group flex items-center gap-2 py-2 px-3 rounded-lg bg-neutral-50 dark:bg-neutral-700/50 hover:bg-neutral-100 dark:hover:bg-neutral-700 transition-colors flex-shrink-0"
isUnsupportedImage
? "bg-red-50 dark:bg-red-900/20 border border-red-200 dark:border-red-800"
: "bg-neutral-50 dark:bg-neutral-700/50 hover:bg-neutral-100 dark:hover:bg-neutral-700"
}`}
> >
{isImageFile(attachment.filename) ? ( {isImageFile(attachment.filename) ? (
<ImageThumbnail <ImageThumbnail
@@ -773,16 +764,9 @@ function ChatForm({
/> />
</svg> </svg>
)} )}
<div className="flex flex-col min-w-0"> <span className="text-sm text-neutral-700 dark:text-neutral-300 max-w-[150px] truncate">
<span className={`text-sm max-w-36 truncate ${isUnsupportedImage ? "text-red-700 dark:text-red-300" : "text-neutral-700 dark:text-neutral-300"}`}> {attachment.filename}
{attachment.filename} </span>
</span>
{isUnsupportedImage && (
<span className="text-xs text-red-600 dark:text-red-400 opacity-75">
This model does not support images
</span>
)}
</div>
<button <button
type="button" type="button"
onClick={() => removeFile(index)} onClick={() => removeFile(index)}
@@ -804,8 +788,7 @@ function ChatForm({
</svg> </svg>
</button> </button>
</div> </div>
); ))}
})}
{message.fileErrors.map((fileError, index) => ( {message.fileErrors.map((fileError, index) => (
<div <div
key={`error-${index}`} key={`error-${index}`}

View File

@@ -6,13 +6,12 @@ import { getChat } from "@/api";
import { Link } from "@/components/ui/link"; import { Link } from "@/components/ui/link";
import { useState, useRef, useEffect, useCallback, useMemo } from "react"; import { useState, useRef, useEffect, useCallback, useMemo } from "react";
import { ChatsResponse } from "@/gotypes"; import { ChatsResponse } from "@/gotypes";
import { CogIcon, RocketLaunchIcon } from "@heroicons/react/24/outline"; import { CogIcon } from "@heroicons/react/24/outline";
// there's a hidden debug feature to copy a chat's data to the clipboard by // there's a hidden debug feature to copy a chat's data to the clipboard by
// holding shift and clicking this many times within this many seconds // holding shift and clicking this many times within this many seconds
const DEBUG_SHIFT_CLICKS_REQUIRED = 5; const DEBUG_SHIFT_CLICKS_REQUIRED = 5;
const DEBUG_SHIFT_CLICK_WINDOW_MS = 7000; // 7 seconds const DEBUG_SHIFT_CLICK_WINDOW_MS = 7000; // 7 seconds
const launchSidebarRequestedKey = "ollama.launchSidebarRequested";
interface ChatSidebarProps { interface ChatSidebarProps {
currentChatId?: string; currentChatId?: string;
@@ -268,8 +267,9 @@ export function ChatSidebar({ currentChatId }: ChatSidebarProps) {
<Link <Link
href="/c/new" href="/c/new"
mask={{ to: "/" }} mask={{ to: "/" }}
className={`flex w-full items-center gap-3 rounded-lg px-2 py-2 text-left text-sm text-neutral-700 hover:bg-neutral-100 dark:hover:bg-neutral-800 dark:text-neutral-100 ${currentChatId === "new" ? "bg-neutral-100 dark:bg-neutral-800" : "" className={`flex w-full items-center gap-3 rounded-lg px-2 py-2 text-left text-sm text-neutral-700 hover:bg-neutral-100 dark:hover:bg-neutral-800 dark:text-neutral-100 ${
}`} currentChatId === "new" ? "bg-neutral-100 dark:bg-neutral-800" : ""
}`}
draggable={false} draggable={false}
> >
<svg <svg
@@ -283,23 +283,6 @@ export function ChatSidebar({ currentChatId }: ChatSidebarProps) {
</svg> </svg>
<span className="truncate">New Chat</span> <span className="truncate">New Chat</span>
</Link> </Link>
<Link
to="/c/$chatId"
params={{ chatId: "launch" }}
onClick={() => {
if (currentChatId !== "launch") {
sessionStorage.setItem(launchSidebarRequestedKey, "1");
}
}}
className={`flex w-full items-center gap-3 rounded-lg px-2 py-2 text-left text-sm text-neutral-700 hover:bg-neutral-100 dark:hover:bg-neutral-800 dark:text-neutral-100 cursor-pointer ${currentChatId === "launch"
? "bg-neutral-100 dark:bg-neutral-800"
: ""
}`}
draggable={false}
>
<RocketLaunchIcon className="h-5 w-5 stroke-current" />
<span className="truncate">Launch</span>
</Link>
{isWindows && ( {isWindows && (
<Link <Link
href="/settings" href="/settings"
@@ -321,18 +304,19 @@ export function ChatSidebar({ currentChatId }: ChatSidebarProps) {
{group.chats.map((chat) => ( {group.chats.map((chat) => (
<div <div
key={chat.id} key={chat.id}
className={`allow-context-menu flex items-center relative text-sm text-neutral-800 dark:text-neutral-400 rounded-lg hover:bg-neutral-100 dark:hover:bg-neutral-800 ${chat.id === currentChatId className={`allow-context-menu flex items-center relative text-sm text-neutral-800 dark:text-neutral-400 rounded-lg hover:bg-neutral-100 dark:hover:bg-neutral-800 ${
? "bg-neutral-100 text-black dark:bg-neutral-800" chat.id === currentChatId
: "" ? "bg-neutral-100 text-black dark:bg-neutral-800"
}`} : ""
}`}
onMouseEnter={() => handleMouseEnter(chat.id)} onMouseEnter={() => handleMouseEnter(chat.id)}
onContextMenu={(e) => onContextMenu={(e) =>
handleContextMenu( handleContextMenu(
e, e,
chat.id, chat.id,
chat.title || chat.title ||
chat.userExcerpt || chat.userExcerpt ||
chat.createdAt.toLocaleString(), chat.createdAt.toLocaleString(),
) )
} }
> >

View File

@@ -10,7 +10,6 @@ interface CopyButtonProps {
showLabels?: boolean; showLabels?: boolean;
className?: string; className?: string;
title?: string; title?: string;
onCopy?: () => void;
} }
const CopyButton: React.FC<CopyButtonProps> = ({ const CopyButton: React.FC<CopyButtonProps> = ({
@@ -21,7 +20,6 @@ const CopyButton: React.FC<CopyButtonProps> = ({
showLabels = false, showLabels = false,
className = "", className = "",
title = "", title = "",
onCopy,
}) => { }) => {
const [isCopied, setIsCopied] = useState(false); const [isCopied, setIsCopied] = useState(false);
@@ -50,14 +48,12 @@ const CopyButton: React.FC<CopyButtonProps> = ({
} }
setIsCopied(true); setIsCopied(true);
onCopy?.();
setTimeout(() => setIsCopied(false), 2000); setTimeout(() => setIsCopied(false), 2000);
} catch (error) { } catch (error) {
console.error("Clipboard API failed, falling back to plain text", error); console.error("Clipboard API failed, falling back to plain text", error);
try { try {
await navigator.clipboard.writeText(content); await navigator.clipboard.writeText(content);
setIsCopied(true); setIsCopied(true);
onCopy?.();
setTimeout(() => setIsCopied(false), 2000); setTimeout(() => setIsCopied(false), 2000);
} catch (fallbackError) { } catch (fallbackError) {
console.error("Fallback copy also failed:", fallbackError); console.error("Fallback copy also failed:", fallbackError);

View File

@@ -1,133 +0,0 @@
import { useSettings } from "@/hooks/useSettings";
import CopyButton from "@/components/CopyButton";
interface LaunchCommand {
id: string;
name: string;
command: string;
description: string;
icon: string;
darkIcon?: string;
iconClassName?: string;
borderless?: boolean;
}
const LAUNCH_COMMANDS: LaunchCommand[] = [
{
id: "openclaw",
name: "OpenClaw",
command: "ollama launch openclaw",
description: "Personal AI with 100+ skills",
icon: "/launch-icons/openclaw.svg",
},
{
id: "claude",
name: "Claude",
command: "ollama launch claude",
description: "Anthropic's coding tool with subagents",
icon: "/launch-icons/claude.svg",
iconClassName: "h-7 w-7",
},
{
id: "codex",
name: "Codex",
command: "ollama launch codex",
description: "OpenAI's open-source coding agent",
icon: "/launch-icons/codex.svg",
darkIcon: "/launch-icons/codex-dark.svg",
iconClassName: "h-7 w-7",
},
{
id: "opencode",
name: "OpenCode",
command: "ollama launch opencode",
description: "Anomaly's open-source coding agent",
icon: "/launch-icons/opencode.svg",
iconClassName: "h-7 w-7 rounded",
},
{
id: "droid",
name: "Droid",
command: "ollama launch droid",
description: "Factory's coding agent across terminal and IDEs",
icon: "/launch-icons/droid.svg",
},
{
id: "pi",
name: "Pi",
command: "ollama launch pi",
description: "Minimal AI agent toolkit with plugin support",
icon: "/launch-icons/pi.svg",
darkIcon: "/launch-icons/pi-dark.svg",
iconClassName: "h-7 w-7",
},
];
export default function LaunchCommands() {
const isWindows = navigator.platform.toLowerCase().includes("win");
const { setSettings } = useSettings();
const renderCommandCard = (item: LaunchCommand) => (
<div key={item.command} className="w-full text-left">
<div className="flex items-start gap-4 sm:gap-5">
<div
aria-hidden="true"
className={`flex h-10 w-10 shrink-0 items-center justify-center rounded-lg overflow-hidden ${item.borderless ? "" : "border border-neutral-200 bg-white dark:border-neutral-700 dark:bg-neutral-900"}`}
>
{item.darkIcon ? (
<picture>
<source srcSet={item.darkIcon} media="(prefers-color-scheme: dark)" />
<img src={item.icon} alt="" className={`${item.iconClassName ?? "h-8 w-8"} rounded-sm`} />
</picture>
) : (
<img src={item.icon} alt="" className={item.borderless ? "h-full w-full rounded-xl" : `${item.iconClassName ?? "h-8 w-8"} rounded-sm`} />
)}
</div>
<div className="min-w-0 flex-1">
<span className="text-sm font-medium text-neutral-900 dark:text-neutral-100">
{item.name}
</span>
<p className="mt-0.5 text-xs text-neutral-500 dark:text-neutral-400">
{item.description}
</p>
<div className="mt-2 flex items-center gap-2 rounded-xl border-neutral-200 dark:border-neutral-700 bg-neutral-50 dark:bg-neutral-800 px-3 py-2">
<code className="min-w-0 flex-1 truncate text-xs text-neutral-600 dark:text-neutral-300">
{item.command}
</code>
<CopyButton
content={item.command}
size="md"
title="Copy command to clipboard"
className="text-neutral-500 dark:text-neutral-400 hover:text-neutral-700 dark:hover:text-neutral-200 hover:bg-neutral-200/60 dark:hover:bg-neutral-700/70"
onCopy={() => {
setSettings({ LastHomeView: item.id }).catch(() => { });
}}
/>
</div>
</div>
</div>
</div>
);
return (
<main className="flex h-screen w-full flex-col relative">
<section
className={`flex-1 overflow-y-auto overscroll-contain relative min-h-0 ${isWindows ? "xl:pt-4" : "xl:pt-8"}`}
>
<div className="max-w-[730px] mx-auto w-full px-4 pt-4 pb-20 sm:px-6 sm:pt-6 sm:pb-24 lg:px-8 lg:pt-8 lg:pb-28">
<h1 className="text-xl font-semibold text-neutral-900 dark:text-neutral-100">
Launch
</h1>
<p className="mt-1 text-sm text-neutral-500 dark:text-neutral-400">
Copy a command and run it in your terminal.
</p>
<div className="mt-6 grid gap-7">
{LAUNCH_COMMANDS.map(renderCommandCard)}
</div>
</div>
</section>
</main>
);
}

View File

@@ -536,7 +536,7 @@ function ToolCallDisplay({
let args: Record<string, unknown> | null = null; let args: Record<string, unknown> | null = null;
try { try {
args = JSON.parse(toolCall.function.arguments) as Record<string, unknown>; args = JSON.parse(toolCall.function.arguments) as Record<string, unknown>;
} catch { } catch (e) {
args = null; args = null;
} }
const query = args && typeof args.query === "string" ? args.query : ""; const query = args && typeof args.query === "string" ? args.query : "";
@@ -562,7 +562,7 @@ function ToolCallDisplay({
let args: Record<string, unknown> | null = null; let args: Record<string, unknown> | null = null;
try { try {
args = JSON.parse(toolCall.function.arguments) as Record<string, unknown>; args = JSON.parse(toolCall.function.arguments) as Record<string, unknown>;
} catch { } catch (e) {
args = null; args = null;
} }
const url = args && typeof args.url === "string" ? args.url : ""; const url = args && typeof args.url === "string" ? args.url : "";

View File

@@ -73,7 +73,7 @@ export default function MessageList({
? String(args.url).trim() ? String(args.url).trim()
: ""; : "";
if (candidate) lastQuery = candidate; if (candidate) lastQuery = candidate;
} catch { /* ignored */ } } catch {}
} }
} }
} }

View File

@@ -61,7 +61,24 @@ export const ModelPicker = forwardRef<
try { try {
const upstreamInfo = await getModelUpstreamInfo(model); const upstreamInfo = await getModelUpstreamInfo(model);
if (upstreamInfo.stale) { // Compare local digest with upstream digest
let isStale =
model.digest &&
upstreamInfo.digest &&
model.digest !== upstreamInfo.digest;
// If the model has a modified time and upstream has a push time,
// check if the model was modified after the push time - if so, it's not stale
if (isStale && model.modified_at && upstreamInfo.pushTime > 0) {
const modifiedAtTime =
new Date(model.modified_at as string | number | Date).getTime() /
1000;
if (modifiedAtTime > upstreamInfo.pushTime) {
isStale = false;
}
}
if (isStale) {
const currentStaleModels = const currentStaleModels =
queryClient.getQueryData<Map<string, boolean>>(["staleModels"]) || queryClient.getQueryData<Map<string, boolean>>(["staleModels"]) ||
new Map(); new Map();

View File

@@ -214,7 +214,6 @@ export default function Settings() {
Agent: false, Agent: false,
Tools: false, Tools: false,
ContextLength: 0, ContextLength: 0,
AutoUpdateEnabled: true,
}); });
updateSettingsMutation.mutate(defaultSettings); updateSettingsMutation.mutate(defaultSettings);
} }
@@ -273,10 +272,6 @@ export default function Settings() {
} }
const isWindows = navigator.platform.toLowerCase().includes("win"); const isWindows = navigator.platform.toLowerCase().includes("win");
const handleCloseSettings = () => {
const chatId = settings.LastHomeView === "chat" ? "new" : "launch";
navigate({ to: "/c/$chatId", params: { chatId } });
};
return ( return (
<main className="flex h-screen w-full flex-col select-none dark:bg-neutral-900"> <main className="flex h-screen w-full flex-col select-none dark:bg-neutral-900">
@@ -290,7 +285,7 @@ export default function Settings() {
> >
{isWindows && ( {isWindows && (
<button <button
onClick={handleCloseSettings} onClick={() => navigate({ to: "/" })}
className="hover:bg-neutral-100 mr-3 dark:hover:bg-neutral-800 rounded-full p-1.5" className="hover:bg-neutral-100 mr-3 dark:hover:bg-neutral-800 rounded-full p-1.5"
> >
<ArrowLeftIcon className="w-5 h-5 dark:text-white" /> <ArrowLeftIcon className="w-5 h-5 dark:text-white" />
@@ -300,7 +295,7 @@ export default function Settings() {
</h1> </h1>
{!isWindows && ( {!isWindows && (
<button <button
onClick={handleCloseSettings} onClick={() => navigate({ to: "/" })}
className="p-1 hover:bg-neutral-100 mr-3 dark:hover:bg-neutral-800 rounded-full" className="p-1 hover:bg-neutral-100 mr-3 dark:hover:bg-neutral-800 rounded-full"
> >
<XMarkIcon className="w-6 h-6 dark:text-white" /> <XMarkIcon className="w-6 h-6 dark:text-white" />

View File

@@ -65,7 +65,7 @@ export const BadgeButton = forwardRef(function BadgeButton(
), ),
ref: React.ForwardedRef<HTMLElement>, ref: React.ForwardedRef<HTMLElement>,
) { ) {
const classes = clsx( let classes = clsx(
className, className,
"group relative inline-flex rounded-md focus:not-data-focus:outline-hidden data-focus:outline-2 data-focus:outline-offset-2 data-focus:outline-blue-500", "group relative inline-flex rounded-md focus:not-data-focus:outline-hidden data-focus:outline-2 data-focus:outline-offset-2 data-focus:outline-blue-500",
); );

View File

@@ -171,7 +171,7 @@ export const Button = forwardRef(function Button(
{ color, outline, plain, className, children, ...props }: ButtonProps, { color, outline, plain, className, children, ...props }: ButtonProps,
ref: React.ForwardedRef<HTMLElement>, ref: React.ForwardedRef<HTMLElement>,
) { ) {
const classes = clsx( let classes = clsx(
className, className,
styles.base, styles.base,
outline outline

View File

@@ -381,7 +381,7 @@ export const useSendMessage = (chatId: string) => {
role: "assistant", role: "assistant",
content: "", content: "",
thinking: "", thinking: "",
model: effectiveModel.model, model: effectiveModel,
}), }),
); );
lastMessage = newMessages[newMessages.length - 1]; lastMessage = newMessages[newMessages.length - 1];
@@ -433,7 +433,7 @@ export const useSendMessage = (chatId: string) => {
role: "assistant", role: "assistant",
content: "", content: "",
thinking: "", thinking: "",
model: effectiveModel.model, model: effectiveModel,
}), }),
); );
lastMessage = newMessages[newMessages.length - 1]; lastMessage = newMessages[newMessages.length - 1];
@@ -520,7 +520,7 @@ export const useSendMessage = (chatId: string) => {
thinkingTimeStart: thinkingTimeStart:
lastMessage.thinkingTimeStart || event.thinkingTimeStart, lastMessage.thinkingTimeStart || event.thinkingTimeStart,
thinkingTimeEnd: event.thinkingTimeEnd, thinkingTimeEnd: event.thinkingTimeEnd,
model: selectedModel.model, model: selectedModel,
}); });
newMessages[newMessages.length - 1] = updatedMessage; newMessages[newMessages.length - 1] = updatedMessage;
} else { } else {
@@ -533,7 +533,7 @@ export const useSendMessage = (chatId: string) => {
tool_calls: event.toolCalls, tool_calls: event.toolCalls,
thinkingTimeStart: event.thinkingTimeStart, thinkingTimeStart: event.thinkingTimeStart,
thinkingTimeEnd: event.thinkingTimeEnd, thinkingTimeEnd: event.thinkingTimeEnd,
model: selectedModel.model, model: selectedModel,
}), }),
); );
} }
@@ -699,7 +699,7 @@ export const useSendMessage = (chatId: string) => {
queryClient.setQueryData(["chat", newId], { queryClient.setQueryData(["chat", newId], {
chat: new Chat({ chat: new Chat({
id: newId, id: newId,
model: effectiveModel.model, model: effectiveModel,
messages: [ messages: [
new Message({ new Message({
role: "user", role: "user",

View File

@@ -9,7 +9,6 @@ interface SettingsState {
webSearchEnabled: boolean; webSearchEnabled: boolean;
selectedModel: string; selectedModel: string;
sidebarOpen: boolean; sidebarOpen: boolean;
lastHomeView: string;
thinkEnabled: boolean; thinkEnabled: boolean;
thinkLevel: string; thinkLevel: string;
} }
@@ -22,7 +21,6 @@ type SettingsUpdate = Partial<{
ThinkLevel: string; ThinkLevel: string;
SelectedModel: string; SelectedModel: string;
SidebarOpen: boolean; SidebarOpen: boolean;
LastHomeView: string;
}>; }>;
export function useSettings() { export function useSettings() {
@@ -52,7 +50,6 @@ export function useSettings() {
thinkLevel: settingsData?.settings?.ThinkLevel ?? "none", thinkLevel: settingsData?.settings?.ThinkLevel ?? "none",
selectedModel: settingsData?.settings?.SelectedModel ?? "", selectedModel: settingsData?.settings?.SelectedModel ?? "",
sidebarOpen: settingsData?.settings?.SidebarOpen ?? false, sidebarOpen: settingsData?.settings?.SidebarOpen ?? false,
lastHomeView: settingsData?.settings?.LastHomeView ?? "launch",
}), }),
[settingsData?.settings], [settingsData?.settings],
); );

View File

@@ -4,37 +4,12 @@ import Chat from "@/components/Chat";
import { getChat } from "@/api"; import { getChat } from "@/api";
import { SidebarLayout } from "@/components/layout/layout"; import { SidebarLayout } from "@/components/layout/layout";
import { ChatSidebar } from "@/components/ChatSidebar"; import { ChatSidebar } from "@/components/ChatSidebar";
import LaunchCommands from "@/components/LaunchCommands";
import { useEffect, useRef } from "react";
import { useSettings } from "@/hooks/useSettings";
const launchSidebarRequestedKey = "ollama.launchSidebarRequested";
const launchSidebarSeenKey = "ollama.launchSidebarSeen";
const fallbackSessionState = new Map<string, string>();
function getSessionState() {
if (typeof sessionStorage !== "undefined") {
return sessionStorage;
}
return {
getItem(key: string) {
return fallbackSessionState.get(key) ?? null;
},
setItem(key: string, value: string) {
fallbackSessionState.set(key, value);
},
removeItem(key: string) {
fallbackSessionState.delete(key);
},
};
}
export const Route = createFileRoute("/c/$chatId")({ export const Route = createFileRoute("/c/$chatId")({
component: RouteComponent, component: RouteComponent,
loader: async ({ context, params }) => { loader: async ({ context, params }) => {
// Skip loading for special non-chat views // Skip loading for "new" chat
if (params.chatId !== "new" && params.chatId !== "launch") { if (params.chatId !== "new") {
context.queryClient.ensureQueryData({ context.queryClient.ensureQueryData({
queryKey: ["chat", params.chatId], queryKey: ["chat", params.chatId],
queryFn: () => getChat(params.chatId), queryFn: () => getChat(params.chatId),
@@ -46,70 +21,13 @@ export const Route = createFileRoute("/c/$chatId")({
function RouteComponent() { function RouteComponent() {
const { chatId } = Route.useParams(); const { chatId } = Route.useParams();
const { settingsData, setSettings } = useSettings();
const previousChatIdRef = useRef<string | null>(null);
// Always call hooks at the top level - use a flag to skip data when chatId is a special view // Always call hooks at the top level - use a flag to skip data when chatId is "new"
const { const {
data: chatData, data: chatData,
isLoading: chatLoading, isLoading: chatLoading,
error: chatError, error: chatError,
} = useChat(chatId === "new" || chatId === "launch" ? "" : chatId); } = useChat(chatId === "new" ? "" : chatId);
useEffect(() => {
if (!settingsData) {
return;
}
const previousChatId = previousChatIdRef.current;
previousChatIdRef.current = chatId;
if (chatId === "launch") {
const sessionState = getSessionState();
const shouldOpenSidebar =
previousChatId !== "launch" &&
(() => {
if (sessionState.getItem(launchSidebarRequestedKey) === "1") {
sessionState.removeItem(launchSidebarRequestedKey);
sessionState.setItem(launchSidebarSeenKey, "1");
return true;
}
if (sessionState.getItem(launchSidebarSeenKey) !== "1") {
sessionState.setItem(launchSidebarSeenKey, "1");
return true;
}
return false;
})();
const updates: { LastHomeView?: string; SidebarOpen?: boolean } = {};
if (settingsData.LastHomeView !== "launch") {
updates.LastHomeView = "launch";
}
if (shouldOpenSidebar && !settingsData.SidebarOpen) {
updates.SidebarOpen = true;
}
if (Object.keys(updates).length === 0) {
return;
}
setSettings(updates).catch(() => {
// Best effort persistence for home view preference.
});
return;
}
if (settingsData.LastHomeView === "chat") {
return;
}
setSettings({ LastHomeView: "chat" }).catch(() => {
// Best effort persistence for home view preference.
});
}, [chatId, settingsData, setSettings]);
// Handle "new" chat case - just use Chat component which handles everything // Handle "new" chat case - just use Chat component which handles everything
if (chatId === "new") { if (chatId === "new") {
@@ -120,14 +38,6 @@ function RouteComponent() {
); );
} }
if (chatId === "launch") {
return (
<SidebarLayout sidebar={<ChatSidebar currentChatId={chatId} />}>
<LaunchCommands />
</SidebarLayout>
);
}
// Handle existing chat case // Handle existing chat case
if (chatLoading) { if (chatLoading) {
return ( return (

View File

@@ -1,18 +1,10 @@
import { createFileRoute, redirect } from "@tanstack/react-router"; import { createFileRoute, redirect } from "@tanstack/react-router";
import { getSettings } from "@/api";
export const Route = createFileRoute("/")({ export const Route = createFileRoute("/")({
beforeLoad: async ({ context }) => { beforeLoad: () => {
const settingsData = await context.queryClient.ensureQueryData({
queryKey: ["settings"],
queryFn: getSettings,
});
const chatId =
settingsData?.settings?.LastHomeView === "chat" ? "new" : "launch";
throw redirect({ throw redirect({
to: "/c/$chatId", to: "/c/$chatId",
params: { chatId }, params: { chatId: "new" },
mask: { mask: {
to: "/", to: "/",
}, },

View File

@@ -1,57 +0,0 @@
import { describe, expect, it, vi, beforeEach } from "vitest";
import { copyTextToClipboard } from "./clipboard";
describe("copyTextToClipboard", () => {
beforeEach(() => {
vi.restoreAllMocks();
});
it("copies via Clipboard API when available", async () => {
const writeText = vi.fn().mockResolvedValue(undefined);
vi.stubGlobal("navigator", {
clipboard: {
writeText,
},
});
const copied = await copyTextToClipboard("ollama launch claude");
expect(copied).toBe(true);
expect(writeText).toHaveBeenCalledWith("ollama launch claude");
});
it("falls back to execCommand when Clipboard API fails", async () => {
const writeText = vi.fn().mockRejectedValue(new Error("not allowed"));
vi.stubGlobal("navigator", {
clipboard: {
writeText,
},
});
const textarea = {
value: "",
setAttribute: vi.fn(),
style: {} as Record<string, string>,
focus: vi.fn(),
select: vi.fn(),
};
const appendChild = vi.fn();
const removeChild = vi.fn();
const execCommand = vi.fn().mockReturnValue(true);
vi.stubGlobal("document", {
createElement: vi.fn().mockReturnValue(textarea),
body: {
appendChild,
removeChild,
},
execCommand,
});
const copied = await copyTextToClipboard("ollama launch openclaw");
expect(copied).toBe(true);
expect(execCommand).toHaveBeenCalledWith("copy");
expect(appendChild).toHaveBeenCalled();
expect(removeChild).toHaveBeenCalled();
});
});

View File

@@ -1,30 +0,0 @@
export async function copyTextToClipboard(text: string): Promise<boolean> {
try {
await navigator.clipboard.writeText(text);
return true;
} catch (clipboardError) {
console.error(
"Clipboard API failed, falling back to execCommand",
clipboardError,
);
}
try {
const textarea = document.createElement("textarea");
textarea.value = text;
textarea.setAttribute("readonly", "true");
textarea.style.position = "fixed";
textarea.style.left = "-9999px";
textarea.style.opacity = "0";
document.body.appendChild(textarea);
textarea.focus();
textarea.select();
const copied = document.execCommand("copy");
document.body.removeChild(textarea);
return copied;
} catch (fallbackError) {
console.error("Fallback copy failed", fallbackError);
return false;
}
}

View File

@@ -29,15 +29,13 @@ describe("fileValidation", () => {
expect(result.valid).toBe(true); expect(result.valid).toBe(true);
}); });
it("should accept images regardless of vision capability", () => { it("should reject WebP images when vision capability is disabled", () => {
// Vision capability check is handled at the UI layer (ChatForm),
// not at validation time, so users can switch models without
// needing to re-upload files.
const file = createMockFile("test.webp", 1024, "image/webp"); const file = createMockFile("test.webp", 1024, "image/webp");
const result = validateFile(file, { const result = validateFile(file, {
hasVisionCapability: false, hasVisionCapability: false,
}); });
expect(result.valid).toBe(true); expect(result.valid).toBe(false);
expect(result.error).toBe("This model does not support images");
}); });
it("should accept PNG images when vision capability is enabled", () => { it("should accept PNG images when vision capability is enabled", () => {

View File

@@ -63,6 +63,7 @@ export function validateFile(
const { const {
maxFileSize = 10, maxFileSize = 10,
allowedExtensions = [...TEXT_FILE_EXTENSIONS, ...IMAGE_EXTENSIONS], allowedExtensions = [...TEXT_FILE_EXTENSIONS, ...IMAGE_EXTENSIONS],
hasVisionCapability = false,
customValidator, customValidator,
} = options; } = options;
@@ -82,6 +83,10 @@ export function validateFile(
return { valid: false, error: "File type not supported" }; return { valid: false, error: "File type not supported" };
} }
if (IMAGE_EXTENSIONS.includes(fileExtension) && !hasVisionCapability) {
return { valid: false, error: "This model does not support images" };
}
// File size validation // File size validation
if (file.size > MAX_FILE_SIZE) { if (file.size > MAX_FILE_SIZE) {
return { valid: false, error: "File too large" }; return { valid: false, error: "File too large" };

View File

@@ -2,28 +2,27 @@ import { Model } from "@/gotypes";
// Featured models list (in priority order) // Featured models list (in priority order)
export const FEATURED_MODELS = [ export const FEATURED_MODELS = [
"kimi-k2.5:cloud",
"glm-5:cloud",
"minimax-m2.7:cloud",
"gemma4:31b-cloud",
"qwen3.5:397b-cloud",
"gpt-oss:120b-cloud", "gpt-oss:120b-cloud",
"gpt-oss:20b-cloud", "gpt-oss:20b-cloud",
"deepseek-v3.1:671b-cloud", "deepseek-v3.1:671b-cloud",
"qwen3-coder:480b-cloud",
"qwen3-vl:235b-cloud",
"minimax-m2:cloud",
"glm-4.6:cloud",
"gpt-oss:120b", "gpt-oss:120b",
"gpt-oss:20b", "gpt-oss:20b",
"gemma4:31b", "gemma3:27b",
"gemma4:26b", "gemma3:12b",
"gemma4:e4b", "gemma3:4b",
"gemma4:e2b", "gemma3:1b",
"deepseek-r1:8b", "deepseek-r1:8b",
"qwen3-coder:30b", "qwen3-coder:30b",
"qwen3-vl:30b", "qwen3-vl:30b",
"qwen3-vl:8b", "qwen3-vl:8b",
"qwen3-vl:4b", "qwen3-vl:4b",
"qwen3.5:27b", "qwen3:30b",
"qwen3.5:9b", "qwen3:8b",
"qwen3.5:4b", "qwen3:4b",
]; ];
function alphabeticalSort(a: Model, b: Model): number { function alphabeticalSort(a: Model, b: Model): number {

View File

@@ -133,8 +133,9 @@ type Error struct {
} }
type ModelUpstreamResponse struct { type ModelUpstreamResponse struct {
Stale bool `json:"stale"` Digest string `json:"digest,omitempty"`
Error string `json:"error,omitempty"` PushTime int64 `json:"pushTime"`
Error string `json:"error,omitempty"`
} }
// Serializable data for the browser state // Serializable data for the browser state

View File

@@ -32,7 +32,6 @@ import (
"github.com/ollama/ollama/app/version" "github.com/ollama/ollama/app/version"
ollamaAuth "github.com/ollama/ollama/auth" ollamaAuth "github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
_ "github.com/tkrajina/typescriptify-golang-structs/typescriptify" _ "github.com/tkrajina/typescriptify-golang-structs/typescriptify"
) )
@@ -156,7 +155,7 @@ func (s *Server) ollamaProxy() http.Handler {
return return
} }
target := envconfig.ConnectableHost() target := envconfig.Host()
s.log().Info("configuring ollama proxy", "target", target.String()) s.log().Info("configuring ollama proxy", "target", target.String())
newProxy := httputil.NewSingleHostReverseProxy(target) newProxy := httputil.NewSingleHostReverseProxy(target)
@@ -194,7 +193,7 @@ func (s *Server) Handler() http.Handler {
if CORS() { if CORS() {
w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, User-Agent, Accept, X-Requested-With") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
w.Header().Set("Access-Control-Allow-Credentials", "true") w.Header().Set("Access-Control-Allow-Credentials", "true")
// Handle preflight requests // Handle preflight requests
@@ -319,7 +318,7 @@ func (s *Server) handleError(w http.ResponseWriter, e error) {
if CORS() { if CORS() {
w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, User-Agent, Accept, X-Requested-With") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
w.Header().Set("Access-Control-Allow-Credentials", "true") w.Header().Set("Access-Control-Allow-Credentials", "true")
} }
@@ -342,18 +341,8 @@ func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error
// httpClient returns an HTTP client that automatically adds the User-Agent header // httpClient returns an HTTP client that automatically adds the User-Agent header
func (s *Server) httpClient() *http.Client { func (s *Server) httpClient() *http.Client {
return userAgentHTTPClient(10 * time.Second)
}
// inferenceClient uses almost the same HTTP client, but without a timeout so
// long requests aren't truncated
func (s *Server) inferenceClient() *api.Client {
return api.NewClient(envconfig.Host(), userAgentHTTPClient(0))
}
func userAgentHTTPClient(timeout time.Duration) *http.Client {
return &http.Client{ return &http.Client{
Timeout: timeout, Timeout: 10 * time.Second,
Transport: &userAgentTransport{ Transport: &userAgentTransport{
base: http.DefaultTransport, base: http.DefaultTransport,
}, },
@@ -731,7 +720,11 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
_, cancelLoading := context.WithCancel(ctx) _, cancelLoading := context.WithCancel(ctx)
loading := false loading := false
c := s.inferenceClient() c, err := api.ClientFromEnvironment()
if err != nil {
cancelLoading()
return err
}
// Check if the model exists locally by trying to show it // Check if the model exists locally by trying to show it
// TODO (jmorganca): skip this round trip and instead just act // TODO (jmorganca): skip this round trip and instead just act
@@ -1579,18 +1572,9 @@ func (s *Server) modelUpstream(w http.ResponseWriter, r *http.Request) error {
return json.NewEncoder(w).Encode(response) return json.NewEncoder(w).Encode(response)
} }
n := model.ParseName(req.Model)
stale := true
if m, err := manifest.ParseNamedManifest(n); err == nil {
if m.Digest() == digest {
stale = false
} else if pushTime > 0 && m.FileInfo().ModTime().Unix() >= pushTime {
stale = false
}
}
response := responses.ModelUpstreamResponse{ response := responses.ModelUpstreamResponse{
Stale: stale, Digest: digest,
PushTime: pushTime,
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
@@ -1688,6 +1672,7 @@ func supportsBrowserTools(model string) bool {
return strings.HasPrefix(strings.ToLower(model), "gpt-oss") return strings.HasPrefix(strings.ToLower(model), "gpt-oss")
} }
// buildChatRequest converts store.Chat to api.ChatRequest // buildChatRequest converts store.Chat to api.ChatRequest
func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, availableTools []map[string]any) (*api.ChatRequest, error) { func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, availableTools []map[string]any) (*api.ChatRequest, error) {
var msgs []api.Message var msgs []api.Message

View File

@@ -15,7 +15,6 @@ import (
"sync/atomic" "sync/atomic"
"testing" "testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/store" "github.com/ollama/ollama/app/store"
"github.com/ollama/ollama/app/updater" "github.com/ollama/ollama/app/updater"
) )
@@ -527,33 +526,6 @@ func TestUserAgentTransport(t *testing.T) {
t.Logf("User-Agent transport successfully set: %s", receivedUA) t.Logf("User-Agent transport successfully set: %s", receivedUA)
} }
func TestInferenceClientUsesUserAgent(t *testing.T) {
var gotUserAgent atomic.Value
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotUserAgent.Store(r.Header.Get("User-Agent"))
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{}`))
}))
defer ts.Close()
t.Setenv("OLLAMA_HOST", ts.URL)
server := &Server{}
client := server.inferenceClient()
_, err := client.Show(context.Background(), &api.ShowRequest{Model: "test"})
if err != nil {
t.Fatalf("show request failed: %v", err)
}
receivedUA, _ := gotUserAgent.Load().(string)
expectedUA := userAgent()
if receivedUA != expectedUA {
t.Errorf("User-Agent mismatch\nExpected: %s\nReceived: %s", expectedUA, receivedUA)
}
}
func TestSupportsBrowserTools(t *testing.T) { func TestSupportsBrowserTools(t *testing.T) {
tests := []struct { tests := []struct {
model string model string

View File

@@ -1,31 +1,27 @@
Ollama Benchmark Tool Ollama Benchmark Tool
--------------------- ---------------------
A Go-based command-line tool for benchmarking Ollama models with configurable parameters, warmup phases, TTFT tracking, VRAM monitoring, and benchstat/CSV output. A Go-based command-line tool for benchmarking Ollama models with configurable parameters and multiple output formats.
## Features ## Features
* Benchmark multiple models in a single run * Benchmark multiple models in a single run
* Support for both text and image prompts * Support for both text and image prompts
* Configurable generation parameters (temperature, max tokens, seed, etc.) * Configurable generation parameters (temperature, max tokens, seed, etc.)
* Warmup phase before timed epochs to stabilize measurements * Supports benchstat and CSV output formats
* Time-to-first-token (TTFT) tracking per epoch * Detailed performance metrics (prefill, generate, load, total durations)
* Model metadata display (parameter size, quantization level, family)
* VRAM and CPU memory usage tracking via running process info
* Controlled prompt token length for reproducible benchmarks
* Benchstat and CSV output formats
## Building from Source ## Building from Source
``` ```
go build -o ollama-bench ./cmd/bench go build -o ollama-bench bench.go
./ollama-bench -model gemma3 -epochs 6 -format csv ./ollama-bench -model gpt-oss:20b -epochs 6 -format csv
``` ```
Using Go Run (without building) Using Go Run (without building)
``` ```
go run ./cmd/bench -model gemma3 -epochs 3 go run bench.go -model gpt-oss:20b -epochs 3
``` ```
## Usage ## Usage
@@ -49,16 +45,10 @@ benchstat -col /name gemma.bench
./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image" ./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
``` ```
### Controlled Prompt Length
```
./ollama-bench -model gemma3 -epochs 6 -prompt-tokens 512
```
### Advanced Example ### Advanced Example
``` ```
./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -warmup 2 -format csv -output results.csv ./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv
``` ```
## Command Line Options ## Command Line Options
@@ -66,48 +56,41 @@ benchstat -col /name gemma.bench
| Option | Description | Default | | Option | Description | Default |
|----------|-------------|---------| |----------|-------------|---------|
| -model | Comma-separated list of models to benchmark | (required) | | -model | Comma-separated list of models to benchmark | (required) |
| -epochs | Number of iterations per model | 6 | | -epochs | Number of iterations per model | 1 |
| -max-tokens | Maximum tokens for model response | 200 | | -max-tokens | Maximum tokens for model response | 0 (unlimited) |
| -temperature | Temperature parameter | 0.0 | | -temperature | Temperature parameter | 0.0 |
| -seed | Random seed | 0 (random) | | -seed | Random seed | 0 (random) |
| -timeout | Timeout in seconds | 300 | | -timeout | Timeout in seconds | 300 |
| -p | Prompt text | (default story prompt) | | -p | Prompt text | "Write a long story." |
| -image | Image file to include in prompt | | | -image | Image file to include in prompt | |
| -k | Keep-alive duration in seconds | 0 | | -k | Keep-alive duration in seconds | 0 |
| -format | Output format (benchstat, csv) | benchstat | | -format | Output format (benchstat, csv) | benchstat |
| -output | Output file for results | "" (stdout) | | -output | Output file for results | "" (stdout) |
| -warmup | Number of warmup requests before timing | 1 |
| -prompt-tokens | Generate prompt targeting ~N tokens (0 = use -p) | 0 |
| -v | Verbose mode | false | | -v | Verbose mode | false |
| -debug | Show debug information | false | | -debug | Show debug information | false |
## Output Formats ## Output Formats
### Benchstat Format (default) ### Markdown Format
Compatible with Go's benchstat tool for statistical analysis. Uses one value/unit pair per line, standard `ns/op` for timing metrics, and `ns/token` for throughput. Each epoch produces one set of lines -- benchstat aggregates across repeated runs to compute statistics.
The default markdown format is suitable for copying and pasting into a GitHub issue and will look like:
``` ```
# Model: gemma3 | Params: 4.3B | Quant: Q4_K_M | Family: gemma3 | Size: 4080218931 | VRAM: 4080218931 Model | Step | Count | Duration | nsPerToken | tokensPerSec |
BenchmarkModel/name=gemma3/step=prefill 1 78125.00 ns/token 12800.00 token/sec |-------|------|-------|----------|------------|--------------|
BenchmarkModel/name=gemma3/step=generate 1 19531.25 ns/token 51200.00 token/sec | gpt-oss:20b | prefill | 124 | 30.006458ms | 241987.56 | 4132.44 |
BenchmarkModel/name=gemma3/step=ttft 1 45123000 ns/op | gpt-oss:20b | generate | 200 | 2.646843954s | 13234219.77 | 75.56 |
BenchmarkModel/name=gemma3/step=load 1 1500000000 ns/op | gpt-oss:20b | load | 1 | 121.674208ms | - | - |
BenchmarkModel/name=gemma3/step=total 1 2861047625 ns/op | gpt-oss:20b | total | 1 | 2.861047625s | - | - |
``` ```
Use with benchstat: ### Benchstat Format
```
./ollama-bench -model gemma3 -epochs 6 > gemma3.bench Compatible with Go's benchstat tool for statistical analysis:
benchstat -col /step gemma3.bench
```
Compare two runs:
``` ```
./ollama-bench -model gemma3 -epochs 6 > before.bench BenchmarkModel/name=gpt-oss:20b/step=prefill 128 78125.00 ns/token 12800.00 token/sec
# ... make changes ... BenchmarkModel/name=gpt-oss:20b/step=generate 512 19531.25 ns/token 51200.00 token/sec
./ollama-bench -model gemma3 -epochs 6 > after.bench BenchmarkModel/name=gpt-oss:20b/step=load 1 1500000000 ns/request
benchstat before.bench after.bench
``` ```
### CSV Format ### CSV Format
@@ -116,28 +99,17 @@ Machine-readable comma-separated values:
``` ```
NAME,STEP,COUNT,NS_PER_COUNT,TOKEN_PER_SEC NAME,STEP,COUNT,NS_PER_COUNT,TOKEN_PER_SEC
# Model: gemma3 | Params: 4.3B | Quant: Q4_K_M | Family: gemma3 | Size: 4080218931 | VRAM: 4080218931 gpt-oss:20b,prefill,128,78125.00,12800.00
gemma3,prefill,128,78125.00,12800.00 gpt-oss:20b,generate,512,19531.25,51200.00
gemma3,generate,512,19531.25,51200.00 gpt-oss:20b,load,1,1500000000,0
gemma3,ttft,1,45123000,0
gemma3,load,1,1500000000,0
gemma3,total,1,2861047625,0
``` ```
## Metrics Explained ## Metrics Explained
The tool reports the following metrics for each epoch: The tool reports four types of metrics for each model:
* **prefill**: Time spent processing the prompt (ns/token) * prefill: Time spent processing the prompt
* **generate**: Time spent generating the response (ns/token) * generate: Time spent generating the response
* **ttft**: Time to first token -- latency from request start to first response content * load: Model loading time (one-time cost)
* **load**: Model loading time (one-time cost) * total: Total request duration
* **total**: Total request duration
Additionally, the model info comment line (displayed once per model before epochs) includes:
* **Params**: Model parameter count (e.g., 4.3B)
* **Quant**: Quantization level (e.g., Q4_K_M)
* **Family**: Model family (e.g., gemma3)
* **Size**: Total model memory in bytes
* **VRAM**: GPU memory used by the loaded model (when Size > VRAM, the difference is CPU spill)

View File

@@ -17,22 +17,19 @@ import (
) )
type flagOptions struct { type flagOptions struct {
models *string models *string
epochs *int epochs *int
maxTokens *int maxTokens *int
temperature *float64 temperature *float64
seed *int seed *int
timeout *int timeout *int
prompt *string prompt *string
imageFile *string imageFile *string
keepAlive *float64 keepAlive *float64
format *string format *string
outputFile *string outputFile *string
debug *bool debug *bool
verbose *bool verbose *bool
warmup *int
promptTokens *int
numCtx *int
} }
type Metrics struct { type Metrics struct {
@@ -42,203 +39,48 @@ type Metrics struct {
Duration time.Duration Duration time.Duration
} }
type ModelInfo struct { var once sync.Once
Name string
ParameterSize string
QuantizationLevel string
Family string
SizeBytes int64
VRAMBytes int64
NumCtx int64
}
const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.` const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.`
// Word list for generating prompts targeting a specific token count.
var promptWordList = []string{
"the", "quick", "brown", "fox", "jumps", "over", "lazy", "dog",
"a", "bright", "sunny", "day", "in", "the", "meadow", "where",
"flowers", "bloom", "and", "birds", "sing", "their", "morning",
"songs", "while", "gentle", "breeze", "carries", "sweet", "scent",
"of", "pine", "trees", "across", "rolling", "hills", "toward",
"distant", "mountains", "covered", "with", "fresh", "snow",
"beneath", "clear", "blue", "sky", "children", "play", "near",
"old", "stone", "bridge", "that", "crosses", "winding", "river",
}
// tokensPerWord is the calibrated ratio of tokens to words for the current model.
// Initialized with a heuristic, then updated during warmup based on actual tokenization.
var tokensPerWord = 1.3
func generatePromptForTokenCount(targetTokens int, epoch int) string {
targetWords := int(float64(targetTokens) / tokensPerWord)
if targetWords < 1 {
targetWords = 1
}
// Vary the starting offset by epoch to defeat KV cache prefix matching
offset := epoch * 7 // stride by a prime to get good distribution
n := len(promptWordList)
words := make([]string, targetWords)
for i := range words {
words[i] = promptWordList[((i+offset)%n+n)%n]
}
return strings.Join(words, " ")
}
// calibratePromptTokens adjusts tokensPerWord based on actual tokenization from a warmup run.
func calibratePromptTokens(targetTokens, actualTokens, wordCount int) {
if actualTokens <= 0 || wordCount <= 0 {
return
}
tokensPerWord = float64(actualTokens) / float64(wordCount)
newWords := int(float64(targetTokens) / tokensPerWord)
fmt.Fprintf(os.Stderr, "bench: calibrated %.2f tokens/word (target=%d, got=%d, words=%d → %d)\n",
tokensPerWord, targetTokens, actualTokens, wordCount, newWords)
}
func buildGenerateRequest(model string, fOpt flagOptions, imgData api.ImageData, epoch int) *api.GenerateRequest {
options := make(map[string]interface{})
if *fOpt.maxTokens > 0 {
options["num_predict"] = *fOpt.maxTokens
}
options["temperature"] = *fOpt.temperature
if fOpt.seed != nil && *fOpt.seed > 0 {
options["seed"] = *fOpt.seed
}
if fOpt.numCtx != nil && *fOpt.numCtx > 0 {
options["num_ctx"] = *fOpt.numCtx
}
var keepAliveDuration *api.Duration
if *fOpt.keepAlive > 0 {
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
keepAliveDuration = &duration
}
prompt := *fOpt.prompt
if *fOpt.promptTokens > 0 {
prompt = generatePromptForTokenCount(*fOpt.promptTokens, epoch)
} else {
// Vary the prompt per epoch to defeat KV cache prefix matching
prompt = fmt.Sprintf("[%d] %s", epoch, prompt)
}
req := &api.GenerateRequest{
Model: model,
Prompt: prompt,
Raw: true,
Options: options,
KeepAlive: keepAliveDuration,
}
if imgData != nil {
req.Images = []api.ImageData{imgData}
}
return req
}
func fetchModelInfo(ctx context.Context, client *api.Client, model string) ModelInfo {
info := ModelInfo{Name: model}
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
if err != nil {
fmt.Fprintf(os.Stderr, "WARNING: Could not fetch model info for '%s': %v\n", model, err)
return info
}
info.ParameterSize = resp.Details.ParameterSize
info.QuantizationLevel = resp.Details.QuantizationLevel
info.Family = resp.Details.Family
return info
}
func fetchMemoryUsage(ctx context.Context, client *api.Client, model string) (size, vram int64) {
resp, err := client.ListRunning(ctx)
if err != nil {
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
fmt.Fprintf(os.Stderr, "WARNING: Could not fetch memory usage: %v\n", err)
}
return 0, 0
}
for _, m := range resp.Models {
if m.Name == model || m.Model == model {
return m.Size, m.SizeVRAM
}
}
for _, m := range resp.Models {
if strings.HasPrefix(m.Name, model) || strings.HasPrefix(m.Model, model) {
return m.Size, m.SizeVRAM
}
}
return 0, 0
}
func fetchContextLength(ctx context.Context, client *api.Client, model string) int64 {
resp, err := client.ListRunning(ctx)
if err != nil {
return 0
}
for _, m := range resp.Models {
if m.Name == model || m.Model == model || strings.HasPrefix(m.Name, model) || strings.HasPrefix(m.Model, model) {
return int64(m.ContextLength)
}
}
return 0
}
func outputFormatHeader(w io.Writer, format string, verbose bool) {
switch format {
case "benchstat":
if verbose {
fmt.Fprintf(w, "goos: %s\n", runtime.GOOS)
fmt.Fprintf(w, "goarch: %s\n", runtime.GOARCH)
}
case "csv":
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
fmt.Fprintln(w, strings.Join(headings, ","))
}
}
func outputModelInfo(w io.Writer, format string, info ModelInfo) {
params := cmp.Or(info.ParameterSize, "unknown")
quant := cmp.Or(info.QuantizationLevel, "unknown")
family := cmp.Or(info.Family, "unknown")
memStr := ""
if info.SizeBytes > 0 {
memStr = fmt.Sprintf(" | Size: %d | VRAM: %d", info.SizeBytes, info.VRAMBytes)
}
ctxStr := ""
if info.NumCtx > 0 {
ctxStr = fmt.Sprintf(" | NumCtx: %d", info.NumCtx)
}
fmt.Fprintf(w, "# Model: %s | Params: %s | Quant: %s | Family: %s%s%s\n",
info.Name, params, quant, family, memStr, ctxStr)
}
func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) { func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) {
switch format { switch format {
case "benchstat": case "benchstat":
if verbose {
printHeader := func() {
fmt.Fprintf(w, "sysname: %s\n", runtime.GOOS)
fmt.Fprintf(w, "machine: %s\n", runtime.GOARCH)
}
once.Do(printHeader)
}
for _, m := range metrics { for _, m := range metrics {
if m.Step == "generate" || m.Step == "prefill" { if m.Step == "generate" || m.Step == "prefill" {
if m.Count > 0 { if m.Count > 0 {
nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count) nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count)
tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9 tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 %.2f ns/token %.2f token/sec\n",
m.Model, m.Step, nsPerToken, tokensPerSec) fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d %.2f ns/token %.2f token/sec\n",
m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
} else { } else {
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 0 ns/token 0 token/sec\n", fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d 0 ns/token 0 token/sec\n",
m.Model, m.Step) m.Model, m.Step, m.Count)
} }
} else if m.Step == "ttft" {
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=ttft 1 %d ns/op\n",
m.Model, m.Duration.Nanoseconds())
} else { } else {
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 %d ns/op\n", var suffix string
m.Model, m.Step, m.Duration.Nanoseconds()) if m.Step == "load" {
suffix = "/step=load"
}
fmt.Fprintf(w, "BenchmarkModel/name=%s%s 1 %d ns/request\n",
m.Model, suffix, m.Duration.Nanoseconds())
} }
} }
case "csv": case "csv":
printHeader := func() {
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
fmt.Fprintln(w, strings.Join(headings, ","))
}
once.Do(printHeader)
for _, m := range metrics { for _, m := range metrics {
if m.Step == "generate" || m.Step == "prefill" { if m.Step == "generate" || m.Step == "prefill" {
var nsPerToken float64 var nsPerToken float64
@@ -252,14 +94,39 @@ func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool)
fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds()) fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds())
} }
} }
case "markdown":
printHeader := func() {
fmt.Fprintln(w, "| Model | Step | Count | Duration | nsPerToken | tokensPerSec |")
fmt.Fprintln(w, "|-------|------|-------|----------|------------|--------------|")
}
once.Do(printHeader)
for _, m := range metrics {
var nsPerToken, tokensPerSec float64
var nsPerTokenStr, tokensPerSecStr string
if m.Step == "generate" || m.Step == "prefill" {
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
nsPerTokenStr = fmt.Sprintf("%.2f", nsPerToken)
tokensPerSecStr = fmt.Sprintf("%.2f", tokensPerSec)
} else {
nsPerTokenStr = "-"
tokensPerSecStr = "-"
}
fmt.Fprintf(w, "| %s | %s | %d | %v | %s | %s |\n",
m.Model, m.Step, m.Count, m.Duration, nsPerTokenStr, tokensPerSecStr)
}
default: default:
fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format) fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format)
} }
} }
func BenchmarkModel(fOpt flagOptions) error { func BenchmarkChat(fOpt flagOptions) error {
models := strings.Split(*fOpt.models, ",") models := strings.Split(*fOpt.models, ",")
// todo - add multi-image support
var imgData api.ImageData var imgData api.ImageData
var err error var err error
if *fOpt.imageFile != "" { if *fOpt.imageFile != "" {
@@ -291,141 +158,71 @@ func BenchmarkModel(fOpt flagOptions) error {
out = f out = f
} }
outputFormatHeader(out, *fOpt.format, *fOpt.verbose)
// Log prompt-tokens info in debug mode
if *fOpt.debug && *fOpt.promptTokens > 0 {
prompt := generatePromptForTokenCount(*fOpt.promptTokens, 0)
wordCount := len(strings.Fields(prompt))
fmt.Fprintf(os.Stderr, "Generated prompt targeting ~%d tokens (%d words, varied per epoch)\n", *fOpt.promptTokens, wordCount)
}
for _, model := range models { for _, model := range models {
// Fetch model info for range *fOpt.epochs {
infoCtx, infoCancel := context.WithTimeout(context.Background(), 10*time.Second) options := make(map[string]interface{})
info := fetchModelInfo(infoCtx, client, model) if *fOpt.maxTokens > 0 {
infoCancel() options["num_predict"] = *fOpt.maxTokens
}
options["temperature"] = *fOpt.temperature
if fOpt.seed != nil && *fOpt.seed > 0 {
options["seed"] = *fOpt.seed
}
var keepAliveDuration *api.Duration
if *fOpt.keepAlive > 0 {
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
keepAliveDuration = &duration
}
req := &api.ChatRequest{
Model: model,
Messages: []api.Message{
{
Role: "user",
Content: *fOpt.prompt,
},
},
Options: options,
KeepAlive: keepAliveDuration,
}
if imgData != nil {
req.Messages[0].Images = []api.ImageData{imgData}
}
var responseMetrics *api.Metrics
// Warmup phase (uses negative epoch numbers to avoid colliding with timed epochs)
for i := range *fOpt.warmup {
req := buildGenerateRequest(model, fOpt, imgData, -(i + 1))
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
defer cancel()
err = client.Chat(ctx, req, func(resp api.ChatResponse) error {
if *fOpt.debug {
fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Message.Thinking, resp.Message.Content))
}
var warmupMetrics *api.Metrics
err = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
if resp.Done { if resp.Done {
warmupMetrics = &resp.Metrics responseMetrics = &resp.Metrics
} }
return nil return nil
}) })
cancel()
if *fOpt.debug {
fmt.Fprintln(os.Stderr)
}
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "WARNING: Warmup %d/%d for %s failed: %v\n", i+1, *fOpt.warmup, model, err) if ctx.Err() == context.DeadlineExceeded {
} else { fmt.Fprintf(os.Stderr, "ERROR: Chat request timed out with model '%s' after %vs\n", model, 1)
if *fOpt.debug { continue
fmt.Fprintf(os.Stderr, "Warmup %d/%d for %s complete\n", i+1, *fOpt.warmup, model)
} }
// Calibrate prompt token count on last warmup run fmt.Fprintf(os.Stderr, "ERROR: Couldn't chat with model '%s': %v\n", model, err)
if i == *fOpt.warmup-1 && *fOpt.promptTokens > 0 && warmupMetrics != nil {
prompt := generatePromptForTokenCount(*fOpt.promptTokens, -(i + 1))
wordCount := len(strings.Fields(prompt))
calibratePromptTokens(*fOpt.promptTokens, warmupMetrics.PromptEvalCount, wordCount)
}
}
}
// Fetch memory/context info once after warmup (model is loaded and stable)
memCtx, memCancel := context.WithTimeout(context.Background(), 5*time.Second)
info.SizeBytes, info.VRAMBytes = fetchMemoryUsage(memCtx, client, model)
if fOpt.numCtx != nil && *fOpt.numCtx > 0 {
info.NumCtx = int64(*fOpt.numCtx)
} else {
info.NumCtx = fetchContextLength(memCtx, client, model)
}
memCancel()
outputModelInfo(out, *fOpt.format, info)
// Timed epoch loop
shortCount := 0
for epoch := range *fOpt.epochs {
var responseMetrics *api.Metrics
var ttft time.Duration
short := false
// Retry loop: if the model hits a stop token before max-tokens,
// retry with a different prompt (up to maxRetries times).
const maxRetries = 3
for attempt := range maxRetries + 1 {
responseMetrics = nil
ttft = 0
var ttftOnce sync.Once
req := buildGenerateRequest(model, fOpt, imgData, epoch+attempt*1000)
requestStart := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
err = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
if *fOpt.debug {
fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Thinking, resp.Response))
}
// Capture TTFT on first content
ttftOnce.Do(func() {
if resp.Response != "" || resp.Thinking != "" {
ttft = time.Since(requestStart)
}
})
if resp.Done {
responseMetrics = &resp.Metrics
}
return nil
})
cancel()
if *fOpt.debug {
fmt.Fprintln(os.Stderr)
}
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
fmt.Fprintf(os.Stderr, "ERROR: Request timed out with model '%s' after %vs\n", model, *fOpt.timeout)
} else {
fmt.Fprintf(os.Stderr, "ERROR: Couldn't generate with model '%s': %v\n", model, err)
}
break
}
if responseMetrics == nil {
fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model)
break
}
// Check if the response was shorter than requested
short = *fOpt.maxTokens > 0 && responseMetrics.EvalCount < *fOpt.maxTokens
if !short || attempt == maxRetries {
break
}
if *fOpt.debug {
fmt.Fprintf(os.Stderr, "Short response (%d/%d tokens), retrying with different prompt (attempt %d/%d)\n",
responseMetrics.EvalCount, *fOpt.maxTokens, attempt+1, maxRetries)
}
}
if err != nil || responseMetrics == nil {
continue continue
} }
if short { if responseMetrics == nil {
shortCount++ fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model)
if *fOpt.debug { continue
fmt.Fprintf(os.Stderr, "WARNING: Short response (%d/%d tokens) after %d retries for epoch %d\n",
responseMetrics.EvalCount, *fOpt.maxTokens, maxRetries, epoch+1)
}
} }
metrics := []Metrics{ metrics := []Metrics{
@@ -441,12 +238,6 @@ func BenchmarkModel(fOpt flagOptions) error {
Count: responseMetrics.EvalCount, Count: responseMetrics.EvalCount,
Duration: responseMetrics.EvalDuration, Duration: responseMetrics.EvalDuration,
}, },
{
Model: model,
Step: "ttft",
Count: 1,
Duration: ttft,
},
{ {
Model: model, Model: model,
Step: "load", Step: "load",
@@ -463,42 +254,15 @@ func BenchmarkModel(fOpt flagOptions) error {
OutputMetrics(out, *fOpt.format, metrics, *fOpt.verbose) OutputMetrics(out, *fOpt.format, metrics, *fOpt.verbose)
if *fOpt.debug && *fOpt.promptTokens > 0 {
fmt.Fprintf(os.Stderr, "Generated prompt targeting ~%d tokens (actual: %d)\n",
*fOpt.promptTokens, responseMetrics.PromptEvalCount)
}
if *fOpt.keepAlive > 0 { if *fOpt.keepAlive > 0 {
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond) time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
} }
} }
if shortCount > 0 {
fmt.Fprintf(os.Stderr, "WARNING: %d/%d epochs for '%s' had short responses (<%d tokens). Generation metrics may be unreliable.\n",
shortCount, *fOpt.epochs, model, *fOpt.maxTokens)
}
// Unload model before moving to the next one
unloadModel(client, model, *fOpt.timeout)
} }
return nil return nil
} }
func unloadModel(client *api.Client, model string, timeout int) {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
defer cancel()
zero := api.Duration{Duration: 0}
req := &api.GenerateRequest{
Model: model,
KeepAlive: &zero,
}
_ = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
return nil
})
}
func readImage(filePath string) (api.ImageData, error) { func readImage(filePath string) (api.ImageData, error) {
file, err := os.Open(filePath) file, err := os.Open(filePath)
if err != nil { if err != nil {
@@ -516,22 +280,19 @@ func readImage(filePath string) (api.ImageData, error) {
func main() { func main() {
fOpt := flagOptions{ fOpt := flagOptions{
models: flag.String("model", "", "Model to benchmark"), models: flag.String("model", "", "Model to benchmark"),
epochs: flag.Int("epochs", 6, "Number of epochs (iterations) per model"), epochs: flag.Int("epochs", 6, "Number of epochs (iterations) per model"),
maxTokens: flag.Int("max-tokens", 200, "Maximum tokens for model response"), maxTokens: flag.Int("max-tokens", 200, "Maximum tokens for model response"),
temperature: flag.Float64("temperature", 0, "Temperature parameter"), temperature: flag.Float64("temperature", 0, "Temperature parameter"),
seed: flag.Int("seed", 0, "Random seed"), seed: flag.Int("seed", 0, "Random seed"),
timeout: flag.Int("timeout", 60*5, "Timeout in seconds (default 300s)"), timeout: flag.Int("timeout", 60*5, "Timeout in seconds (default 300s)"),
prompt: flag.String("p", DefaultPrompt, "Prompt to use"), prompt: flag.String("p", DefaultPrompt, "Prompt to use"),
imageFile: flag.String("image", "", "Filename for an image to include"), imageFile: flag.String("image", "", "Filename for an image to include"),
keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"), keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"),
format: flag.String("format", "benchstat", "Output format [benchstat|csv]"), format: flag.String("format", "markdown", "Output format [benchstat|csv] (default benchstat)"),
outputFile: flag.String("output", "", "Output file for results (stdout if empty)"), outputFile: flag.String("output", "", "Output file for results (stdout if empty)"),
verbose: flag.Bool("v", false, "Show system information"), verbose: flag.Bool("v", false, "Show system information"),
debug: flag.Bool("debug", false, "Show debug information"), debug: flag.Bool("debug", false, "Show debug information"),
warmup: flag.Int("warmup", 1, "Number of warmup requests before timing"),
promptTokens: flag.Int("prompt-tokens", 0, "Generate prompt targeting ~N tokens (0 = use -p prompt)"),
numCtx: flag.Int("num-ctx", 0, "Context size (0 = server default)"),
} }
flag.Usage = func() { flag.Usage = func() {
@@ -541,12 +302,11 @@ func main() {
fmt.Fprintf(os.Stderr, "Options:\n") fmt.Fprintf(os.Stderr, "Options:\n")
flag.PrintDefaults() flag.PrintDefaults()
fmt.Fprintf(os.Stderr, "\nExamples:\n") fmt.Fprintf(os.Stderr, "\nExamples:\n")
fmt.Fprintf(os.Stderr, " bench -model gemma3,llama3 -epochs 6\n") fmt.Fprintf(os.Stderr, " bench -model gpt-oss:20b -epochs 3 -temperature 0.7\n")
fmt.Fprintf(os.Stderr, " bench -model gemma3 -epochs 6 -prompt-tokens 512 -format csv\n")
} }
flag.Parse() flag.Parse()
if !slices.Contains([]string{"benchstat", "csv"}, *fOpt.format) { if !slices.Contains([]string{"markdown", "benchstat", "csv"}, *fOpt.format) {
fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format) fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format)
os.Exit(1) os.Exit(1)
} }
@@ -557,5 +317,5 @@ func main() {
return return
} }
BenchmarkModel(fOpt) BenchmarkChat(fOpt)
} }

File diff suppressed because it is too large Load Diff

View File

@@ -11,7 +11,6 @@ import (
"fmt" "fmt"
"io" "io"
"log" "log"
"log/slog"
"math" "math"
"net" "net"
"net/http" "net/http"
@@ -39,12 +38,9 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config" "github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/cmd/launch"
"github.com/ollama/ollama/cmd/tui" "github.com/ollama/ollama/cmd/tui"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress" "github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline" "github.com/ollama/ollama/readline"
@@ -54,45 +50,46 @@ import (
"github.com/ollama/ollama/types/syncmap" "github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
xcmd "github.com/ollama/ollama/x/cmd" xcmd "github.com/ollama/ollama/x/cmd"
"github.com/ollama/ollama/x/create"
xcreateclient "github.com/ollama/ollama/x/create/client" xcreateclient "github.com/ollama/ollama/x/create/client"
"github.com/ollama/ollama/x/imagegen" "github.com/ollama/ollama/x/imagegen"
) )
func init() { func init() {
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O. // Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
launch.DefaultSingleSelector = func(title string, items []launch.ModelItem, current string) (string, error) { config.DefaultSingleSelector = func(title string, items []config.ModelItem, current string) (string, error) {
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
return "", fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode")
}
tuiItems := tui.ReorderItems(tui.ConvertItems(items)) tuiItems := tui.ReorderItems(tui.ConvertItems(items))
result, err := tui.SelectSingle(title, tuiItems, current) result, err := tui.SelectSingle(title, tuiItems, current)
if errors.Is(err, tui.ErrCancelled) { if errors.Is(err, tui.ErrCancelled) {
return "", launch.ErrCancelled return "", config.ErrCancelled
} }
return result, err return result, err
} }
launch.DefaultMultiSelector = func(title string, items []launch.ModelItem, preChecked []string) ([]string, error) { config.DefaultMultiSelector = func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
return nil, fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode")
}
tuiItems := tui.ReorderItems(tui.ConvertItems(items)) tuiItems := tui.ReorderItems(tui.ConvertItems(items))
result, err := tui.SelectMultiple(title, tuiItems, preChecked) result, err := tui.SelectMultiple(title, tuiItems, preChecked)
if errors.Is(err, tui.ErrCancelled) { if errors.Is(err, tui.ErrCancelled) {
return nil, launch.ErrCancelled return nil, config.ErrCancelled
} }
return result, err return result, err
} }
launch.DefaultSignIn = func(modelName, signInURL string) (string, error) { config.DefaultSignIn = func(modelName, signInURL string) (string, error) {
userName, err := tui.RunSignIn(modelName, signInURL) userName, err := tui.RunSignIn(modelName, signInURL)
if errors.Is(err, tui.ErrCancelled) { if errors.Is(err, tui.ErrCancelled) {
return "", launch.ErrCancelled return "", config.ErrCancelled
} }
return userName, err return userName, err
} }
launch.DefaultConfirmPrompt = tui.RunConfirmWithOptions config.DefaultConfirmPrompt = func(prompt string) (bool, error) {
ok, err := tui.RunConfirm(prompt)
if errors.Is(err, tui.ErrCancelled) {
return false, config.ErrCancelled
}
return ok, err
}
} }
const ConnectInstructions = "If your browser did not open, navigate to:\n %s\n\n" const ConnectInstructions = "If your browser did not open, navigate to:\n %s\n\n"
@@ -134,17 +131,6 @@ func getModelfileName(cmd *cobra.Command) (string, error) {
return absName, nil return absName, nil
} }
// isLocalhost returns true if the configured Ollama host is a loopback or unspecified address.
func isLocalhost() bool {
host := envconfig.Host()
h, _, _ := net.SplitHostPort(host.Host)
if h == "localhost" {
return true
}
ip := net.ParseIP(h)
return ip != nil && (ip.IsLoopback() || ip.IsUnspecified())
}
func CreateHandler(cmd *cobra.Command, args []string) error { func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr) p := progress.NewProgress(os.Stderr)
defer p.Stop() defer p.Stop()
@@ -157,13 +143,8 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
} }
// Check for --experimental flag for safetensors model creation // Check for --experimental flag for safetensors model creation
// This gates both safetensors LLM and imagegen model creation
experimental, _ := cmd.Flags().GetBool("experimental") experimental, _ := cmd.Flags().GetBool("experimental")
if experimental { if experimental {
if !isLocalhost() {
return errors.New("remote safetensor model creation not yet supported")
}
// Get Modelfile content - either from -f flag or default to "FROM ." // Get Modelfile content - either from -f flag or default to "FROM ."
var reader io.Reader var reader io.Reader
filename, err := getModelfileName(cmd) filename, err := getModelfileName(cmd)
@@ -187,9 +168,29 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to parse Modelfile: %w", err) return fmt.Errorf("failed to parse Modelfile: %w", err)
} }
modelDir, mfConfig, err := xcreateclient.ConfigFromModelfile(modelfile) // Extract FROM path and configuration
if err != nil { var modelDir string
return err mfConfig := &xcreateclient.ModelfileConfig{}
for _, cmd := range modelfile.Commands {
switch cmd.Name {
case "model":
modelDir = cmd.Args
case "template":
mfConfig.Template = cmd.Args
case "system":
mfConfig.System = cmd.Args
case "license":
mfConfig.License = cmd.Args
case "parser":
mfConfig.Parser = cmd.Args
case "renderer":
mfConfig.Renderer = cmd.Args
}
}
if modelDir == "" {
modelDir = "."
} }
// Resolve relative paths based on Modelfile location // Resolve relative paths based on Modelfile location
@@ -206,12 +207,20 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
}, p) }, p)
} }
// Standard Modelfile + API path
var reader io.Reader var reader io.Reader
filename, err := getModelfileName(cmd) filename, err := getModelfileName(cmd)
if os.IsNotExist(err) { if os.IsNotExist(err) {
if filename == "" { if filename == "" {
// No Modelfile found - check if current directory is an image gen model
if create.IsTensorModelDir(".") {
quantize, _ := cmd.Flags().GetString("quantize")
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
ModelName: modelName,
ModelDir: ".",
Quantize: quantize,
}, p)
}
reader = strings.NewReader("FROM .\n") reader = strings.NewReader("FROM .\n")
} else { } else {
return errModelfileNotFound return errModelfileNotFound
@@ -397,14 +406,12 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
return err return err
} }
requestedCloud := modelref.HasExplicitCloudSource(opts.Model)
if info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model}); err != nil { if info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model}); err != nil {
return err return err
} else if info.RemoteHost != "" || requestedCloud { } else if info.RemoteHost != "" {
// Cloud model, no need to load/unload // Cloud model, no need to load/unload
isCloud := requestedCloud || strings.HasPrefix(info.RemoteHost, "https://ollama.com") isCloud := strings.HasPrefix(info.RemoteHost, "https://ollama.com")
// Check if user is signed in for ollama.com cloud models // Check if user is signed in for ollama.com cloud models
if isCloud { if isCloud {
@@ -415,14 +422,10 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
if opts.ShowConnect { if opts.ShowConnect {
p.StopAndClear() p.StopAndClear()
remoteModel := info.RemoteModel
if remoteModel == "" {
remoteModel = opts.Model
}
if isCloud { if isCloud {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", remoteModel) fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
} else { } else {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", remoteModel, info.RemoteHost) fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
} }
} }
@@ -494,64 +497,6 @@ func generateEmbedding(cmd *cobra.Command, modelName, input string, keepAlive *a
return nil return nil
} }
// TODO(parthsareen): consolidate with TUI signin flow
func handleCloudAuthorizationError(err error) bool {
var authErr api.AuthorizationError
if errors.As(err, &authErr) && authErr.StatusCode == http.StatusUnauthorized {
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
if authErr.SigninURL != "" {
fmt.Printf(ConnectInstructions, authErr.SigninURL)
}
return true
}
return false
}
// TEMP(drifkin): To match legacy `ollama run some-model:cloud` behavior, we
// best-effort pull cloud stub files for any explicit cloud source models.
// Remove this once `/api/tags` is cloud-aware.
func ensureCloudStub(ctx context.Context, client *api.Client, modelName string) {
if !modelref.HasExplicitCloudSource(modelName) {
return
}
normalizedName, _, err := modelref.NormalizePullName(modelName)
if err != nil {
slog.Warn("failed to normalize pull name", "model", modelName, "error", err, "normalizedName", normalizedName)
return
}
listResp, err := client.List(ctx)
if err != nil {
slog.Warn("failed to list models", "error", err)
return
}
if hasListedModelName(listResp.Models, modelName) || hasListedModelName(listResp.Models, normalizedName) {
return
}
logutil.Trace("pulling cloud stub", "model", modelName, "normalizedName", normalizedName)
err = client.Pull(ctx, &api.PullRequest{
Model: normalizedName,
}, func(api.ProgressResponse) error {
return nil
})
if err != nil {
slog.Warn("failed to pull cloud stub", "model", modelName, "error", err)
}
}
func hasListedModelName(models []api.ListModelResponse, name string) bool {
for _, m := range models {
if strings.EqualFold(m.Name, name) || strings.EqualFold(m.Model, name) {
return true
}
}
return false
}
func RunHandler(cmd *cobra.Command, args []string) error { func RunHandler(cmd *cobra.Command, args []string) error {
interactive := true interactive := true
@@ -640,6 +585,17 @@ func RunHandler(cmd *cobra.Command, args []string) error {
} }
opts.WordWrap = !nowrap opts.WordWrap = !nowrap
useImagegen := false
if cmd.Flags().Lookup("imagegen") != nil {
useImagegen, err = cmd.Flags().GetBool("imagegen")
if err != nil {
return err
}
}
if useImagegen {
opts.Options["use_imagegen_runner"] = true
}
// Fill out the rest of the options based on information about the // Fill out the rest of the options based on information about the
// model. // model.
client, err := api.ClientFromEnvironment() client, err := api.ClientFromEnvironment()
@@ -648,16 +604,12 @@ func RunHandler(cmd *cobra.Command, args []string) error {
} }
name := args[0] name := args[0]
requestedCloud := modelref.HasExplicitCloudSource(name)
info, err := func() (*api.ShowResponse, error) { info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: name} showReq := &api.ShowRequest{Name: name}
info, err := client.Show(cmd.Context(), showReq) info, err := client.Show(cmd.Context(), showReq)
var se api.StatusError var se api.StatusError
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound { if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
if requestedCloud {
return nil, err
}
if err := PullHandler(cmd, []string{name}); err != nil { if err := PullHandler(cmd, []string{name}); err != nil {
return nil, err return nil, err
} }
@@ -666,21 +618,15 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return info, err return info, err
}() }()
if err != nil { if err != nil {
if handleCloudAuthorizationError(err) {
return nil
}
return err return err
} }
ensureCloudStub(cmd.Context(), client, name)
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkFlag.Changed) opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkFlag.Changed)
if err != nil { if err != nil {
return err return err
} }
audioCapable := slices.Contains(info.Capabilities, model.CapabilityAudio) opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision) || audioCapable
// TODO: remove the projector info and vision info checks below, // TODO: remove the projector info and vision info checks below,
// these are left in for backwards compatibility with older servers // these are left in for backwards compatibility with older servers
@@ -695,7 +641,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
} }
} }
applyShowResponseToRunOptions(&opts, info) opts.ParentModel = info.Details.ParentModel
// Check if this is an embedding model // Check if this is an embedding model
isEmbeddingModel := slices.Contains(info.Capabilities, model.CapabilityEmbedding) isEmbeddingModel := slices.Contains(info.Capabilities, model.CapabilityEmbedding)
@@ -766,13 +712,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generateInteractive(cmd, opts) return generateInteractive(cmd, opts)
} }
if err := generate(cmd, opts); err != nil { return generate(cmd, opts)
if handleCloudAuthorizationError(err) {
return nil
}
return err
}
return nil
} }
func SigninHandler(cmd *cobra.Command, args []string) error { func SigninHandler(cmd *cobra.Command, args []string) error {
@@ -1411,30 +1351,23 @@ func PullHandler(cmd *cobra.Command, args []string) error {
type generateContextKey string type generateContextKey string
type runOptions struct { type runOptions struct {
Model string Model string
ParentModel string ParentModel string
LoadedMessages []api.Message Prompt string
Prompt string Messages []api.Message
Messages []api.Message WordWrap bool
WordWrap bool Format string
Format string System string
System string Images []api.ImageData
Images []api.ImageData Options map[string]any
Options map[string]any MultiModal bool
MultiModal bool KeepAlive *api.Duration
KeepAlive *api.Duration Think *api.ThinkValue
Think *api.ThinkValue HideThinking bool
HideThinking bool ShowConnect bool
ShowConnect bool
} }
func (r runOptions) Copy() runOptions { func (r runOptions) Copy() runOptions {
var loadedMessages []api.Message
if r.LoadedMessages != nil {
loadedMessages = make([]api.Message, len(r.LoadedMessages))
copy(loadedMessages, r.LoadedMessages)
}
var messages []api.Message var messages []api.Message
if r.Messages != nil { if r.Messages != nil {
messages = make([]api.Message, len(r.Messages)) messages = make([]api.Message, len(r.Messages))
@@ -1462,29 +1395,23 @@ func (r runOptions) Copy() runOptions {
} }
return runOptions{ return runOptions{
Model: r.Model, Model: r.Model,
ParentModel: r.ParentModel, ParentModel: r.ParentModel,
LoadedMessages: loadedMessages, Prompt: r.Prompt,
Prompt: r.Prompt, Messages: messages,
Messages: messages, WordWrap: r.WordWrap,
WordWrap: r.WordWrap, Format: r.Format,
Format: r.Format, System: r.System,
System: r.System, Images: images,
Images: images, Options: opts,
Options: opts, MultiModal: r.MultiModal,
MultiModal: r.MultiModal, KeepAlive: r.KeepAlive,
KeepAlive: r.KeepAlive, Think: think,
Think: think, HideThinking: r.HideThinking,
HideThinking: r.HideThinking, ShowConnect: r.ShowConnect,
ShowConnect: r.ShowConnect,
} }
} }
func applyShowResponseToRunOptions(opts *runOptions, info *api.ShowResponse) {
opts.ParentModel = info.Details.ParentModel
opts.LoadedMessages = slices.Clone(info.Messages)
}
type displayResponseState struct { type displayResponseState struct {
lineLength int lineLength int
wordBuffer string wordBuffer string
@@ -1492,9 +1419,6 @@ type displayResponseState struct {
func displayResponse(content string, wordWrap bool, state *displayResponseState) { func displayResponse(content string, wordWrap bool, state *displayResponseState) {
termWidth, _, _ := term.GetSize(int(os.Stdout.Fd())) termWidth, _, _ := term.GetSize(int(os.Stdout.Fd()))
if termWidth == 0 {
termWidth = 80
}
if wordWrap && termWidth >= 10 { if wordWrap && termWidth >= 10 {
for _, ch := range content { for _, ch := range content {
if state.lineLength+1 > termWidth-5 { if state.lineLength+1 > termWidth-5 {
@@ -1968,77 +1892,6 @@ func ensureServerRunning(ctx context.Context) error {
} }
} }
func launchInteractiveModel(cmd *cobra.Command, modelName string) error {
opts := runOptions{
Model: modelName,
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]any{},
ShowConnect: true,
}
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
requestedCloud := modelref.HasExplicitCloudSource(modelName)
info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: modelName}
info, err := client.Show(cmd.Context(), showReq)
var se api.StatusError
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
if requestedCloud {
return nil, err
}
if err := PullHandler(cmd, []string{modelName}); err != nil {
return nil, err
}
return client.Show(cmd.Context(), &api.ShowRequest{Name: modelName})
}
return info, err
}()
if err != nil {
if handleCloudAuthorizationError(err) {
return nil
}
return err
}
ensureCloudStub(cmd.Context(), client, modelName)
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, false)
if err != nil {
return err
}
audioCapable := slices.Contains(info.Capabilities, model.CapabilityAudio)
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision) || audioCapable
// TODO: remove the projector info and vision info checks below,
// these are left in for backwards compatibility with older servers
// that don't have the capabilities field in the model info
if len(info.ProjectorInfo) != 0 {
opts.MultiModal = true
}
for k := range info.ModelInfo {
if strings.Contains(k, ".vision.") {
opts.MultiModal = true
break
}
}
applyShowResponseToRunOptions(&opts, info)
if err := loadOrUnloadModel(cmd, &opts); err != nil {
return fmt.Errorf("error loading model: %w", err)
}
if err := generateInteractive(cmd, opts); err != nil {
return fmt.Errorf("error running model: %w", err)
}
return nil
}
// runInteractiveTUI runs the main interactive TUI menu. // runInteractiveTUI runs the main interactive TUI menu.
func runInteractiveTUI(cmd *cobra.Command) { func runInteractiveTUI(cmd *cobra.Command) {
// Ensure the server is running before showing the TUI // Ensure the server is running before showing the TUI
@@ -2047,85 +1900,175 @@ func runInteractiveTUI(cmd *cobra.Command) {
return return
} }
deps := launcherDeps{ // Selector adapters for tui
buildState: launch.BuildLauncherState, singleSelector := func(title string, items []config.ModelItem, current string) (string, error) {
runMenu: tui.RunMenu, tuiItems := tui.ReorderItems(tui.ConvertItems(items))
resolveRunModel: launch.ResolveRunModel, result, err := tui.SelectSingle(title, tuiItems, current)
launchIntegration: launch.LaunchIntegration, if errors.Is(err, tui.ErrCancelled) {
runModel: launchInteractiveModel, return "", config.ErrCancelled
}
return result, err
}
multiSelector := func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
if errors.Is(err, tui.ErrCancelled) {
return nil, config.ErrCancelled
}
return result, err
} }
for { for {
continueLoop, err := runInteractiveTUIStep(cmd, deps) result, err := tui.Run()
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err) fmt.Fprintf(os.Stderr, "Error: %v\n", err)
}
if !continueLoop {
return return
} }
}
}
type launcherDeps struct { runModel := func(modelName string) {
buildState func(context.Context) (*launch.LauncherState, error) client, err := api.ClientFromEnvironment()
runMenu func(*launch.LauncherState) (tui.TUIAction, error) if err != nil {
resolveRunModel func(context.Context, launch.RunModelRequest) (string, error) fmt.Fprintf(os.Stderr, "Error: %v\n", err)
launchIntegration func(context.Context, launch.IntegrationLaunchRequest) error return
runModel func(*cobra.Command, string) error }
} if err := config.ShowOrPull(cmd.Context(), client, modelName); err != nil {
if errors.Is(err, config.ErrCancelled) {
return
}
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
return
}
_ = config.SetLastModel(modelName)
opts := runOptions{
Model: modelName,
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]any{},
ShowConnect: true,
}
if err := loadOrUnloadModel(cmd, &opts); err != nil {
fmt.Fprintf(os.Stderr, "Error loading model: %v\n", err)
return
}
if err := generateInteractive(cmd, opts); err != nil {
fmt.Fprintf(os.Stderr, "Error running model: %v\n", err)
}
}
func runInteractiveTUIStep(cmd *cobra.Command, deps launcherDeps) (bool, error) { launchIntegration := func(name string) bool {
state, err := deps.buildState(cmd.Context()) if err := config.EnsureInstalled(name); err != nil {
if err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err)
return false, fmt.Errorf("build launcher state: %w", err) return true
} }
// If not configured or model no longer exists, prompt for model selection
configuredModel := config.IntegrationModel(name)
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) || config.IsCloudModelDisabled(cmd.Context(), configuredModel) {
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), name, singleSelector, multiSelector)
if errors.Is(err, config.ErrCancelled) {
return false // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", name, err)
return true
}
}
if err := config.LaunchIntegration(name); err != nil {
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", name, err)
}
return true
}
action, err := deps.runMenu(state) switch result.Selection {
if err != nil { case tui.SelectionNone:
return false, fmt.Errorf("run launcher menu: %w", err) // User quit
} return
case tui.SelectionRunModel:
return runLauncherAction(cmd, action, deps) _ = config.SetLastSelection("run")
} if modelName := config.LastModel(); modelName != "" && !config.IsCloudModelDisabled(cmd.Context(), modelName) {
runModel(modelName)
func saveLauncherSelection(action tui.TUIAction) { } else {
// Best effort only: this affects menu recall, not launch correctness. modelName, err := config.SelectModelWithSelector(cmd.Context(), singleSelector)
_ = config.SetLastSelection(action.LastSelection()) if errors.Is(err, config.ErrCancelled) {
} continue // Return to main menu
}
func runLauncherAction(cmd *cobra.Command, action tui.TUIAction, deps launcherDeps) (bool, error) { if err != nil {
switch action.Kind { fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
case tui.TUIActionNone: continue
return false, nil }
case tui.TUIActionRunModel: runModel(modelName)
saveLauncherSelection(action) }
modelName, err := deps.resolveRunModel(cmd.Context(), action.RunModelRequest()) case tui.SelectionChangeRunModel:
if errors.Is(err, launch.ErrCancelled) { _ = config.SetLastSelection("run")
return true, nil // Use model from modal if selected, otherwise show picker
modelName := result.Model
if modelName == "" {
var err error
modelName, err = config.SelectModelWithSelector(cmd.Context(), singleSelector)
if errors.Is(err, config.ErrCancelled) {
continue // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
continue
}
}
if config.IsCloudModelDisabled(cmd.Context(), modelName) {
continue // Return to main menu
}
runModel(modelName)
case tui.SelectionIntegration:
_ = config.SetLastSelection(result.Integration)
if !launchIntegration(result.Integration) {
continue // Return to main menu
}
case tui.SelectionChangeIntegration:
_ = config.SetLastSelection(result.Integration)
if len(result.Models) > 0 {
// Filter out cloud-disabled models
var filtered []string
for _, m := range result.Models {
if !config.IsCloudModelDisabled(cmd.Context(), m) {
filtered = append(filtered, m)
}
}
if len(filtered) == 0 {
continue
}
result.Models = filtered
// Multi-select from modal (Editor integrations)
if err := config.SaveAndEditIntegration(result.Integration, result.Models); err != nil {
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
continue
}
if err := config.LaunchIntegrationWithModel(result.Integration, result.Models[0]); err != nil {
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
}
} else if result.Model != "" {
if config.IsCloudModelDisabled(cmd.Context(), result.Model) {
continue
}
// Single-select from modal - save and launch
if err := config.SaveIntegration(result.Integration, []string{result.Model}); err != nil {
fmt.Fprintf(os.Stderr, "Error saving config: %v\n", err)
continue
}
if err := config.LaunchIntegrationWithModel(result.Integration, result.Model); err != nil {
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
}
} else {
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), result.Integration, singleSelector, multiSelector)
if errors.Is(err, config.ErrCancelled) {
continue // Return to main menu
}
if err != nil {
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
continue
}
if err := config.LaunchIntegration(result.Integration); err != nil {
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
}
}
} }
if err != nil {
return true, fmt.Errorf("selecting model: %w", err)
}
if err := deps.runModel(cmd, modelName); err != nil {
return true, err
}
return true, nil
case tui.TUIActionLaunchIntegration:
saveLauncherSelection(action)
err := deps.launchIntegration(cmd.Context(), action.IntegrationLaunchRequest())
if errors.Is(err, launch.ErrCancelled) {
return true, nil
}
if err != nil {
return true, fmt.Errorf("launching %s: %w", action.Integration, err)
}
// VS Code is a GUI app — exit the TUI loop after launching
if action.Integration == "vscode" {
return false, nil
}
return true, nil
default:
return false, fmt.Errorf("unknown launcher action: %d", action.Kind)
} }
} }
@@ -2395,7 +2338,7 @@ func NewCLI() *cobra.Command {
copyCmd, copyCmd,
deleteCmd, deleteCmd,
runnerCmd, runnerCmd,
launch.LaunchCmd(checkServerHeartbeat, runInteractiveTUI), config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
) )
return rootCmd return rootCmd

View File

@@ -1,270 +0,0 @@
package cmd
import (
"context"
"testing"
"github.com/spf13/cobra"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/cmd/launch"
"github.com/ollama/ollama/cmd/tui"
)
func setCmdTestHome(t *testing.T, dir string) {
t.Helper()
t.Setenv("HOME", dir)
t.Setenv("USERPROFILE", dir)
}
func unexpectedRunModelResolution(t *testing.T) func(context.Context, launch.RunModelRequest) (string, error) {
t.Helper()
return func(ctx context.Context, req launch.RunModelRequest) (string, error) {
t.Fatalf("did not expect run-model resolution: %+v", req)
return "", nil
}
}
func unexpectedIntegrationLaunch(t *testing.T) func(context.Context, launch.IntegrationLaunchRequest) error {
t.Helper()
return func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
t.Fatalf("did not expect integration launch: %+v", req)
return nil
}
}
func unexpectedModelLaunch(t *testing.T) func(*cobra.Command, string) error {
t.Helper()
return func(cmd *cobra.Command, model string) error {
t.Fatalf("did not expect chat launch: %s", model)
return nil
}
}
func TestRunInteractiveTUI_RunModelActionsUseResolveRunModel(t *testing.T) {
tests := []struct {
name string
action tui.TUIAction
wantForce bool
wantModel string
}{
{
name: "enter uses saved model flow",
action: tui.TUIAction{Kind: tui.TUIActionRunModel},
wantModel: "qwen3:8b",
},
{
name: "right forces picker",
action: tui.TUIAction{Kind: tui.TUIActionRunModel, ForceConfigure: true},
wantForce: true,
wantModel: "glm-5:cloud",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setCmdTestHome(t, t.TempDir())
var menuCalls int
runMenu := func(state *launch.LauncherState) (tui.TUIAction, error) {
menuCalls++
if menuCalls == 1 {
return tt.action, nil
}
return tui.TUIAction{Kind: tui.TUIActionNone}, nil
}
var gotReq launch.RunModelRequest
var launched string
deps := launcherDeps{
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
return &launch.LauncherState{}, nil
},
runMenu: runMenu,
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
gotReq = req
return tt.wantModel, nil
},
launchIntegration: unexpectedIntegrationLaunch(t),
runModel: func(cmd *cobra.Command, model string) error {
launched = model
return nil
},
}
cmd := &cobra.Command{}
cmd.SetContext(context.Background())
for {
continueLoop, err := runInteractiveTUIStep(cmd, deps)
if err != nil {
t.Fatalf("unexpected step error: %v", err)
}
if !continueLoop {
break
}
}
if gotReq.ForcePicker != tt.wantForce {
t.Fatalf("expected ForcePicker=%v, got %v", tt.wantForce, gotReq.ForcePicker)
}
if launched != tt.wantModel {
t.Fatalf("expected interactive launcher to run %q, got %q", tt.wantModel, launched)
}
if got := config.LastSelection(); got != "run" {
t.Fatalf("expected last selection to be run, got %q", got)
}
})
}
}
func TestRunInteractiveTUI_IntegrationActionsUseLaunchIntegration(t *testing.T) {
tests := []struct {
name string
action tui.TUIAction
wantForce bool
}{
{
name: "enter launches integration",
action: tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"},
},
{
name: "right forces configure",
action: tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude", ForceConfigure: true},
wantForce: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setCmdTestHome(t, t.TempDir())
var menuCalls int
runMenu := func(state *launch.LauncherState) (tui.TUIAction, error) {
menuCalls++
if menuCalls == 1 {
return tt.action, nil
}
return tui.TUIAction{Kind: tui.TUIActionNone}, nil
}
var gotReq launch.IntegrationLaunchRequest
deps := launcherDeps{
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
return &launch.LauncherState{}, nil
},
runMenu: runMenu,
resolveRunModel: unexpectedRunModelResolution(t),
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
gotReq = req
return nil
},
runModel: unexpectedModelLaunch(t),
}
cmd := &cobra.Command{}
cmd.SetContext(context.Background())
for {
continueLoop, err := runInteractiveTUIStep(cmd, deps)
if err != nil {
t.Fatalf("unexpected step error: %v", err)
}
if !continueLoop {
break
}
}
if gotReq.Name != "claude" {
t.Fatalf("expected integration name to be passed through, got %q", gotReq.Name)
}
if gotReq.ForceConfigure != tt.wantForce {
t.Fatalf("expected ForceConfigure=%v, got %v", tt.wantForce, gotReq.ForceConfigure)
}
if got := config.LastSelection(); got != "claude" {
t.Fatalf("expected last selection to be claude, got %q", got)
}
})
}
}
func TestRunLauncherAction_RunModelContinuesAfterCancellation(t *testing.T) {
setCmdTestHome(t, t.TempDir())
cmd := &cobra.Command{}
cmd.SetContext(context.Background())
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionRunModel}, launcherDeps{
buildState: nil,
runMenu: nil,
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
return "", launch.ErrCancelled
},
launchIntegration: unexpectedIntegrationLaunch(t),
runModel: unexpectedModelLaunch(t),
})
if err != nil {
t.Fatalf("expected nil error on cancellation, got %v", err)
}
if !continueLoop {
t.Fatal("expected cancellation to continue the menu loop")
}
}
func TestRunLauncherAction_VSCodeExitsTUILoop(t *testing.T) {
setCmdTestHome(t, t.TempDir())
cmd := &cobra.Command{}
cmd.SetContext(context.Background())
// VS Code should exit the TUI loop (return false) after a successful launch.
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "vscode"}, launcherDeps{
resolveRunModel: unexpectedRunModelResolution(t),
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
return nil
},
runModel: unexpectedModelLaunch(t),
})
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if continueLoop {
t.Fatal("expected vscode launch to exit the TUI loop (return false)")
}
// Other integrations should continue the TUI loop (return true).
continueLoop, err = runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"}, launcherDeps{
resolveRunModel: unexpectedRunModelResolution(t),
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
return nil
},
runModel: unexpectedModelLaunch(t),
})
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if !continueLoop {
t.Fatal("expected non-vscode integration to continue the TUI loop (return true)")
}
}
func TestRunLauncherAction_IntegrationContinuesAfterCancellation(t *testing.T) {
setCmdTestHome(t, t.TempDir())
cmd := &cobra.Command{}
cmd.SetContext(context.Background())
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"}, launcherDeps{
buildState: nil,
runMenu: nil,
resolveRunModel: unexpectedRunModelResolution(t),
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
return launch.ErrCancelled
},
runModel: unexpectedModelLaunch(t),
})
if err != nil {
t.Fatalf("expected nil error on cancellation, got %v", err)
}
if !continueLoop {
t.Fatal("expected cancellation to continue the menu loop")
}
}

View File

@@ -301,7 +301,7 @@ Weigh anchor!
ParameterSize: "7B", ParameterSize: "7B",
QuantizationLevel: "FP16", QuantizationLevel: "FP16",
}, },
Requires: "0.19.0", Requires: "0.14.0",
}, false, &b); err != nil { }, false, &b); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -310,17 +310,10 @@ Weigh anchor!
architecture test architecture test
parameters 7B parameters 7B
quantization FP16 quantization FP16
requires 0.19.0 requires 0.14.0
` `
trimLinePadding := func(s string) string { if diff := cmp.Diff(expect, b.String()); diff != "" {
lines := strings.Split(s, "\n")
for i, line := range lines {
lines[i] = strings.TrimRight(line, " \t\r")
}
return strings.Join(lines, "\n")
}
if diff := cmp.Diff(trimLinePadding(expect), trimLinePadding(b.String())); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff) t.Errorf("unexpected output (-want +got):\n%s", diff)
} }
}) })
@@ -712,347 +705,6 @@ func TestRunEmbeddingModelNoInput(t *testing.T) {
} }
} }
func TestRunHandler_CloudAuthErrorOnShow_PrintsSigninMessage(t *testing.T) {
var generateCalled bool
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusUnauthorized)
if err := json.NewEncoder(w).Encode(map[string]string{
"error": "unauthorized",
"signin_url": "https://ollama.com/signin",
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
generateCalled = true
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
default:
http.NotFound(w, r)
}
}))
t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close)
cmd := &cobra.Command{}
cmd.SetContext(t.Context())
cmd.Flags().String("keepalive", "", "")
cmd.Flags().Bool("truncate", false, "")
cmd.Flags().Int("dimensions", 0, "")
cmd.Flags().Bool("verbose", false, "")
cmd.Flags().Bool("insecure", false, "")
cmd.Flags().Bool("nowordwrap", false, "")
cmd.Flags().String("format", "", "")
cmd.Flags().String("think", "", "")
cmd.Flags().Bool("hidethinking", false, "")
oldStdout := os.Stdout
readOut, writeOut, _ := os.Pipe()
os.Stdout = writeOut
t.Cleanup(func() { os.Stdout = oldStdout })
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
_ = writeOut.Close()
var out bytes.Buffer
_, _ = io.Copy(&out, readOut)
if err != nil {
t.Fatalf("RunHandler returned error: %v", err)
}
if generateCalled {
t.Fatal("expected run to stop before /api/generate after unauthorized /api/show")
}
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
t.Fatalf("expected sign-in guidance message, got %q", out.String())
}
if !strings.Contains(out.String(), "https://ollama.com/signin") {
t.Fatalf("expected signin_url in output, got %q", out.String())
}
}
func TestRunHandler_CloudAuthErrorOnGenerate_PrintsSigninMessage(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ShowResponse{
Capabilities: []model.Capability{model.CapabilityCompletion},
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusUnauthorized)
if err := json.NewEncoder(w).Encode(map[string]string{
"error": "unauthorized",
"signin_url": "https://ollama.com/signin",
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
default:
http.NotFound(w, r)
}
}))
t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close)
cmd := &cobra.Command{}
cmd.SetContext(t.Context())
cmd.Flags().String("keepalive", "", "")
cmd.Flags().Bool("truncate", false, "")
cmd.Flags().Int("dimensions", 0, "")
cmd.Flags().Bool("verbose", false, "")
cmd.Flags().Bool("insecure", false, "")
cmd.Flags().Bool("nowordwrap", false, "")
cmd.Flags().String("format", "", "")
cmd.Flags().String("think", "", "")
cmd.Flags().Bool("hidethinking", false, "")
oldStdout := os.Stdout
readOut, writeOut, _ := os.Pipe()
os.Stdout = writeOut
t.Cleanup(func() { os.Stdout = oldStdout })
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
_ = writeOut.Close()
var out bytes.Buffer
_, _ = io.Copy(&out, readOut)
if err != nil {
t.Fatalf("RunHandler returned error: %v", err)
}
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
t.Fatalf("expected sign-in guidance message, got %q", out.String())
}
if !strings.Contains(out.String(), "https://ollama.com/signin") {
t.Fatalf("expected signin_url in output, got %q", out.String())
}
}
func TestRunHandler_ExplicitCloudStubMissing_PullsNormalizedNameTEMP(t *testing.T) {
var pulledModel string
var generateCalled bool
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ShowResponse{
Capabilities: []model.Capability{model.CapabilityCompletion},
RemoteModel: "gpt-oss:20b",
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/tags" && r.Method == http.MethodGet:
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ListResponse{Models: nil}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/pull" && r.Method == http.MethodPost:
var req api.PullRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
pulledModel = req.Model
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ProgressResponse{Status: "success"}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
generateCalled = true
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
default:
http.NotFound(w, r)
}
}))
t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close)
cmd := &cobra.Command{}
cmd.SetContext(t.Context())
cmd.Flags().String("keepalive", "", "")
cmd.Flags().Bool("truncate", false, "")
cmd.Flags().Int("dimensions", 0, "")
cmd.Flags().Bool("verbose", false, "")
cmd.Flags().Bool("insecure", false, "")
cmd.Flags().Bool("nowordwrap", false, "")
cmd.Flags().String("format", "", "")
cmd.Flags().String("think", "", "")
cmd.Flags().Bool("hidethinking", false, "")
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
if err != nil {
t.Fatalf("RunHandler returned error: %v", err)
}
if pulledModel != "gpt-oss:20b-cloud" {
t.Fatalf("expected normalized pull model %q, got %q", "gpt-oss:20b-cloud", pulledModel)
}
if !generateCalled {
t.Fatal("expected /api/generate to be called")
}
}
func TestRunHandler_ExplicitCloudStubPresent_SkipsPullTEMP(t *testing.T) {
var pullCalled bool
var generateCalled bool
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ShowResponse{
Capabilities: []model.Capability{model.CapabilityCompletion},
RemoteModel: "gpt-oss:20b",
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/tags" && r.Method == http.MethodGet:
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ListResponse{
Models: []api.ListModelResponse{{Name: "gpt-oss:20b-cloud"}},
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/pull" && r.Method == http.MethodPost:
pullCalled = true
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ProgressResponse{Status: "success"}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
generateCalled = true
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
default:
http.NotFound(w, r)
}
}))
t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close)
cmd := &cobra.Command{}
cmd.SetContext(t.Context())
cmd.Flags().String("keepalive", "", "")
cmd.Flags().Bool("truncate", false, "")
cmd.Flags().Int("dimensions", 0, "")
cmd.Flags().Bool("verbose", false, "")
cmd.Flags().Bool("insecure", false, "")
cmd.Flags().Bool("nowordwrap", false, "")
cmd.Flags().String("format", "", "")
cmd.Flags().String("think", "", "")
cmd.Flags().Bool("hidethinking", false, "")
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
if err != nil {
t.Fatalf("RunHandler returned error: %v", err)
}
if pullCalled {
t.Fatal("expected /api/pull not to be called when cloud stub already exists")
}
if !generateCalled {
t.Fatal("expected /api/generate to be called")
}
}
func TestRunHandler_ExplicitCloudStubPullFailure_IsBestEffortTEMP(t *testing.T) {
var generateCalled bool
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ShowResponse{
Capabilities: []model.Capability{model.CapabilityCompletion},
RemoteModel: "gpt-oss:20b",
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/tags" && r.Method == http.MethodGet:
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.ListResponse{Models: nil}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/pull" && r.Method == http.MethodPost:
w.WriteHeader(http.StatusInternalServerError)
if err := json.NewEncoder(w).Encode(map[string]string{"error": "pull failed"}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
generateCalled = true
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
default:
http.NotFound(w, r)
}
}))
t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close)
cmd := &cobra.Command{}
cmd.SetContext(t.Context())
cmd.Flags().String("keepalive", "", "")
cmd.Flags().Bool("truncate", false, "")
cmd.Flags().Int("dimensions", 0, "")
cmd.Flags().Bool("verbose", false, "")
cmd.Flags().Bool("insecure", false, "")
cmd.Flags().Bool("nowordwrap", false, "")
cmd.Flags().String("format", "", "")
cmd.Flags().String("think", "", "")
cmd.Flags().Bool("hidethinking", false, "")
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
if err != nil {
t.Fatalf("RunHandler returned error: %v", err)
}
if !generateCalled {
t.Fatal("expected /api/generate to be called despite pull failure")
}
}
func TestGetModelfileName(t *testing.T) { func TestGetModelfileName(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -1560,20 +1212,6 @@ func TestNewCreateRequest(t *testing.T) {
Model: "newmodel", Model: "newmodel",
}, },
}, },
{
"explicit cloud model preserves source when parent lacks it",
"newmodel",
runOptions{
Model: "qwen3.5:cloud",
ParentModel: "qwen3.5",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "qwen3.5:cloud",
Model: "newmodel",
},
},
{ {
"parent model as filepath test", "parent model as filepath test",
"newmodel", "newmodel",
@@ -1655,24 +1293,6 @@ func TestNewCreateRequest(t *testing.T) {
}, },
}, },
}, },
{
"loaded messages are preserved when saving",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "parentmodel",
LoadedMessages: []api.Message{{Role: "assistant", Content: "loaded"}},
Messages: []api.Message{{Role: "user", Content: "new"}},
},
&api.CreateRequest{
From: "parentmodel",
Model: "newmodel",
Messages: []api.Message{
{Role: "assistant", Content: "loaded"},
{Role: "user", Content: "new"},
},
},
},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -1685,43 +1305,15 @@ func TestNewCreateRequest(t *testing.T) {
} }
} }
func TestApplyShowResponseToRunOptions(t *testing.T) {
opts := runOptions{}
info := &api.ShowResponse{
Details: api.ModelDetails{
ParentModel: "parentmodel",
},
Messages: []api.Message{
{Role: "assistant", Content: "loaded"},
},
}
applyShowResponseToRunOptions(&opts, info)
if opts.ParentModel != "parentmodel" {
t.Fatalf("ParentModel = %q, want %q", opts.ParentModel, "parentmodel")
}
if !cmp.Equal(opts.LoadedMessages, info.Messages) {
t.Fatalf("LoadedMessages = %#v, want %#v", opts.LoadedMessages, info.Messages)
}
info.Messages[0].Content = "modified"
if opts.LoadedMessages[0].Content == "modified" {
t.Fatal("LoadedMessages should be copied independently from ShowResponse")
}
}
func TestRunOptions_Copy(t *testing.T) { func TestRunOptions_Copy(t *testing.T) {
// Setup test data // Setup test data
originalKeepAlive := &api.Duration{Duration: 5 * time.Minute} originalKeepAlive := &api.Duration{Duration: 5 * time.Minute}
originalThink := &api.ThinkValue{Value: "test reasoning"} originalThink := &api.ThinkValue{Value: "test reasoning"}
original := runOptions{ original := runOptions{
Model: "test-model", Model: "test-model",
ParentModel: "parent-model", ParentModel: "parent-model",
LoadedMessages: []api.Message{{Role: "assistant", Content: "loaded hello"}}, Prompt: "test prompt",
Prompt: "test prompt",
Messages: []api.Message{ Messages: []api.Message{
{Role: "user", Content: "hello"}, {Role: "user", Content: "hello"},
{Role: "assistant", Content: "hi there"}, {Role: "assistant", Content: "hi there"},
@@ -1761,7 +1353,6 @@ func TestRunOptions_Copy(t *testing.T) {
}{ }{
{"Model", copied.Model, original.Model}, {"Model", copied.Model, original.Model},
{"ParentModel", copied.ParentModel, original.ParentModel}, {"ParentModel", copied.ParentModel, original.ParentModel},
{"LoadedMessages", copied.LoadedMessages, original.LoadedMessages},
{"Prompt", copied.Prompt, original.Prompt}, {"Prompt", copied.Prompt, original.Prompt},
{"WordWrap", copied.WordWrap, original.WordWrap}, {"WordWrap", copied.WordWrap, original.WordWrap},
{"Format", copied.Format, original.Format}, {"Format", copied.Format, original.Format},
@@ -1866,18 +1457,13 @@ func TestRunOptions_Copy(t *testing.T) {
func TestRunOptions_Copy_EmptySlicesAndMaps(t *testing.T) { func TestRunOptions_Copy_EmptySlicesAndMaps(t *testing.T) {
// Test with empty slices and maps // Test with empty slices and maps
original := runOptions{ original := runOptions{
LoadedMessages: []api.Message{}, Messages: []api.Message{},
Messages: []api.Message{}, Images: []api.ImageData{},
Images: []api.ImageData{}, Options: map[string]any{},
Options: map[string]any{},
} }
copied := original.Copy() copied := original.Copy()
if copied.LoadedMessages == nil {
t.Error("Empty LoadedMessages slice should remain empty, not nil")
}
if copied.Messages == nil { if copied.Messages == nil {
t.Error("Empty Messages slice should remain empty, not nil") t.Error("Empty Messages slice should remain empty, not nil")
} }
@@ -1894,10 +1480,6 @@ func TestRunOptions_Copy_EmptySlicesAndMaps(t *testing.T) {
t.Error("Empty Messages slice should remain empty") t.Error("Empty Messages slice should remain empty")
} }
if len(copied.LoadedMessages) != 0 {
t.Error("Empty LoadedMessages slice should remain empty")
}
if len(copied.Images) != 0 { if len(copied.Images) != 0 {
t.Error("Empty Images slice should remain empty") t.Error("Empty Images slice should remain empty")
} }
@@ -1975,7 +1557,7 @@ func TestShowInfoImageGen(t *testing.T) {
QuantizationLevel: "Q8", QuantizationLevel: "Q8",
}, },
Capabilities: []model.Capability{model.CapabilityImage}, Capabilities: []model.Capability{model.CapabilityImage},
Requires: "0.19.0", Requires: "0.14.0",
}, false, &b) }, false, &b)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -1985,7 +1567,7 @@ func TestShowInfoImageGen(t *testing.T) {
" architecture ZImagePipeline \n" + " architecture ZImagePipeline \n" +
" parameters 10.3B \n" + " parameters 10.3B \n" +
" quantization Q8 \n" + " quantization Q8 \n" +
" requires 0.19.0 \n" + " requires 0.14.0 \n" +
"\n" + "\n" +
" Capabilities\n" + " Capabilities\n" +
" image \n" + " image \n" +
@@ -2043,20 +1625,16 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
// Test that modifications to original don't affect copy // Test that modifications to original don't affect copy
originalThink := &api.ThinkValue{Value: "original"} originalThink := &api.ThinkValue{Value: "original"}
original := runOptions{ original := runOptions{
Model: "original-model", Model: "original-model",
LoadedMessages: []api.Message{{Role: "assistant", Content: "loaded"}}, Messages: []api.Message{{Role: "user", Content: "original"}},
Messages: []api.Message{{Role: "user", Content: "original"}}, Options: map[string]any{"key": "value"},
Options: map[string]any{"key": "value"}, Think: originalThink,
Think: originalThink,
} }
copied := original.Copy() copied := original.Copy()
// Modify original // Modify original
original.Model = "modified-model" original.Model = "modified-model"
if len(original.LoadedMessages) > 0 {
original.LoadedMessages[0].Content = "modified loaded"
}
if len(original.Messages) > 0 { if len(original.Messages) > 0 {
original.Messages[0].Content = "modified" original.Messages[0].Content = "modified"
} }
@@ -2070,10 +1648,6 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
t.Error("Copy Model should not be affected by original modification") t.Error("Copy Model should not be affected by original modification")
} }
if len(copied.LoadedMessages) > 0 && copied.LoadedMessages[0].Content == "modified loaded" {
t.Error("Copy LoadedMessages should not be affected by original modification")
}
if len(copied.Messages) > 0 && copied.Messages[0].Content == "modified" { if len(copied.Messages) > 0 && copied.Messages[0].Content == "modified" {
t.Error("Copy Messages should not be affected by original modification") t.Error("Copy Messages should not be affected by original modification")
} }
@@ -2089,81 +1663,31 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) { func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
model string remoteHost string
showStatus int whoamiStatus int
remoteHost string whoamiResp any
remoteModel string expectedError string
whoamiStatus int
whoamiResp any
expectWhoami bool
expectedError string
expectAuthError bool
}{ }{
{ {
name: "ollama.com cloud model - user signed in", name: "ollama.com cloud model - user signed in",
model: "test-cloud-model",
remoteHost: "https://ollama.com", remoteHost: "https://ollama.com",
remoteModel: "test-model",
whoamiStatus: http.StatusOK, whoamiStatus: http.StatusOK,
whoamiResp: api.UserResponse{Name: "testuser"}, whoamiResp: api.UserResponse{Name: "testuser"},
expectWhoami: true,
}, },
{ {
name: "ollama.com cloud model - user not signed in", name: "ollama.com cloud model - user not signed in",
model: "test-cloud-model",
remoteHost: "https://ollama.com", remoteHost: "https://ollama.com",
remoteModel: "test-model",
whoamiStatus: http.StatusUnauthorized, whoamiStatus: http.StatusUnauthorized,
whoamiResp: map[string]string{ whoamiResp: map[string]string{
"error": "unauthorized", "error": "unauthorized",
"signin_url": "https://ollama.com/signin", "signin_url": "https://ollama.com/signin",
}, },
expectWhoami: true, expectedError: "unauthorized",
expectedError: "unauthorized",
expectAuthError: true,
}, },
{ {
name: "non-ollama.com remote - no auth check", name: "non-ollama.com remote - no auth check",
model: "test-cloud-model",
remoteHost: "https://other-remote.com", remoteHost: "https://other-remote.com",
remoteModel: "test-model",
whoamiStatus: http.StatusUnauthorized, // should not be called
whoamiResp: nil,
},
{
name: "explicit :cloud model - auth check without remote metadata",
model: "kimi-k2.5:cloud",
remoteHost: "",
remoteModel: "",
whoamiStatus: http.StatusOK,
whoamiResp: api.UserResponse{Name: "testuser"},
expectWhoami: true,
},
{
name: "explicit :cloud model without local stub returns not found by default",
model: "minimax-m2.7:cloud",
showStatus: http.StatusNotFound,
whoamiStatus: http.StatusOK,
whoamiResp: api.UserResponse{Name: "testuser"},
expectedError: "not found",
expectWhoami: false,
expectAuthError: false,
},
{
name: "explicit -cloud model - auth check without remote metadata",
model: "kimi-k2.5:latest-cloud",
remoteHost: "",
remoteModel: "",
whoamiStatus: http.StatusOK,
whoamiResp: api.UserResponse{Name: "testuser"},
expectWhoami: true,
},
{
name: "dash cloud-like name without explicit source does not require auth",
model: "test-cloud-model",
remoteHost: "",
remoteModel: "",
whoamiStatus: http.StatusUnauthorized, // should not be called whoamiStatus: http.StatusUnauthorized, // should not be called
whoamiResp: nil, whoamiResp: nil,
}, },
@@ -2175,15 +1699,10 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path { switch r.URL.Path {
case "/api/show": case "/api/show":
if tt.showStatus != 0 && tt.showStatus != http.StatusOK {
w.WriteHeader(tt.showStatus)
_ = json.NewEncoder(w).Encode(map[string]string{"error": "not found"})
return
}
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(api.ShowResponse{ if err := json.NewEncoder(w).Encode(api.ShowResponse{
RemoteHost: tt.remoteHost, RemoteHost: tt.remoteHost,
RemoteModel: tt.remoteModel, RemoteModel: "test-model",
}); err != nil { }); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
} }
@@ -2196,8 +1715,6 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
} }
} }
case "/api/generate":
w.WriteHeader(http.StatusOK)
default: default:
http.NotFound(w, r) http.NotFound(w, r)
} }
@@ -2210,28 +1727,29 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
cmd.SetContext(t.Context()) cmd.SetContext(t.Context())
opts := &runOptions{ opts := &runOptions{
Model: tt.model, Model: "test-cloud-model",
ShowConnect: false, ShowConnect: false,
} }
err := loadOrUnloadModel(cmd, opts) err := loadOrUnloadModel(cmd, opts)
if whoamiCalled != tt.expectWhoami { if strings.HasPrefix(tt.remoteHost, "https://ollama.com") {
t.Errorf("whoami called = %v, want %v", whoamiCalled, tt.expectWhoami) if !whoamiCalled {
t.Error("expected whoami to be called for ollama.com cloud model")
}
} else {
if whoamiCalled {
t.Error("whoami should not be called for non-ollama.com remote")
}
} }
if tt.expectedError != "" { if tt.expectedError != "" {
if err == nil { if err == nil {
t.Errorf("expected error containing %q, got nil", tt.expectedError) t.Errorf("expected error containing %q, got nil", tt.expectedError)
} else { } else {
if !tt.expectAuthError && !strings.Contains(strings.ToLower(err.Error()), strings.ToLower(tt.expectedError)) { var authErr api.AuthorizationError
t.Errorf("expected error containing %q, got %v", tt.expectedError, err) if !errors.As(err, &authErr) {
} t.Errorf("expected AuthorizationError, got %T: %v", err, err)
if tt.expectAuthError {
var authErr api.AuthorizationError
if !errors.As(err, &authErr) {
t.Errorf("expected AuthorizationError, got %T: %v", err, err)
}
} }
} }
} else { } else {
@@ -2242,38 +1760,3 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
}) })
} }
} }
func TestIsLocalhost(t *testing.T) {
tests := []struct {
name string
host string
expected bool
}{
{"default empty", "", true},
{"localhost no port", "localhost", true},
{"localhost with port", "localhost:11435", true},
{"127.0.0.1 no port", "127.0.0.1", true},
{"127.0.0.1 with port", "127.0.0.1:11434", true},
{"0.0.0.0 no port", "0.0.0.0", true},
{"0.0.0.0 with port", "0.0.0.0:11434", true},
{"::1 no port", "::1", true},
{"[::1] with port", "[::1]:11434", true},
{"loopback with scheme", "http://localhost:11434", true},
{"remote hostname", "example.com", false},
{"remote hostname with port", "example.com:11434", false},
{"remote IP", "192.168.1.1", false},
{"remote IP with port", "192.168.1.1:11434", false},
{"remote with scheme", "http://example.com:11434", false},
{"https remote", "https://example.com:443", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", tt.host)
got := isLocalhost()
if got != tt.expected {
t.Errorf("isLocalhost() with OLLAMA_HOST=%q = %v, want %v", tt.host, got, tt.expected)
}
})
}
}

192
cmd/config/claude.go Normal file
View File

@@ -0,0 +1,192 @@
package config
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
// Claude implements Runner and AliasConfigurer for Claude Code integration
type Claude struct{}
// Compile-time check that Claude implements AliasConfigurer
var _ AliasConfigurer = (*Claude)(nil)
func (c *Claude) String() string { return "Claude Code" }
func (c *Claude) args(model string, extra []string) []string {
var args []string
if model != "" {
args = append(args, "--model", model)
}
args = append(args, extra...)
return args
}
func (c *Claude) findPath() (string, error) {
if p, err := exec.LookPath("claude"); err == nil {
return p, nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
name := "claude"
if runtime.GOOS == "windows" {
name = "claude.exe"
}
fallback := filepath.Join(home, ".claude", "local", name)
if _, err := os.Stat(fallback); err != nil {
return "", err
}
return fallback, nil
}
func (c *Claude) Run(model string, args []string) error {
claudePath, err := c.findPath()
if err != nil {
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
}
cmd := exec.Command(claudePath, c.args(model, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
env := append(os.Environ(),
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
"ANTHROPIC_API_KEY=",
"ANTHROPIC_AUTH_TOKEN=ollama",
)
env = append(env, c.modelEnvVars(model)...)
cmd.Env = env
return cmd.Run()
}
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
func (c *Claude) modelEnvVars(model string) []string {
primary := model
fast := model
if cfg, err := loadIntegration("claude"); err == nil && cfg.Aliases != nil {
if p := cfg.Aliases["primary"]; p != "" {
primary = p
}
if f := cfg.Aliases["fast"]; f != "" {
fast = f
}
}
return []string{
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + primary,
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + primary,
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + fast,
"CLAUDE_CODE_SUBAGENT_MODEL=" + primary,
}
}
// ConfigureAliases sets up model aliases for Claude Code.
// model: the model to use (if empty, user will be prompted to select)
// aliases: existing alias configuration to preserve/update
// Cloud-only: subagent routing (fast model) is gated to cloud models only until
// there is a better strategy for prompt caching on local models.
func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAliases map[string]string, force bool) (map[string]string, bool, error) {
aliases := make(map[string]string)
for k, v := range existingAliases {
aliases[k] = v
}
if model != "" {
aliases["primary"] = model
}
if !force && aliases["primary"] != "" {
client, _ := api.ClientFromEnvironment()
if isCloudModel(ctx, client, aliases["primary"]) {
if isCloudModel(ctx, client, aliases["fast"]) {
return aliases, false, nil
}
} else {
delete(aliases, "fast")
return aliases, false, nil
}
}
items, existingModels, cloudModels, client, err := listModels(ctx)
if err != nil {
return nil, false, err
}
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
if aliases["primary"] == "" || force {
primary, err := DefaultSingleSelector("Select model:", items, aliases["primary"])
if err != nil {
return nil, false, err
}
if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil {
return nil, false, err
}
if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil {
return nil, false, err
}
aliases["primary"] = primary
}
if isCloudModel(ctx, client, aliases["primary"]) {
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
aliases["fast"] = aliases["primary"]
}
} else {
delete(aliases, "fast")
}
return aliases, true, nil
}
// SetAliases syncs the configured aliases to the Ollama server using prefix matching.
// Cloud-only: for local models (fast is empty), we delete any existing aliases to
// prevent stale routing to a previous cloud model.
func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
prefixes := []string{"claude-sonnet-", "claude-haiku-"}
if aliases["fast"] == "" {
for _, prefix := range prefixes {
_ = client.DeleteAliasExperimental(ctx, &api.AliasDeleteRequest{Alias: prefix})
}
return nil
}
prefixAliases := map[string]string{
"claude-sonnet-": aliases["primary"],
"claude-haiku-": aliases["fast"],
}
var errs []string
for prefix, target := range prefixAliases {
req := &api.AliasRequest{
Alias: prefix,
Target: target,
PrefixMatching: true,
}
if err := client.SetAliasExperimental(ctx, req); err != nil {
errs = append(errs, prefix)
}
}
if len(errs) > 0 {
return fmt.Errorf("failed to set aliases: %v", errs)
}
return nil
}

View File

@@ -1,4 +1,4 @@
package launch package config
import ( import (
"os" "os"
@@ -117,7 +117,10 @@ func TestClaudeModelEnvVars(t *testing.T) {
return m return m
} }
t.Run("maps all Claude model env vars to the provided model", func(t *testing.T) { t.Run("falls back to model param when no aliases saved", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
got := envMap(c.modelEnvVars("llama3.2")) got := envMap(c.modelEnvVars("llama3.2"))
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2" { if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2" {
t.Errorf("OPUS = %q, want llama3.2", got["ANTHROPIC_DEFAULT_OPUS_MODEL"]) t.Errorf("OPUS = %q, want llama3.2", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
@@ -131,41 +134,65 @@ func TestClaudeModelEnvVars(t *testing.T) {
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2" { if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2" {
t.Errorf("SUBAGENT = %q, want llama3.2", got["CLAUDE_CODE_SUBAGENT_MODEL"]) t.Errorf("SUBAGENT = %q, want llama3.2", got["CLAUDE_CODE_SUBAGENT_MODEL"])
} }
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" { })
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty for local models", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
t.Run("uses primary alias for opus sonnet and subagent", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
SaveIntegration("claude", []string{"qwen3:8b"})
saveAliases("claude", map[string]string{"primary": "qwen3:8b"})
got := envMap(c.modelEnvVars("qwen3:8b"))
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "qwen3:8b" {
t.Errorf("OPUS = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
}
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "qwen3:8b" {
t.Errorf("SONNET = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
}
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "qwen3:8b" {
t.Errorf("HAIKU = %q, want qwen3:8b (no fast alias)", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
}
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "qwen3:8b" {
t.Errorf("SUBAGENT = %q, want qwen3:8b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
} }
}) })
t.Run("supports empty model", func(t *testing.T) { t.Run("uses fast alias for haiku", func(t *testing.T) {
got := envMap(c.modelEnvVars("")) tmpDir := t.TempDir()
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "" { setTestHome(t, tmpDir)
t.Errorf("OPUS = %q, want empty", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
SaveIntegration("claude", []string{"llama3.2:70b"})
saveAliases("claude", map[string]string{
"primary": "llama3.2:70b",
"fast": "llama3.2:8b",
})
got := envMap(c.modelEnvVars("llama3.2:70b"))
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2:70b" {
t.Errorf("OPUS = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
} }
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "" { if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2:70b" {
t.Errorf("SONNET = %q, want empty", got["ANTHROPIC_DEFAULT_SONNET_MODEL"]) t.Errorf("SONNET = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
} }
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "" { if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2:8b" {
t.Errorf("HAIKU = %q, want empty", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"]) t.Errorf("HAIKU = %q, want llama3.2:8b", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
} }
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "" { if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2:70b" {
t.Errorf("SUBAGENT = %q, want empty", got["CLAUDE_CODE_SUBAGENT_MODEL"]) t.Errorf("SUBAGENT = %q, want llama3.2:70b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
}
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
} }
}) })
t.Run("sets auto compact window for known cloud models", func(t *testing.T) { t.Run("alias primary overrides model param", func(t *testing.T) {
got := envMap(c.modelEnvVars("glm-5:cloud")) tmpDir := t.TempDir()
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "202752" { setTestHome(t, tmpDir)
t.Errorf("AUTO_COMPACT_WINDOW = %q, want 202752", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
}
})
t.Run("does not set auto compact window for unknown cloud models", func(t *testing.T) { SaveIntegration("claude", []string{"saved-model"})
got := envMap(c.modelEnvVars("unknown-model:cloud")) saveAliases("claude", map[string]string{"primary": "saved-model"})
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"]) got := envMap(c.modelEnvVars("different-model"))
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "saved-model" {
t.Errorf("OPUS = %q, want saved-model", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
} }
}) })
} }

View File

@@ -1,13 +1,14 @@
package launch package config
import ( import (
"context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
) )
@@ -21,6 +22,24 @@ func (c *Cline) Run(model string, args []string) error {
return fmt.Errorf("cline is not installed, install with: npm install -g cline") return fmt.Errorf("cline is not installed, install with: npm install -g cline")
} }
models := []string{model}
if config, err := loadIntegration("cline"); err == nil && len(config.Models) > 0 {
models = config.Models
}
var err error
models, err = resolveEditorModels("cline", models, func() ([]string, error) {
return selectModels(context.Background(), "cline", "")
})
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
if err := c.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("cline", args...) cmd := exec.Command("cline", args...)
cmd.Stdin = os.Stdin cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
@@ -78,7 +97,7 @@ func (c *Cline) Edit(models []string) error {
if err != nil { if err != nil {
return err return err
} }
return fileutil.WriteWithBackup(configPath, data) return writeWithBackup(configPath, data)
} }
func (c *Cline) Models() []string { func (c *Cline) Models() []string {
@@ -87,7 +106,7 @@ func (c *Cline) Models() []string {
return nil return nil
} }
config, err := fileutil.ReadJSON(filepath.Join(home, ".cline", "data", "globalState.json")) config, err := readJSONFile(filepath.Join(home, ".cline", "data", "globalState.json"))
if err != nil { if err != nil {
return nil return nil
} }

View File

@@ -1,4 +1,4 @@
package launch package config
import ( import (
"encoding/json" "encoding/json"

67
cmd/config/codex.go Normal file
View File

@@ -0,0 +1,67 @@
package config
import (
"fmt"
"os"
"os/exec"
"strings"
"github.com/ollama/ollama/envconfig"
"golang.org/x/mod/semver"
)
// Codex implements Runner for Codex integration
type Codex struct{}
func (c *Codex) String() string { return "Codex" }
func (c *Codex) args(model string, extra []string) []string {
args := []string{"--oss"}
if model != "" {
args = append(args, "-m", model)
}
args = append(args, extra...)
return args
}
func (c *Codex) Run(model string, args []string) error {
if err := checkCodexVersion(); err != nil {
return err
}
cmd := exec.Command("codex", c.args(model, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(),
"OPENAI_BASE_URL="+envconfig.Host().String()+"/v1/",
"OPENAI_API_KEY=ollama",
)
return cmd.Run()
}
func checkCodexVersion() error {
if _, err := exec.LookPath("codex"); err != nil {
return fmt.Errorf("codex is not installed, install with: npm install -g @openai/codex")
}
out, err := exec.Command("codex", "--version").Output()
if err != nil {
return fmt.Errorf("failed to get codex version: %w", err)
}
// Parse output like "codex-cli 0.87.0"
fields := strings.Fields(strings.TrimSpace(string(out)))
if len(fields) < 2 {
return fmt.Errorf("unexpected codex version output: %s", string(out))
}
version := "v" + fields[len(fields)-1]
minVersion := "v0.81.0"
if semver.Compare(version, minVersion) < 0 {
return fmt.Errorf("codex version %s is too old, minimum required is %s, update with: npm update -g @openai/codex", fields[len(fields)-1], "0.81.0")
}
return nil
}

31
cmd/config/codex_test.go Normal file
View File

@@ -0,0 +1,31 @@
package config
import (
"slices"
"testing"
)
func TestCodexArgs(t *testing.T) {
c := &Codex{}
tests := []struct {
name string
model string
args []string
want []string
}{
{"with model", "llama3.2", nil, []string{"--oss", "-m", "llama3.2"}},
{"empty model", "", nil, []string{"--oss"}},
{"with model and profile", "qwen3-coder", []string{"-p", "myprofile"}, []string{"--oss", "-m", "qwen3-coder", "-p", "myprofile"}},
{"with sandbox flag", "llama3.2", []string{"--sandbox", "workspace-write"}, []string{"--oss", "-m", "llama3.2", "--sandbox", "workspace-write"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model, tt.args)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
}
})
}
}

View File

@@ -3,6 +3,7 @@
package config package config
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -10,7 +11,7 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/ollama/ollama/cmd/internal/fileutil" "github.com/ollama/ollama/api"
) )
type integration struct { type integration struct {
@@ -19,9 +20,6 @@ type integration struct {
Onboarded bool `json:"onboarded,omitempty"` Onboarded bool `json:"onboarded,omitempty"`
} }
// IntegrationConfig is the persisted config for one integration.
type IntegrationConfig = integration
type config struct { type config struct {
Integrations map[string]*integration `json:"integrations"` Integrations map[string]*integration `json:"integrations"`
LastModel string `json:"last_model,omitempty"` LastModel string `json:"last_model,omitempty"`
@@ -126,7 +124,7 @@ func save(cfg *config) error {
return err return err
} }
return fileutil.WriteWithBackup(path, data) return writeWithBackup(path, data)
} }
func SaveIntegration(appName string, models []string) error { func SaveIntegration(appName string, models []string) error {
@@ -157,8 +155,8 @@ func SaveIntegration(appName string, models []string) error {
return save(cfg) return save(cfg)
} }
// MarkIntegrationOnboarded marks an integration as onboarded in Ollama's config. // integrationOnboarded marks an integration as onboarded in ollama's config.
func MarkIntegrationOnboarded(appName string) error { func integrationOnboarded(appName string) error {
cfg, err := load() cfg, err := load()
if err != nil { if err != nil {
return err return err
@@ -176,7 +174,7 @@ func MarkIntegrationOnboarded(appName string) error {
// IntegrationModel returns the first configured model for an integration, or empty string if not configured. // IntegrationModel returns the first configured model for an integration, or empty string if not configured.
func IntegrationModel(appName string) string { func IntegrationModel(appName string) string {
integrationConfig, err := LoadIntegration(appName) integrationConfig, err := loadIntegration(appName)
if err != nil || len(integrationConfig.Models) == 0 { if err != nil || len(integrationConfig.Models) == 0 {
return "" return ""
} }
@@ -185,7 +183,7 @@ func IntegrationModel(appName string) string {
// IntegrationModels returns all configured models for an integration, or nil. // IntegrationModels returns all configured models for an integration, or nil.
func IntegrationModels(appName string) []string { func IntegrationModels(appName string) []string {
integrationConfig, err := LoadIntegration(appName) integrationConfig, err := loadIntegration(appName)
if err != nil || len(integrationConfig.Models) == 0 { if err != nil || len(integrationConfig.Models) == 0 {
return nil return nil
} }
@@ -230,8 +228,28 @@ func SetLastSelection(selection string) error {
return save(cfg) return save(cfg)
} }
// LoadIntegration returns the saved config for one integration. // ModelExists checks if a model exists on the Ollama server.
func LoadIntegration(appName string) (*integration, error) { func ModelExists(ctx context.Context, name string) bool {
if name == "" {
return false
}
client, err := api.ClientFromEnvironment()
if err != nil {
return false
}
models, err := client.List(ctx)
if err != nil {
return false
}
for _, m := range models.Models {
if m.Name == name || strings.HasPrefix(m.Name, name+":") {
return true
}
}
return false
}
func loadIntegration(appName string) (*integration, error) {
cfg, err := load() cfg, err := load()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -245,8 +263,7 @@ func LoadIntegration(appName string) (*integration, error) {
return integrationConfig, nil return integrationConfig, nil
} }
// SaveAliases replaces the saved aliases for one integration. func saveAliases(appName string, aliases map[string]string) error {
func SaveAliases(appName string, aliases map[string]string) error {
if appName == "" { if appName == "" {
return errors.New("app name cannot be empty") return errors.New("app name cannot be empty")
} }

View File

@@ -1,6 +1,7 @@
package config package config
import ( import (
"context"
"errors" "errors"
"os" "os"
"path/filepath" "path/filepath"
@@ -44,12 +45,12 @@ func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
"primary": "cloud-model", "primary": "cloud-model",
"fast": "cloud-model", "fast": "cloud-model",
} }
if err := SaveAliases("claude", initial); err != nil { if err := saveAliases("claude", initial); err != nil {
t.Fatalf("failed to save initial aliases: %v", err) t.Fatalf("failed to save initial aliases: %v", err)
} }
// Verify both are saved // Verify both are saved
loaded, err := LoadIntegration("claude") loaded, err := loadIntegration("claude")
if err != nil { if err != nil {
t.Fatalf("failed to load: %v", err) t.Fatalf("failed to load: %v", err)
} }
@@ -62,12 +63,12 @@ func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
"primary": "local-model", "primary": "local-model",
// fast intentionally missing // fast intentionally missing
} }
if err := SaveAliases("claude", updated); err != nil { if err := saveAliases("claude", updated); err != nil {
t.Fatalf("failed to save updated aliases: %v", err) t.Fatalf("failed to save updated aliases: %v", err)
} }
// Verify fast is GONE (not merged/preserved) // Verify fast is GONE (not merged/preserved)
loaded, err = LoadIntegration("claude") loaded, err = loadIntegration("claude")
if err != nil { if err != nil {
t.Fatalf("failed to load after update: %v", err) t.Fatalf("failed to load after update: %v", err)
} }
@@ -90,12 +91,12 @@ func TestSaveAliases_PreservesModels(t *testing.T) {
// Then update aliases // Then update aliases
aliases := map[string]string{"primary": "new-model"} aliases := map[string]string{"primary": "new-model"}
if err := SaveAliases("claude", aliases); err != nil { if err := saveAliases("claude", aliases); err != nil {
t.Fatalf("failed to save aliases: %v", err) t.Fatalf("failed to save aliases: %v", err)
} }
// Verify models are preserved // Verify models are preserved
loaded, err := LoadIntegration("claude") loaded, err := loadIntegration("claude")
if err != nil { if err != nil {
t.Fatalf("failed to load: %v", err) t.Fatalf("failed to load: %v", err)
} }
@@ -110,16 +111,16 @@ func TestSaveAliases_EmptyMap(t *testing.T) {
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
// Save with aliases // Save with aliases
if err := SaveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil { if err := saveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
t.Fatalf("failed to save: %v", err) t.Fatalf("failed to save: %v", err)
} }
// Save empty map // Save empty map
if err := SaveAliases("claude", map[string]string{}); err != nil { if err := saveAliases("claude", map[string]string{}); err != nil {
t.Fatalf("failed to save empty: %v", err) t.Fatalf("failed to save empty: %v", err)
} }
loaded, err := LoadIntegration("claude") loaded, err := loadIntegration("claude")
if err != nil { if err != nil {
t.Fatalf("failed to load: %v", err) t.Fatalf("failed to load: %v", err)
} }
@@ -134,16 +135,16 @@ func TestSaveAliases_NilMap(t *testing.T) {
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
// Save with aliases first // Save with aliases first
if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil { if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("failed to save: %v", err) t.Fatalf("failed to save: %v", err)
} }
// Save nil map - should clear aliases // Save nil map - should clear aliases
if err := SaveAliases("claude", nil); err != nil { if err := saveAliases("claude", nil); err != nil {
t.Fatalf("failed to save nil: %v", err) t.Fatalf("failed to save nil: %v", err)
} }
loaded, err := LoadIntegration("claude") loaded, err := loadIntegration("claude")
if err != nil { if err != nil {
t.Fatalf("failed to load: %v", err) t.Fatalf("failed to load: %v", err)
} }
@@ -154,7 +155,7 @@ func TestSaveAliases_NilMap(t *testing.T) {
// TestSaveAliases_EmptyAppName returns error // TestSaveAliases_EmptyAppName returns error
func TestSaveAliases_EmptyAppName(t *testing.T) { func TestSaveAliases_EmptyAppName(t *testing.T) {
err := SaveAliases("", map[string]string{"primary": "model"}) err := saveAliases("", map[string]string{"primary": "model"})
if err == nil { if err == nil {
t.Error("expected error for empty app name") t.Error("expected error for empty app name")
} }
@@ -164,12 +165,12 @@ func TestSaveAliases_CaseInsensitive(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
if err := SaveAliases("Claude", map[string]string{"primary": "model1"}); err != nil { if err := saveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
t.Fatalf("failed to save: %v", err) t.Fatalf("failed to save: %v", err)
} }
// Load with different case // Load with different case
loaded, err := LoadIntegration("claude") loaded, err := loadIntegration("claude")
if err != nil { if err != nil {
t.Fatalf("failed to load: %v", err) t.Fatalf("failed to load: %v", err)
} }
@@ -178,11 +179,11 @@ func TestSaveAliases_CaseInsensitive(t *testing.T) {
} }
// Update with different case // Update with different case
if err := SaveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil { if err := saveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
t.Fatalf("failed to update: %v", err) t.Fatalf("failed to update: %v", err)
} }
loaded, err = LoadIntegration("claude") loaded, err = loadIntegration("claude")
if err != nil { if err != nil {
t.Fatalf("failed to load after update: %v", err) t.Fatalf("failed to load after update: %v", err)
} }
@@ -197,11 +198,11 @@ func TestSaveAliases_CreatesIntegration(t *testing.T) {
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
// Save aliases for non-existent integration // Save aliases for non-existent integration
if err := SaveAliases("newintegration", map[string]string{"primary": "model"}); err != nil { if err := saveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("failed to save: %v", err) t.Fatalf("failed to save: %v", err)
} }
loaded, err := LoadIntegration("newintegration") loaded, err := loadIntegration("newintegration")
if err != nil { if err != nil {
t.Fatalf("failed to load: %v", err) t.Fatalf("failed to load: %v", err)
} }
@@ -370,12 +371,12 @@ func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) {
t.Fatal("server should succeed") t.Fatal("server should succeed")
} }
if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil { if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("saveAliases failed: %v", err) t.Fatalf("saveAliases failed: %v", err)
} }
// Verify it was actually saved // Verify it was actually saved
loaded, err := LoadIntegration("claude") loaded, err := loadIntegration("claude")
if err != nil { if err != nil {
t.Fatalf("failed to load: %v", err) t.Fatalf("failed to load: %v", err)
} }
@@ -407,7 +408,7 @@ func TestConfigFile_PreservesUnknownFields(t *testing.T) {
os.WriteFile(configPath, []byte(initialConfig), 0o644) os.WriteFile(configPath, []byte(initialConfig), 0o644)
// Update aliases // Update aliases
if err := SaveAliases("claude", map[string]string{"primary": "model2"}); err != nil { if err := saveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
t.Fatalf("failed to save: %v", err) t.Fatalf("failed to save: %v", err)
} }
@@ -439,6 +440,11 @@ func containsHelper(s, substr string) bool {
return false return false
} }
func TestClaudeImplementsAliasConfigurer(t *testing.T) {
c := &Claude{}
var _ AliasConfigurer = c // Compile-time check
}
func TestModelNameEdgeCases(t *testing.T) { func TestModelNameEdgeCases(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
@@ -458,11 +464,11 @@ func TestModelNameEdgeCases(t *testing.T) {
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
aliases := map[string]string{"primary": tc.model} aliases := map[string]string{"primary": tc.model}
if err := SaveAliases("claude", aliases); err != nil { if err := saveAliases("claude", aliases); err != nil {
t.Fatalf("failed to save model %q: %v", tc.model, err) t.Fatalf("failed to save model %q: %v", tc.model, err)
} }
loaded, err := LoadIntegration("claude") loaded, err := loadIntegration("claude")
if err != nil { if err != nil {
t.Fatalf("failed to load: %v", err) t.Fatalf("failed to load: %v", err)
} }
@@ -479,7 +485,7 @@ func TestSwitchingScenarios(t *testing.T) {
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
// Initial cloud config // Initial cloud config
if err := SaveAliases("claude", map[string]string{ if err := saveAliases("claude", map[string]string{
"primary": "cloud-model", "primary": "cloud-model",
"fast": "cloud-model", "fast": "cloud-model",
}); err != nil { }); err != nil {
@@ -487,13 +493,13 @@ func TestSwitchingScenarios(t *testing.T) {
} }
// Switch to local (no fast) // Switch to local (no fast)
if err := SaveAliases("claude", map[string]string{ if err := saveAliases("claude", map[string]string{
"primary": "local-model", "primary": "local-model",
}); err != nil { }); err != nil {
t.Fatal(err) t.Fatal(err)
} }
loaded, _ := LoadIntegration("claude") loaded, _ := loadIntegration("claude")
if loaded.Aliases["fast"] != "" { if loaded.Aliases["fast"] != "" {
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"]) t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
} }
@@ -507,21 +513,21 @@ func TestSwitchingScenarios(t *testing.T) {
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
// Initial local config // Initial local config
if err := SaveAliases("claude", map[string]string{ if err := saveAliases("claude", map[string]string{
"primary": "local-model", "primary": "local-model",
}); err != nil { }); err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Switch to cloud (with fast) // Switch to cloud (with fast)
if err := SaveAliases("claude", map[string]string{ if err := saveAliases("claude", map[string]string{
"primary": "cloud-model", "primary": "cloud-model",
"fast": "cloud-model", "fast": "cloud-model",
}); err != nil { }); err != nil {
t.Fatal(err) t.Fatal(err)
} }
loaded, _ := LoadIntegration("claude") loaded, _ := loadIntegration("claude")
if loaded.Aliases["fast"] != "cloud-model" { if loaded.Aliases["fast"] != "cloud-model" {
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"]) t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
} }
@@ -532,7 +538,7 @@ func TestSwitchingScenarios(t *testing.T) {
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
// Initial cloud config // Initial cloud config
if err := SaveAliases("claude", map[string]string{ if err := saveAliases("claude", map[string]string{
"primary": "cloud-model-1", "primary": "cloud-model-1",
"fast": "cloud-model-1", "fast": "cloud-model-1",
}); err != nil { }); err != nil {
@@ -540,14 +546,14 @@ func TestSwitchingScenarios(t *testing.T) {
} }
// Switch to different cloud // Switch to different cloud
if err := SaveAliases("claude", map[string]string{ if err := saveAliases("claude", map[string]string{
"primary": "cloud-model-2", "primary": "cloud-model-2",
"fast": "cloud-model-2", "fast": "cloud-model-2",
}); err != nil { }); err != nil {
t.Fatal(err) t.Fatal(err)
} }
loaded, _ := LoadIntegration("claude") loaded, _ := loadIntegration("claude")
if loaded.Aliases["primary"] != "cloud-model-2" { if loaded.Aliases["primary"] != "cloud-model-2" {
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"]) t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
} }
@@ -557,13 +563,43 @@ func TestSwitchingScenarios(t *testing.T) {
}) })
} }
func TestToolCapabilityFiltering(t *testing.T) {
t.Run("all models checked for tool capability", func(t *testing.T) {
// Both cloud and local models are checked for tool capability via Show API
// Only models with "tools" in capabilities are included
m := modelInfo{Name: "tool-model", Remote: false, ToolCapable: true}
if !m.ToolCapable {
t.Error("tool capable model should be marked as such")
}
})
t.Run("modelInfo includes ToolCapable field", func(t *testing.T) {
m := modelInfo{Name: "test", Remote: true, ToolCapable: true}
if !m.ToolCapable {
t.Error("ToolCapable field should be accessible")
}
})
}
func TestIsCloudModel_RequiresClient(t *testing.T) {
t.Run("nil client always returns false", func(t *testing.T) {
// isCloudModel now only uses Show API, no suffix detection
if isCloudModel(context.Background(), nil, "model:cloud") {
t.Error("nil client should return false regardless of suffix")
}
if isCloudModel(context.Background(), nil, "local-model") {
t.Error("nil client should return false")
}
})
}
func TestModelsAndAliasesMustStayInSync(t *testing.T) { func TestModelsAndAliasesMustStayInSync(t *testing.T) {
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) { t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
// Save aliases with one model // Save aliases with one model
if err := SaveAliases("claude", map[string]string{"primary": "model-a"}); err != nil { if err := saveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -572,7 +608,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
loaded, _ := LoadIntegration("claude") loaded, _ := loadIntegration("claude")
if loaded.Aliases["primary"] != loaded.Models[0] { if loaded.Aliases["primary"] != loaded.Models[0] {
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0]) t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
} }
@@ -586,11 +622,11 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
if err := SaveIntegration("claude", []string{"old-model"}); err != nil { if err := SaveIntegration("claude", []string{"old-model"}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := SaveAliases("claude", map[string]string{"primary": "new-model"}); err != nil { if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
loaded, _ := LoadIntegration("claude") loaded, _ := loadIntegration("claude")
// They should be different (this is the bug state) // They should be different (this is the bug state)
if loaded.Models[0] == loaded.Aliases["primary"] { if loaded.Models[0] == loaded.Aliases["primary"] {
@@ -602,7 +638,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
loaded, _ = LoadIntegration("claude") loaded, _ = loadIntegration("claude")
if loaded.Models[0] != loaded.Aliases["primary"] { if loaded.Models[0] != loaded.Aliases["primary"] {
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)", t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
loaded.Models[0], loaded.Aliases["primary"]) loaded.Models[0], loaded.Aliases["primary"])
@@ -617,20 +653,20 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
if err := SaveIntegration("claude", []string{"initial-model"}); err != nil { if err := SaveIntegration("claude", []string{"initial-model"}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := SaveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil { if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Update aliases AND models together // Update aliases AND models together
newAliases := map[string]string{"primary": "updated-model"} newAliases := map[string]string{"primary": "updated-model"}
if err := SaveAliases("claude", newAliases); err != nil { if err := saveAliases("claude", newAliases); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := SaveIntegration("claude", []string{newAliases["primary"]}); err != nil { if err := SaveIntegration("claude", []string{newAliases["primary"]}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
loaded, _ := LoadIntegration("claude") loaded, _ := loadIntegration("claude")
if loaded.Models[0] != "updated-model" { if loaded.Models[0] != "updated-model" {
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0]) t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
} }

View File

@@ -10,10 +10,17 @@ import (
// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests // setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests
func setTestHome(t *testing.T, dir string) { func setTestHome(t *testing.T, dir string) {
t.Setenv("HOME", dir) t.Setenv("HOME", dir)
t.Setenv("TMPDIR", dir)
t.Setenv("USERPROFILE", dir) t.Setenv("USERPROFILE", dir)
} }
// editorPaths is a test helper that safely calls Paths if the runner implements Editor
func editorPaths(r Runner) []string {
if editor, ok := r.(Editor); ok {
return editor.Paths()
}
return nil
}
func TestIntegrationConfig(t *testing.T) { func TestIntegrationConfig(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
@@ -24,7 +31,7 @@ func TestIntegrationConfig(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
config, err := LoadIntegration("claude") config, err := loadIntegration("claude")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -48,11 +55,11 @@ func TestIntegrationConfig(t *testing.T) {
"primary": "llama3.2:70b", "primary": "llama3.2:70b",
"fast": "llama3.2:8b", "fast": "llama3.2:8b",
} }
if err := SaveAliases("claude", aliases); err != nil { if err := saveAliases("claude", aliases); err != nil {
t.Fatal(err) t.Fatal(err)
} }
config, err := LoadIntegration("claude") config, err := loadIntegration("claude")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -70,14 +77,14 @@ func TestIntegrationConfig(t *testing.T) {
if err := SaveIntegration("claude", []string{"model-a"}); err != nil { if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := SaveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil { if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := SaveIntegration("claude", []string{"model-b"}); err != nil { if err := SaveIntegration("claude", []string{"model-b"}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
config, err := LoadIntegration("claude") config, err := loadIntegration("claude")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -89,7 +96,7 @@ func TestIntegrationConfig(t *testing.T) {
t.Run("defaultModel returns first model", func(t *testing.T) { t.Run("defaultModel returns first model", func(t *testing.T) {
SaveIntegration("codex", []string{"model-a", "model-b"}) SaveIntegration("codex", []string{"model-a", "model-b"})
config, _ := LoadIntegration("codex") config, _ := loadIntegration("codex")
defaultModel := "" defaultModel := ""
if len(config.Models) > 0 { if len(config.Models) > 0 {
defaultModel = config.Models[0] defaultModel = config.Models[0]
@@ -113,7 +120,7 @@ func TestIntegrationConfig(t *testing.T) {
t.Run("app name is case-insensitive", func(t *testing.T) { t.Run("app name is case-insensitive", func(t *testing.T) {
SaveIntegration("Claude", []string{"model-x"}) SaveIntegration("Claude", []string{"model-x"})
config, err := LoadIntegration("claude") config, err := loadIntegration("claude")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -130,8 +137,8 @@ func TestIntegrationConfig(t *testing.T) {
SaveIntegration("app1", []string{"model-1"}) SaveIntegration("app1", []string{"model-1"})
SaveIntegration("app2", []string{"model-2"}) SaveIntegration("app2", []string{"model-2"})
config1, _ := LoadIntegration("app1") config1, _ := loadIntegration("app1")
config2, _ := LoadIntegration("app2") config2, _ := loadIntegration("app2")
defaultModel1 := "" defaultModel1 := ""
if len(config1.Models) > 0 { if len(config1.Models) > 0 {
@@ -178,6 +185,64 @@ func TestListIntegrations(t *testing.T) {
}) })
} }
func TestEditorPaths(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("returns empty for claude (no Editor)", func(t *testing.T) {
r := integrations["claude"]
paths := editorPaths(r)
if len(paths) != 0 {
t.Errorf("expected no paths for claude, got %v", paths)
}
})
t.Run("returns empty for codex (no Editor)", func(t *testing.T) {
r := integrations["codex"]
paths := editorPaths(r)
if len(paths) != 0 {
t.Errorf("expected no paths for codex, got %v", paths)
}
})
t.Run("returns empty for droid when no config exists", func(t *testing.T) {
r := integrations["droid"]
paths := editorPaths(r)
if len(paths) != 0 {
t.Errorf("expected no paths, got %v", paths)
}
})
t.Run("returns path for droid when config exists", func(t *testing.T) {
settingsDir, _ := os.UserHomeDir()
settingsDir = filepath.Join(settingsDir, ".factory")
os.MkdirAll(settingsDir, 0o755)
os.WriteFile(filepath.Join(settingsDir, "settings.json"), []byte(`{}`), 0o644)
r := integrations["droid"]
paths := editorPaths(r)
if len(paths) != 1 {
t.Errorf("expected 1 path, got %d", len(paths))
}
})
t.Run("returns paths for opencode when configs exist", func(t *testing.T) {
home, _ := os.UserHomeDir()
configDir := filepath.Join(home, ".config", "opencode")
stateDir := filepath.Join(home, ".local", "state", "opencode")
os.MkdirAll(configDir, 0o755)
os.MkdirAll(stateDir, 0o755)
os.WriteFile(filepath.Join(configDir, "opencode.json"), []byte(`{}`), 0o644)
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644)
r := integrations["opencode"]
paths := editorPaths(r)
if len(paths) != 2 {
t.Errorf("expected 2 paths, got %d: %v", len(paths), paths)
}
})
}
func TestLoadIntegration_CorruptedJSON(t *testing.T) { func TestLoadIntegration_CorruptedJSON(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
@@ -186,7 +251,7 @@ func TestLoadIntegration_CorruptedJSON(t *testing.T) {
os.MkdirAll(dir, 0o755) os.MkdirAll(dir, 0o755)
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644) os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
_, err := LoadIntegration("test") _, err := loadIntegration("test")
if err == nil { if err == nil {
t.Error("expected error for nonexistent integration in corrupted file") t.Error("expected error for nonexistent integration in corrupted file")
} }
@@ -200,7 +265,7 @@ func TestSaveIntegration_NilModels(t *testing.T) {
t.Fatalf("saveIntegration with nil models failed: %v", err) t.Fatalf("saveIntegration with nil models failed: %v", err)
} }
config, err := LoadIntegration("test") config, err := loadIntegration("test")
if err != nil { if err != nil {
t.Fatalf("loadIntegration failed: %v", err) t.Fatalf("loadIntegration failed: %v", err)
} }
@@ -229,7 +294,7 @@ func TestLoadIntegration_NonexistentIntegration(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
_, err := LoadIntegration("nonexistent") _, err := loadIntegration("nonexistent")
if err == nil { if err == nil {
t.Error("expected error for nonexistent integration, got nil") t.Error("expected error for nonexistent integration, got nil")
} }

View File

@@ -1,14 +1,16 @@
package launch package config
import ( import (
"context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"slices" "slices"
"github.com/ollama/ollama/cmd/internal/fileutil" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
) )
@@ -45,6 +47,25 @@ func (d *Droid) Run(model string, args []string) error {
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart") return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
} }
// Call Edit() to ensure config is up-to-date before launch
models := []string{model}
if config, err := loadIntegration("droid"); err == nil && len(config.Models) > 0 {
models = config.Models
}
var err error
models, err = resolveEditorModels("droid", models, func() ([]string, error) {
return selectModels(context.Background(), "droid", "")
})
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
if err := d.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("droid", args...) cmd := exec.Command("droid", args...)
cmd.Stdin = os.Stdin cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
@@ -90,16 +111,6 @@ func (d *Droid) Edit(models []string) error {
json.Unmarshal(data, &settings) // ignore error, zero values are fine json.Unmarshal(data, &settings) // ignore error, zero values are fine
} }
settingsMap = updateDroidSettings(settingsMap, settings, models)
data, err := json.MarshalIndent(settingsMap, "", " ")
if err != nil {
return err
}
return fileutil.WriteWithBackup(settingsPath, data)
}
func updateDroidSettings(settingsMap map[string]any, settings droidSettings, models []string) map[string]any {
// Keep only non-Ollama models from the raw map (preserves extra fields) // Keep only non-Ollama models from the raw map (preserves extra fields)
// Rebuild Ollama models // Rebuild Ollama models
var nonOllamaModels []any var nonOllamaModels []any
@@ -114,12 +125,13 @@ func updateDroidSettings(settingsMap map[string]any, settings droidSettings, mod
} }
// Build new Ollama model entries with sequential indices (0, 1, 2, ...) // Build new Ollama model entries with sequential indices (0, 1, 2, ...)
client, _ := api.ClientFromEnvironment()
var newModels []any var newModels []any
var defaultModelID string var defaultModelID string
for i, model := range models { for i, model := range models {
maxOutput := 64000 maxOutput := 64000
if isCloudModelName(model) { if isCloudModel(context.Background(), client, model) {
if l, ok := lookupCloudModelLimit(model); ok { if l, ok := lookupCloudModelLimit(model); ok {
maxOutput = l.Output maxOutput = l.Output
} }
@@ -155,7 +167,12 @@ func updateDroidSettings(settingsMap map[string]any, settings droidSettings, mod
} }
settingsMap["sessionDefaultSettings"] = sessionSettings settingsMap["sessionDefaultSettings"] = sessionSettings
return settingsMap
data, err := json.MarshalIndent(settingsMap, "", " ")
if err != nil {
return err
}
return writeWithBackup(settingsPath, data)
} }
func (d *Droid) Models() []string { func (d *Droid) Models() []string {

View File

@@ -1,4 +1,4 @@
package launch package config
import ( import (
"encoding/json" "encoding/json"
@@ -6,8 +6,6 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/ollama/ollama/cmd/internal/fileutil"
) )
func TestDroidIntegration(t *testing.T) { func TestDroidIntegration(t *testing.T) {
@@ -364,7 +362,7 @@ func TestDroidEdit_DuplicateModels(t *testing.T) {
t.Fatalf("Edit with duplicates failed: %v", err) t.Fatalf("Edit with duplicates failed: %v", err)
} }
settings, err := fileutil.ReadJSON(settingsPath) settings, err := readJSONFile(settingsPath)
if err != nil { if err != nil {
t.Fatalf("readJSONFile failed: %v", err) t.Fatalf("readJSONFile failed: %v", err)
} }
@@ -394,7 +392,7 @@ func TestDroidEdit_MalformedModelEntry(t *testing.T) {
} }
// Malformed entries (non-object) are dropped - only valid model objects are preserved // Malformed entries (non-object) are dropped - only valid model objects are preserved
settings, _ := fileutil.ReadJSON(settingsPath) settings, _ := readJSONFile(settingsPath)
customModels, _ := settings["customModels"].([]any) customModels, _ := settings["customModels"].([]any)
// Should have: 1 new Ollama model only (malformed entries dropped) // Should have: 1 new Ollama model only (malformed entries dropped)
@@ -421,7 +419,7 @@ func TestDroidEdit_WrongTypeSessionSettings(t *testing.T) {
} }
// Should create proper sessionDefaultSettings // Should create proper sessionDefaultSettings
settings, _ := fileutil.ReadJSON(settingsPath) settings, _ := readJSONFile(settingsPath)
session, ok := settings["sessionDefaultSettings"].(map[string]any) session, ok := settings["sessionDefaultSettings"].(map[string]any)
if !ok { if !ok {
t.Fatalf("sessionDefaultSettings should be map after setup, got %T", settings["sessionDefaultSettings"]) t.Fatalf("sessionDefaultSettings should be map after setup, got %T", settings["sessionDefaultSettings"])
@@ -1010,34 +1008,34 @@ func TestDroidEdit_ModelNamesWithSpecialCharacters(t *testing.T) {
} }
func TestDroidEdit_MissingCustomModelsKey(t *testing.T) { func TestDroidEdit_MissingCustomModelsKey(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
os.MkdirAll(settingsDir, 0o755)
// No customModels key at all // No customModels key at all
original := `{ original := `{
"diffMode": "github", "diffMode": "github",
"sessionDefaultSettings": {"autonomyMode": "auto-high"} "sessionDefaultSettings": {"autonomyMode": "auto-high"}
}` }`
os.WriteFile(settingsPath, []byte(original), 0o644)
var settingsStruct droidSettings if err := d.Edit([]string{"model-a"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(settingsPath)
var settings map[string]any var settings map[string]any
if err := json.Unmarshal([]byte(original), &settings); err != nil { json.Unmarshal(data, &settings)
t.Fatal(err)
}
if err := json.Unmarshal([]byte(original), &settingsStruct); err != nil {
t.Fatal(err)
}
settings = updateDroidSettings(settings, settingsStruct, []string{"model-a"})
// Original fields preserved // Original fields preserved
if settings["diffMode"] != "github" { if settings["diffMode"] != "github" {
t.Error("diffMode not preserved") t.Error("diffMode not preserved")
} }
session, ok := settings["sessionDefaultSettings"].(map[string]any)
if !ok {
t.Fatal("sessionDefaultSettings not preserved")
}
if session["autonomyMode"] != "auto-high" {
t.Error("sessionDefaultSettings.autonomyMode not preserved")
}
// customModels created // customModels created
models, ok := settings["customModels"].([]any) models, ok := settings["customModels"].([]any)
@@ -1278,18 +1276,26 @@ func TestDroidEdit_LocalModelDefaultMaxOutput(t *testing.T) {
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) { func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
// Verify that every cloud model in cloudModelLimits has a valid output // Verify that every cloud model in cloudModelLimits has a valid output
// value that would be used for maxOutputTokens when the selected model uses // value that would be used for maxOutputTokens when isCloudModel returns true.
// the explicit :cloud source tag. // :cloud suffix stripping must also work since that's how users specify them.
for name, expected := range cloudModelLimits { for name, expected := range cloudModelLimits {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
cloudName := name + ":cloud" l, ok := lookupCloudModelLimit(name)
l, ok := lookupCloudModelLimit(cloudName)
if !ok { if !ok {
t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName) t.Fatalf("lookupCloudModelLimit(%q) returned false", name)
} }
if l.Output != expected.Output { if l.Output != expected.Output {
t.Errorf("output = %d, want %d", l.Output, expected.Output) t.Errorf("output = %d, want %d", l.Output, expected.Output)
} }
// Also verify :cloud suffix lookup
cloudName := name + ":cloud"
l2, ok := lookupCloudModelLimit(cloudName)
if !ok {
t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName)
}
if l2.Output != expected.Output {
t.Errorf(":cloud output = %d, want %d", l2.Output, expected.Output)
}
}) })
} }
} }

View File

@@ -1,6 +1,4 @@
// Package fileutil provides small shared helpers for reading JSON files package config
// and writing config files with backup-on-overwrite semantics.
package fileutil
import ( import (
"bytes" "bytes"
@@ -11,8 +9,7 @@ import (
"time" "time"
) )
// ReadJSON reads a JSON object file into a generic map. func readJSONFile(path string) (map[string]any, error) {
func ReadJSON(path string) (map[string]any, error) {
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -36,13 +33,12 @@ func copyFile(src, dst string) error {
return os.WriteFile(dst, data, info.Mode().Perm()) return os.WriteFile(dst, data, info.Mode().Perm())
} }
// BackupDir returns the shared backup directory used before overwriting files. func backupDir() string {
func BackupDir() string {
return filepath.Join(os.TempDir(), "ollama-backups") return filepath.Join(os.TempDir(), "ollama-backups")
} }
func backupToTmp(srcPath string) (string, error) { func backupToTmp(srcPath string) (string, error) {
dir := BackupDir() dir := backupDir()
if err := os.MkdirAll(dir, 0o755); err != nil { if err := os.MkdirAll(dir, 0o755); err != nil {
return "", err return "", err
} }
@@ -54,8 +50,8 @@ func backupToTmp(srcPath string) (string, error) {
return backupPath, nil return backupPath, nil
} }
// WriteWithBackup writes data to path via temp file + rename, backing up any existing file first. // writeWithBackup writes data to path via temp file + rename, backing up any existing file first
func WriteWithBackup(path string, data []byte) error { func writeWithBackup(path string, data []byte) error {
var backupPath string var backupPath string
// backup must be created before any writes to the target file // backup must be created before any writes to the target file
if existingContent, err := os.ReadFile(path); err == nil { if existingContent, err := os.ReadFile(path); err == nil {

View File

@@ -1,4 +1,4 @@
package fileutil package config
import ( import (
"encoding/json" "encoding/json"
@@ -9,21 +9,6 @@ import (
"testing" "testing"
) )
func TestMain(m *testing.M) {
tmpRoot, err := os.MkdirTemp("", "fileutil-test-*")
if err != nil {
panic(err)
}
if err := os.Setenv("TMPDIR", tmpRoot); err != nil {
panic(err)
}
code := m.Run()
_ = os.RemoveAll(tmpRoot)
os.Exit(code)
}
func mustMarshal(t *testing.T, v any) []byte { func mustMarshal(t *testing.T, v any) []byte {
t.Helper() t.Helper()
data, err := json.MarshalIndent(v, "", " ") data, err := json.MarshalIndent(v, "", " ")
@@ -33,19 +18,14 @@ func mustMarshal(t *testing.T, v any) []byte {
return data return data
} }
func isolatedTempDir(t *testing.T) string {
t.Helper()
return t.TempDir()
}
func TestWriteWithBackup(t *testing.T) { func TestWriteWithBackup(t *testing.T) {
tmpDir := isolatedTempDir(t) tmpDir := t.TempDir()
t.Run("creates file", func(t *testing.T) { t.Run("creates file", func(t *testing.T) {
path := filepath.Join(tmpDir, "new.json") path := filepath.Join(tmpDir, "new.json")
data := mustMarshal(t, map[string]string{"key": "value"}) data := mustMarshal(t, map[string]string{"key": "value"})
if err := WriteWithBackup(path, data); err != nil { if err := writeWithBackup(path, data); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -63,17 +43,17 @@ func TestWriteWithBackup(t *testing.T) {
} }
}) })
t.Run("creates backup in the temp backup directory", func(t *testing.T) { t.Run("creates backup in /tmp/ollama-backups", func(t *testing.T) {
path := filepath.Join(tmpDir, "backup.json") path := filepath.Join(tmpDir, "backup.json")
os.WriteFile(path, []byte(`{"original": true}`), 0o644) os.WriteFile(path, []byte(`{"original": true}`), 0o644)
data := mustMarshal(t, map[string]bool{"updated": true}) data := mustMarshal(t, map[string]bool{"updated": true})
if err := WriteWithBackup(path, data); err != nil { if err := writeWithBackup(path, data); err != nil {
t.Fatal(err) t.Fatal(err)
} }
entries, err := os.ReadDir(BackupDir()) entries, err := os.ReadDir(backupDir())
if err != nil { if err != nil {
t.Fatal("backup directory not created") t.Fatal("backup directory not created")
} }
@@ -83,7 +63,7 @@ func TestWriteWithBackup(t *testing.T) {
if filepath.Ext(entry.Name()) != ".json" { if filepath.Ext(entry.Name()) != ".json" {
name := entry.Name() name := entry.Name()
if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." { if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." {
backupPath := filepath.Join(BackupDir(), name) backupPath := filepath.Join(backupDir(), name)
backup, err := os.ReadFile(backupPath) backup, err := os.ReadFile(backupPath)
if err == nil { if err == nil {
var backupData map[string]bool var backupData map[string]bool
@@ -99,7 +79,7 @@ func TestWriteWithBackup(t *testing.T) {
} }
if !foundBackup { if !foundBackup {
t.Error("backup file not created in backup directory") t.Error("backup file not created in /tmp/ollama-backups")
} }
current, _ := os.ReadFile(path) current, _ := os.ReadFile(path)
@@ -114,11 +94,11 @@ func TestWriteWithBackup(t *testing.T) {
path := filepath.Join(tmpDir, "nobak.json") path := filepath.Join(tmpDir, "nobak.json")
data := mustMarshal(t, map[string]string{"new": "file"}) data := mustMarshal(t, map[string]string{"new": "file"})
if err := WriteWithBackup(path, data); err != nil { if err := writeWithBackup(path, data); err != nil {
t.Fatal(err) t.Fatal(err)
} }
entries, _ := os.ReadDir(BackupDir()) entries, _ := os.ReadDir(backupDir())
for _, entry := range entries { for _, entry := range entries {
if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." { if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." {
t.Error("backup should not exist for new file") t.Error("backup should not exist for new file")
@@ -131,11 +111,11 @@ func TestWriteWithBackup(t *testing.T) {
data := mustMarshal(t, map[string]string{"key": "value"}) data := mustMarshal(t, map[string]string{"key": "value"})
if err := WriteWithBackup(path, data); err != nil { if err := writeWithBackup(path, data); err != nil {
t.Fatal(err) t.Fatal(err)
} }
entries1, _ := os.ReadDir(BackupDir()) entries1, _ := os.ReadDir(backupDir())
countBefore := 0 countBefore := 0
for _, e := range entries1 { for _, e := range entries1 {
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." { if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
@@ -143,11 +123,11 @@ func TestWriteWithBackup(t *testing.T) {
} }
} }
if err := WriteWithBackup(path, data); err != nil { if err := writeWithBackup(path, data); err != nil {
t.Fatal(err) t.Fatal(err)
} }
entries2, _ := os.ReadDir(BackupDir()) entries2, _ := os.ReadDir(backupDir())
countAfter := 0 countAfter := 0
for _, e := range entries2 { for _, e := range entries2 {
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." { if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
@@ -165,11 +145,11 @@ func TestWriteWithBackup(t *testing.T) {
os.WriteFile(path, []byte(`{"v": 1}`), 0o644) os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
data := mustMarshal(t, map[string]int{"v": 2}) data := mustMarshal(t, map[string]int{"v": 2})
if err := WriteWithBackup(path, data); err != nil { if err := writeWithBackup(path, data); err != nil {
t.Fatal(err) t.Fatal(err)
} }
entries, _ := os.ReadDir(BackupDir()) entries, _ := os.ReadDir(backupDir())
var found bool var found bool
for _, entry := range entries { for _, entry := range entries {
name := entry.Name() name := entry.Name()
@@ -181,7 +161,7 @@ func TestWriteWithBackup(t *testing.T) {
} }
} }
found = true found = true
os.Remove(filepath.Join(BackupDir(), name)) os.Remove(filepath.Join(backupDir(), name))
break break
} }
} }
@@ -200,7 +180,7 @@ func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
t.Skip("permission tests unreliable on Windows") t.Skip("permission tests unreliable on Windows")
} }
tmpDir := isolatedTempDir(t) tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "config.json") path := filepath.Join(tmpDir, "config.json")
// Create original file // Create original file
@@ -208,13 +188,13 @@ func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
os.WriteFile(path, originalContent, 0o644) os.WriteFile(path, originalContent, 0o644)
// Make backup directory read-only to force backup failure // Make backup directory read-only to force backup failure
backupDir := BackupDir() backupDir := backupDir()
os.MkdirAll(backupDir, 0o755) os.MkdirAll(backupDir, 0o755)
os.Chmod(backupDir, 0o444) // Read-only os.Chmod(backupDir, 0o444) // Read-only
defer os.Chmod(backupDir, 0o755) defer os.Chmod(backupDir, 0o755)
newContent := []byte(`{"updated": true}`) newContent := []byte(`{"updated": true}`)
err := WriteWithBackup(path, newContent) err := writeWithBackup(path, newContent)
// Should fail because backup couldn't be created // Should fail because backup couldn't be created
if err == nil { if err == nil {
@@ -235,7 +215,7 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
t.Skip("permission tests unreliable on Windows") t.Skip("permission tests unreliable on Windows")
} }
tmpDir := isolatedTempDir(t) tmpDir := t.TempDir()
// Create a read-only directory // Create a read-only directory
readOnlyDir := filepath.Join(tmpDir, "readonly") readOnlyDir := filepath.Join(tmpDir, "readonly")
@@ -244,7 +224,7 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
defer os.Chmod(readOnlyDir, 0o755) defer os.Chmod(readOnlyDir, 0o755)
path := filepath.Join(readOnlyDir, "config.json") path := filepath.Join(readOnlyDir, "config.json")
err := WriteWithBackup(path, []byte(`{"test": true}`)) err := writeWithBackup(path, []byte(`{"test": true}`))
if err == nil { if err == nil {
t.Error("expected permission error, got nil") t.Error("expected permission error, got nil")
@@ -254,10 +234,10 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist. // TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist.
// writeWithBackup doesn't create directories - caller is responsible. // writeWithBackup doesn't create directories - caller is responsible.
func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) { func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) {
tmpDir := isolatedTempDir(t) tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json") path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json")
err := WriteWithBackup(path, []byte(`{"test": true}`)) err := writeWithBackup(path, []byte(`{"test": true}`))
// Should fail because directory doesn't exist // Should fail because directory doesn't exist
if err == nil { if err == nil {
@@ -272,7 +252,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
t.Skip("symlink tests may require admin on Windows") t.Skip("symlink tests may require admin on Windows")
} }
tmpDir := isolatedTempDir(t) tmpDir := t.TempDir()
realFile := filepath.Join(tmpDir, "real.json") realFile := filepath.Join(tmpDir, "real.json")
symlink := filepath.Join(tmpDir, "link.json") symlink := filepath.Join(tmpDir, "link.json")
@@ -281,7 +261,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
os.Symlink(realFile, symlink) os.Symlink(realFile, symlink)
// Write through symlink // Write through symlink
err := WriteWithBackup(symlink, []byte(`{"v": 2}`)) err := writeWithBackup(symlink, []byte(`{"v": 2}`))
if err != nil { if err != nil {
t.Fatalf("writeWithBackup through symlink failed: %v", err) t.Fatalf("writeWithBackup through symlink failed: %v", err)
} }
@@ -296,7 +276,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters. // TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters.
// User may have config files with unusual names. // User may have config files with unusual names.
func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) { func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) {
tmpDir := isolatedTempDir(t) tmpDir := t.TempDir()
// File with spaces and special chars // File with spaces and special chars
path := filepath.Join(tmpDir, "my config (backup).json") path := filepath.Join(tmpDir, "my config (backup).json")
@@ -325,7 +305,7 @@ func TestCopyFile_PreservesPermissions(t *testing.T) {
t.Skip("permission preservation tests unreliable on Windows") t.Skip("permission preservation tests unreliable on Windows")
} }
tmpDir := isolatedTempDir(t) tmpDir := t.TempDir()
src := filepath.Join(tmpDir, "src.json") src := filepath.Join(tmpDir, "src.json")
dst := filepath.Join(tmpDir, "dst.json") dst := filepath.Join(tmpDir, "dst.json")
@@ -347,7 +327,7 @@ func TestCopyFile_PreservesPermissions(t *testing.T) {
// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist. // TestCopyFile_SourceNotFound verifies clear error when source doesn't exist.
func TestCopyFile_SourceNotFound(t *testing.T) { func TestCopyFile_SourceNotFound(t *testing.T) {
tmpDir := isolatedTempDir(t) tmpDir := t.TempDir()
src := filepath.Join(tmpDir, "nonexistent.json") src := filepath.Join(tmpDir, "nonexistent.json")
dst := filepath.Join(tmpDir, "dst.json") dst := filepath.Join(tmpDir, "dst.json")
@@ -359,11 +339,11 @@ func TestCopyFile_SourceNotFound(t *testing.T) {
// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory. // TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory.
func TestWriteWithBackup_TargetIsDirectory(t *testing.T) { func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
tmpDir := isolatedTempDir(t) tmpDir := t.TempDir()
dirPath := filepath.Join(tmpDir, "actualdir") dirPath := filepath.Join(tmpDir, "actualdir")
os.MkdirAll(dirPath, 0o755) os.MkdirAll(dirPath, 0o755)
err := WriteWithBackup(dirPath, []byte(`{"test": true}`)) err := writeWithBackup(dirPath, []byte(`{"test": true}`))
if err == nil { if err == nil {
t.Error("expected error when target is a directory, got nil") t.Error("expected error when target is a directory, got nil")
} }
@@ -371,10 +351,10 @@ func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly. // TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly.
func TestWriteWithBackup_EmptyData(t *testing.T) { func TestWriteWithBackup_EmptyData(t *testing.T) {
tmpDir := isolatedTempDir(t) tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "empty.json") path := filepath.Join(tmpDir, "empty.json")
err := WriteWithBackup(path, []byte{}) err := writeWithBackup(path, []byte{})
if err != nil { if err != nil {
t.Fatalf("writeWithBackup with empty data failed: %v", err) t.Fatalf("writeWithBackup with empty data failed: %v", err)
} }
@@ -395,7 +375,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
t.Skip("permission tests unreliable on Windows") t.Skip("permission tests unreliable on Windows")
} }
tmpDir := isolatedTempDir(t) tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "unreadable.json") path := filepath.Join(tmpDir, "unreadable.json")
// Create file and make it unreadable // Create file and make it unreadable
@@ -404,7 +384,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
defer os.Chmod(path, 0o644) defer os.Chmod(path, 0o644)
// Should fail because we can't read the file to compare/backup // Should fail because we can't read the file to compare/backup
err := WriteWithBackup(path, []byte(`{"updated": true}`)) err := writeWithBackup(path, []byte(`{"updated": true}`))
if err == nil { if err == nil {
t.Error("expected error when file is unreadable, got nil") t.Error("expected error when file is unreadable, got nil")
} }
@@ -413,7 +393,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes // TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes
// within the same second (timestamp collision scenario). // within the same second (timestamp collision scenario).
func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) { func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
tmpDir := isolatedTempDir(t) tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "rapid.json") path := filepath.Join(tmpDir, "rapid.json")
// Create initial file // Create initial file
@@ -422,7 +402,7 @@ func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
// Rapid successive writes // Rapid successive writes
for i := 1; i <= 3; i++ { for i := 1; i <= 3; i++ {
data := []byte(fmt.Sprintf(`{"v": %d}`, i)) data := []byte(fmt.Sprintf(`{"v": %d}`, i))
if err := WriteWithBackup(path, data); err != nil { if err := writeWithBackup(path, data); err != nil {
t.Fatalf("write %d failed: %v", i, err) t.Fatalf("write %d failed: %v", i, err)
} }
} }
@@ -434,7 +414,7 @@ func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
} }
// Verify at least one backup exists // Verify at least one backup exists
entries, _ := os.ReadDir(BackupDir()) entries, _ := os.ReadDir(backupDir())
var backupCount int var backupCount int
for _, e := range entries { for _, e := range entries {
if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." { if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." {
@@ -452,9 +432,8 @@ func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
t.Skip("test modifies system temp directory") t.Skip("test modifies system temp directory")
} }
tmpDir := isolatedTempDir(t)
// Create a file at the backup directory path // Create a file at the backup directory path
backupPath := BackupDir() backupPath := backupDir()
// Clean up any existing directory first // Clean up any existing directory first
os.RemoveAll(backupPath) os.RemoveAll(backupPath)
// Create a file instead of directory // Create a file instead of directory
@@ -464,10 +443,11 @@ func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
os.MkdirAll(backupPath, 0o755) os.MkdirAll(backupPath, 0o755)
}() }()
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "test.json") path := filepath.Join(tmpDir, "test.json")
os.WriteFile(path, []byte(`{"original": true}`), 0o644) os.WriteFile(path, []byte(`{"original": true}`), 0o644)
err := WriteWithBackup(path, []byte(`{"updated": true}`)) err := writeWithBackup(path, []byte(`{"updated": true}`))
if err == nil { if err == nil {
t.Error("expected error when backup dir is a file, got nil") t.Error("expected error when backup dir is a file, got nil")
} }
@@ -479,7 +459,7 @@ func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
t.Skip("permission tests unreliable on Windows") t.Skip("permission tests unreliable on Windows")
} }
tmpDir := isolatedTempDir(t) tmpDir := t.TempDir()
// Count existing temp files // Count existing temp files
countTempFiles := func() int { countTempFiles := func() int {
@@ -513,7 +493,7 @@ func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
badPath := filepath.Join(tmpDir, "isdir") badPath := filepath.Join(tmpDir, "isdir")
os.MkdirAll(badPath, 0o755) os.MkdirAll(badPath, 0o755)
_ = WriteWithBackup(badPath, []byte(`{"test": true}`)) _ = writeWithBackup(badPath, []byte(`{"test": true}`))
after := countTempFiles() after := countTempFiles()
if after > before { if after > before {

1447
cmd/config/integrations.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
package launch package config
import ( import (
"context" "context"
@@ -15,7 +15,6 @@ import (
"time" "time"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
@@ -25,9 +24,6 @@ const defaultGatewayPort = 18789
// Bound model capability probing so launch/config cannot hang on slow/unreachable API calls. // Bound model capability probing so launch/config cannot hang on slow/unreachable API calls.
var openclawModelShowTimeout = 5 * time.Second var openclawModelShowTimeout = 5 * time.Second
// openclawFreshInstall is set to true when ensureOpenclawInstalled performs an install
var openclawFreshInstall bool
type Openclaw struct{} type Openclaw struct{}
func (c *Openclaw) String() string { return "OpenClaw" } func (c *Openclaw) String() string { return "OpenClaw" }
@@ -38,7 +34,10 @@ func (c *Openclaw) Run(model string, args []string) error {
return err return err
} }
firstLaunch := !c.onboarded() firstLaunch := true
if integrationConfig, err := loadIntegration("openclaw"); err == nil {
firstLaunch = !integrationConfig.Onboarded
}
if firstLaunch { if firstLaunch {
fmt.Fprintf(os.Stderr, "\n%sSecurity%s\n\n", ansiBold, ansiReset) fmt.Fprintf(os.Stderr, "\n%sSecurity%s\n\n", ansiBold, ansiReset)
@@ -46,46 +45,28 @@ func (c *Openclaw) Run(model string, args []string) error {
fmt.Fprintf(os.Stderr, " A bad prompt can trick it into doing unsafe things.\n\n") fmt.Fprintf(os.Stderr, " A bad prompt can trick it into doing unsafe things.\n\n")
fmt.Fprintf(os.Stderr, "%s Learn more: https://docs.openclaw.ai/gateway/security%s\n\n", ansiGray, ansiReset) fmt.Fprintf(os.Stderr, "%s Learn more: https://docs.openclaw.ai/gateway/security%s\n\n", ansiGray, ansiReset)
ok, err := ConfirmPrompt("I understand the risks. Continue?") ok, err := confirmPrompt("I understand the risks. Continue?")
if err != nil { if err != nil {
return err return err
} }
if !ok { if !ok {
return nil return nil
} }
}
// Ensure the latest version is installed before onboarding so we get if !c.onboarded() {
// the newest wizard flags (e.g. --auth-choice ollama).
if !openclawFreshInstall {
update := exec.Command(bin, "update")
update.Stdout = os.Stdout
update.Stderr = os.Stderr
_ = update.Run() // best-effort; continue even if update fails
}
fmt.Fprintf(os.Stderr, "\n%sSetting up OpenClaw with Ollama...%s\n", ansiGreen, ansiReset) fmt.Fprintf(os.Stderr, "\n%sSetting up OpenClaw with Ollama...%s\n", ansiGreen, ansiReset)
fmt.Fprintf(os.Stderr, "%s Model: %s%s\n\n", ansiGray, model, ansiReset) fmt.Fprintf(os.Stderr, "%s Model: %s%s\n\n", ansiGray, model, ansiReset)
onboardArgs := []string{ cmd := exec.Command(bin, "onboard",
"onboard",
"--non-interactive", "--non-interactive",
"--accept-risk", "--accept-risk",
"--auth-choice", "ollama", "--auth-choice", "skip",
"--custom-base-url", envconfig.Host().String(), "--gateway-token", "ollama",
"--custom-model-id", model, "--install-daemon",
"--skip-channels", "--skip-channels",
"--skip-skills", "--skip-skills",
} )
if canInstallDaemon() {
onboardArgs = append(onboardArgs, "--install-daemon")
} else {
// When we can't install a daemon (e.g. no systemd, sudo dropped
// XDG_RUNTIME_DIR, or container environment), skip the gateway
// health check so non-interactive onboarding completes. The
// gateway is started as a foreground child process after onboarding.
onboardArgs = append(onboardArgs, "--skip-health")
}
cmd := exec.Command(bin, onboardArgs...)
cmd.Stdin = os.Stdin cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
@@ -94,9 +75,25 @@ func (c *Openclaw) Run(model string, args []string) error {
} }
patchDeviceScopes() patchDeviceScopes()
// Onboarding overwrites openclaw.json, so re-apply the model config
// that Edit() wrote before Run() was called.
if err := c.Edit([]string{model}); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not re-apply model config: %v%s\n", ansiYellow, err, ansiReset)
}
} }
configureOllamaWebSearch() if strings.HasSuffix(model, ":cloud") || strings.HasSuffix(model, "-cloud") {
if ensureWebSearchPlugin() {
registerWebSearchPlugin()
}
}
if firstLaunch {
fmt.Fprintf(os.Stderr, "\n%sPreparing your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
} else {
fmt.Fprintf(os.Stderr, "\n%sStarting your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
}
// When extra args are passed through, run exactly what the user asked for // When extra args are passed through, run exactly what the user asked for
// after setup and skip the built-in gateway+TUI convenience flow. // after setup and skip the built-in gateway+TUI convenience flow.
@@ -109,23 +106,19 @@ func (c *Openclaw) Run(model string, args []string) error {
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
return windowsHint(err) return windowsHint(err)
} }
if firstLaunch {
if err := integrationOnboarded("openclaw"); err != nil {
return fmt.Errorf("failed to save onboarding state: %w", err)
}
}
return nil return nil
} }
if err := c.runChannelSetupPreflight(bin); err != nil {
return err
}
// Keep local pairing scopes up to date before the gateway lifecycle
// (restart/start) regardless of channel preflight branch behavior.
patchDeviceScopes()
fmt.Fprintf(os.Stderr, "\n%sStarting your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
token, port := c.gatewayInfo() token, port := c.gatewayInfo()
addr := fmt.Sprintf("localhost:%d", port) addr := fmt.Sprintf("localhost:%d", port)
// If the gateway is already running (e.g. via the daemon), restart it // If the gateway is already running (e.g. via the daemon), restart it
// so it picks up any config changes (model, provider, etc.). // so it picks up any config changes from Edit() above (model, provider, etc.).
if portOpen(addr) { if portOpen(addr) {
restart := exec.Command(bin, "daemon", "restart") restart := exec.Command(bin, "daemon", "restart")
restart.Env = openclawEnv() restart.Env = openclawEnv()
@@ -172,97 +165,14 @@ func (c *Openclaw) Run(model string, args []string) error {
return windowsHint(err) return windowsHint(err)
} }
if firstLaunch {
if err := integrationOnboarded("openclaw"); err != nil {
return fmt.Errorf("failed to save onboarding state: %w", err)
}
}
return nil return nil
} }
// runChannelSetupPreflight prompts users to connect a messaging channel before
// starting the built-in gateway+TUI flow. In interactive sessions, it loops
// until a channel is configured, unless the user chooses "Set up later".
func (c *Openclaw) runChannelSetupPreflight(bin string) error {
if !isInteractiveSession() {
return nil
}
// --yes is headless; channel setup spawns an interactive picker we can't
// auto-answer, so skip it. Users can run `openclaw channels add` later.
if currentLaunchConfirmPolicy.yes {
return nil
}
for {
if c.channelsConfigured() {
return nil
}
fmt.Fprintf(os.Stderr, "\nYour assistant can message you on WhatsApp, Telegram, Discord, and more.\n\n")
ok, err := ConfirmPromptWithOptions("Connect a channel (messaging app) now?", ConfirmOptions{
YesLabel: "Yes",
NoLabel: "Set up later",
})
if err != nil {
return err
}
if !ok {
return nil
}
cmd := exec.Command(bin, "channels", "add")
cmd.Env = openclawEnv()
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return windowsHint(fmt.Errorf("openclaw channel setup failed: %w\n\nTry running: %s channels add", err, bin))
}
}
}
// channelsConfigured reports whether local OpenClaw config contains at least
// one meaningfully configured channel entry.
func (c *Openclaw) channelsConfigured() bool {
home, err := os.UserHomeDir()
if err != nil {
return false
}
for _, path := range []string{
filepath.Join(home, ".openclaw", "openclaw.json"),
filepath.Join(home, ".clawdbot", "clawdbot.json"),
} {
data, err := os.ReadFile(path)
if err != nil {
continue
}
var cfg map[string]any
if json.Unmarshal(data, &cfg) != nil {
continue
}
channels, _ := cfg["channels"].(map[string]any)
if channels == nil {
return false
}
for key, value := range channels {
if key == "defaults" || key == "modelByChannel" {
continue
}
entry, ok := value.(map[string]any)
if !ok {
continue
}
for entryKey := range entry {
if entryKey != "enabled" {
return true
}
}
}
return false
}
return false
}
// gatewayInfo reads the gateway auth token and port from the OpenClaw config. // gatewayInfo reads the gateway auth token and port from the OpenClaw config.
func (c *Openclaw) gatewayInfo() (token string, port int) { func (c *Openclaw) gatewayInfo() (token string, port int) {
port = defaultGatewayPort port = defaultGatewayPort
@@ -309,9 +219,12 @@ func printOpenclawReady(bin, token string, port int, firstLaunch bool) {
if firstLaunch { if firstLaunch {
fmt.Fprintf(os.Stderr, "%s Quick start:%s\n", ansiBold, ansiReset) fmt.Fprintf(os.Stderr, "%s Quick start:%s\n", ansiBold, ansiReset)
fmt.Fprintf(os.Stderr, "%s /help see all commands%s\n", ansiGray, ansiReset) fmt.Fprintf(os.Stderr, "%s /help see all commands%s\n", ansiGray, ansiReset)
fmt.Fprintf(os.Stderr, "%s %s configure --section channels connect WhatsApp, Telegram, etc.%s\n", ansiGray, bin, ansiReset)
fmt.Fprintf(os.Stderr, "%s %s skills browse and install skills%s\n\n", ansiGray, bin, ansiReset) fmt.Fprintf(os.Stderr, "%s %s skills browse and install skills%s\n\n", ansiGray, bin, ansiReset)
fmt.Fprintf(os.Stderr, "%s The OpenClaw gateway is running in the background.%s\n", ansiYellow, ansiReset) fmt.Fprintf(os.Stderr, "%s The OpenClaw gateway is running in the background.%s\n", ansiYellow, ansiReset)
fmt.Fprintf(os.Stderr, "%s Stop it with: %s gateway stop%s\n\n", ansiYellow, bin, ansiReset) fmt.Fprintf(os.Stderr, "%s Stop it with: %s gateway stop%s\n\n", ansiYellow, bin, ansiReset)
} else {
fmt.Fprintf(os.Stderr, "%sTip: connect WhatsApp, Telegram, and more with: %s configure --section channels%s\n", ansiGray, bin, ansiReset)
} }
} }
@@ -400,10 +313,9 @@ func (c *Openclaw) onboarded() bool {
return lastRunAt != "" return lastRunAt != ""
} }
// patchDeviceScopes upgrades the local CLI device's paired operator scopes so // patchDeviceScopes upgrades the local CLI device's paired scopes to include
// newer gateway auth baselines (approvedScopes) allow launch+TUI reconnects // operator.admin. Only patches the local device, not remote ones.
// without forcing an interactive re-pair. Only patches the local device, // Best-effort: silently returns on any error.
// not remote ones. Best-effort: silently returns on any error.
func patchDeviceScopes() { func patchDeviceScopes() {
home, err := os.UserHomeDir() home, err := os.UserHomeDir()
if err != nil { if err != nil {
@@ -439,15 +351,9 @@ func patchDeviceScopes() {
} }
changed := patchScopes(dev, "scopes", required) changed := patchScopes(dev, "scopes", required)
if patchScopes(dev, "approvedScopes", required) {
changed = true
}
if tokens, ok := dev["tokens"].(map[string]any); ok { if tokens, ok := dev["tokens"].(map[string]any); ok {
for role, tok := range tokens { for _, tok := range tokens {
if tokenMap, ok := tok.(map[string]any); ok { if tokenMap, ok := tok.(map[string]any); ok {
if !isOperatorToken(role, tokenMap) {
continue
}
if patchScopes(tokenMap, "scopes", required) { if patchScopes(tokenMap, "scopes", required) {
changed = true changed = true
} }
@@ -503,33 +409,6 @@ func patchScopes(obj map[string]any, key string, required []string) bool {
return added return added
} }
func isOperatorToken(tokenRole string, token map[string]any) bool {
if strings.EqualFold(strings.TrimSpace(tokenRole), "operator") {
return true
}
role, _ := token["role"].(string)
return strings.EqualFold(strings.TrimSpace(role), "operator")
}
// canInstallDaemon reports whether the openclaw daemon can be installed as a
// background service. Returns false on Linux when systemd is absent (e.g.
// containers) so that --install-daemon is omitted and the gateway is started
// as a foreground child process instead. Returns true in all other cases.
func canInstallDaemon() bool {
if runtime.GOOS != "linux" {
return true
}
// /run/systemd/system exists as a directory when systemd is the init system.
// This is absent in most containers.
fi, err := os.Stat("/run/systemd/system")
if err != nil || !fi.IsDir() {
return false
}
// Even when systemd is the init system, user services require a user
// manager instance. XDG_RUNTIME_DIR being set is a prerequisite.
return os.Getenv("XDG_RUNTIME_DIR") != ""
}
func ensureOpenclawInstalled() (string, error) { func ensureOpenclawInstalled() (string, error) {
if _, err := exec.LookPath("openclaw"); err == nil { if _, err := exec.LookPath("openclaw"); err == nil {
return "openclaw", nil return "openclaw", nil
@@ -538,20 +417,16 @@ func ensureOpenclawInstalled() (string, error) {
return "clawdbot", nil return "clawdbot", nil
} }
_, npmErr := exec.LookPath("npm") if _, err := exec.LookPath("npm"); err != nil {
_, gitErr := exec.LookPath("git") return "", fmt.Errorf("openclaw is not installed and npm was not found\n\n" +
if npmErr != nil || gitErr != nil { "Install Node.js first:\n" +
var missing []string " https://nodejs.org/\n\n" +
if npmErr != nil { "Then rerun:\n" +
missing = append(missing, "npm (Node.js): https://nodejs.org/") " ollama launch\n" +
} "and select OpenClaw")
if gitErr != nil {
missing = append(missing, "git: https://git-scm.com/")
}
return "", fmt.Errorf("OpenClaw is not installed and required dependencies are missing\n\nInstall the following first:\n %s\n\nThen re-run:\n ollama launch openclaw", strings.Join(missing, "\n "))
} }
ok, err := ConfirmPrompt("OpenClaw is not installed. Install with npm?") ok, err := confirmPrompt("OpenClaw is not installed. Install with npm?")
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -573,7 +448,6 @@ func ensureOpenclawInstalled() (string, error) {
} }
fmt.Fprintf(os.Stderr, "%sOpenClaw installed successfully%s\n\n", ansiGreen, ansiReset) fmt.Fprintf(os.Stderr, "%sOpenClaw installed successfully%s\n\n", ansiGreen, ansiReset)
openclawFreshInstall = true
return "openclaw", nil return "openclaw", nil
} }
@@ -628,7 +502,7 @@ func (c *Openclaw) Edit(models []string) error {
ollama = make(map[string]any) ollama = make(map[string]any)
} }
ollama["baseUrl"] = envconfig.Host().String() ollama["baseUrl"] = envconfig.Host().String() + "/v1"
// needed to register provider // needed to register provider
ollama["apiKey"] = "ollama-local" ollama["apiKey"] = "ollama-local"
ollama["api"] = "ollama" ollama["api"] = "ollama"
@@ -687,7 +561,7 @@ func (c *Openclaw) Edit(models []string) error {
if err != nil { if err != nil {
return err return err
} }
if err := fileutil.WriteWithBackup(configPath, data); err != nil { if err := writeWithBackup(configPath, data); err != nil {
return err return err
} }
@@ -718,8 +592,6 @@ func clearSessionModelOverride(primary string) {
if override, _ := sess["modelOverride"].(string); override != "" && override != primary { if override, _ := sess["modelOverride"].(string); override != "" && override != primary {
delete(sess, "modelOverride") delete(sess, "modelOverride")
delete(sess, "providerOverride") delete(sess, "providerOverride")
}
if model, _ := sess["model"].(string); model != "" && model != primary {
sess["model"] = primary sess["model"] = primary
changed = true changed = true
} }
@@ -734,13 +606,57 @@ func clearSessionModelOverride(primary string) {
_ = os.WriteFile(path, out, 0o600) _ = os.WriteFile(path, out, 0o600)
} }
// configureOllamaWebSearch keeps launch-managed OpenClaw installs on the const webSearchNpmPackage = "@ollama/openclaw-web-search"
// bundled Ollama web_search provider. Older launch builds installed an
// external openclaw-web-search plugin that added custom ollama_web_search and // ensureWebSearchPlugin installs the openclaw-web-search extension into the
// ollama_web_fetch tools. Current OpenClaw versions ship Ollama web_search as // user-level extensions directory (~/.openclaw/extensions/) if it isn't already
// the bundled "ollama" plugin instead, so we migrate stale config and ensure // present. Returns true if the extension is available.
// fresh installs select the bundled provider. func ensureWebSearchPlugin() bool {
func configureOllamaWebSearch() { home, err := os.UserHomeDir()
if err != nil {
return false
}
pluginDir := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
if _, err := os.Stat(filepath.Join(pluginDir, "index.ts")); err == nil {
return true // already installed
}
npmBin, err := exec.LookPath("npm")
if err != nil {
return false
}
if err := os.MkdirAll(pluginDir, 0o755); err != nil {
return false
}
// Download the tarball via `npm pack`, extract it flat into the plugin dir.
pack := exec.Command(npmBin, "pack", webSearchNpmPackage, "--pack-destination", pluginDir)
out, err := pack.Output()
if err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not download web search plugin: %v%s\n", ansiYellow, err, ansiReset)
return false
}
tgzName := strings.TrimSpace(string(out))
tgzPath := filepath.Join(pluginDir, tgzName)
defer os.Remove(tgzPath)
tar := exec.Command("tar", "xzf", tgzPath, "--strip-components=1", "-C", pluginDir)
if err := tar.Run(); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not extract web search plugin: %v%s\n", ansiYellow, err, ansiReset)
return false
}
fmt.Fprintf(os.Stderr, "%s ✓ Installed web search plugin%s\n", ansiGreen, ansiReset)
return true
}
// registerWebSearchPlugin adds plugins.entries.openclaw-web-search to the OpenClaw
// config so the gateway activates it on next start. Best-effort; silently returns
// on any error.
func registerWebSearchPlugin() {
home, err := os.UserHomeDir() home, err := os.UserHomeDir()
if err != nil { if err != nil {
return return
@@ -755,8 +671,6 @@ func configureOllamaWebSearch() {
return return
} }
stalePluginConfigured := false
plugins, _ := config["plugins"].(map[string]any) plugins, _ := config["plugins"].(map[string]any)
if plugins == nil { if plugins == nil {
plugins = make(map[string]any) plugins = make(map[string]any)
@@ -765,6 +679,14 @@ func configureOllamaWebSearch() {
if entries == nil { if entries == nil {
entries = make(map[string]any) entries = make(map[string]any)
} }
if _, ok := entries["openclaw-web-search"]; ok {
return // already registered
}
entries["openclaw-web-search"] = map[string]any{"enabled": true}
plugins["entries"] = entries
config["plugins"] = plugins
// Disable the built-in web search since our plugin replaces it.
tools, _ := config["tools"].(map[string]any) tools, _ := config["tools"].(map[string]any)
if tools == nil { if tools == nil {
tools = make(map[string]any) tools = make(map[string]any)
@@ -773,92 +695,8 @@ func configureOllamaWebSearch() {
if web == nil { if web == nil {
web = make(map[string]any) web = make(map[string]any)
} }
search, _ := web["search"].(map[string]any) web["search"] = map[string]any{"enabled": false}
if search == nil {
search = make(map[string]any)
}
fetch, _ := web["fetch"].(map[string]any)
if fetch == nil {
fetch = make(map[string]any)
}
alsoAllow, _ := tools["alsoAllow"].([]any)
var filteredAlsoAllow []any
for _, v := range alsoAllow {
s, ok := v.(string)
if !ok {
filteredAlsoAllow = append(filteredAlsoAllow, v)
continue
}
if s == "ollama_web_search" || s == "ollama_web_fetch" {
stalePluginConfigured = true
continue
}
filteredAlsoAllow = append(filteredAlsoAllow, v)
}
if len(filteredAlsoAllow) > 0 {
tools["alsoAllow"] = filteredAlsoAllow
} else {
delete(tools, "alsoAllow")
}
if _, ok := entries["openclaw-web-search"]; ok {
delete(entries, "openclaw-web-search")
stalePluginConfigured = true
}
ollamaEntry, _ := entries["ollama"].(map[string]any)
if ollamaEntry == nil {
ollamaEntry = make(map[string]any)
}
ollamaEntry["enabled"] = true
entries["ollama"] = ollamaEntry
plugins["entries"] = entries
if allow, ok := plugins["allow"].([]any); ok {
var nextAllow []any
hasOllama := false
for _, v := range allow {
s, ok := v.(string)
if ok && s == "openclaw-web-search" {
stalePluginConfigured = true
continue
}
if ok && s == "ollama" {
hasOllama = true
}
nextAllow = append(nextAllow, v)
}
if !hasOllama {
nextAllow = append(nextAllow, "ollama")
}
plugins["allow"] = nextAllow
}
if installs, ok := plugins["installs"].(map[string]any); ok {
if _, exists := installs["openclaw-web-search"]; exists {
delete(installs, "openclaw-web-search")
stalePluginConfigured = true
}
if len(installs) > 0 {
plugins["installs"] = installs
} else {
delete(plugins, "installs")
}
}
if stalePluginConfigured || search["provider"] == nil {
search["provider"] = "ollama"
}
if stalePluginConfigured {
fetch["enabled"] = true
}
search["enabled"] = true
web["search"] = search
if len(fetch) > 0 {
web["fetch"] = fetch
}
tools["web"] = web tools["web"] = web
config["plugins"] = plugins
config["tools"] = tools config["tools"] = tools
out, err := json.MarshalIndent(config, "", " ") out, err := json.MarshalIndent(config, "", " ")
@@ -938,9 +776,9 @@ func (c *Openclaw) Models() []string {
return nil return nil
} }
config, err := fileutil.ReadJSON(filepath.Join(home, ".openclaw", "openclaw.json")) config, err := readJSONFile(filepath.Join(home, ".openclaw", "openclaw.json"))
if err != nil { if err != nil {
config, err = fileutil.ReadJSON(filepath.Join(home, ".clawdbot", "clawdbot.json")) config, err = readJSONFile(filepath.Join(home, ".clawdbot", "clawdbot.json"))
if err != nil { if err != nil {
return nil return nil
} }

File diff suppressed because it is too large Load Diff

279
cmd/config/opencode.go Normal file
View File

@@ -0,0 +1,279 @@
package config
import (
"context"
"encoding/json"
"errors"
"fmt"
"maps"
"os"
"os/exec"
"path/filepath"
"slices"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
// OpenCode implements Runner and Editor for OpenCode integration
type OpenCode struct{}
// cloudModelLimit holds context and output token limits for a cloud model.
type cloudModelLimit struct {
Context int
Output int
}
// lookupCloudModelLimit returns the token limits for a cloud model.
// It tries the exact name first, then strips the ":cloud" suffix.
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
if l, ok := cloudModelLimits[name]; ok {
return l, true
}
base := strings.TrimSuffix(name, ":cloud")
if base != name {
if l, ok := cloudModelLimits[base]; ok {
return l, true
}
}
return cloudModelLimit{}, false
}
func (o *OpenCode) String() string { return "OpenCode" }
func (o *OpenCode) Run(model string, args []string) error {
if _, err := exec.LookPath("opencode"); err != nil {
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
}
// Call Edit() to ensure config is up-to-date before launch
models := []string{model}
if config, err := loadIntegration("opencode"); err == nil && len(config.Models) > 0 {
models = config.Models
}
var err error
models, err = resolveEditorModels("opencode", models, func() ([]string, error) {
return selectModels(context.Background(), "opencode", "")
})
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
if err := o.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("opencode", args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func (o *OpenCode) Paths() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
var paths []string
p := filepath.Join(home, ".config", "opencode", "opencode.json")
if _, err := os.Stat(p); err == nil {
paths = append(paths, p)
}
sp := filepath.Join(home, ".local", "state", "opencode", "model.json")
if _, err := os.Stat(sp); err == nil {
paths = append(paths, sp)
}
return paths
}
func (o *OpenCode) Edit(modelList []string) error {
if len(modelList) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".config", "opencode", "opencode.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config) // Ignore parse errors; treat missing/corrupt files as empty
}
config["$schema"] = "https://opencode.ai/config.json"
provider, ok := config["provider"].(map[string]any)
if !ok {
provider = make(map[string]any)
}
ollama, ok := provider["ollama"].(map[string]any)
if !ok {
ollama = map[string]any{
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama (local)",
"options": map[string]any{
"baseURL": envconfig.Host().String() + "/v1",
},
}
}
models, ok := ollama["models"].(map[string]any)
if !ok {
models = make(map[string]any)
}
selectedSet := make(map[string]bool)
for _, m := range modelList {
selectedSet[m] = true
}
for name, cfg := range models {
if cfgMap, ok := cfg.(map[string]any); ok {
if isOllamaModel(cfgMap) && !selectedSet[name] {
delete(models, name)
}
}
}
client, _ := api.ClientFromEnvironment()
for _, model := range modelList {
if existing, ok := models[model].(map[string]any); ok {
// migrate existing models without _launch marker
if isOllamaModel(existing) {
existing["_launch"] = true
if name, ok := existing["name"].(string); ok {
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
}
}
if isCloudModel(context.Background(), client, model) {
if l, ok := lookupCloudModelLimit(model); ok {
existing["limit"] = map[string]any{
"context": l.Context,
"output": l.Output,
}
}
}
continue
}
entry := map[string]any{
"name": model,
"_launch": true,
}
if isCloudModel(context.Background(), client, model) {
if l, ok := lookupCloudModelLimit(model); ok {
entry["limit"] = map[string]any{
"context": l.Context,
"output": l.Output,
}
}
}
models[model] = entry
}
ollama["models"] = models
provider["ollama"] = ollama
config["provider"] = provider
configData, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
if err := writeWithBackup(configPath, configData); err != nil {
return err
}
statePath := filepath.Join(home, ".local", "state", "opencode", "model.json")
if err := os.MkdirAll(filepath.Dir(statePath), 0o755); err != nil {
return err
}
state := map[string]any{
"recent": []any{},
"favorite": []any{},
"variant": map[string]any{},
}
if data, err := os.ReadFile(statePath); err == nil {
_ = json.Unmarshal(data, &state) // Ignore parse errors; use defaults
}
recent, _ := state["recent"].([]any)
modelSet := make(map[string]bool)
for _, m := range modelList {
modelSet[m] = true
}
// Filter out existing Ollama models we're about to re-add
newRecent := slices.DeleteFunc(slices.Clone(recent), func(entry any) bool {
e, ok := entry.(map[string]any)
if !ok || e["providerID"] != "ollama" {
return false
}
modelID, _ := e["modelID"].(string)
return modelSet[modelID]
})
// Prepend models in reverse order so first model ends up first
for _, model := range slices.Backward(modelList) {
newRecent = slices.Insert(newRecent, 0, any(map[string]any{
"providerID": "ollama",
"modelID": model,
}))
}
const maxRecentModels = 10
newRecent = newRecent[:min(len(newRecent), maxRecentModels)]
state["recent"] = newRecent
stateData, err := json.MarshalIndent(state, "", " ")
if err != nil {
return err
}
return writeWithBackup(statePath, stateData)
}
func (o *OpenCode) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
config, err := readJSONFile(filepath.Join(home, ".config", "opencode", "opencode.json"))
if err != nil {
return nil
}
provider, _ := config["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
models, _ := ollama["models"].(map[string]any)
if len(models) == 0 {
return nil
}
keys := slices.Collect(maps.Keys(models))
slices.Sort(keys)
return keys
}
// isOllamaModel reports whether a model config entry is managed by us
func isOllamaModel(cfg map[string]any) bool {
if v, ok := cfg["_launch"].(bool); ok && v {
return true
}
// previously used [Ollama] as a suffix for the model managed by ollama launch
if name, ok := cfg["name"].(string); ok {
return strings.HasSuffix(name, "[Ollama]")
}
return false
}

668
cmd/config/opencode_test.go Normal file
View File

@@ -0,0 +1,668 @@
package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
)
func TestOpenCodeIntegration(t *testing.T) {
o := &OpenCode{}
t.Run("String", func(t *testing.T) {
if got := o.String(); got != "OpenCode" {
t.Errorf("String() = %q, want %q", got, "OpenCode")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = o
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = o
})
}
func TestOpenCodeEdit(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
statePath := filepath.Join(stateDir, "model.json")
cleanup := func() {
os.RemoveAll(configDir)
os.RemoveAll(stateDir)
}
t.Run("fresh install", func(t *testing.T) {
cleanup()
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertOpenCodeModelExists(t, configPath, "llama3.2")
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
})
t.Run("preserve other providers", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"provider":{"anthropic":{"apiKey":"xxx"}}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider := cfg["provider"].(map[string]any)
if provider["anthropic"] == nil {
t.Error("anthropic provider was removed")
}
assertOpenCodeModelExists(t, configPath, "llama3.2")
})
t.Run("preserve other models", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"mistral":{"name":"Mistral"}}}}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertOpenCodeModelExists(t, configPath, "mistral")
assertOpenCodeModelExists(t, configPath, "llama3.2")
})
t.Run("update existing model", func(t *testing.T) {
cleanup()
o.Edit([]string{"llama3.2"})
o.Edit([]string{"llama3.2"})
assertOpenCodeModelExists(t, configPath, "llama3.2")
})
t.Run("preserve top-level keys", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"theme":"dark","keybindings":{}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
if cfg["theme"] != "dark" {
t.Error("theme was removed")
}
if cfg["keybindings"] == nil {
t.Error("keybindings was removed")
}
})
t.Run("model state - insert at index 0", func(t *testing.T) {
cleanup()
os.MkdirAll(stateDir, 0o755)
os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
assertOpenCodeRecentModel(t, statePath, 1, "anthropic", "claude")
})
t.Run("model state - preserve favorites and variants", func(t *testing.T) {
cleanup()
os.MkdirAll(stateDir, 0o755)
os.WriteFile(statePath, []byte(`{"recent":[],"favorite":[{"providerID":"x","modelID":"y"}],"variant":{"a":"b"}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(statePath)
var state map[string]any
json.Unmarshal(data, &state)
if len(state["favorite"].([]any)) != 1 {
t.Error("favorite was modified")
}
if state["variant"].(map[string]any)["a"] != "b" {
t.Error("variant was modified")
}
})
t.Run("model state - deduplicate on re-add", func(t *testing.T) {
cleanup()
os.MkdirAll(stateDir, 0o755)
os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"ollama","modelID":"llama3.2"},{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(statePath)
var state map[string]any
json.Unmarshal(data, &state)
recent := state["recent"].([]any)
if len(recent) != 2 {
t.Errorf("expected 2 recent entries, got %d", len(recent))
}
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
})
t.Run("remove model", func(t *testing.T) {
cleanup()
// First add two models
o.Edit([]string{"llama3.2", "mistral"})
assertOpenCodeModelExists(t, configPath, "llama3.2")
assertOpenCodeModelExists(t, configPath, "mistral")
// Then remove one by only selecting the other
o.Edit([]string{"llama3.2"})
assertOpenCodeModelExists(t, configPath, "llama3.2")
assertOpenCodeModelNotExists(t, configPath, "mistral")
})
t.Run("preserve user customizations on managed models", func(t *testing.T) {
cleanup()
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
// Add custom fields to the model entry (simulating user edits)
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider := cfg["provider"].(map[string]any)
ollama := provider["ollama"].(map[string]any)
models := ollama["models"].(map[string]any)
entry := models["llama3.2"].(map[string]any)
entry["_myPref"] = "custom-value"
entry["_myNum"] = 42
configData, _ := json.MarshalIndent(cfg, "", " ")
os.WriteFile(configPath, configData, 0o644)
// Re-run Edit — should preserve custom fields
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ = os.ReadFile(configPath)
json.Unmarshal(data, &cfg)
provider = cfg["provider"].(map[string]any)
ollama = provider["ollama"].(map[string]any)
models = ollama["models"].(map[string]any)
entry = models["llama3.2"].(map[string]any)
if entry["_myPref"] != "custom-value" {
t.Errorf("_myPref was lost: got %v", entry["_myPref"])
}
if entry["_myNum"] != float64(42) {
t.Errorf("_myNum was lost: got %v", entry["_myNum"])
}
if v, ok := entry["_launch"].(bool); !ok || !v {
t.Errorf("_launch marker missing or false: got %v", entry["_launch"])
}
})
t.Run("migrate legacy [Ollama] suffix entries", func(t *testing.T) {
cleanup()
// Write a config with a legacy entry (has [Ollama] suffix but no _launch marker)
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"llama3.2":{"name":"llama3.2 [Ollama]"}}}}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider := cfg["provider"].(map[string]any)
ollama := provider["ollama"].(map[string]any)
models := ollama["models"].(map[string]any)
entry := models["llama3.2"].(map[string]any)
// _launch marker should be added
if v, ok := entry["_launch"].(bool); !ok || !v {
t.Errorf("_launch marker not added during migration: got %v", entry["_launch"])
}
// [Ollama] suffix should be stripped
if name, ok := entry["name"].(string); !ok || name != "llama3.2" {
t.Errorf("name suffix not stripped: got %q", entry["name"])
}
})
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
// Add a non-Ollama model manually
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"external":{"name":"External Model"}}}}}`), 0o644)
o.Edit([]string{"llama3.2"})
assertOpenCodeModelExists(t, configPath, "llama3.2")
assertOpenCodeModelExists(t, configPath, "external") // Should be preserved
})
}
func assertOpenCodeModelExists(t *testing.T, path, model string) {
t.Helper()
data, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatal(err)
}
provider, ok := cfg["provider"].(map[string]any)
if !ok {
t.Fatal("provider not found")
}
ollama, ok := provider["ollama"].(map[string]any)
if !ok {
t.Fatal("ollama provider not found")
}
models, ok := ollama["models"].(map[string]any)
if !ok {
t.Fatal("models not found")
}
if models[model] == nil {
t.Errorf("model %s not found", model)
}
}
func assertOpenCodeModelNotExists(t *testing.T, path, model string) {
t.Helper()
data, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatal(err)
}
provider, ok := cfg["provider"].(map[string]any)
if !ok {
return // No provider means no model
}
ollama, ok := provider["ollama"].(map[string]any)
if !ok {
return // No ollama means no model
}
models, ok := ollama["models"].(map[string]any)
if !ok {
return // No models means no model
}
if models[model] != nil {
t.Errorf("model %s should not exist but was found", model)
}
}
func assertOpenCodeRecentModel(t *testing.T, path string, index int, providerID, modelID string) {
t.Helper()
data, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
var state map[string]any
if err := json.Unmarshal(data, &state); err != nil {
t.Fatal(err)
}
recent, ok := state["recent"].([]any)
if !ok {
t.Fatal("recent not found")
}
if index >= len(recent) {
t.Fatalf("index %d out of range (len=%d)", index, len(recent))
}
entry, ok := recent[index].(map[string]any)
if !ok {
t.Fatal("entry is not a map")
}
if entry["providerID"] != providerID {
t.Errorf("expected providerID %s, got %s", providerID, entry["providerID"])
}
if entry["modelID"] != modelID {
t.Errorf("expected modelID %s, got %s", modelID, entry["modelID"])
}
}
// Edge case tests for opencode.go
func TestOpenCodeEdit_CorruptedConfigJSON(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{corrupted json content`), 0o644)
// Should not panic - corrupted JSON should be treated as empty
err := o.Edit([]string{"llama3.2"})
if err != nil {
t.Fatalf("Edit failed with corrupted config: %v", err)
}
// Verify valid JSON was created
data, _ := os.ReadFile(configPath)
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Errorf("resulting config is not valid JSON: %v", err)
}
}
func TestOpenCodeEdit_CorruptedStateJSON(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
statePath := filepath.Join(stateDir, "model.json")
os.MkdirAll(stateDir, 0o755)
os.WriteFile(statePath, []byte(`{corrupted state`), 0o644)
err := o.Edit([]string{"llama3.2"})
if err != nil {
t.Fatalf("Edit failed with corrupted state: %v", err)
}
// Verify valid state was created
data, _ := os.ReadFile(statePath)
var state map[string]any
if err := json.Unmarshal(data, &state); err != nil {
t.Errorf("resulting state is not valid JSON: %v", err)
}
}
func TestOpenCodeEdit_WrongTypeProvider(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"provider": "not a map"}`), 0o644)
err := o.Edit([]string{"llama3.2"})
if err != nil {
t.Fatalf("Edit with wrong type provider failed: %v", err)
}
// Verify provider is now correct type
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider, ok := cfg["provider"].(map[string]any)
if !ok {
t.Fatalf("provider should be map after setup, got %T", cfg["provider"])
}
if provider["ollama"] == nil {
t.Error("ollama provider should be created")
}
}
func TestOpenCodeEdit_WrongTypeRecent(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
statePath := filepath.Join(stateDir, "model.json")
os.MkdirAll(stateDir, 0o755)
os.WriteFile(statePath, []byte(`{"recent": "not an array", "favorite": [], "variant": {}}`), 0o644)
err := o.Edit([]string{"llama3.2"})
if err != nil {
t.Fatalf("Edit with wrong type recent failed: %v", err)
}
// The function should handle this gracefully
data, _ := os.ReadFile(statePath)
var state map[string]any
json.Unmarshal(data, &state)
// recent should be properly set after setup
recent, ok := state["recent"].([]any)
if !ok {
t.Logf("Note: recent type after setup is %T (documenting behavior)", state["recent"])
} else if len(recent) == 0 {
t.Logf("Note: recent is empty (documenting behavior)")
}
}
func TestOpenCodeEdit_EmptyModels(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
os.MkdirAll(configDir, 0o755)
originalContent := `{"provider":{"ollama":{"models":{"existing":{}}}}}`
os.WriteFile(configPath, []byte(originalContent), 0o644)
// Empty models should be no-op
err := o.Edit([]string{})
if err != nil {
t.Fatalf("Edit with empty models failed: %v", err)
}
// Original content should be preserved (file not modified)
data, _ := os.ReadFile(configPath)
if string(data) != originalContent {
t.Errorf("empty models should not modify file, but content changed")
}
}
func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Model name with special characters (though unusual)
specialModel := `model-with-"quotes"`
err := o.Edit([]string{specialModel})
if err != nil {
t.Fatalf("Edit with special chars failed: %v", err)
}
// Verify it was stored correctly
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
data, _ := os.ReadFile(configPath)
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatalf("resulting config is invalid JSON: %v", err)
}
// Model should be accessible
provider, _ := cfg["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
models, _ := ollama["models"].(map[string]any)
if models[specialModel] == nil {
t.Errorf("model with special chars not found in config")
}
}
func readOpenCodeModel(t *testing.T, configPath, model string) map[string]any {
t.Helper()
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatal(err)
}
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider := cfg["provider"].(map[string]any)
ollama := provider["ollama"].(map[string]any)
models := ollama["models"].(map[string]any)
entry, ok := models[model].(map[string]any)
if !ok {
t.Fatalf("model %s not found in config", model)
}
return entry
}
func TestOpenCodeEdit_LocalModelNoLimit(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configPath := filepath.Join(tmpDir, ".config", "opencode", "opencode.json")
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
entry := readOpenCodeModel(t, configPath, "llama3.2")
if entry["limit"] != nil {
t.Errorf("local model should not have limit set, got %v", entry["limit"])
}
}
func TestOpenCodeEdit_PreservesUserLimit(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
// Set up a model with a user-configured limit
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{
"provider": {
"ollama": {
"models": {
"llama3.2": {
"name": "llama3.2",
"_launch": true,
"limit": {"context": 8192, "output": 4096}
}
}
}
}
}`), 0o644)
// Re-edit should preserve the user's limit (not delete it)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
entry := readOpenCodeModel(t, configPath, "llama3.2")
limit, ok := entry["limit"].(map[string]any)
if !ok {
t.Fatal("user-configured limit was removed")
}
if limit["context"] != float64(8192) {
t.Errorf("context limit changed: got %v, want 8192", limit["context"])
}
if limit["output"] != float64(4096) {
t.Errorf("output limit changed: got %v, want 4096", limit["output"])
}
}
func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
// Verify that when a cloud model entry has limits set (as Edit would do),
// the structure matches what opencode expects and re-edit preserves them.
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".config", "opencode")
configPath := filepath.Join(configDir, "opencode.json")
expected := cloudModelLimits["glm-4.7"]
// Simulate a cloud model that already has the limit set by a previous Edit
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(fmt.Sprintf(`{
"provider": {
"ollama": {
"models": {
"glm-4.7:cloud": {
"name": "glm-4.7:cloud",
"_launch": true,
"limit": {"context": %d, "output": %d}
}
}
}
}
}`, expected.Context, expected.Output)), 0o644)
// Re-edit should preserve the cloud model limit
if err := o.Edit([]string{"glm-4.7:cloud"}); err != nil {
t.Fatal(err)
}
entry := readOpenCodeModel(t, configPath, "glm-4.7:cloud")
limit, ok := entry["limit"].(map[string]any)
if !ok {
t.Fatal("cloud model limit was removed on re-edit")
}
if limit["context"] != float64(expected.Context) {
t.Errorf("context = %v, want %d", limit["context"], expected.Context)
}
if limit["output"] != float64(expected.Output) {
t.Errorf("output = %v, want %d", limit["output"], expected.Output)
}
}
func TestLookupCloudModelLimit(t *testing.T) {
tests := []struct {
name string
wantOK bool
wantContext int
wantOutput int
}{
{"glm-4.7", true, 202_752, 131_072},
{"glm-4.7:cloud", true, 202_752, 131_072},
{"kimi-k2.5", true, 262_144, 262_144},
{"kimi-k2.5:cloud", true, 262_144, 262_144},
{"deepseek-v3.2", true, 163_840, 65_536},
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
{"qwen3-coder:480b", true, 262_144, 65_536},
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
{"llama3.2", false, 0, 0},
{"unknown-model:cloud", false, 0, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l, ok := lookupCloudModelLimit(tt.name)
if ok != tt.wantOK {
t.Errorf("lookupCloudModelLimit(%q) ok = %v, want %v", tt.name, ok, tt.wantOK)
}
if ok {
if l.Context != tt.wantContext {
t.Errorf("context = %d, want %d", l.Context, tt.wantContext)
}
if l.Output != tt.wantOutput {
t.Errorf("output = %d, want %d", l.Output, tt.wantOutput)
}
}
})
}
}
func TestOpenCodeModels_NoConfig(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
models := o.Models()
if len(models) > 0 {
t.Errorf("expected nil/empty for missing config, got %v", models)
}
}

237
cmd/config/pi.go Normal file
View File

@@ -0,0 +1,237 @@
package config
import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"os/exec"
"path/filepath"
"slices"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
// Pi implements Runner and Editor for Pi (Pi Coding Agent) integration
type Pi struct{}
func (p *Pi) String() string { return "Pi" }
func (p *Pi) Run(model string, args []string) error {
if _, err := exec.LookPath("pi"); err != nil {
return fmt.Errorf("pi is not installed, install with: npm install -g @mariozechner/pi-coding-agent")
}
// Call Edit() to ensure config is up-to-date before launch
models := []string{model}
if config, err := loadIntegration("pi"); err == nil && len(config.Models) > 0 {
models = config.Models
}
if err := p.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("pi", args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func (p *Pi) Paths() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
var paths []string
modelsPath := filepath.Join(home, ".pi", "agent", "models.json")
if _, err := os.Stat(modelsPath); err == nil {
paths = append(paths, modelsPath)
}
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
if _, err := os.Stat(settingsPath); err == nil {
paths = append(paths, settingsPath)
}
return paths
}
func (p *Pi) Edit(models []string) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".pi", "agent", "models.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config)
}
providers, ok := config["providers"].(map[string]any)
if !ok {
providers = make(map[string]any)
}
ollama, ok := providers["ollama"].(map[string]any)
if !ok {
ollama = map[string]any{
"baseUrl": envconfig.Host().String() + "/v1",
"api": "openai-completions",
"apiKey": "ollama",
}
}
existingModels, ok := ollama["models"].([]any)
if !ok {
existingModels = make([]any, 0)
}
// Build set of selected models to track which need to be added
selectedSet := make(map[string]bool, len(models))
for _, m := range models {
selectedSet[m] = true
}
// Build new models list:
// 1. Keep user-managed models (no _launch marker) - untouched
// 2. Keep ollama-managed models (_launch marker) that are still selected
// 3. Add new ollama-managed models
var newModels []any
for _, m := range existingModels {
if modelObj, ok := m.(map[string]any); ok {
if id, ok := modelObj["id"].(string); ok {
// User-managed model (no _launch marker) - always preserve
if !isPiOllamaModel(modelObj) {
newModels = append(newModels, m)
} else if selectedSet[id] {
// Ollama-managed and still selected - keep it
newModels = append(newModels, m)
selectedSet[id] = false
}
}
}
}
// Add newly selected models that weren't already in the list
client := api.NewClient(envconfig.Host(), http.DefaultClient)
ctx := context.Background()
for _, model := range models {
if selectedSet[model] {
newModels = append(newModels, createConfig(ctx, client, model))
}
}
ollama["models"] = newModels
providers["ollama"] = ollama
config["providers"] = providers
configData, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
if err := writeWithBackup(configPath, configData); err != nil {
return err
}
// Update settings.json with default provider and model
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
settings := make(map[string]any)
if data, err := os.ReadFile(settingsPath); err == nil {
_ = json.Unmarshal(data, &settings)
}
settings["defaultProvider"] = "ollama"
settings["defaultModel"] = models[0]
settingsData, err := json.MarshalIndent(settings, "", " ")
if err != nil {
return err
}
return writeWithBackup(settingsPath, settingsData)
}
func (p *Pi) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
configPath := filepath.Join(home, ".pi", "agent", "models.json")
config, err := readJSONFile(configPath)
if err != nil {
return nil
}
providers, _ := config["providers"].(map[string]any)
ollama, _ := providers["ollama"].(map[string]any)
models, _ := ollama["models"].([]any)
var result []string
for _, m := range models {
if modelObj, ok := m.(map[string]any); ok {
if id, ok := modelObj["id"].(string); ok {
result = append(result, id)
}
}
}
slices.Sort(result)
return result
}
// isPiOllamaModel reports whether a model config entry is managed by ollama launch
func isPiOllamaModel(cfg map[string]any) bool {
if v, ok := cfg["_launch"].(bool); ok && v {
return true
}
return false
}
// createConfig builds Pi model config with capability detection
func createConfig(ctx context.Context, client *api.Client, modelID string) map[string]any {
cfg := map[string]any{
"id": modelID,
"_launch": true,
}
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID})
if err != nil {
return cfg
}
// Set input types based on vision capability
if slices.Contains(resp.Capabilities, model.CapabilityVision) {
cfg["input"] = []string{"text", "image"}
} else {
cfg["input"] = []string{"text"}
}
// Set reasoning based on thinking capability
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
cfg["reasoning"] = true
}
// Extract context window from ModelInfo
for key, val := range resp.ModelInfo {
if strings.HasSuffix(key, ".context_length") {
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
cfg["contextWindow"] = int(ctxLen)
}
break
}
}
return cfg
}

View File

@@ -1,4 +1,4 @@
package launch package config
import ( import (
"context" "context"
@@ -9,8 +9,6 @@ import (
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"strings"
"testing" "testing"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
@@ -35,341 +33,6 @@ func TestPiIntegration(t *testing.T) {
}) })
} }
func TestPiRun_InstallAndWebSearchLifecycle(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell test binaries")
}
writeScript := func(t *testing.T, path, content string) {
t.Helper()
if err := os.WriteFile(path, []byte(content), 0o755); err != nil {
t.Fatal(err)
}
}
seedPiScript := func(t *testing.T, dir string) {
t.Helper()
piPath := filepath.Join(dir, "pi")
listPath := filepath.Join(dir, "pi-list.txt")
piScript := fmt.Sprintf(`#!/bin/sh
echo "$@" >> %q
if [ "$1" = "list" ]; then
if [ -f %q ]; then
/bin/cat %q
fi
exit 0
fi
if [ "$1" = "update" ] && [ "$PI_FAIL_UPDATE" = "1" ]; then
echo "update failed" >&2
exit 1
fi
if [ "$1" = "install" ] && [ "$PI_FAIL_INSTALL" = "1" ]; then
echo "install failed" >&2
exit 1
fi
exit 0
`, filepath.Join(dir, "pi.log"), listPath, listPath)
writeScript(t, piPath, piScript)
}
seedNpmNoop := func(t *testing.T, dir string) {
t.Helper()
writeScript(t, filepath.Join(dir, "npm"), "#!/bin/sh\nexit 0\n")
}
withConfirm := func(t *testing.T, fn func(prompt string) (bool, error)) {
t.Helper()
oldConfirm := DefaultConfirmPrompt
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
return fn(prompt)
}
t.Cleanup(func() { DefaultConfirmPrompt = oldConfirm })
}
setCloudStatus := func(t *testing.T, disabled bool) {
t.Helper()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/status" {
fmt.Fprintf(w, `{"cloud":{"disabled":%t,"source":"config"}}`, disabled)
return
}
http.NotFound(w, r)
}))
t.Cleanup(srv.Close)
t.Setenv("OLLAMA_HOST", srv.URL)
}
t.Run("pi missing + user accepts install", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
setCloudStatus(t, false)
if err := os.WriteFile(filepath.Join(tmpDir, "pi-list.txt"), []byte("User packages:\n npm:@ollama/pi-web-search\n"), 0o644); err != nil {
t.Fatal(err)
}
npmScript := fmt.Sprintf(`#!/bin/sh
echo "$@" >> %q
if [ "$1" = "install" ] && [ "$2" = "-g" ] && [ "$3" = %q ]; then
/bin/cat > %q <<'EOS'
#!/bin/sh
echo "$@" >> %q
if [ "$1" = "list" ]; then
if [ -f %q ]; then
/bin/cat %q
fi
exit 0
fi
exit 0
EOS
/bin/chmod +x %q
fi
exit 0
`, filepath.Join(tmpDir, "npm.log"), piNpmPackage+"@latest", filepath.Join(tmpDir, "pi"), filepath.Join(tmpDir, "pi.log"), filepath.Join(tmpDir, "pi-list.txt"), filepath.Join(tmpDir, "pi-list.txt"), filepath.Join(tmpDir, "pi"))
writeScript(t, filepath.Join(tmpDir, "npm"), npmScript)
withConfirm(t, func(prompt string) (bool, error) {
if strings.Contains(prompt, "Pi is not installed.") {
return true, nil
}
return true, nil
})
p := &Pi{}
if err := p.Run("ignored", []string{"--version"}); err != nil {
t.Fatalf("Run() error = %v", err)
}
npmCalls, err := os.ReadFile(filepath.Join(tmpDir, "npm.log"))
if err != nil {
t.Fatal(err)
}
if !strings.Contains(string(npmCalls), "install -g "+piNpmPackage+"@latest") {
t.Fatalf("expected npm install call, got:\n%s", npmCalls)
}
piCalls, err := os.ReadFile(filepath.Join(tmpDir, "pi.log"))
if err != nil {
t.Fatal(err)
}
got := string(piCalls)
if !strings.Contains(got, "list\n") {
t.Fatalf("expected pi list call, got:\n%s", got)
}
if !strings.Contains(got, "update "+piWebSearchSource+"\n") {
t.Fatalf("expected pi update call, got:\n%s", got)
}
if !strings.Contains(got, "--version\n") {
t.Fatalf("expected final pi launch call, got:\n%s", got)
}
})
t.Run("pi missing + user declines install", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
setCloudStatus(t, false)
writeScript(t, filepath.Join(tmpDir, "npm"), "#!/bin/sh\nexit 0\n")
withConfirm(t, func(prompt string) (bool, error) {
if strings.Contains(prompt, "Pi is not installed.") {
return false, nil
}
return true, nil
})
p := &Pi{}
err := p.Run("ignored", nil)
if err == nil || !strings.Contains(err.Error(), "pi installation cancelled") {
t.Fatalf("expected install cancellation error, got %v", err)
}
})
t.Run("pi installed + web search missing auto-installs", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
setCloudStatus(t, false)
if err := os.WriteFile(filepath.Join(tmpDir, "pi-list.txt"), []byte("User packages:\n"), 0o644); err != nil {
t.Fatal(err)
}
seedPiScript(t, tmpDir)
seedNpmNoop(t, tmpDir)
withConfirm(t, func(prompt string) (bool, error) {
t.Fatalf("did not expect confirmation prompt, got %q", prompt)
return false, nil
})
p := &Pi{}
if err := p.Run("ignored", []string{"session"}); err != nil {
t.Fatalf("Run() error = %v", err)
}
piCalls, err := os.ReadFile(filepath.Join(tmpDir, "pi.log"))
if err != nil {
t.Fatal(err)
}
got := string(piCalls)
if !strings.Contains(got, "list\n") {
t.Fatalf("expected pi list call, got:\n%s", got)
}
if !strings.Contains(got, "install "+piWebSearchSource+"\n") {
t.Fatalf("expected pi install call, got:\n%s", got)
}
if strings.Contains(got, "update "+piWebSearchSource+"\n") {
t.Fatalf("did not expect pi update call when package missing, got:\n%s", got)
}
if !strings.Contains(got, "session\n") {
t.Fatalf("expected final pi launch call, got:\n%s", got)
}
})
t.Run("pi installed + web search present updates every launch", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
setCloudStatus(t, false)
if err := os.WriteFile(filepath.Join(tmpDir, "pi-list.txt"), []byte("User packages:\n "+piWebSearchSource+"\n"), 0o644); err != nil {
t.Fatal(err)
}
seedPiScript(t, tmpDir)
seedNpmNoop(t, tmpDir)
p := &Pi{}
if err := p.Run("ignored", []string{"doctor"}); err != nil {
t.Fatalf("Run() error = %v", err)
}
piCalls, err := os.ReadFile(filepath.Join(tmpDir, "pi.log"))
if err != nil {
t.Fatal(err)
}
got := string(piCalls)
if !strings.Contains(got, "update "+piWebSearchSource+"\n") {
t.Fatalf("expected pi update call, got:\n%s", got)
}
})
t.Run("web search update failure warns and continues", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
setCloudStatus(t, false)
t.Setenv("PI_FAIL_UPDATE", "1")
if err := os.WriteFile(filepath.Join(tmpDir, "pi-list.txt"), []byte("User packages:\n "+piWebSearchSource+"\n"), 0o644); err != nil {
t.Fatal(err)
}
seedPiScript(t, tmpDir)
seedNpmNoop(t, tmpDir)
p := &Pi{}
stderr := captureStderr(t, func() {
if err := p.Run("ignored", []string{"session"}); err != nil {
t.Fatalf("Run() should continue after web search update failure, got %v", err)
}
})
if !strings.Contains(stderr, "Warning: could not update "+piWebSearchPkg) {
t.Fatalf("expected update warning, got:\n%s", stderr)
}
piCalls, err := os.ReadFile(filepath.Join(tmpDir, "pi.log"))
if err != nil {
t.Fatal(err)
}
if !strings.Contains(string(piCalls), "session\n") {
t.Fatalf("expected final pi launch call, got:\n%s", piCalls)
}
})
t.Run("web search install failure warns and continues", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
setCloudStatus(t, false)
t.Setenv("PI_FAIL_INSTALL", "1")
if err := os.WriteFile(filepath.Join(tmpDir, "pi-list.txt"), []byte("User packages:\n"), 0o644); err != nil {
t.Fatal(err)
}
seedPiScript(t, tmpDir)
seedNpmNoop(t, tmpDir)
withConfirm(t, func(prompt string) (bool, error) {
t.Fatalf("did not expect confirmation prompt, got %q", prompt)
return false, nil
})
p := &Pi{}
stderr := captureStderr(t, func() {
if err := p.Run("ignored", []string{"session"}); err != nil {
t.Fatalf("Run() should continue after web search install failure, got %v", err)
}
})
if !strings.Contains(stderr, "Warning: could not install "+piWebSearchPkg) {
t.Fatalf("expected install warning, got:\n%s", stderr)
}
piCalls, err := os.ReadFile(filepath.Join(tmpDir, "pi.log"))
if err != nil {
t.Fatal(err)
}
if !strings.Contains(string(piCalls), "session\n") {
t.Fatalf("expected final pi launch call, got:\n%s", piCalls)
}
})
t.Run("cloud disabled skips web search package management", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
setCloudStatus(t, true)
if err := os.WriteFile(filepath.Join(tmpDir, "pi-list.txt"), []byte("User packages:\n"), 0o644); err != nil {
t.Fatal(err)
}
seedPiScript(t, tmpDir)
seedNpmNoop(t, tmpDir)
p := &Pi{}
stderr := captureStderr(t, func() {
if err := p.Run("ignored", []string{"session"}); err != nil {
t.Fatalf("Run() error = %v", err)
}
})
if !strings.Contains(stderr, "Cloud is disabled; skipping "+piWebSearchPkg+" setup.") {
t.Fatalf("expected cloud-disabled skip message, got:\n%s", stderr)
}
piCalls, err := os.ReadFile(filepath.Join(tmpDir, "pi.log"))
if err != nil {
t.Fatal(err)
}
got := string(piCalls)
if strings.Contains(got, "list\n") || strings.Contains(got, "install "+piWebSearchSource+"\n") || strings.Contains(got, "update "+piWebSearchSource+"\n") {
t.Fatalf("did not expect web search package management calls, got:\n%s", got)
}
if !strings.Contains(got, "session\n") {
t.Fatalf("expected final pi launch call, got:\n%s", got)
}
})
t.Run("missing npm returns error before pi flow", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
setCloudStatus(t, false)
seedPiScript(t, tmpDir)
p := &Pi{}
err := p.Run("ignored", []string{"session"})
if err == nil || !strings.Contains(err.Error(), "npm (Node.js) is required to launch pi") {
t.Fatalf("expected missing npm error, got %v", err)
}
if _, statErr := os.Stat(filepath.Join(tmpDir, "pi.log")); !os.IsNotExist(statErr) {
t.Fatalf("expected pi not to run when npm is missing, stat err = %v", statErr)
}
})
}
func TestPiPaths(t *testing.T) { func TestPiPaths(t *testing.T) {
pi := &Pi{} pi := &Pi{}
@@ -529,48 +192,6 @@ func TestPiEdit(t *testing.T) {
} }
}) })
t.Run("rebuilds stale existing managed cloud model", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)
existingConfig := `{
"providers": {
"ollama": {
"baseUrl": "http://localhost:11434/v1",
"api": "openai-completions",
"apiKey": "ollama",
"models": [
{"id": "glm-5:cloud", "_launch": true, "legacyField": "stale"}
]
}
}
}`
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
t.Fatal(err)
}
if err := pi.Edit([]string{"glm-5:cloud"}); err != nil {
t.Fatalf("Edit() error = %v", err)
}
cfg := readConfig()
providers := cfg["providers"].(map[string]any)
ollama := providers["ollama"].(map[string]any)
modelsArray := ollama["models"].([]any)
modelEntry := modelsArray[0].(map[string]any)
if modelEntry["contextWindow"] != float64(202_752) {
t.Errorf("contextWindow = %v, want 202752", modelEntry["contextWindow"])
}
input, ok := modelEntry["input"].([]any)
if !ok || len(input) != 1 || input[0] != "text" {
t.Errorf("input = %v, want [text]", modelEntry["input"])
}
if _, ok := modelEntry["legacyField"]; ok {
t.Error("legacyField should be removed when stale managed cloud entry is rebuilt")
}
})
t.Run("replaces old models with new ones", func(t *testing.T) { t.Run("replaces old models with new ones", func(t *testing.T) {
cleanup() cleanup()
os.MkdirAll(configDir, 0o755) os.MkdirAll(configDir, 0o755)
@@ -1177,59 +798,6 @@ func TestCreateConfig(t *testing.T) {
} }
}) })
t.Run("cloud model falls back to hardcoded context when show fails", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"model not found"}`)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg := createConfig(context.Background(), client, "kimi-k2.5:cloud")
if cfg["contextWindow"] != 262_144 {
t.Errorf("contextWindow = %v, want 262144", cfg["contextWindow"])
}
})
t.Run("cloud model falls back to hardcoded context when show omits model info", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":[],"model_info":{}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg := createConfig(context.Background(), client, "glm-5:cloud")
if cfg["contextWindow"] != 202_752 {
t.Errorf("contextWindow = %v, want 202752", cfg["contextWindow"])
}
})
t.Run("cloud model with dash suffix falls back to hardcoded context", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"model not found"}`)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg := createConfig(context.Background(), client, "gpt-oss:120b-cloud")
if cfg["contextWindow"] != 131_072 {
t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
}
})
t.Run("skips zero context length", func(t *testing.T) { t.Run("skips zero context length", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" { if r.URL.Path == "/api/show" {

59
cmd/config/selector.go Normal file
View File

@@ -0,0 +1,59 @@
package config
import (
"errors"
"fmt"
"os"
"golang.org/x/term"
)
// ANSI escape sequences for terminal formatting.
const (
ansiBold = "\033[1m"
ansiReset = "\033[0m"
ansiGray = "\033[37m"
ansiGreen = "\033[32m"
ansiYellow = "\033[33m"
)
// ErrCancelled is returned when the user cancels a selection.
var ErrCancelled = errors.New("cancelled")
// errCancelled is kept as an alias for backward compatibility within the package.
var errCancelled = ErrCancelled
// DefaultConfirmPrompt provides a TUI-based confirmation prompt.
// When set, confirmPrompt delegates to it instead of using raw terminal I/O.
var DefaultConfirmPrompt func(prompt string) (bool, error)
func confirmPrompt(prompt string) (bool, error) {
if DefaultConfirmPrompt != nil {
return DefaultConfirmPrompt(prompt)
}
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
return false, err
}
defer term.Restore(fd, oldState)
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
buf := make([]byte, 1)
for {
if _, err := os.Stdin.Read(buf); err != nil {
return false, err
}
switch buf[0] {
case 'Y', 'y', 13:
fmt.Fprintf(os.Stderr, "yes\r\n")
return true, nil
case 'N', 'n', 27, 3:
fmt.Fprintf(os.Stderr, "no\r\n")
return false, nil
}
}
}

View File

@@ -0,0 +1,19 @@
package config
import (
"testing"
)
func TestErrCancelled(t *testing.T) {
t.Run("NotNil", func(t *testing.T) {
if errCancelled == nil {
t.Error("errCancelled should not be nil")
}
})
t.Run("Message", func(t *testing.T) {
if errCancelled.Error() != "cancelled" {
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
}
})
}

View File

@@ -17,7 +17,6 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/readline" "github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
@@ -47,7 +46,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.") fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
if opts.MultiModal { if opts.MultiModal {
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, .webp images, or .wav audio files.\n", filepath.FromSlash("/path/to/file")) fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, or .webp images.\n", filepath.FromSlash("/path/to/file"))
} }
fmt.Fprintln(os.Stderr, "") fmt.Fprintln(os.Stderr, "")
@@ -214,17 +213,10 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
} }
origOpts := opts.Copy() origOpts := opts.Copy()
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
return err
}
opts.Model = args[1] opts.Model = args[1]
opts.Messages = []api.Message{} opts.Messages = []api.Message{}
opts.LoadedMessages = nil
fmt.Printf("Loading model '%s'\n", opts.Model) fmt.Printf("Loading model '%s'\n", opts.Model)
info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model}) opts.Think, err = inferThinkingOption(nil, &opts, thinkExplicitlySet)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "not found") { if strings.Contains(err.Error(), "not found") {
fmt.Printf("Couldn't find model '%s'\n", opts.Model) fmt.Printf("Couldn't find model '%s'\n", opts.Model)
@@ -233,11 +225,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
} }
return err return err
} }
applyShowResponseToRunOptions(&opts, info)
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkExplicitlySet)
if err != nil {
return err
}
if err := loadOrUnloadModel(cmd, &opts); err != nil { if err := loadOrUnloadModel(cmd, &opts); err != nil {
if strings.Contains(err.Error(), "not found") { if strings.Contains(err.Error(), "not found") {
fmt.Printf("Couldn't find model '%s'\n", opts.Model) fmt.Printf("Couldn't find model '%s'\n", opts.Model)
@@ -553,13 +540,6 @@ func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
parentModel = "" parentModel = ""
} }
// Preserve explicit cloud intent for sessions started with `:cloud`.
// Cloud model metadata can return a source-less parent_model (for example
// "qwen3.5"), which would otherwise make `/save` create a local derivative.
if modelref.HasExplicitCloudSource(opts.Model) && !modelref.HasExplicitCloudSource(parentModel) {
parentModel = ""
}
req := &api.CreateRequest{ req := &api.CreateRequest{
Model: name, Model: name,
From: cmp.Or(parentModel, opts.Model), From: cmp.Or(parentModel, opts.Model),
@@ -573,10 +553,8 @@ func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
req.Parameters = opts.Options req.Parameters = opts.Options
} }
messages := slices.Clone(opts.LoadedMessages) if len(opts.Messages) > 0 {
messages = append(messages, opts.Messages...) req.Messages = opts.Messages
if len(messages) > 0 {
req.Messages = messages
} }
return req return req
@@ -606,7 +584,7 @@ func extractFileNames(input string) []string {
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20) // Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
// and followed by more characters and a file extension // and followed by more characters and a file extension
// This will capture non filename strings, but we'll check for file existence to remove mismatches // This will capture non filename strings, but we'll check for file existence to remove mismatches
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp|wav)\b` regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp)\b`
re := regexp.MustCompile(regexPattern) re := regexp.MustCompile(regexPattern)
return re.FindAllString(input, -1) return re.FindAllString(input, -1)
@@ -622,16 +600,10 @@ func extractFileData(input string) (string, []api.ImageData, error) {
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
continue continue
} else if err != nil { } else if err != nil {
fmt.Fprintf(os.Stderr, "Couldn't process file: %q\n", err) fmt.Fprintf(os.Stderr, "Couldn't process image: %q\n", err)
return "", imgs, err return "", imgs, err
} }
ext := strings.ToLower(filepath.Ext(nfp)) fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
switch ext {
case ".wav":
fmt.Fprintf(os.Stderr, "Added audio '%s'\n", nfp)
default:
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
}
input = strings.ReplaceAll(input, "'"+nfp+"'", "") input = strings.ReplaceAll(input, "'"+nfp+"'", "")
input = strings.ReplaceAll(input, "'"+fp+"'", "") input = strings.ReplaceAll(input, "'"+fp+"'", "")
input = strings.ReplaceAll(input, fp, "") input = strings.ReplaceAll(input, fp, "")
@@ -705,9 +677,9 @@ func getImageData(filePath string) ([]byte, error) {
} }
contentType := http.DetectContentType(buf) contentType := http.DetectContentType(buf)
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp", "audio/wave"} allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp"}
if !slices.Contains(allowedTypes, contentType) { if !slices.Contains(allowedTypes, contentType) {
return nil, fmt.Errorf("invalid file type: %s", contentType) return nil, fmt.Errorf("invalid image type: %s", contentType)
} }
info, err := file.Stat() info, err := file.Stat()
@@ -715,7 +687,8 @@ func getImageData(filePath string) ([]byte, error) {
return nil, err return nil, err
} }
var maxSize int64 = 100 * 1024 * 1024 // 100MB // Check if the file size exceeds 100MB
var maxSize int64 = 100 * 1024 * 1024 // 100MB in bytes
if info.Size() > maxSize { if info.Size() > maxSize {
return nil, errors.New("file size exceeds maximum limit (100MB)") return nil, errors.New("file size exceeds maximum limit (100MB)")
} }

View File

@@ -84,33 +84,3 @@ func TestExtractFileDataRemovesQuotedFilepath(t *testing.T) {
assert.Len(t, imgs, 1) assert.Len(t, imgs, 1)
assert.Equal(t, cleaned, "before after") assert.Equal(t, cleaned, "before after")
} }
func TestExtractFileDataWAV(t *testing.T) {
dir := t.TempDir()
fp := filepath.Join(dir, "sample.wav")
data := make([]byte, 600)
copy(data[:44], []byte{
'R', 'I', 'F', 'F',
0x58, 0x02, 0x00, 0x00, // file size - 8
'W', 'A', 'V', 'E',
'f', 'm', 't', ' ',
0x10, 0x00, 0x00, 0x00, // fmt chunk size
0x01, 0x00, // PCM
0x01, 0x00, // mono
0x80, 0x3e, 0x00, 0x00, // 16000 Hz
0x00, 0x7d, 0x00, 0x00, // byte rate
0x02, 0x00, // block align
0x10, 0x00, // 16-bit
'd', 'a', 't', 'a',
0x34, 0x02, 0x00, 0x00, // data size
})
if err := os.WriteFile(fp, data, 0o600); err != nil {
t.Fatalf("failed to write test audio: %v", err)
}
input := "before " + fp + " after"
cleaned, imgs, err := extractFileData(input)
assert.NoError(t, err)
assert.Len(t, imgs, 1)
assert.Equal(t, "before after", cleaned)
}

View File

@@ -1,87 +0,0 @@
package launch
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"github.com/ollama/ollama/envconfig"
)
// Claude implements Runner for Claude Code integration.
type Claude struct{}
func (c *Claude) String() string { return "Claude Code" }
func (c *Claude) args(model string, extra []string) []string {
var args []string
if model != "" {
args = append(args, "--model", model)
}
args = append(args, extra...)
return args
}
func (c *Claude) findPath() (string, error) {
if p, err := exec.LookPath("claude"); err == nil {
return p, nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
name := "claude"
if runtime.GOOS == "windows" {
name = "claude.exe"
}
fallback := filepath.Join(home, ".claude", "local", name)
if _, err := os.Stat(fallback); err != nil {
return "", err
}
return fallback, nil
}
func (c *Claude) Run(model string, args []string) error {
claudePath, err := c.findPath()
if err != nil {
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
}
cmd := exec.Command(claudePath, c.args(model, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
env := append(os.Environ(),
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
"ANTHROPIC_API_KEY=",
"ANTHROPIC_AUTH_TOKEN=ollama",
"CLAUDE_CODE_ATTRIBUTION_HEADER=0",
)
env = append(env, c.modelEnvVars(model)...)
cmd.Env = env
return cmd.Run()
}
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
func (c *Claude) modelEnvVars(model string) []string {
env := []string{
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + model,
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + model,
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + model,
"CLAUDE_CODE_SUBAGENT_MODEL=" + model,
}
if isCloudModelName(model) {
if l, ok := lookupCloudModelLimit(model); ok {
env = append(env, "CLAUDE_CODE_AUTO_COMPACT_WINDOW="+strconv.Itoa(l.Context))
}
}
return env
}

View File

@@ -1,265 +0,0 @@
package launch
import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"os/exec"
"path/filepath"
"slices"
"strconv"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
"golang.org/x/mod/semver"
)
// Codex implements Runner for Codex integration
type Codex struct{}
func (c *Codex) String() string { return "Codex" }
const codexProfileName = "ollama-launch"
func (c *Codex) args(model string, extra []string) []string {
args := []string{"--profile", codexProfileName}
if model != "" {
args = append(args, "-m", model)
}
args = append(args, extra...)
return args
}
func (c *Codex) Run(model string, args []string) error {
if err := checkCodexVersion(); err != nil {
return err
}
if err := ensureCodexConfig(model); err != nil {
return fmt.Errorf("failed to configure codex: %w", err)
}
cmd := exec.Command("codex", c.args(model, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(),
"OPENAI_API_KEY=ollama",
)
return cmd.Run()
}
// ensureCodexConfig writes a Codex profile and model catalog so Codex uses the
// local Ollama server and has model metadata available.
func ensureCodexConfig(modelName string) error {
home, err := os.UserHomeDir()
if err != nil {
return err
}
codexDir := filepath.Join(home, ".codex")
if err := os.MkdirAll(codexDir, 0o755); err != nil {
return err
}
catalogPath := filepath.Join(codexDir, "model.json")
if err := writeCodexModelCatalog(catalogPath, modelName); err != nil {
return err
}
configPath := filepath.Join(codexDir, "config.toml")
return writeCodexProfile(configPath, catalogPath)
}
// writeCodexProfile ensures ~/.codex/config.toml has the ollama-launch profile
// and model provider sections with the correct base URL.
func writeCodexProfile(configPath, catalogPath string) error {
baseURL := envconfig.Host().String() + "/v1/"
sections := []struct {
header string
lines []string
}{
{
header: fmt.Sprintf("[profiles.%s]", codexProfileName),
lines: []string{
fmt.Sprintf("openai_base_url = %q", baseURL),
`forced_login_method = "api"`,
fmt.Sprintf("model_provider = %q", codexProfileName),
fmt.Sprintf("model_catalog_json = %q", catalogPath),
},
},
{
header: fmt.Sprintf("[model_providers.%s]", codexProfileName),
lines: []string{
`name = "Ollama"`,
fmt.Sprintf("base_url = %q", baseURL),
},
},
}
content, readErr := os.ReadFile(configPath)
text := ""
if readErr == nil {
text = string(content)
}
for _, s := range sections {
block := strings.Join(append([]string{s.header}, s.lines...), "\n") + "\n"
if idx := strings.Index(text, s.header); idx >= 0 {
// Replace the existing section up to the next section header.
rest := text[idx+len(s.header):]
if endIdx := strings.Index(rest, "\n["); endIdx >= 0 {
text = text[:idx] + block + rest[endIdx+1:]
} else {
text = text[:idx] + block
}
} else {
// Append the section.
if text != "" && !strings.HasSuffix(text, "\n") {
text += "\n"
}
if text != "" {
text += "\n"
}
text += block
}
}
return os.WriteFile(configPath, []byte(text), 0o644)
}
func writeCodexModelCatalog(catalogPath, modelName string) error {
entry := buildCodexModelEntry(modelName)
catalog := map[string]any{
"models": []any{entry},
}
data, err := json.MarshalIndent(catalog, "", " ")
if err != nil {
return err
}
return os.WriteFile(catalogPath, data, 0o644)
}
func buildCodexModelEntry(modelName string) map[string]any {
contextWindow := 0
hasVision := false
hasThinking := false
systemPrompt := ""
if l, ok := lookupCloudModelLimit(modelName); ok {
contextWindow = l.Context
}
client := api.NewClient(envconfig.Host(), http.DefaultClient)
resp, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName})
if err == nil {
systemPrompt = resp.System
if slices.Contains(resp.Capabilities, model.CapabilityVision) {
hasVision = true
}
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
hasThinking = true
}
if !isCloudModelName(modelName) {
if n, ok := modelInfoContextLength(resp.ModelInfo); ok {
contextWindow = n
}
if resp.Details.Format != "safetensors" {
if ctxLen := envconfig.ContextLength(); ctxLen > 0 {
contextWindow = int(ctxLen)
}
if numCtx := parseNumCtx(resp.Parameters); numCtx > 0 {
contextWindow = numCtx
}
}
}
}
modalities := []string{"text"}
if hasVision {
modalities = append(modalities, "image")
}
reasoningLevels := []any{}
if hasThinking {
reasoningLevels = []any{
map[string]any{"effort": "low", "description": "Fast responses with lighter reasoning"},
map[string]any{"effort": "medium", "description": "Balances speed and reasoning depth"},
map[string]any{"effort": "high", "description": "Greater reasoning depth for complex problems"},
}
}
truncationMode := "bytes"
if isCloudModelName(modelName) {
truncationMode = "tokens"
}
return map[string]any{
"slug": modelName,
"display_name": modelName,
"context_window": contextWindow,
"apply_patch_tool_type": "function",
"shell_type": "default",
"visibility": "list",
"supported_in_api": true,
"priority": 0,
"truncation_policy": map[string]any{"mode": truncationMode, "limit": 10000},
"input_modalities": modalities,
"base_instructions": systemPrompt,
"support_verbosity": true,
"default_verbosity": "low",
"supports_parallel_tool_calls": false,
"supports_reasoning_summaries": hasThinking,
"supported_reasoning_levels": reasoningLevels,
"experimental_supported_tools": []any{},
}
}
func parseNumCtx(parameters string) int {
for _, line := range strings.Split(parameters, "\n") {
fields := strings.Fields(line)
if len(fields) == 2 && fields[0] == "num_ctx" {
if v, err := strconv.ParseFloat(fields[1], 64); err == nil {
return int(v)
}
}
}
return 0
}
func checkCodexVersion() error {
if _, err := exec.LookPath("codex"); err != nil {
return fmt.Errorf("codex is not installed, install with: npm install -g @openai/codex")
}
out, err := exec.Command("codex", "--version").Output()
if err != nil {
return fmt.Errorf("failed to get codex version: %w", err)
}
// Parse output like "codex-cli 0.87.0"
fields := strings.Fields(strings.TrimSpace(string(out)))
if len(fields) < 2 {
return fmt.Errorf("unexpected codex version output: %s", string(out))
}
version := "v" + fields[len(fields)-1]
minVersion := "v0.81.0"
if semver.Compare(version, minVersion) < 0 {
return fmt.Errorf("codex version %s is too old, minimum required is %s, update with: npm update -g @openai/codex", fields[len(fields)-1], "0.81.0")
}
return nil
}

View File

@@ -1,452 +0,0 @@
package launch
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"slices"
"strings"
"testing"
)
func TestCodexArgs(t *testing.T) {
c := &Codex{}
tests := []struct {
name string
model string
args []string
want []string
}{
{"with model", "llama3.2", nil, []string{"--profile", "ollama-launch", "-m", "llama3.2"}},
{"empty model", "", nil, []string{"--profile", "ollama-launch"}},
{"with model and extra args", "qwen3.5", []string{"-p", "myprofile"}, []string{"--profile", "ollama-launch", "-m", "qwen3.5", "-p", "myprofile"}},
{"with sandbox flag", "llama3.2", []string{"--sandbox", "workspace-write"}, []string{"--profile", "ollama-launch", "-m", "llama3.2", "--sandbox", "workspace-write"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model, tt.args)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
}
})
}
}
func TestWriteCodexProfile(t *testing.T) {
t.Run("creates new file when none exists", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml")
catalogPath := filepath.Join(tmpDir, "model.json")
if err := writeCodexProfile(configPath, catalogPath); err != nil {
t.Fatal(err)
}
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatal(err)
}
content := string(data)
if !strings.Contains(content, "[profiles.ollama-launch]") {
t.Error("missing [profiles.ollama-launch] header")
}
if !strings.Contains(content, "openai_base_url") {
t.Error("missing openai_base_url key")
}
if !strings.Contains(content, "/v1/") {
t.Error("missing /v1/ suffix in base URL")
}
if !strings.Contains(content, `forced_login_method = "api"`) {
t.Error("missing forced_login_method key")
}
if !strings.Contains(content, `model_provider = "ollama-launch"`) {
t.Error("missing model_provider key")
}
if !strings.Contains(content, fmt.Sprintf("model_catalog_json = %q", catalogPath)) {
t.Error("missing model_catalog_json key")
}
if !strings.Contains(content, "[model_providers.ollama-launch]") {
t.Error("missing [model_providers.ollama-launch] section")
}
if !strings.Contains(content, `name = "Ollama"`) {
t.Error("missing model provider name")
}
})
t.Run("appends profile to existing file without profile", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml")
catalogPath := filepath.Join(tmpDir, "model.json")
existing := "[some_other_section]\nkey = \"value\"\n"
os.WriteFile(configPath, []byte(existing), 0o644)
if err := writeCodexProfile(configPath, catalogPath); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
content := string(data)
if !strings.Contains(content, "[some_other_section]") {
t.Error("existing section was removed")
}
if !strings.Contains(content, "[profiles.ollama-launch]") {
t.Error("missing [profiles.ollama-launch] header")
}
})
t.Run("replaces existing profile section", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml")
catalogPath := filepath.Join(tmpDir, "model.json")
existing := "[profiles.ollama-launch]\nopenai_base_url = \"http://old:1234/v1/\"\n\n[model_providers.ollama-launch]\nname = \"Ollama\"\nbase_url = \"http://old:1234/v1/\"\n"
os.WriteFile(configPath, []byte(existing), 0o644)
if err := writeCodexProfile(configPath, catalogPath); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
content := string(data)
if strings.Contains(content, "old:1234") {
t.Error("old URL was not replaced")
}
if strings.Count(content, "[profiles.ollama-launch]") != 1 {
t.Errorf("expected exactly one [profiles.ollama-launch] section, got %d", strings.Count(content, "[profiles.ollama-launch]"))
}
if strings.Count(content, "[model_providers.ollama-launch]") != 1 {
t.Errorf("expected exactly one [model_providers.ollama-launch] section, got %d", strings.Count(content, "[model_providers.ollama-launch]"))
}
})
t.Run("replaces profile while preserving following sections", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml")
catalogPath := filepath.Join(tmpDir, "model.json")
existing := "[profiles.ollama-launch]\nopenai_base_url = \"http://old:1234/v1/\"\n[another_section]\nfoo = \"bar\"\n"
os.WriteFile(configPath, []byte(existing), 0o644)
if err := writeCodexProfile(configPath, catalogPath); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
content := string(data)
if strings.Contains(content, "old:1234") {
t.Error("old URL was not replaced")
}
if !strings.Contains(content, "[another_section]") {
t.Error("following section was removed")
}
if !strings.Contains(content, "foo = \"bar\"") {
t.Error("following section content was removed")
}
})
t.Run("appends newline to file not ending with newline", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml")
catalogPath := filepath.Join(tmpDir, "model.json")
existing := "[other]\nkey = \"val\""
os.WriteFile(configPath, []byte(existing), 0o644)
if err := writeCodexProfile(configPath, catalogPath); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
content := string(data)
if !strings.Contains(content, "[profiles.ollama-launch]") {
t.Error("missing [profiles.ollama-launch] header")
}
// Should not have double blank lines from missing trailing newline
if strings.Contains(content, "\n\n\n") {
t.Error("unexpected triple newline in output")
}
})
t.Run("uses custom OLLAMA_HOST", func(t *testing.T) {
t.Setenv("OLLAMA_HOST", "http://myhost:9999")
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml")
catalogPath := filepath.Join(tmpDir, "model.json")
if err := writeCodexProfile(configPath, catalogPath); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
content := string(data)
if !strings.Contains(content, "myhost:9999/v1/") {
t.Errorf("expected custom host in URL, got:\n%s", content)
}
})
}
func TestEnsureCodexConfig(t *testing.T) {
t.Run("creates .codex dir and config.toml", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if err := ensureCodexConfig("llama3.2"); err != nil {
t.Fatal(err)
}
configPath := filepath.Join(tmpDir, ".codex", "config.toml")
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatalf("config.toml not created: %v", err)
}
content := string(data)
if !strings.Contains(content, "[profiles.ollama-launch]") {
t.Error("missing [profiles.ollama-launch] header")
}
if !strings.Contains(content, "openai_base_url") {
t.Error("missing openai_base_url key")
}
catalogPath := filepath.Join(tmpDir, ".codex", "model.json")
data, err = os.ReadFile(catalogPath)
if err != nil {
t.Fatalf("model.json not created: %v", err)
}
if !strings.Contains(string(data), `"slug": "llama3.2"`) {
t.Error("missing model catalog entry for selected model")
}
})
t.Run("is idempotent", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if err := ensureCodexConfig("llama3.2"); err != nil {
t.Fatal(err)
}
if err := ensureCodexConfig("llama3.2"); err != nil {
t.Fatal(err)
}
configPath := filepath.Join(tmpDir, ".codex", "config.toml")
data, _ := os.ReadFile(configPath)
content := string(data)
if strings.Count(content, "[profiles.ollama-launch]") != 1 {
t.Errorf("expected exactly one [profiles.ollama-launch] section after two calls, got %d", strings.Count(content, "[profiles.ollama-launch]"))
}
if strings.Count(content, "[model_providers.ollama-launch]") != 1 {
t.Errorf("expected exactly one [model_providers.ollama-launch] section after two calls, got %d", strings.Count(content, "[model_providers.ollama-launch]"))
}
})
}
func TestParseNumCtx(t *testing.T) {
tests := []struct {
name string
parameters string
want int
}{
{"num_ctx set", "num_ctx 8192", 8192},
{"num_ctx with other params", "temperature 0.7\nnum_ctx 4096\ntop_p 0.9", 4096},
{"no num_ctx", "temperature 0.7\ntop_p 0.9", 0},
{"empty string", "", 0},
{"malformed value", "num_ctx abc", 0},
{"float value", "num_ctx 8192.0", 8192},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := parseNumCtx(tt.parameters); got != tt.want {
t.Errorf("parseNumCtx(%q) = %d, want %d", tt.parameters, got, tt.want)
}
})
}
}
func TestModelInfoContextLength(t *testing.T) {
tests := []struct {
name string
modelInfo map[string]any
want int
}{
{"float64 value", map[string]any{"qwen3_5_moe.context_length": float64(262144)}, 262144},
{"int value", map[string]any{"llama.context_length": 131072}, 131072},
{"no context_length key", map[string]any{"llama.embedding_length": float64(4096)}, 0},
{"empty map", map[string]any{}, 0},
{"nil map", nil, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, _ := modelInfoContextLength(tt.modelInfo)
if got != tt.want {
t.Errorf("modelInfoContextLength() = %d, want %d", got, tt.want)
}
})
}
}
func TestBuildCodexModelEntryContextWindow(t *testing.T) {
tests := []struct {
name string
modelName string
showResponse string
envContextLen string
wantContext int
}{
{
name: "architectural context length as fallback",
modelName: "llama3.2",
showResponse: `{
"model_info": {"llama.context_length": 131072},
"details": {"format": "gguf"}
}`,
wantContext: 131072,
},
{
name: "OLLAMA_CONTEXT_LENGTH overrides architectural",
modelName: "llama3.2",
showResponse: `{
"model_info": {"llama.context_length": 131072},
"details": {"format": "gguf"}
}`,
envContextLen: "64000",
wantContext: 64000,
},
{
name: "num_ctx overrides OLLAMA_CONTEXT_LENGTH",
modelName: "llama3.2",
showResponse: `{
"model_info": {"llama.context_length": 131072},
"parameters": "num_ctx 8192",
"details": {"format": "gguf"}
}`,
envContextLen: "64000",
wantContext: 8192,
},
{
name: "num_ctx overrides architectural",
modelName: "llama3.2",
showResponse: `{
"model_info": {"llama.context_length": 131072},
"parameters": "num_ctx 32768",
"details": {"format": "gguf"}
}`,
wantContext: 32768,
},
{
name: "safetensors uses architectural context only",
modelName: "llama3.2",
showResponse: `{
"model_info": {"llama.context_length": 131072},
"parameters": "num_ctx 8192",
"details": {"format": "safetensors"}
}`,
envContextLen: "64000",
wantContext: 131072,
},
{
name: "cloud model uses hardcoded limits",
modelName: "qwen3.5:cloud",
showResponse: `{
"model_info": {"qwen3_5_moe.context_length": 131072},
"details": {"format": "gguf"}
}`,
envContextLen: "64000",
wantContext: 262144,
},
{
name: "vision and thinking capabilities",
modelName: "llama3.2",
showResponse: `{
"model_info": {"llama.context_length": 131072},
"details": {"format": "gguf"},
"capabilities": ["vision", "thinking"]
}`,
wantContext: 131072,
},
{
name: "system prompt passed through",
modelName: "llama3.2",
showResponse: `{
"model_info": {"llama.context_length": 131072},
"details": {"format": "gguf"},
"system": "You are a helpful assistant."
}`,
wantContext: 131072,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
fmt.Fprint(w, tt.showResponse)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
if tt.envContextLen != "" {
t.Setenv("OLLAMA_CONTEXT_LENGTH", tt.envContextLen)
} else {
t.Setenv("OLLAMA_CONTEXT_LENGTH", "")
}
entry := buildCodexModelEntry(tt.modelName)
gotContext, _ := entry["context_window"].(int)
if gotContext != tt.wantContext {
t.Errorf("context_window = %d, want %d", gotContext, tt.wantContext)
}
if tt.name == "vision and thinking capabilities" {
modalities, _ := entry["input_modalities"].([]string)
if !slices.Contains(modalities, "image") {
t.Error("expected image in input_modalities")
}
levels, _ := entry["supported_reasoning_levels"].([]any)
if len(levels) == 0 {
t.Error("expected non-empty supported_reasoning_levels")
}
}
if tt.name == "system prompt passed through" {
if got, _ := entry["base_instructions"].(string); got != "You are a helpful assistant." {
t.Errorf("base_instructions = %q, want %q", got, "You are a helpful assistant.")
}
}
if tt.name == "cloud model uses hardcoded limits" {
truncationPolicy, _ := entry["truncation_policy"].(map[string]any)
if mode, _ := truncationPolicy["mode"].(string); mode != "tokens" {
t.Errorf("truncation_policy mode = %q, want %q", mode, "tokens")
}
}
requiredKeys := []string{"slug", "display_name", "apply_patch_tool_type", "shell_type"}
for _, key := range requiredKeys {
if _, ok := entry[key]; !ok {
t.Errorf("missing required key %q", key)
}
}
if _, err := json.Marshal(entry); err != nil {
t.Errorf("entry is not JSON serializable: %v", err)
}
})
}
}

View File

@@ -1,604 +0,0 @@
package launch
import (
"bytes"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/cmd/config"
"github.com/spf13/cobra"
)
func captureStderr(t *testing.T, fn func()) string {
t.Helper()
oldStderr := os.Stderr
r, w, err := os.Pipe()
if err != nil {
t.Fatalf("failed to create stderr pipe: %v", err)
}
os.Stderr = w
defer func() {
os.Stderr = oldStderr
}()
done := make(chan string, 1)
go func() {
var buf bytes.Buffer
_, _ = io.Copy(&buf, r)
done <- buf.String()
}()
fn()
_ = w.Close()
return <-done
}
func TestLaunchCmd(t *testing.T) {
mockCheck := func(cmd *cobra.Command, args []string) error {
return nil
}
mockTUI := func(cmd *cobra.Command) {}
cmd := LaunchCmd(mockCheck, mockTUI)
t.Run("command structure", func(t *testing.T) {
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]")
}
if cmd.Short == "" {
t.Error("Short description should not be empty")
}
if cmd.Long == "" {
t.Error("Long description should not be empty")
}
if !strings.Contains(cmd.Long, "hermes") {
t.Error("Long description should mention hermes")
}
if !strings.Contains(cmd.Long, "kimi") {
t.Error("Long description should mention kimi")
}
})
t.Run("flags exist", func(t *testing.T) {
if cmd.Flags().Lookup("model") == nil {
t.Error("--model flag should exist")
}
if cmd.Flags().Lookup("config") == nil {
t.Error("--config flag should exist")
}
if cmd.Flags().Lookup("yes") == nil {
t.Error("--yes flag should exist")
}
})
t.Run("PreRunE is set", func(t *testing.T) {
if cmd.PreRunE == nil {
t.Error("PreRunE should be set to checkServerHeartbeat")
}
})
}
func TestLaunchCmdTUICallback(t *testing.T) {
mockCheck := func(cmd *cobra.Command, args []string) error {
return nil
}
t.Run("no args calls TUI", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{})
_ = cmd.Execute()
if !tuiCalled {
t.Error("TUI callback should be called when no args provided")
}
})
t.Run("integration arg bypasses TUI", func(t *testing.T) {
srv := httptest.NewServer(http.NotFoundHandler())
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"claude"})
_ = cmd.Execute()
if tuiCalled {
t.Error("TUI callback should NOT be called when integration arg provided")
}
})
t.Run("--model flag without integration returns error", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"--model", "test-model"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected --model without an integration to fail")
}
if !strings.Contains(err.Error(), "require an integration name") {
t.Fatalf("expected integration-name guidance, got %v", err)
}
if tuiCalled {
t.Error("TUI callback should NOT be called when --model is provided without an integration")
}
})
t.Run("--config flag without integration returns error", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"--config"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected --config without an integration to fail")
}
if !strings.Contains(err.Error(), "require an integration name") {
t.Fatalf("expected integration-name guidance, got %v", err)
}
if tuiCalled {
t.Error("TUI callback should NOT be called when --config is provided without an integration")
}
})
t.Run("--yes flag without integration returns error", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"--yes"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected --yes without an integration to fail")
}
if !strings.Contains(err.Error(), "require an integration name") {
t.Fatalf("expected integration-name guidance, got %v", err)
}
if tuiCalled {
t.Error("TUI callback should NOT be called when --yes is provided without an integration")
}
})
t.Run("extra args without integration return error", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"--model", "test-model", "--", "--sandbox", "workspace-write"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected flags and extra args without an integration to fail")
}
if !strings.Contains(err.Error(), "require an integration name") {
t.Fatalf("expected integration-name guidance, got %v", err)
}
if tuiCalled {
t.Error("TUI callback should NOT be called when flags or extra args are provided without an integration")
}
})
}
func TestLaunchCmdNilHeartbeat(t *testing.T) {
cmd := LaunchCmd(nil, nil)
if cmd == nil {
t.Fatal("LaunchCmd returned nil")
}
if cmd.PreRunE != nil {
t.Log("Note: PreRunE is set even when nil is passed (acceptable)")
}
}
func TestLaunchCmdModelFlagFiltersDisabledCloudFromSavedConfig(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
if err := config.SaveIntegration("stubeditor", []string{"glm-5:cloud"}); err != nil {
t.Fatalf("failed to seed saved config: %v", err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/status":
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
case "/api/show":
fmt.Fprintf(w, `{"model":"llama3.2"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherEditorRunner{}
restore := OverrideIntegration("stubeditor", stub)
defer restore()
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
if err := cmd.Execute(); err != nil {
t.Fatalf("launch command failed: %v", err)
}
saved, err := config.LoadIntegration("stubeditor")
if err != nil {
t.Fatalf("failed to reload integration config: %v", err)
}
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
}
if stub.ranModel != "llama3.2" {
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
}
}
func TestLaunchCmdModelFlagClearsDisabledCloudOverride(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/status":
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
case "/api/show":
fmt.Fprint(w, `{"model":"llama3.2"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherSingleRunner{}
restore := OverrideIntegration("stubapp", stub)
defer restore()
oldSelector := DefaultSingleSelector
defer func() { DefaultSingleSelector = oldSelector }()
var selectorCalls int
var gotCurrent string
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
selectorCalls++
gotCurrent = current
return "llama3.2", nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubapp", "--model", "glm-5:cloud"})
stderr := captureStderr(t, func() {
if err := cmd.Execute(); err != nil {
t.Fatalf("launch command failed: %v", err)
}
})
if selectorCalls != 1 {
t.Fatalf("expected disabled cloud override to fall back to selector, got %d calls", selectorCalls)
}
if gotCurrent != "" {
t.Fatalf("expected disabled override to be cleared before selection, got current %q", gotCurrent)
}
if stub.ranModel != "llama3.2" {
t.Fatalf("expected launch to run with replacement local model, got %q", stub.ranModel)
}
if !strings.Contains(stderr, "Warning: ignoring --model glm-5:cloud because cloud is disabled") {
t.Fatalf("expected disabled-cloud warning, got stderr: %q", stderr)
}
saved, err := config.LoadIntegration("stubapp")
if err != nil {
t.Fatalf("failed to reload integration config: %v", err)
}
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
}
}
func TestLaunchCmdYes_AutoConfirmsLaunchPromptPath(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
withInteractiveSession(t, false)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
fmt.Fprint(w, `{"model":"llama3.2"}`)
case "/api/status":
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, `{"error":"not found"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherEditorRunner{paths: []string{"/tmp/stubeditor.json"}}
restore := OverrideIntegration("stubeditor", stub)
defer restore()
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
t.Fatalf("unexpected prompt with --yes: %q", prompt)
return false, nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2", "--yes"})
if err := cmd.Execute(); err != nil {
t.Fatalf("launch command with --yes failed: %v", err)
}
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
}
if stub.ranModel != "llama3.2" {
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
}
}
func TestLaunchCmdHeadlessWithYes_AutoPullsMissingLocalModel(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
withInteractiveSession(t, false)
var pullCalled bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, `{"error":"model not found"}`)
case "/api/pull":
pullCalled = true
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, `{"status":"success"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherSingleRunner{}
restore := OverrideIntegration("stubapp", stub)
defer restore()
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
t.Fatalf("unexpected prompt with --yes in headless autopull path: %q", prompt)
return false, nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubapp", "--model", "missing-model", "--yes"})
if err := cmd.Execute(); err != nil {
t.Fatalf("launch command with --yes failed: %v", err)
}
if !pullCalled {
t.Fatal("expected missing local model to be auto-pulled with --yes in headless mode")
}
if stub.ranModel != "missing-model" {
t.Fatalf("expected launch to run with pulled model, got %q", stub.ranModel)
}
}
func TestLaunchCmdHeadlessWithoutYes_ReturnsActionableConfirmError(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
withInteractiveSession(t, false)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
fmt.Fprint(w, `{"model":"llama3.2"}`)
case "/api/status":
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, `{"error":"not found"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherEditorRunner{paths: []string{"/tmp/stubeditor.json"}}
restore := OverrideIntegration("stubeditor", stub)
defer restore()
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
t.Fatalf("unexpected prompt in headless non-yes mode: %q", prompt)
return false, nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected launch command to fail without --yes in headless mode")
}
if !strings.Contains(err.Error(), "re-run with --yes") {
t.Fatalf("expected actionable --yes guidance, got %v", err)
}
if len(stub.edited) != 0 {
t.Fatalf("expected no editor writes when confirmation is blocked, got %v", stub.edited)
}
if stub.ranModel != "" {
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
}
}
func TestLaunchCmdIntegrationArgPromptsForModelWithSavedSelection(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
if err := config.SaveIntegration("stubapp", []string{"llama3.2"}); err != nil {
t.Fatalf("failed to seed saved config: %v", err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"},{"name":"qwen3:8b"}]}`)
case "/api/show":
fmt.Fprint(w, `{"model":"qwen3:8b"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherSingleRunner{}
restore := OverrideIntegration("stubapp", stub)
defer restore()
oldSelector := DefaultSingleSelector
defer func() { DefaultSingleSelector = oldSelector }()
var gotCurrent string
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
gotCurrent = current
return "qwen3:8b", nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubapp"})
if err := cmd.Execute(); err != nil {
t.Fatalf("launch command failed: %v", err)
}
if gotCurrent != "llama3.2" {
t.Fatalf("expected selector current model to be saved model llama3.2, got %q", gotCurrent)
}
if stub.ranModel != "qwen3:8b" {
t.Fatalf("expected launch to run selected model qwen3:8b, got %q", stub.ranModel)
}
saved, err := config.LoadIntegration("stubapp")
if err != nil {
t.Fatalf("failed to reload integration config: %v", err)
}
if diff := cmp.Diff([]string{"qwen3:8b"}, saved.Models); diff != "" {
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
}
}
func TestLaunchCmdHeadlessYes_IntegrationRequiresModelEvenWhenSaved(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
withInteractiveSession(t, false)
if err := config.SaveIntegration("stubapp", []string{"llama3.2"}); err != nil {
t.Fatalf("failed to seed saved config: %v", err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
fmt.Fprint(w, `{"model":"llama3.2"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherSingleRunner{}
restore := OverrideIntegration("stubapp", stub)
defer restore()
oldSelector := DefaultSingleSelector
defer func() { DefaultSingleSelector = oldSelector }()
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
t.Fatal("selector should not be called for headless --yes saved-model launch")
return "", nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubapp", "--yes"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected launch command to fail when --yes is used headlessly without --model")
}
if !strings.Contains(err.Error(), "requires --model <model>") {
t.Fatalf("expected actionable --model guidance, got %v", err)
}
if stub.ranModel != "" {
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
}
}
func TestLaunchCmdHeadlessYes_IntegrationWithoutSavedModelReturnsError(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
withInteractiveSession(t, false)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherSingleRunner{}
restore := OverrideIntegration("stubapp", stub)
defer restore()
oldSelector := DefaultSingleSelector
defer func() { DefaultSingleSelector = oldSelector }()
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
t.Fatal("selector should not be called for headless --yes without saved model")
return "", nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubapp", "--yes"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected launch command to fail when --yes is used headlessly without --model")
}
if !strings.Contains(err.Error(), "requires --model <model>") {
t.Fatalf("expected actionable --model guidance, got %v", err)
}
if stub.ranModel != "" {
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
}
}

View File

@@ -1,76 +0,0 @@
package launch
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"github.com/ollama/ollama/envconfig"
)
// Copilot implements Runner for GitHub Copilot CLI integration.
type Copilot struct{}
func (c *Copilot) String() string { return "Copilot CLI" }
func (c *Copilot) args(model string, extra []string) []string {
var args []string
if model != "" {
args = append(args, "--model", model)
}
args = append(args, extra...)
return args
}
func (c *Copilot) findPath() (string, error) {
if p, err := exec.LookPath("copilot"); err == nil {
return p, nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
name := "copilot"
if runtime.GOOS == "windows" {
name = "copilot.exe"
}
fallback := filepath.Join(home, ".local", "bin", name)
if _, err := os.Stat(fallback); err != nil {
return "", err
}
return fallback, nil
}
func (c *Copilot) Run(model string, args []string) error {
copilotPath, err := c.findPath()
if err != nil {
return fmt.Errorf("copilot is not installed, install from https://docs.github.com/en/copilot/how-tos/set-up/install-copilot-cli")
}
cmd := exec.Command(copilotPath, c.args(model, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(), c.envVars(model)...)
return cmd.Run()
}
// envVars returns the environment variables that configure Copilot CLI
// to use Ollama as its model provider.
func (c *Copilot) envVars(model string) []string {
env := []string{
"COPILOT_PROVIDER_BASE_URL=" + envconfig.Host().String() + "/v1",
"COPILOT_PROVIDER_API_KEY=",
"COPILOT_PROVIDER_WIRE_API=responses",
}
if model != "" {
env = append(env, "COPILOT_MODEL="+model)
}
return env
}

View File

@@ -1,161 +0,0 @@
package launch
import (
"os"
"path/filepath"
"runtime"
"slices"
"strings"
"testing"
)
func TestCopilotIntegration(t *testing.T) {
c := &Copilot{}
t.Run("String", func(t *testing.T) {
if got := c.String(); got != "Copilot CLI" {
t.Errorf("String() = %q, want %q", got, "Copilot CLI")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = c
})
}
func TestCopilotFindPath(t *testing.T) {
c := &Copilot{}
t.Run("finds copilot in PATH", func(t *testing.T) {
tmpDir := t.TempDir()
name := "copilot"
if runtime.GOOS == "windows" {
name = "copilot.exe"
}
fakeBin := filepath.Join(tmpDir, name)
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
t.Setenv("PATH", tmpDir)
got, err := c.findPath()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != fakeBin {
t.Errorf("findPath() = %q, want %q", got, fakeBin)
}
})
t.Run("returns error when not in PATH", func(t *testing.T) {
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
_, err := c.findPath()
if err == nil {
t.Fatal("expected error, got nil")
}
})
t.Run("falls back to ~/.local/bin/copilot", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
name := "copilot"
if runtime.GOOS == "windows" {
name = "copilot.exe"
}
fallback := filepath.Join(tmpDir, ".local", "bin", name)
os.MkdirAll(filepath.Dir(fallback), 0o755)
os.WriteFile(fallback, []byte("#!/bin/sh\n"), 0o755)
got, err := c.findPath()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != fallback {
t.Errorf("findPath() = %q, want %q", got, fallback)
}
})
t.Run("returns error when neither PATH nor fallback exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
_, err := c.findPath()
if err == nil {
t.Fatal("expected error, got nil")
}
})
}
func TestCopilotArgs(t *testing.T) {
c := &Copilot{}
tests := []struct {
name string
model string
args []string
want []string
}{
{"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
{"empty model", "", nil, nil},
{"with model and extra", "llama3.2", []string{"--verbose"}, []string{"--model", "llama3.2", "--verbose"}},
{"empty model with help", "", []string{"--help"}, []string{"--help"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model, tt.args)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
}
})
}
}
func TestCopilotEnvVars(t *testing.T) {
c := &Copilot{}
envMap := func(envs []string) map[string]string {
m := make(map[string]string)
for _, e := range envs {
k, v, _ := strings.Cut(e, "=")
m[k] = v
}
return m
}
t.Run("sets required provider env vars with model", func(t *testing.T) {
got := envMap(c.envVars("llama3.2"))
if got["COPILOT_PROVIDER_BASE_URL"] == "" {
t.Error("COPILOT_PROVIDER_BASE_URL should be set")
}
if !strings.HasSuffix(got["COPILOT_PROVIDER_BASE_URL"], "/v1") {
t.Errorf("COPILOT_PROVIDER_BASE_URL = %q, want /v1 suffix", got["COPILOT_PROVIDER_BASE_URL"])
}
if _, ok := got["COPILOT_PROVIDER_API_KEY"]; !ok {
t.Error("COPILOT_PROVIDER_API_KEY should be set (empty)")
}
if got["COPILOT_PROVIDER_WIRE_API"] != "responses" {
t.Errorf("COPILOT_PROVIDER_WIRE_API = %q, want %q", got["COPILOT_PROVIDER_WIRE_API"], "responses")
}
if got["COPILOT_MODEL"] != "llama3.2" {
t.Errorf("COPILOT_MODEL = %q, want %q", got["COPILOT_MODEL"], "llama3.2")
}
})
t.Run("omits COPILOT_MODEL when model is empty", func(t *testing.T) {
got := envMap(c.envVars(""))
if _, ok := got["COPILOT_MODEL"]; ok {
t.Errorf("COPILOT_MODEL should not be set for empty model, got %q", got["COPILOT_MODEL"])
}
})
t.Run("uses custom OLLAMA_HOST", func(t *testing.T) {
t.Setenv("OLLAMA_HOST", "http://myhost:9999")
got := envMap(c.envVars("test"))
if !strings.Contains(got["COPILOT_PROVIDER_BASE_URL"], "myhost:9999") {
t.Errorf("COPILOT_PROVIDER_BASE_URL = %q, want custom host", got["COPILOT_PROVIDER_BASE_URL"])
}
})
}

View File

@@ -1,679 +0,0 @@
package launch
import (
"bufio"
"bytes"
"context"
"fmt"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"slices"
"strconv"
"strings"
"gopkg.in/yaml.v3"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig"
)
const (
hermesInstallScript = "curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash -s -- --skip-setup"
hermesProviderName = "Ollama"
hermesProviderKey = "ollama-launch"
hermesLegacyKey = "ollama"
hermesPlaceholderKey = "ollama"
hermesGatewaySetupHint = "hermes gateway setup"
hermesGatewaySetupTitle = "Connect a messaging app now?"
)
var (
hermesGOOS = runtime.GOOS
hermesLookPath = exec.LookPath
hermesCommand = exec.Command
hermesUserHome = os.UserHomeDir
hermesOllamaURL = envconfig.ConnectableHost
)
var hermesMessagingEnvGroups = [][]string{
{"TELEGRAM_BOT_TOKEN"},
{"DISCORD_BOT_TOKEN"},
{"SLACK_BOT_TOKEN"},
{"SIGNAL_ACCOUNT"},
{"EMAIL_ADDRESS"},
{"TWILIO_ACCOUNT_SID"},
{"MATRIX_ACCESS_TOKEN", "MATRIX_PASSWORD"},
{"MATTERMOST_TOKEN"},
{"WHATSAPP_PHONE_NUMBER_ID"},
{"DINGTALK_CLIENT_ID"},
{"FEISHU_APP_ID"},
{"WECOM_BOT_ID"},
{"WEIXIN_ACCOUNT_ID"},
{"BLUEBUBBLES_SERVER_URL"},
{"WEBHOOK_ENABLED"},
}
// Hermes is intentionally not an Editor integration: launch owns one primary
// model and the local Ollama endpoint, while Hermes keeps its own discovery and
// switching UX after startup.
type Hermes struct{}
func (h *Hermes) String() string { return "Hermes Agent" }
func (h *Hermes) Run(_ string, args []string) error {
// Hermes reads its primary model from config.yaml. launch configures that
// default model ahead of time so we can keep runtime invocation simple and
// still let Hermes discover additional models later via its own UX.
bin, err := h.binary()
if err != nil {
return err
}
if err := h.runGatewaySetupPreflight(args, func() error {
return hermesAttachedCommand(bin, "gateway", "setup").Run()
}); err != nil {
return err
}
return hermesAttachedCommand(bin, args...).Run()
}
func (h *Hermes) Paths() []string {
configPath, err := hermesConfigPath()
if err != nil {
return nil
}
return []string{configPath}
}
func (h *Hermes) Configure(model string) error {
configPath, err := hermesConfigPath()
if err != nil {
return err
}
cfg := map[string]any{}
if data, err := os.ReadFile(configPath); err == nil {
if err := yaml.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("parse hermes config: %w", err)
}
} else if !os.IsNotExist(err) {
return err
}
modelSection, _ := cfg["model"].(map[string]any)
if modelSection == nil {
modelSection = make(map[string]any)
}
models := h.listModels(model)
applyHermesManagedProviders(cfg, hermesBaseURL(), model, models)
// launch writes the minimum provider/default-model settings needed to
// bootstrap Hermes against Ollama. The active provider stays on a
// launch-owned key so /model stays aligned with the launcher-managed entry,
// and the Ollama endpoint lives in providers: so the picker shows one row.
modelSection["provider"] = hermesProviderKey
modelSection["default"] = model
modelSection["base_url"] = hermesBaseURL()
modelSection["api_key"] = hermesPlaceholderKey
cfg["model"] = modelSection
// use Hermes' built-in web toolset for now.
// TODO(parthsareen): move this to using Ollama web search
cfg["toolsets"] = mergeHermesToolsets(cfg["toolsets"])
data, err := yaml.Marshal(cfg)
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
return fileutil.WriteWithBackup(configPath, data)
}
func (h *Hermes) CurrentModel() string {
configPath, err := hermesConfigPath()
if err != nil {
return ""
}
data, err := os.ReadFile(configPath)
if err != nil {
return ""
}
cfg := map[string]any{}
if yaml.Unmarshal(data, &cfg) != nil {
return ""
}
return hermesManagedCurrentModel(cfg, hermesBaseURL())
}
func (h *Hermes) Onboard() error {
return config.MarkIntegrationOnboarded("hermes")
}
func (h *Hermes) RequiresInteractiveOnboarding() bool {
return false
}
func (h *Hermes) RefreshRuntimeAfterConfigure() error {
running, err := h.gatewayRunning()
if err != nil {
return fmt.Errorf("check Hermes gateway status: %w", err)
}
if !running {
return nil
}
fmt.Fprintf(os.Stderr, "%sRefreshing Hermes messaging gateway...%s\n", ansiGray, ansiReset)
if err := h.restartGateway(); err != nil {
return fmt.Errorf("restart Hermes gateway: %w", err)
}
fmt.Fprintln(os.Stderr)
return nil
}
func (h *Hermes) installed() bool {
_, err := h.binary()
return err == nil
}
func (h *Hermes) ensureInstalled() error {
if h.installed() {
return nil
}
if hermesGOOS == "windows" {
return hermesWindowsHint()
}
var missing []string
for _, dep := range []string{"bash", "curl", "git"} {
if _, err := hermesLookPath(dep); err != nil {
missing = append(missing, dep)
}
}
if len(missing) > 0 {
return fmt.Errorf("Hermes is not installed and required dependencies are missing\n\nInstall the following first:\n %s\n\nThen re-run:\n ollama launch hermes", strings.Join(missing, "\n "))
}
ok, err := ConfirmPrompt("Hermes is not installed. Install now?")
if err != nil {
return err
}
if !ok {
return fmt.Errorf("hermes installation cancelled")
}
fmt.Fprintf(os.Stderr, "\nInstalling Hermes...\n")
if err := hermesAttachedCommand("bash", "-lc", hermesInstallScript).Run(); err != nil {
return fmt.Errorf("failed to install hermes: %w", err)
}
if !h.installed() {
return fmt.Errorf("hermes was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
}
fmt.Fprintf(os.Stderr, "%sHermes installed successfully%s\n\n", ansiGreen, ansiReset)
return nil
}
func (h *Hermes) listModels(defaultModel string) []string {
client := hermesOllamaClient()
resp, err := client.List(context.Background())
if err != nil {
return []string{defaultModel}
}
models := make([]string, 0, len(resp.Models)+1)
seen := make(map[string]struct{}, len(resp.Models)+1)
add := func(name string) {
name = strings.TrimSpace(name)
if name == "" {
return
}
if _, ok := seen[name]; ok {
return
}
seen[name] = struct{}{}
models = append(models, name)
}
add(defaultModel)
for _, entry := range resp.Models {
add(entry.Name)
}
if len(models) == 0 {
return []string{defaultModel}
}
return models
}
func (h *Hermes) binary() (string, error) {
if path, err := hermesLookPath("hermes"); err == nil {
return path, nil
}
if hermesGOOS == "windows" {
return "", hermesWindowsHint()
}
home, err := hermesUserHome()
if err != nil {
return "", err
}
fallback := filepath.Join(home, ".local", "bin", "hermes")
if _, err := os.Stat(fallback); err == nil {
return fallback, nil
}
return "", fmt.Errorf("hermes is not installed")
}
func hermesConfigPath() (string, error) {
home, err := hermesUserHome()
if err != nil {
return "", err
}
return filepath.Join(home, ".hermes", "config.yaml"), nil
}
func hermesBaseURL() string {
return strings.TrimRight(hermesOllamaURL().String(), "/") + "/v1"
}
func hermesEnvPath() (string, error) {
home, err := hermesUserHome()
if err != nil {
return "", err
}
return filepath.Join(home, ".hermes", ".env"), nil
}
func (h *Hermes) runGatewaySetupPreflight(args []string, runSetup func() error) error {
if len(args) > 0 || !isInteractiveSession() || currentLaunchConfirmPolicy.yes || currentLaunchConfirmPolicy.requireYesMessage {
return nil
}
if h.messagingConfigured() {
return nil
}
fmt.Fprintf(os.Stderr, "\nHermes can message you on Telegram, Discord, Slack, and more.\n\n")
ok, err := ConfirmPromptWithOptions(hermesGatewaySetupTitle, ConfirmOptions{
YesLabel: "Yes",
NoLabel: "Set up later",
})
if err != nil {
return err
}
if !ok {
return nil
}
if err := runSetup(); err != nil {
return fmt.Errorf("hermes messaging setup failed: %w\n\nTry running: %s", err, hermesGatewaySetupHint)
}
return nil
}
func (h *Hermes) messagingConfigured() bool {
envVars, err := h.gatewayEnvVars()
if err != nil {
return false
}
for _, group := range hermesMessagingEnvGroups {
for _, key := range group {
if strings.TrimSpace(envVars[key]) != "" {
return true
}
}
}
return false
}
func (h *Hermes) gatewayEnvVars() (map[string]string, error) {
envVars := make(map[string]string)
envFilePath, err := hermesEnvPath()
if err != nil {
return nil, err
}
switch data, err := os.ReadFile(envFilePath); {
case err == nil:
for key, value := range hermesParseEnvFile(data) {
envVars[key] = value
}
case os.IsNotExist(err):
// nothing persisted yet
default:
return nil, err
}
for _, group := range hermesMessagingEnvGroups {
for _, key := range group {
if value, ok := os.LookupEnv(key); ok {
envVars[key] = value
}
}
}
return envVars, nil
}
func (h *Hermes) gatewayRunning() (bool, error) {
status, err := h.gatewayStatusOutput()
if err != nil {
return false, err
}
return hermesGatewayStatusRunning(status), nil
}
func (h *Hermes) gatewayStatusOutput() (string, error) {
bin, err := h.binary()
if err != nil {
return "", err
}
out, err := hermesCommand(bin, "gateway", "status").CombinedOutput()
return string(out), err
}
func (h *Hermes) restartGateway() error {
bin, err := h.binary()
if err != nil {
return err
}
return hermesAttachedCommand(bin, "gateway", "restart").Run()
}
func hermesGatewayStatusRunning(output string) bool {
status := strings.ToLower(output)
switch {
case strings.Contains(status, "gateway is not running"):
return false
case strings.Contains(status, "gateway service is stopped"):
return false
case strings.Contains(status, "gateway service is not loaded"):
return false
case strings.Contains(status, "gateway is running"):
return true
case strings.Contains(status, "gateway service is running"):
return true
case strings.Contains(status, "gateway service is loaded"):
return true
default:
return false
}
}
func hermesParseEnvFile(data []byte) map[string]string {
out := make(map[string]string)
scanner := bufio.NewScanner(bytes.NewReader(data))
for scanner.Scan() {
line := strings.TrimSpace(strings.TrimPrefix(scanner.Text(), "\ufeff"))
if line == "" || strings.HasPrefix(line, "#") {
continue
}
if strings.HasPrefix(line, "export ") {
line = strings.TrimSpace(strings.TrimPrefix(line, "export "))
}
key, value, ok := strings.Cut(line, "=")
if !ok {
continue
}
key = strings.TrimSpace(key)
if key == "" {
continue
}
value = strings.TrimSpace(value)
if len(value) >= 2 {
switch {
case value[0] == '"' && value[len(value)-1] == '"':
if unquoted, err := strconv.Unquote(value); err == nil {
value = unquoted
}
case value[0] == '\'' && value[len(value)-1] == '\'':
value = value[1 : len(value)-1]
}
}
out[key] = value
}
return out
}
func hermesOllamaClient() *api.Client {
// Hermes queries the same launch-resolved Ollama host that launch writes
// into config, so model discovery follows the configured endpoint.
return api.NewClient(hermesOllamaURL(), http.DefaultClient)
}
func applyHermesManagedProviders(cfg map[string]any, baseURL string, model string, models []string) {
providers := hermesUserProviders(cfg["providers"])
entry := hermesManagedProviderEntry(providers)
if entry == nil {
entry = make(map[string]any)
}
entry["name"] = hermesProviderName
entry["api"] = baseURL
entry["default_model"] = model
entry["models"] = hermesStringListAny(models)
providers[hermesProviderKey] = entry
delete(providers, hermesLegacyKey)
cfg["providers"] = providers
customProviders := hermesWithoutManagedCustomProviders(cfg["custom_providers"])
if len(customProviders) == 0 {
delete(cfg, "custom_providers")
return
}
cfg["custom_providers"] = customProviders
}
func hermesManagedCurrentModel(cfg map[string]any, baseURL string) string {
modelCfg, _ := cfg["model"].(map[string]any)
if modelCfg == nil {
return ""
}
provider, _ := modelCfg["provider"].(string)
if strings.TrimSpace(strings.ToLower(provider)) != hermesProviderKey {
return ""
}
configBaseURL, _ := modelCfg["base_url"].(string)
if hermesNormalizeURL(configBaseURL) != hermesNormalizeURL(baseURL) {
return ""
}
current, _ := modelCfg["default"].(string)
current = strings.TrimSpace(current)
if current == "" {
return ""
}
providers := hermesUserProviders(cfg["providers"])
entry, _ := providers[hermesProviderKey].(map[string]any)
if entry == nil {
return ""
}
if hermesHasManagedCustomProvider(cfg["custom_providers"]) {
return ""
}
apiURL, _ := entry["api"].(string)
if hermesNormalizeURL(apiURL) != hermesNormalizeURL(baseURL) {
return ""
}
defaultModel, _ := entry["default_model"].(string)
if strings.TrimSpace(defaultModel) != current {
return ""
}
return current
}
func hermesUserProviders(current any) map[string]any {
switch existing := current.(type) {
case map[string]any:
out := make(map[string]any, len(existing))
for key, value := range existing {
out[key] = value
}
return out
case map[any]any:
out := make(map[string]any, len(existing))
for key, value := range existing {
if s, ok := key.(string); ok {
out[s] = value
}
}
return out
default:
return make(map[string]any)
}
}
func hermesCustomProviders(current any) []any {
switch existing := current.(type) {
case []any:
return append([]any(nil), existing...)
case []map[string]any:
out := make([]any, 0, len(existing))
for _, entry := range existing {
out = append(out, entry)
}
return out
default:
return nil
}
}
func hermesManagedProviderEntry(providers map[string]any) map[string]any {
for _, key := range []string{hermesProviderKey, hermesLegacyKey} {
if entry, _ := providers[key].(map[string]any); entry != nil {
return entry
}
}
return nil
}
func hermesWithoutManagedCustomProviders(current any) []any {
customProviders := hermesCustomProviders(current)
preserved := make([]any, 0, len(customProviders))
for _, item := range customProviders {
entry, _ := item.(map[string]any)
if entry == nil {
preserved = append(preserved, item)
continue
}
if hermesManagedCustomProvider(entry) {
continue
}
preserved = append(preserved, entry)
}
return preserved
}
func hermesHasManagedCustomProvider(current any) bool {
for _, item := range hermesCustomProviders(current) {
entry, _ := item.(map[string]any)
if entry != nil && hermesManagedCustomProvider(entry) {
return true
}
}
return false
}
func hermesManagedCustomProvider(entry map[string]any) bool {
name, _ := entry["name"].(string)
return strings.EqualFold(strings.TrimSpace(name), hermesProviderName)
}
func hermesNormalizeURL(raw string) string {
return strings.TrimRight(strings.TrimSpace(raw), "/")
}
func hermesStringListAny(models []string) []any {
out := make([]any, 0, len(models))
for _, model := range dedupeModelList(models) {
model = strings.TrimSpace(model)
if model == "" {
continue
}
out = append(out, model)
}
return out
}
func mergeHermesToolsets(current any) any {
added := false
switch existing := current.(type) {
case []any:
out := make([]any, 0, len(existing)+1)
for _, item := range existing {
out = append(out, item)
if s, _ := item.(string); s == "web" {
added = true
}
}
if !added {
out = append(out, "web")
}
return out
case []string:
out := append([]string(nil), existing...)
if !slices.Contains(out, "web") {
out = append(out, "web")
}
asAny := make([]any, 0, len(out))
for _, item := range out {
asAny = append(asAny, item)
}
return asAny
case string:
if strings.TrimSpace(existing) == "" {
return []any{"hermes-cli", "web"}
}
parts := strings.Split(existing, ",")
out := make([]any, 0, len(parts)+1)
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
if part == "web" {
added = true
}
out = append(out, part)
}
if !added {
out = append(out, "web")
}
return out
default:
return []any{"hermes-cli", "web"}
}
}
func hermesAttachedCommand(name string, args ...string) *exec.Cmd {
cmd := hermesCommand(name, args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd
}
func hermesWindowsHint() error {
return fmt.Errorf("Hermes on Windows requires WSL2. Install WSL with: wsl --install\n" +
"Then run 'ollama launch hermes' from inside your WSL shell.\n" +
"Docs: https://hermes-agent.nousresearch.com/docs/getting-started/installation/")
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,315 +0,0 @@
package launch
import (
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
// Kimi implements Runner for Kimi Code CLI integration.
type Kimi struct{}
const (
kimiDefaultModelAlias = "ollama"
kimiDefaultMaxContextSize = 32768
)
var (
kimiGOOS = runtime.GOOS
kimiModelShowTimeout = 5 * time.Second
)
func (k *Kimi) String() string { return "Kimi Code CLI" }
func (k *Kimi) args(config string, extra []string) []string {
args := []string{"--config", config}
args = append(args, extra...)
return args
}
func (k *Kimi) Run(model string, args []string) error {
if strings.TrimSpace(model) == "" {
return fmt.Errorf("model is required")
}
if err := validateKimiPassthroughArgs(args); err != nil {
return err
}
config, err := buildKimiInlineConfig(model, resolveKimiMaxContextSize(model))
if err != nil {
return fmt.Errorf("failed to build kimi config: %w", err)
}
bin, err := ensureKimiInstalled()
if err != nil {
return err
}
cmd := exec.Command(bin, k.args(config, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func findKimiBinary() (string, error) {
if path, err := exec.LookPath("kimi"); err == nil {
return path, nil
}
home, _ := os.UserHomeDir()
var candidates []string
switch kimiGOOS {
case "windows":
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(home, ".local", "bin"))
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(home, "bin"))
if appData := strings.TrimSpace(os.Getenv("APPDATA")); appData != "" {
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(appData, "uv", "bin"))
}
if localAppData := strings.TrimSpace(os.Getenv("LOCALAPPDATA")); localAppData != "" {
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(localAppData, "uv", "bin"))
}
default:
candidates = append(candidates,
filepath.Join(home, ".local", "bin", "kimi"),
filepath.Join(home, "bin", "kimi"),
filepath.Join(home, ".local", "share", "uv", "tools", "kimi-cli", "bin", "kimi"),
filepath.Join(home, ".local", "share", "uv", "tools", "kimi", "bin", "kimi"),
)
if xdgDataHome := strings.TrimSpace(os.Getenv("XDG_DATA_HOME")); xdgDataHome != "" {
candidates = append(candidates,
filepath.Join(xdgDataHome, "uv", "tools", "kimi-cli", "bin", "kimi"),
filepath.Join(xdgDataHome, "uv", "tools", "kimi", "bin", "kimi"),
)
}
// WSL users can inherit Windows env vars while launching from Linux shells.
if profile := windowsPathToWSL(os.Getenv("USERPROFILE")); profile != "" {
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(profile, ".local", "bin"))
}
if appData := windowsPathToWSL(os.Getenv("APPDATA")); appData != "" {
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(appData, "uv", "bin"))
}
if localAppData := windowsPathToWSL(os.Getenv("LOCALAPPDATA")); localAppData != "" {
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(localAppData, "uv", "bin"))
}
}
for _, candidate := range candidates {
if info, err := os.Stat(candidate); err == nil && !info.IsDir() {
return candidate, nil
}
}
return "", fmt.Errorf("kimi binary not found")
}
func appendWindowsKimiCandidates(candidates []string, dir string) []string {
if strings.TrimSpace(dir) == "" {
return candidates
}
return append(candidates,
filepath.Join(dir, "kimi.exe"),
filepath.Join(dir, "kimi.cmd"),
filepath.Join(dir, "kimi.bat"),
)
}
func windowsPathToWSL(path string) string {
trimmed := strings.TrimSpace(path)
if len(trimmed) < 3 || trimmed[1] != ':' {
return ""
}
drive := strings.ToLower(string(trimmed[0]))
rest := strings.ReplaceAll(trimmed[2:], "\\", "/")
rest = strings.TrimPrefix(rest, "/")
if rest == "" {
return filepath.Join("/mnt", drive)
}
return filepath.Join("/mnt", drive, rest)
}
func validateKimiPassthroughArgs(args []string) error {
for _, arg := range args {
switch {
case arg == "--config", strings.HasPrefix(arg, "--config="):
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages --config", arg)
case arg == "--config-file", strings.HasPrefix(arg, "--config-file="):
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages --config-file", arg)
case arg == "--model", strings.HasPrefix(arg, "--model="):
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages --model", arg)
case arg == "-m", strings.HasPrefix(arg, "-m="):
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages -m/--model", arg)
}
}
return nil
}
func buildKimiInlineConfig(model string, maxContextSize int) (string, error) {
cfg := map[string]any{
"default_model": kimiDefaultModelAlias,
"providers": map[string]any{
kimiDefaultModelAlias: map[string]any{
"type": "openai_legacy",
"base_url": envconfig.ConnectableHost().String() + "/v1",
"api_key": "ollama",
},
},
"models": map[string]any{
kimiDefaultModelAlias: map[string]any{
"provider": kimiDefaultModelAlias,
"model": model,
"max_context_size": maxContextSize,
},
},
}
data, err := json.Marshal(cfg)
if err != nil {
return "", err
}
return string(data), nil
}
func resolveKimiMaxContextSize(model string) int {
if l, ok := lookupCloudModelLimit(model); ok {
return l.Context
}
client, err := api.ClientFromEnvironment()
if err != nil {
return kimiDefaultMaxContextSize
}
ctx, cancel := context.WithTimeout(context.Background(), kimiModelShowTimeout)
defer cancel()
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
if err != nil {
return kimiDefaultMaxContextSize
}
if n, ok := modelInfoContextLength(resp.ModelInfo); ok {
return n
}
return kimiDefaultMaxContextSize
}
func modelInfoContextLength(modelInfo map[string]any) (int, bool) {
for key, val := range modelInfo {
if !strings.HasSuffix(key, ".context_length") {
continue
}
switch v := val.(type) {
case float64:
if v > 0 {
return int(v), true
}
case int:
if v > 0 {
return v, true
}
case int64:
if v > 0 {
return int(v), true
}
}
}
return 0, false
}
func ensureKimiInstalled() (string, error) {
if path, err := findKimiBinary(); err == nil {
return path, nil
}
if err := checkKimiInstallerDependencies(); err != nil {
return "", err
}
ok, err := ConfirmPrompt("Kimi is not installed. Install now?")
if err != nil {
return "", err
}
if !ok {
return "", fmt.Errorf("kimi installation cancelled")
}
bin, args, err := kimiInstallerCommand(kimiGOOS)
if err != nil {
return "", err
}
fmt.Fprintf(os.Stderr, "\nInstalling Kimi...\n")
cmd := exec.Command(bin, args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("failed to install kimi: %w", err)
}
path, err := findKimiBinary()
if err != nil {
return "", fmt.Errorf("kimi was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
}
fmt.Fprintf(os.Stderr, "%sKimi installed successfully%s\n\n", ansiGreen, ansiReset)
return path, nil
}
func checkKimiInstallerDependencies() error {
switch kimiGOOS {
case "windows":
if _, err := exec.LookPath("powershell"); err != nil {
return fmt.Errorf("kimi is not installed and required dependencies are missing\n\nInstall the following first:\n PowerShell: https://learn.microsoft.com/powershell/\n\nThen re-run:\n ollama launch kimi")
}
default:
var missing []string
if _, err := exec.LookPath("curl"); err != nil {
missing = append(missing, "curl: https://curl.se/")
}
if _, err := exec.LookPath("bash"); err != nil {
missing = append(missing, "bash: https://www.gnu.org/software/bash/")
}
if len(missing) > 0 {
return fmt.Errorf("kimi is not installed and required dependencies are missing\n\nInstall the following first:\n %s\n\nThen re-run:\n ollama launch kimi", strings.Join(missing, "\n "))
}
}
return nil
}
func kimiInstallerCommand(goos string) (string, []string, error) {
switch goos {
case "windows":
return "powershell", []string{
"-NoProfile",
"-ExecutionPolicy",
"Bypass",
"-Command",
"Invoke-RestMethod https://code.kimi.com/install.ps1 | Invoke-Expression",
}, nil
case "darwin", "linux":
return "bash", []string{
"-c",
"curl -LsSf https://code.kimi.com/install.sh | bash",
}, nil
default:
return "", nil, fmt.Errorf("unsupported platform for kimi install: %s", goos)
}
}

View File

@@ -1,636 +0,0 @@
package launch
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"runtime"
"slices"
"strings"
"testing"
)
func assertKimiBinPath(t *testing.T, bin string) {
t.Helper()
base := strings.ToLower(filepath.Base(bin))
if !strings.HasPrefix(base, "kimi") {
t.Fatalf("bin = %q, want path to kimi executable", bin)
}
}
func TestKimiIntegration(t *testing.T) {
k := &Kimi{}
t.Run("String", func(t *testing.T) {
if got := k.String(); got != "Kimi Code CLI" {
t.Errorf("String() = %q, want %q", got, "Kimi Code CLI")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = k
})
}
func TestKimiArgs(t *testing.T) {
k := &Kimi{}
got := k.args(`{"foo":"bar"}`, []string{"--quiet", "--print"})
want := []string{"--config", `{"foo":"bar"}`, "--quiet", "--print"}
if !slices.Equal(got, want) {
t.Fatalf("args() = %v, want %v", got, want)
}
}
func TestWindowsPathToWSL(t *testing.T) {
tests := []struct {
name string
in string
want string
valid bool
}{
{
name: "user profile path",
in: `C:\Users\parth`,
want: filepath.Join("/mnt", "c", "Users", "parth"),
valid: true,
},
{
name: "path with trailing slash",
in: `D:\tools\bin\`,
want: filepath.Join("/mnt", "d", "tools", "bin"),
valid: true,
},
{
name: "non windows path",
in: "/home/parth",
valid: false,
},
{
name: "empty",
in: "",
valid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := windowsPathToWSL(tt.in)
if !tt.valid {
if got != "" {
t.Fatalf("windowsPathToWSL(%q) = %q, want empty", tt.in, got)
}
return
}
if got != tt.want {
t.Fatalf("windowsPathToWSL(%q) = %q, want %q", tt.in, got, tt.want)
}
})
}
}
func TestFindKimiBinaryFallbacks(t *testing.T) {
oldGOOS := kimiGOOS
t.Cleanup(func() { kimiGOOS = oldGOOS })
t.Run("linux/ubuntu uv tool path", func(t *testing.T) {
homeDir := t.TempDir()
setTestHome(t, homeDir)
t.Setenv("PATH", t.TempDir())
kimiGOOS = "linux"
target := filepath.Join(homeDir, ".local", "share", "uv", "tools", "kimi-cli", "bin", "kimi")
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
t.Fatalf("failed to create candidate dir: %v", err)
}
if err := os.WriteFile(target, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
t.Fatalf("failed to write kimi candidate: %v", err)
}
got, err := findKimiBinary()
if err != nil {
t.Fatalf("findKimiBinary() error = %v", err)
}
if got != target {
t.Fatalf("findKimiBinary() = %q, want %q", got, target)
}
})
t.Run("windows appdata uv bin", func(t *testing.T) {
setTestHome(t, t.TempDir())
t.Setenv("PATH", t.TempDir())
kimiGOOS = "windows"
appDataDir := t.TempDir()
t.Setenv("APPDATA", appDataDir)
t.Setenv("LOCALAPPDATA", "")
target := filepath.Join(appDataDir, "uv", "bin", "kimi.cmd")
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
t.Fatalf("failed to create candidate dir: %v", err)
}
if err := os.WriteFile(target, []byte("@echo off\r\nexit /b 0\r\n"), 0o755); err != nil {
t.Fatalf("failed to write kimi candidate: %v", err)
}
got, err := findKimiBinary()
if err != nil {
t.Fatalf("findKimiBinary() error = %v", err)
}
if got != target {
t.Fatalf("findKimiBinary() = %q, want %q", got, target)
}
})
}
func TestValidateKimiPassthroughArgs_RejectsConflicts(t *testing.T) {
tests := []struct {
name string
args []string
want string
}{
{name: "--config", args: []string{"--config", "{}"}, want: "--config"},
{name: "--config=", args: []string{"--config={}"}, want: "--config={"},
{name: "--config-file", args: []string{"--config-file", "x.toml"}, want: "--config-file"},
{name: "--config-file=", args: []string{"--config-file=x.toml"}, want: "--config-file=x.toml"},
{name: "--model", args: []string{"--model", "foo"}, want: "--model"},
{name: "--model=", args: []string{"--model=foo"}, want: "--model=foo"},
{name: "-m", args: []string{"-m", "foo"}, want: "-m"},
{name: "-m=", args: []string{"-m=foo"}, want: "-m=foo"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateKimiPassthroughArgs(tt.args)
if err == nil {
t.Fatalf("expected error for args %v", tt.args)
}
if !strings.Contains(err.Error(), tt.want) {
t.Fatalf("error %q does not contain %q", err.Error(), tt.want)
}
})
}
}
func TestBuildKimiInlineConfig(t *testing.T) {
t.Setenv("OLLAMA_HOST", "http://127.0.0.1:11434")
cfg, err := buildKimiInlineConfig("llama3.2", 65536)
if err != nil {
t.Fatalf("buildKimiInlineConfig() error = %v", err)
}
var parsed map[string]any
if err := json.Unmarshal([]byte(cfg), &parsed); err != nil {
t.Fatalf("config is not valid JSON: %v", err)
}
if parsed["default_model"] != "ollama" {
t.Fatalf("default_model = %v, want ollama", parsed["default_model"])
}
providers, ok := parsed["providers"].(map[string]any)
if !ok {
t.Fatalf("providers missing or wrong type: %T", parsed["providers"])
}
ollamaProvider, ok := providers["ollama"].(map[string]any)
if !ok {
t.Fatalf("providers.ollama missing or wrong type: %T", providers["ollama"])
}
if ollamaProvider["type"] != "openai_legacy" {
t.Fatalf("provider type = %v, want openai_legacy", ollamaProvider["type"])
}
if ollamaProvider["base_url"] != "http://127.0.0.1:11434/v1" {
t.Fatalf("provider base_url = %v, want http://127.0.0.1:11434/v1", ollamaProvider["base_url"])
}
if ollamaProvider["api_key"] != "ollama" {
t.Fatalf("provider api_key = %v, want ollama", ollamaProvider["api_key"])
}
models, ok := parsed["models"].(map[string]any)
if !ok {
t.Fatalf("models missing or wrong type: %T", parsed["models"])
}
ollamaModel, ok := models["ollama"].(map[string]any)
if !ok {
t.Fatalf("models.ollama missing or wrong type: %T", models["ollama"])
}
if ollamaModel["provider"] != "ollama" {
t.Fatalf("model provider = %v, want ollama", ollamaModel["provider"])
}
if ollamaModel["model"] != "llama3.2" {
t.Fatalf("model model = %v, want llama3.2", ollamaModel["model"])
}
if ollamaModel["max_context_size"] != float64(65536) {
t.Fatalf("model max_context_size = %v, want 65536", ollamaModel["max_context_size"])
}
}
func TestBuildKimiInlineConfig_UsesConnectableHostForUnspecifiedBind(t *testing.T) {
t.Setenv("OLLAMA_HOST", "http://0.0.0.0:11434")
cfg, err := buildKimiInlineConfig("llama3.2", 65536)
if err != nil {
t.Fatalf("buildKimiInlineConfig() error = %v", err)
}
var parsed map[string]any
if err := json.Unmarshal([]byte(cfg), &parsed); err != nil {
t.Fatalf("config is not valid JSON: %v", err)
}
providers, ok := parsed["providers"].(map[string]any)
if !ok {
t.Fatalf("providers missing or wrong type: %T", parsed["providers"])
}
ollamaProvider, ok := providers["ollama"].(map[string]any)
if !ok {
t.Fatalf("providers.ollama missing or wrong type: %T", providers["ollama"])
}
if got, _ := ollamaProvider["base_url"].(string); got != "http://127.0.0.1:11434/v1" {
t.Fatalf("provider base_url = %q, want %q", got, "http://127.0.0.1:11434/v1")
}
}
func TestResolveKimiMaxContextSize(t *testing.T) {
t.Run("uses cloud limit when known", func(t *testing.T) {
got := resolveKimiMaxContextSize("kimi-k2.5:cloud")
if got != 262_144 {
t.Fatalf("resolveKimiMaxContextSize() = %d, want 262144", got)
}
})
t.Run("uses model show context length for local models", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/show" {
http.NotFound(w, r)
return
}
fmt.Fprint(w, `{"model_info":{"llama.context_length":131072}}`)
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
got := resolveKimiMaxContextSize("llama3.2")
if got != 131_072 {
t.Fatalf("resolveKimiMaxContextSize() = %d, want 131072", got)
}
})
t.Run("falls back to default when show fails", func(t *testing.T) {
srv := httptest.NewServer(http.NotFoundHandler())
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
oldTimeout := kimiModelShowTimeout
kimiModelShowTimeout = 100 * 1000 * 1000 // 100ms
t.Cleanup(func() { kimiModelShowTimeout = oldTimeout })
got := resolveKimiMaxContextSize("llama3.2")
if got != kimiDefaultMaxContextSize {
t.Fatalf("resolveKimiMaxContextSize() = %d, want %d", got, kimiDefaultMaxContextSize)
}
})
}
func TestKimiRun_RejectsConflictingArgsBeforeInstall(t *testing.T) {
k := &Kimi{}
oldConfirm := DefaultConfirmPrompt
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
t.Fatalf("did not expect install prompt, got %q", prompt)
return false, nil
}
t.Cleanup(func() { DefaultConfirmPrompt = oldConfirm })
err := k.Run("llama3.2", []string{"--model", "other"})
if err == nil || !strings.Contains(err.Error(), "--model") {
t.Fatalf("expected conflict error mentioning --model, got %v", err)
}
}
func TestKimiRun_PassesInlineConfigAndExtraArgs(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binary")
}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
logPath := filepath.Join(tmpDir, "kimi-args.log")
script := fmt.Sprintf(`#!/bin/sh
for arg in "$@"; do
printf "%%s\n" "$arg" >> %q
done
exit 0
`, logPath)
if err := os.WriteFile(filepath.Join(tmpDir, "kimi"), []byte(script), 0o755); err != nil {
t.Fatalf("failed to write fake kimi: %v", err)
}
t.Setenv("PATH", tmpDir)
srv := httptest.NewServer(http.NotFoundHandler())
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
k := &Kimi{}
if err := k.Run("llama3.2", []string{"--quiet", "--print"}); err != nil {
t.Fatalf("Run() error = %v", err)
}
data, err := os.ReadFile(logPath)
if err != nil {
t.Fatalf("failed to read args log: %v", err)
}
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
if len(lines) < 4 {
t.Fatalf("expected at least 4 args, got %v", lines)
}
if lines[0] != "--config" {
t.Fatalf("first arg = %q, want --config", lines[0])
}
var cfg map[string]any
if err := json.Unmarshal([]byte(lines[1]), &cfg); err != nil {
t.Fatalf("config arg is not valid JSON: %v", err)
}
providers := cfg["providers"].(map[string]any)
ollamaProvider := providers["ollama"].(map[string]any)
if ollamaProvider["type"] != "openai_legacy" {
t.Fatalf("provider type = %v, want openai_legacy", ollamaProvider["type"])
}
if lines[2] != "--quiet" || lines[3] != "--print" {
t.Fatalf("extra args = %v, want [--quiet --print]", lines[2:])
}
}
func TestEnsureKimiInstalled(t *testing.T) {
oldGOOS := kimiGOOS
t.Cleanup(func() { kimiGOOS = oldGOOS })
withConfirm := func(t *testing.T, fn func(prompt string) (bool, error)) {
t.Helper()
oldConfirm := DefaultConfirmPrompt
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
return fn(prompt)
}
t.Cleanup(func() { DefaultConfirmPrompt = oldConfirm })
}
t.Run("already installed", func(t *testing.T) {
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
writeFakeBinary(t, tmpDir, "kimi")
kimiGOOS = runtime.GOOS
withConfirm(t, func(prompt string) (bool, error) {
t.Fatalf("did not expect prompt, got %q", prompt)
return false, nil
})
bin, err := ensureKimiInstalled()
if err != nil {
t.Fatalf("ensureKimiInstalled() error = %v", err)
}
assertKimiBinPath(t, bin)
})
t.Run("missing dependencies", func(t *testing.T) {
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
kimiGOOS = "linux"
withConfirm(t, func(prompt string) (bool, error) {
t.Fatalf("did not expect prompt, got %q", prompt)
return false, nil
})
_, err := ensureKimiInstalled()
if err == nil || !strings.Contains(err.Error(), "required dependencies are missing") {
t.Fatalf("expected missing dependency error, got %v", err)
}
})
t.Run("missing and user declines install", func(t *testing.T) {
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
writeFakeBinary(t, tmpDir, "curl")
writeFakeBinary(t, tmpDir, "bash")
kimiGOOS = "linux"
withConfirm(t, func(prompt string) (bool, error) {
if !strings.Contains(prompt, "Kimi is not installed.") {
t.Fatalf("unexpected prompt: %q", prompt)
}
return false, nil
})
_, err := ensureKimiInstalled()
if err == nil || !strings.Contains(err.Error(), "installation cancelled") {
t.Fatalf("expected cancellation error, got %v", err)
}
})
t.Run("missing and user confirms install succeeds", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binaries")
}
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
kimiGOOS = "linux"
writeFakeBinary(t, tmpDir, "curl")
installLog := filepath.Join(tmpDir, "bash.log")
kimiPath := filepath.Join(tmpDir, "kimi")
bashScript := fmt.Sprintf(`#!/bin/sh
echo "$@" >> %q
if [ "$1" = "-c" ]; then
/bin/cat > %q <<'EOS'
#!/bin/sh
exit 0
EOS
/bin/chmod +x %q
fi
exit 0
`, installLog, kimiPath, kimiPath)
if err := os.WriteFile(filepath.Join(tmpDir, "bash"), []byte(bashScript), 0o755); err != nil {
t.Fatalf("failed to write fake bash: %v", err)
}
withConfirm(t, func(prompt string) (bool, error) {
return true, nil
})
bin, err := ensureKimiInstalled()
if err != nil {
t.Fatalf("ensureKimiInstalled() error = %v", err)
}
assertKimiBinPath(t, bin)
logData, err := os.ReadFile(installLog)
if err != nil {
t.Fatalf("failed to read install log: %v", err)
}
if !strings.Contains(string(logData), "https://code.kimi.com/install.sh") {
t.Fatalf("expected install.sh command in log, got:\n%s", string(logData))
}
})
t.Run("install succeeds and kimi is in home local bin without PATH update", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binaries")
}
homeDir := t.TempDir()
setTestHome(t, homeDir)
tmpBin := t.TempDir()
t.Setenv("PATH", tmpBin)
kimiGOOS = "linux"
writeFakeBinary(t, tmpBin, "curl")
installedKimi := filepath.Join(homeDir, ".local", "bin", "kimi")
bashScript := fmt.Sprintf(`#!/bin/sh
if [ "$1" = "-c" ]; then
/bin/mkdir -p %q
/bin/cat > %q <<'EOS'
#!/bin/sh
exit 0
EOS
/bin/chmod +x %q
fi
exit 0
`, filepath.Dir(installedKimi), installedKimi, installedKimi)
if err := os.WriteFile(filepath.Join(tmpBin, "bash"), []byte(bashScript), 0o755); err != nil {
t.Fatalf("failed to write fake bash: %v", err)
}
withConfirm(t, func(prompt string) (bool, error) {
return true, nil
})
bin, err := ensureKimiInstalled()
if err != nil {
t.Fatalf("ensureKimiInstalled() error = %v", err)
}
if bin != installedKimi {
t.Fatalf("bin = %q, want %q", bin, installedKimi)
}
})
t.Run("install command fails", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binaries")
}
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
kimiGOOS = "linux"
writeFakeBinary(t, tmpDir, "curl")
if err := os.WriteFile(filepath.Join(tmpDir, "bash"), []byte("#!/bin/sh\nexit 1\n"), 0o755); err != nil {
t.Fatalf("failed to write fake bash: %v", err)
}
withConfirm(t, func(prompt string) (bool, error) {
return true, nil
})
_, err := ensureKimiInstalled()
if err == nil || !strings.Contains(err.Error(), "failed to install kimi") {
t.Fatalf("expected install failure error, got %v", err)
}
})
t.Run("install succeeds but binary missing on PATH", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binaries")
}
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
kimiGOOS = "linux"
writeFakeBinary(t, tmpDir, "curl")
if err := os.WriteFile(filepath.Join(tmpDir, "bash"), []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
t.Fatalf("failed to write fake bash: %v", err)
}
withConfirm(t, func(prompt string) (bool, error) {
return true, nil
})
_, err := ensureKimiInstalled()
if err == nil || !strings.Contains(err.Error(), "binary was not found on PATH") {
t.Fatalf("expected PATH guidance error, got %v", err)
}
})
}
func TestKimiInstallerCommand(t *testing.T) {
tests := []struct {
name string
goos string
wantBin string
wantParts []string
wantErr bool
}{
{
name: "linux",
goos: "linux",
wantBin: "bash",
wantParts: []string{"-c", "install.sh"},
},
{
name: "darwin",
goos: "darwin",
wantBin: "bash",
wantParts: []string{"-c", "install.sh"},
},
{
name: "windows",
goos: "windows",
wantBin: "powershell",
wantParts: []string{"-Command", "install.ps1"},
},
{
name: "unsupported",
goos: "freebsd",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bin, args, err := kimiInstallerCommand(tt.goos)
if tt.wantErr {
if err == nil {
t.Fatal("expected error")
}
return
}
if err != nil {
t.Fatalf("kimiInstallerCommand() error = %v", err)
}
if bin != tt.wantBin {
t.Fatalf("bin = %q, want %q", bin, tt.wantBin)
}
joined := strings.Join(args, " ")
for _, part := range tt.wantParts {
if !strings.Contains(joined, part) {
t.Fatalf("args %q missing %q", joined, part)
}
}
})
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,508 +0,0 @@
package launch
import (
"context"
"errors"
"fmt"
"net/http"
"os"
"os/exec"
"runtime"
"slices"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/cmd/internal/fileutil"
internalcloud "github.com/ollama/ollama/internal/cloud"
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/progress"
)
var recommendedModels = []ModelItem{
{Name: "kimi-k2.6:cloud", Description: "State-of-the-art coding, long-horizon execution, and multimodal agent swarm capability", Recommended: true},
{Name: "qwen3.5:cloud", Description: "Reasoning, coding, and agentic tool use with vision", Recommended: true},
{Name: "glm-5.1:cloud", Description: "Reasoning and code generation", Recommended: true},
{Name: "minimax-m2.7:cloud", Description: "Fast, efficient coding and real-world productivity", Recommended: true},
{Name: "gemma4", Description: "Reasoning and code generation locally", Recommended: true},
{Name: "qwen3.5", Description: "Reasoning, coding, and visual understanding locally", Recommended: true},
}
var recommendedVRAM = map[string]string{
"gemma4": "~16GB",
"qwen3.5": "~11GB",
}
// cloudModelLimit holds context and output token limits for a cloud model.
type cloudModelLimit struct {
Context int
Output int
}
// cloudModelLimits maps cloud model base names to their token limits.
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
var cloudModelLimits = map[string]cloudModelLimit{
"minimax-m2.7": {Context: 204_800, Output: 128_000},
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
"gemma4:31b": {Context: 262_144, Output: 131_072},
"glm-4.6": {Context: 202_752, Output: 131_072},
"glm-4.7": {Context: 202_752, Output: 131_072},
"glm-5": {Context: 202_752, Output: 131_072},
"glm-5.1": {Context: 202_752, Output: 131_072},
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
"kimi-k2.5": {Context: 262_144, Output: 262_144},
"kimi-k2.6": {Context: 262_144, Output: 262_144},
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
"qwen3.5": {Context: 262_144, Output: 32_768},
}
// lookupCloudModelLimit returns the token limits for a cloud model.
// It normalizes explicit cloud source suffixes before checking the shared limit map.
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
base, stripped := modelref.StripCloudSourceTag(name)
if stripped {
if l, ok := cloudModelLimits[base]; ok {
return l, true
}
}
return cloudModelLimit{}, false
}
// missingModelPolicy controls how model-not-found errors should be handled.
type missingModelPolicy int
const (
// missingModelPromptPull prompts the user to download missing local models.
missingModelPromptPull missingModelPolicy = iota
// missingModelAutoPull downloads missing local models without prompting.
missingModelAutoPull
// missingModelFail returns an error for missing local models without prompting.
missingModelFail
)
// OpenBrowser opens the URL in the user's browser.
func OpenBrowser(url string) {
switch runtime.GOOS {
case "darwin":
_ = exec.Command("open", url).Start()
case "linux":
// Skip on headless systems where no display server is available
if os.Getenv("DISPLAY") == "" && os.Getenv("WAYLAND_DISPLAY") == "" {
return
}
_ = exec.Command("xdg-open", url).Start()
case "windows":
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
}
}
// ensureAuth ensures the user is signed in before cloud-backed models run.
func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error {
var selectedCloudModels []string
for _, m := range selected {
if cloudModels[m] {
selectedCloudModels = append(selectedCloudModels, m)
}
}
if len(selectedCloudModels) == 0 {
return nil
}
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
}
user, err := client.Whoami(ctx)
if err == nil && user != nil && user.Name != "" {
return nil
}
var aErr api.AuthorizationError
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
return err
}
modelList := strings.Join(selectedCloudModels, ", ")
if DefaultSignIn != nil {
_, err := DefaultSignIn(modelList, aErr.SigninURL)
if errors.Is(err, ErrCancelled) {
return ErrCancelled
}
if err != nil {
return fmt.Errorf("%s requires sign in", modelList)
}
return nil
}
yes, err := ConfirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
if errors.Is(err, ErrCancelled) {
return ErrCancelled
}
if err != nil {
return err
}
if !yes {
return ErrCancelled
}
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
OpenBrowser(aErr.SigninURL)
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\r\033[K")
return ctx.Err()
case <-ticker.C:
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
if frame%10 == 0 {
u, err := client.Whoami(ctx)
if err == nil && u != nil && u.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
return nil
}
}
}
}
}
// showOrPullWithPolicy checks if a model exists and applies the provided missing-model policy.
func showOrPullWithPolicy(ctx context.Context, client *api.Client, model string, policy missingModelPolicy, isCloudModel bool) error {
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
return nil
} else {
var statusErr api.StatusError
if !errors.As(err, &statusErr) || statusErr.StatusCode != http.StatusNotFound {
return err
}
}
if isCloudModel {
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
}
return fmt.Errorf("model %q not found", model)
}
switch policy {
case missingModelAutoPull:
return pullMissingModel(ctx, client, model)
case missingModelFail:
return fmt.Errorf("model %q not found; run 'ollama pull %s' first, or use --yes to auto-pull", model, model)
default:
return confirmAndPull(ctx, client, model)
}
}
func confirmAndPull(ctx context.Context, client *api.Client, model string) error {
if ok, err := ConfirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
return err
} else if !ok {
return errCancelled
}
fmt.Fprintf(os.Stderr, "\n")
return pullMissingModel(ctx, client, model)
}
func pullMissingModel(ctx context.Context, client *api.Client, model string) error {
if err := pullModel(ctx, client, model, false); err != nil {
return fmt.Errorf("failed to pull %s: %w", model, err)
}
return nil
}
// prepareEditorIntegration persists models and applies editor-managed config files.
func prepareEditorIntegration(name string, runner Runner, editor Editor, models []string) error {
if ok, err := confirmConfigEdit(runner, editor.Paths()); err != nil {
return err
} else if !ok {
return errCancelled
}
if err := editor.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
if err := config.SaveIntegration(name, models); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
return nil
}
func prepareManagedSingleIntegration(name string, runner Runner, managed ManagedSingleModel, model string) error {
if ok, err := confirmConfigEdit(runner, managed.Paths()); err != nil {
return err
} else if !ok {
return errCancelled
}
if err := managed.Configure(model); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
if err := config.SaveIntegration(name, []string{model}); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
return nil
}
func confirmConfigEdit(runner Runner, paths []string) (bool, error) {
if len(paths) == 0 {
return true, nil
}
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", runner)
for _, path := range paths {
fmt.Fprintf(os.Stderr, " %s\n", path)
}
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", fileutil.BackupDir())
return ConfirmPrompt("Proceed?")
}
// buildModelList merges existing models with recommendations for selection UIs.
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
existingModels = make(map[string]bool)
cloudModels = make(map[string]bool)
recommended := make(map[string]bool)
var hasLocalModel, hasCloudModel bool
recDesc := make(map[string]string)
for _, rec := range recommendedModels {
recommended[rec.Name] = true
recDesc[rec.Name] = rec.Description
}
for _, m := range existing {
existingModels[m.Name] = true
if m.Remote {
cloudModels[m.Name] = true
hasCloudModel = true
} else {
hasLocalModel = true
}
displayName := strings.TrimSuffix(m.Name, ":latest")
existingModels[displayName] = true
item := ModelItem{Name: displayName, Recommended: recommended[displayName], Description: recDesc[displayName]}
items = append(items, item)
}
for _, rec := range recommendedModels {
if existingModels[rec.Name] || existingModels[rec.Name+":latest"] {
continue
}
items = append(items, rec)
if isCloudModelName(rec.Name) {
cloudModels[rec.Name] = true
}
}
checked := make(map[string]bool, len(preChecked))
for _, n := range preChecked {
checked[n] = true
}
if current != "" {
matchedCurrent := false
for _, item := range items {
if item.Name == current {
current = item.Name
matchedCurrent = true
break
}
}
if !matchedCurrent {
for _, item := range items {
if strings.HasPrefix(item.Name, current+":") {
current = item.Name
break
}
}
}
}
if checked[current] {
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
}
notInstalled := make(map[string]bool)
for i := range items {
if !existingModels[items[i].Name] && !cloudModels[items[i].Name] {
notInstalled[items[i].Name] = true
var parts []string
if items[i].Description != "" {
parts = append(parts, items[i].Description)
}
if vram := recommendedVRAM[items[i].Name]; vram != "" {
parts = append(parts, vram)
}
parts = append(parts, "(not downloaded)")
items[i].Description = strings.Join(parts, ", ")
}
}
recRank := make(map[string]int)
for i, rec := range recommendedModels {
recRank[rec.Name] = i + 1
}
if hasLocalModel || hasCloudModel {
// Keep the Recommended section pinned to recommendedModels order. Checked
// and default-model priority only apply within the More section.
slices.SortStableFunc(items, func(a, b ModelItem) int {
ac, bc := checked[a.Name], checked[b.Name]
aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
aRec, bRec := recRank[a.Name] > 0, recRank[b.Name] > 0
if aRec != bRec {
if aRec {
return -1
}
return 1
}
if aRec && bRec {
return recRank[a.Name] - recRank[b.Name]
}
if ac != bc {
if ac {
return -1
}
return 1
}
// Among checked non-recommended items - put the default first
if ac && !aRec && current != "" {
aCurrent := a.Name == current
bCurrent := b.Name == current
if aCurrent != bCurrent {
if aCurrent {
return -1
}
return 1
}
}
if aNew != bNew {
if aNew {
return 1
}
return -1
}
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
})
}
return items, preChecked, existingModels, cloudModels
}
// isCloudModelName reports whether the model name has an explicit cloud source.
func isCloudModelName(name string) bool {
return modelref.HasExplicitCloudSource(name)
}
// filterCloudModels drops remote-only models from the given inventory.
func filterCloudModels(existing []modelInfo) []modelInfo {
filtered := existing[:0]
for _, m := range existing {
if !m.Remote {
filtered = append(filtered, m)
}
}
return filtered
}
// filterCloudItems removes cloud models from selection items.
func filterCloudItems(items []ModelItem) []ModelItem {
filtered := items[:0]
for _, item := range items {
if !isCloudModelName(item.Name) {
filtered = append(filtered, item)
}
}
return filtered
}
func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
if client == nil {
return false
}
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
if err != nil {
return false
}
return resp.RemoteModel != ""
}
// cloudStatusDisabled returns whether cloud usage is currently disabled.
func cloudStatusDisabled(ctx context.Context, client *api.Client) (disabled bool, known bool) {
status, err := client.CloudStatusExperimental(ctx)
if err != nil {
var statusErr api.StatusError
if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
return false, false
}
return false, false
}
return status.Cloud.Disabled, true
}
// TODO(parthsareen): this duplicates the pull progress UI in cmd.PullHandler.
// Move the shared pull rendering to a small utility once the package boundary settles.
func pullModel(ctx context.Context, client *api.Client, model string, insecure bool) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()
bars := make(map[string]*progress.Bar)
var status string
var spinner *progress.Spinner
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
if resp.Completed == 0 {
return nil
}
if spinner != nil {
spinner.Stop()
}
bar, ok := bars[resp.Digest]
if !ok {
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
name = strings.TrimSpace(name)
if isDigest {
name = name[:min(12, len(name))]
}
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
bar.Set(resp.Completed)
} else if status != resp.Status {
if spinner != nil {
spinner.Stop()
}
status = resp.Status
spinner = progress.NewSpinner(status)
p.Add(status, spinner)
}
return nil
}
request := api.PullRequest{Name: model, Insecure: insecure}
return client.Pull(ctx, &request, fn)
}

View File

@@ -1,248 +0,0 @@
package launch
import (
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"slices"
"github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig"
)
// OpenCode implements Runner and Editor for OpenCode integration.
// Config is passed via OPENCODE_CONFIG_CONTENT env var at launch time
// instead of writing to opencode's config files.
type OpenCode struct {
configContent string // JSON config built by Edit, passed to Run via env var
}
func (o *OpenCode) String() string { return "OpenCode" }
// findOpenCode returns the opencode binary path, checking PATH first then the
// curl installer location (~/.opencode/bin) which may not be on PATH yet.
func findOpenCode() (string, bool) {
if p, err := exec.LookPath("opencode"); err == nil {
return p, true
}
home, err := os.UserHomeDir()
if err != nil {
return "", false
}
name := "opencode"
if runtime.GOOS == "windows" {
name = "opencode.exe"
}
fallback := filepath.Join(home, ".opencode", "bin", name)
if _, err := os.Stat(fallback); err == nil {
return fallback, true
}
return "", false
}
func (o *OpenCode) Run(model string, args []string) error {
opencodePath, ok := findOpenCode()
if !ok {
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
}
cmd := exec.Command(opencodePath, args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = os.Environ()
if content := o.resolveContent(model); content != "" {
cmd.Env = append(cmd.Env, "OPENCODE_CONFIG_CONTENT="+content)
}
return cmd.Run()
}
// resolveContent returns the inline config to send via OPENCODE_CONFIG_CONTENT.
// Returns content built by Edit if available, otherwise builds from model.json
// with the requested model as primary (e.g. re-launch with saved config).
func (o *OpenCode) resolveContent(model string) string {
if o.configContent != "" {
return o.configContent
}
models := readModelJSONModels()
if !slices.Contains(models, model) {
models = append([]string{model}, models...)
}
content, err := buildInlineConfig(model, models)
if err != nil {
return ""
}
return content
}
func (o *OpenCode) Paths() []string {
sp, err := openCodeStatePath()
if err != nil {
return nil
}
if _, err := os.Stat(sp); err == nil {
return []string{sp}
}
return nil
}
// openCodeStatePath returns the path to opencode's model state file.
// TODO: this hardcodes the Linux/macOS XDG path. On Windows, opencode stores
// state under %LOCALAPPDATA% (or similar) — verify and branch on runtime.GOOS.
func openCodeStatePath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".local", "state", "opencode", "model.json"), nil
}
func (o *OpenCode) Edit(modelList []string) error {
if len(modelList) == 0 {
return nil
}
content, err := buildInlineConfig(modelList[0], modelList)
if err != nil {
return err
}
o.configContent = content
// Write model state file so models appear in OpenCode's model picker
statePath, err := openCodeStatePath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(statePath), 0o755); err != nil {
return err
}
state := map[string]any{
"recent": []any{},
"favorite": []any{},
"variant": map[string]any{},
}
if data, err := os.ReadFile(statePath); err == nil {
_ = json.Unmarshal(data, &state) // Ignore parse errors; use defaults
}
recent, _ := state["recent"].([]any)
modelSet := make(map[string]bool)
for _, m := range modelList {
modelSet[m] = true
}
// Filter out existing Ollama models we're about to re-add
newRecent := slices.DeleteFunc(slices.Clone(recent), func(entry any) bool {
e, ok := entry.(map[string]any)
if !ok || e["providerID"] != "ollama" {
return false
}
modelID, _ := e["modelID"].(string)
return modelSet[modelID]
})
// Prepend models in reverse order so first model ends up first
for _, model := range slices.Backward(modelList) {
newRecent = slices.Insert(newRecent, 0, any(map[string]any{
"providerID": "ollama",
"modelID": model,
}))
}
const maxRecentModels = 10
newRecent = newRecent[:min(len(newRecent), maxRecentModels)]
state["recent"] = newRecent
stateData, err := json.MarshalIndent(state, "", " ")
if err != nil {
return err
}
return fileutil.WriteWithBackup(statePath, stateData)
}
func (o *OpenCode) Models() []string {
return nil
}
// buildInlineConfig produces the JSON string for OPENCODE_CONFIG_CONTENT.
// primary is the model to launch with, models is the full list of available models.
func buildInlineConfig(primary string, models []string) (string, error) {
if primary == "" || len(models) == 0 {
return "", fmt.Errorf("buildInlineConfig: primary and models are required")
}
config := map[string]any{
"$schema": "https://opencode.ai/config.json",
"provider": map[string]any{
"ollama": map[string]any{
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama",
"options": map[string]any{
"baseURL": envconfig.Host().String() + "/v1",
},
"models": buildModelEntries(models),
},
},
"model": "ollama/" + primary,
}
data, err := json.Marshal(config)
if err != nil {
return "", err
}
return string(data), nil
}
// readModelJSONModels reads ollama model IDs from the opencode model.json state file
func readModelJSONModels() []string {
statePath, err := openCodeStatePath()
if err != nil {
return nil
}
data, err := os.ReadFile(statePath)
if err != nil {
return nil
}
var state map[string]any
if err := json.Unmarshal(data, &state); err != nil {
return nil
}
recent, _ := state["recent"].([]any)
var models []string
for _, entry := range recent {
e, ok := entry.(map[string]any)
if !ok {
continue
}
if e["providerID"] != "ollama" {
continue
}
if id, ok := e["modelID"].(string); ok && id != "" {
models = append(models, id)
}
}
return models
}
func buildModelEntries(modelList []string) map[string]any {
models := make(map[string]any)
for _, model := range modelList {
entry := map[string]any{
"name": model,
}
if isCloudModelName(model) {
if l, ok := lookupCloudModelLimit(model); ok {
entry["limit"] = map[string]any{
"context": l.Context,
"output": l.Output,
}
}
}
models[model] = entry
}
return models
}

View File

@@ -1,778 +0,0 @@
package launch
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"runtime"
"testing"
)
func TestOpenCodeIntegration(t *testing.T) {
o := &OpenCode{}
t.Run("String", func(t *testing.T) {
if got := o.String(); got != "OpenCode" {
t.Errorf("String() = %q, want %q", got, "OpenCode")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = o
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = o
})
}
func TestOpenCodeEdit(t *testing.T) {
t.Run("builds config content with provider", func(t *testing.T) {
setTestHome(t, t.TempDir())
o := &OpenCode{}
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
var cfg map[string]any
if err := json.Unmarshal([]byte(o.configContent), &cfg); err != nil {
t.Fatalf("configContent is not valid JSON: %v", err)
}
// Verify provider structure
provider, _ := cfg["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
if ollama["name"] != "Ollama" {
t.Errorf("provider name = %v, want Ollama", ollama["name"])
}
if ollama["npm"] != "@ai-sdk/openai-compatible" {
t.Errorf("npm = %v, want @ai-sdk/openai-compatible", ollama["npm"])
}
// Verify model exists
models, _ := ollama["models"].(map[string]any)
if models["llama3.2"] == nil {
t.Error("model llama3.2 not found in config content")
}
// Verify default model
if cfg["model"] != "ollama/llama3.2" {
t.Errorf("model = %v, want ollama/llama3.2", cfg["model"])
}
})
t.Run("multiple models", func(t *testing.T) {
setTestHome(t, t.TempDir())
o := &OpenCode{}
if err := o.Edit([]string{"llama3.2", "qwen3:32b"}); err != nil {
t.Fatal(err)
}
var cfg map[string]any
json.Unmarshal([]byte(o.configContent), &cfg)
provider, _ := cfg["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
models, _ := ollama["models"].(map[string]any)
if models["llama3.2"] == nil {
t.Error("model llama3.2 not found")
}
if models["qwen3:32b"] == nil {
t.Error("model qwen3:32b not found")
}
// First model should be the default
if cfg["model"] != "ollama/llama3.2" {
t.Errorf("default model = %v, want ollama/llama3.2", cfg["model"])
}
})
t.Run("empty models is no-op", func(t *testing.T) {
setTestHome(t, t.TempDir())
o := &OpenCode{}
if err := o.Edit([]string{}); err != nil {
t.Fatal(err)
}
if o.configContent != "" {
t.Errorf("expected empty configContent for no models, got %s", o.configContent)
}
})
t.Run("does not write config files", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
o := &OpenCode{}
o.Edit([]string{"llama3.2"})
configDir := filepath.Join(tmpDir, ".config", "opencode")
if _, err := os.Stat(filepath.Join(configDir, "opencode.json")); !os.IsNotExist(err) {
t.Error("opencode.json should not be created")
}
if _, err := os.Stat(filepath.Join(configDir, "opencode.jsonc")); !os.IsNotExist(err) {
t.Error("opencode.jsonc should not be created")
}
})
t.Run("cloud model has limits", func(t *testing.T) {
setTestHome(t, t.TempDir())
o := &OpenCode{}
if err := o.Edit([]string{"glm-4.7:cloud"}); err != nil {
t.Fatal(err)
}
var cfg map[string]any
json.Unmarshal([]byte(o.configContent), &cfg)
provider, _ := cfg["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
models, _ := ollama["models"].(map[string]any)
entry, _ := models["glm-4.7:cloud"].(map[string]any)
limit, ok := entry["limit"].(map[string]any)
if !ok {
t.Fatal("cloud model should have limit set")
}
expected := cloudModelLimits["glm-4.7"]
if limit["context"] != float64(expected.Context) {
t.Errorf("context = %v, want %d", limit["context"], expected.Context)
}
if limit["output"] != float64(expected.Output) {
t.Errorf("output = %v, want %d", limit["output"], expected.Output)
}
})
t.Run("local model has no limits", func(t *testing.T) {
setTestHome(t, t.TempDir())
o := &OpenCode{}
o.Edit([]string{"llama3.2"})
var cfg map[string]any
json.Unmarshal([]byte(o.configContent), &cfg)
provider, _ := cfg["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
models, _ := ollama["models"].(map[string]any)
entry, _ := models["llama3.2"].(map[string]any)
if entry["limit"] != nil {
t.Errorf("local model should not have limit, got %v", entry["limit"])
}
})
}
func TestOpenCodeModels_ReturnsNil(t *testing.T) {
o := &OpenCode{}
if models := o.Models(); models != nil {
t.Errorf("Models() = %v, want nil", models)
}
}
func TestOpenCodePaths(t *testing.T) {
t.Run("returns nil when model.json does not exist", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
o := &OpenCode{}
if paths := o.Paths(); paths != nil {
t.Errorf("Paths() = %v, want nil", paths)
}
})
t.Run("returns model.json path when it exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644)
o := &OpenCode{}
paths := o.Paths()
if len(paths) != 1 {
t.Fatalf("Paths() returned %d paths, want 1", len(paths))
}
if paths[0] != filepath.Join(stateDir, "model.json") {
t.Errorf("Paths() = %v, want %v", paths[0], filepath.Join(stateDir, "model.json"))
}
})
}
func TestLookupCloudModelLimit(t *testing.T) {
tests := []struct {
name string
wantOK bool
wantContext int
wantOutput int
}{
{"glm-4.7", false, 0, 0},
{"glm-4.7:cloud", true, 202_752, 131_072},
{"glm-5:cloud", true, 202_752, 131_072},
{"glm-5.1:cloud", true, 202_752, 131_072},
{"gemma4:31b-cloud", true, 262_144, 131_072},
{"gpt-oss:120b-cloud", true, 131_072, 131_072},
{"gpt-oss:20b-cloud", true, 131_072, 131_072},
{"kimi-k2.5", false, 0, 0},
{"kimi-k2.5:cloud", true, 262_144, 262_144},
{"deepseek-v3.2", false, 0, 0},
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
{"qwen3.5", false, 0, 0},
{"qwen3.5:cloud", true, 262_144, 32_768},
{"qwen3-coder:480b", false, 0, 0},
{"qwen3-coder:480b:cloud", true, 262_144, 65_536},
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
{"llama3.2", false, 0, 0},
{"unknown-model:cloud", false, 0, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l, ok := lookupCloudModelLimit(tt.name)
if ok != tt.wantOK {
t.Errorf("lookupCloudModelLimit(%q) ok = %v, want %v", tt.name, ok, tt.wantOK)
}
if ok {
if l.Context != tt.wantContext {
t.Errorf("context = %d, want %d", l.Context, tt.wantContext)
}
if l.Output != tt.wantOutput {
t.Errorf("output = %d, want %d", l.Output, tt.wantOutput)
}
}
})
}
}
func TestFindOpenCode(t *testing.T) {
t.Run("fallback to ~/.opencode/bin", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Ensure opencode is not on PATH
t.Setenv("PATH", tmpDir)
// Without the fallback binary, findOpenCode should fail
if _, ok := findOpenCode(); ok {
t.Fatal("findOpenCode should fail when binary is not on PATH or in fallback location")
}
// Create a fake binary at the curl install fallback location
binDir := filepath.Join(tmpDir, ".opencode", "bin")
os.MkdirAll(binDir, 0o755)
name := "opencode"
if runtime.GOOS == "windows" {
name = "opencode.exe"
}
fakeBin := filepath.Join(binDir, name)
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
// Now findOpenCode should succeed via fallback
path, ok := findOpenCode()
if !ok {
t.Fatal("findOpenCode should succeed with fallback binary")
}
if path != fakeBin {
t.Errorf("findOpenCode = %q, want %q", path, fakeBin)
}
})
}
// Verify that the BackfillsCloudModelLimitOnExistingEntry test from the old
// file-based approach is covered by the new inline config approach.
func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
expected := cloudModelLimits["glm-4.7"]
if err := o.Edit([]string{"glm-4.7:cloud"}); err != nil {
t.Fatal(err)
}
var cfg map[string]any
json.Unmarshal([]byte(o.configContent), &cfg)
provider, _ := cfg["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
models, _ := ollama["models"].(map[string]any)
entry, _ := models["glm-4.7:cloud"].(map[string]any)
limit, ok := entry["limit"].(map[string]any)
if !ok {
t.Fatal("cloud model limit was not set")
}
if limit["context"] != float64(expected.Context) {
t.Errorf("context = %v, want %d", limit["context"], expected.Context)
}
if limit["output"] != float64(expected.Output) {
t.Errorf("output = %v, want %d", limit["output"], expected.Output)
}
}
func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
specialModel := `model-with-"quotes"`
err := o.Edit([]string{specialModel})
if err != nil {
t.Fatalf("Edit with special chars failed: %v", err)
}
var cfg map[string]any
if err := json.Unmarshal([]byte(o.configContent), &cfg); err != nil {
t.Fatalf("resulting config is invalid JSON: %v", err)
}
provider, _ := cfg["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
models, _ := ollama["models"].(map[string]any)
if models[specialModel] == nil {
t.Errorf("model with special chars not found in config")
}
}
func TestReadModelJSONModels(t *testing.T) {
t.Run("reads ollama models from model.json", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
state := map[string]any{
"recent": []any{
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
map[string]any{"providerID": "ollama", "modelID": "qwen3:32b"},
},
}
data, _ := json.MarshalIndent(state, "", " ")
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
models := readModelJSONModels()
if len(models) != 2 {
t.Fatalf("got %d models, want 2", len(models))
}
if models[0] != "llama3.2" || models[1] != "qwen3:32b" {
t.Errorf("got %v, want [llama3.2 qwen3:32b]", models)
}
})
t.Run("skips non-ollama providers", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
state := map[string]any{
"recent": []any{
map[string]any{"providerID": "openai", "modelID": "gpt-4"},
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
},
}
data, _ := json.MarshalIndent(state, "", " ")
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
models := readModelJSONModels()
if len(models) != 1 || models[0] != "llama3.2" {
t.Errorf("got %v, want [llama3.2]", models)
}
})
t.Run("returns nil when file does not exist", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if models := readModelJSONModels(); models != nil {
t.Errorf("got %v, want nil", models)
}
})
t.Run("returns nil for corrupt JSON", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{corrupt`), 0o644)
if models := readModelJSONModels(); models != nil {
t.Errorf("got %v, want nil", models)
}
})
}
func TestOpenCodeResolveContent(t *testing.T) {
t.Run("returns Edit's content when set", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
o := &OpenCode{}
if err := o.Edit([]string{"gemma4"}); err != nil {
t.Fatal(err)
}
editContent := o.configContent
// Write a different model.json — should be ignored
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
state := map[string]any{
"recent": []any{
map[string]any{"providerID": "ollama", "modelID": "different-model"},
},
}
data, _ := json.MarshalIndent(state, "", " ")
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
got := o.resolveContent("gemma4")
if got != editContent {
t.Errorf("resolveContent returned different content than Edit set\ngot: %s\nwant: %s", got, editContent)
}
})
t.Run("falls back to model.json when Edit was not called", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
state := map[string]any{
"recent": []any{
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
map[string]any{"providerID": "ollama", "modelID": "qwen3:32b"},
},
}
data, _ := json.MarshalIndent(state, "", " ")
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
o := &OpenCode{}
content := o.resolveContent("llama3.2")
if content == "" {
t.Fatal("resolveContent returned empty")
}
var cfg map[string]any
json.Unmarshal([]byte(content), &cfg)
if cfg["model"] != "ollama/llama3.2" {
t.Errorf("primary = %v, want ollama/llama3.2", cfg["model"])
}
provider, _ := cfg["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
cfgModels, _ := ollama["models"].(map[string]any)
if cfgModels["llama3.2"] == nil || cfgModels["qwen3:32b"] == nil {
t.Errorf("expected both models in config, got %v", cfgModels)
}
})
t.Run("uses requested model as primary even when not first in model.json", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
state := map[string]any{
"recent": []any{
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
map[string]any{"providerID": "ollama", "modelID": "qwen3:32b"},
},
}
data, _ := json.MarshalIndent(state, "", " ")
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
o := &OpenCode{}
content := o.resolveContent("qwen3:32b")
var cfg map[string]any
json.Unmarshal([]byte(content), &cfg)
if cfg["model"] != "ollama/qwen3:32b" {
t.Errorf("primary = %v, want ollama/qwen3:32b", cfg["model"])
}
})
t.Run("injects requested model when missing from model.json", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
state := map[string]any{
"recent": []any{
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
},
}
data, _ := json.MarshalIndent(state, "", " ")
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
o := &OpenCode{}
content := o.resolveContent("gemma4")
var cfg map[string]any
json.Unmarshal([]byte(content), &cfg)
provider, _ := cfg["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
cfgModels, _ := ollama["models"].(map[string]any)
if cfgModels["gemma4"] == nil {
t.Error("requested model gemma4 not injected into config")
}
if cfg["model"] != "ollama/gemma4" {
t.Errorf("primary = %v, want ollama/gemma4", cfg["model"])
}
})
t.Run("returns empty when no model.json and no model param", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
o := &OpenCode{}
if got := o.resolveContent(""); got != "" {
t.Errorf("resolveContent(\"\") = %q, want empty", got)
}
})
t.Run("does not mutate configContent on fallback", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
state := map[string]any{
"recent": []any{
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
},
}
data, _ := json.MarshalIndent(state, "", " ")
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
o := &OpenCode{}
_ = o.resolveContent("llama3.2")
if o.configContent != "" {
t.Errorf("resolveContent should not mutate configContent, got %q", o.configContent)
}
})
}
func TestBuildInlineConfig(t *testing.T) {
t.Run("returns error for empty primary", func(t *testing.T) {
if _, err := buildInlineConfig("", []string{"llama3.2"}); err == nil {
t.Error("expected error for empty primary")
}
})
t.Run("returns error for empty models", func(t *testing.T) {
if _, err := buildInlineConfig("llama3.2", nil); err == nil {
t.Error("expected error for empty models")
}
})
t.Run("primary differs from first model in list", func(t *testing.T) {
content, err := buildInlineConfig("qwen3:32b", []string{"llama3.2", "qwen3:32b"})
if err != nil {
t.Fatal(err)
}
var cfg map[string]any
json.Unmarshal([]byte(content), &cfg)
if cfg["model"] != "ollama/qwen3:32b" {
t.Errorf("primary = %v, want ollama/qwen3:32b", cfg["model"])
}
})
}
func TestOpenCodeEdit_PreservesRecentEntries(t *testing.T) {
t.Run("prepends new models to existing recent", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
initial := map[string]any{
"recent": []any{
map[string]any{"providerID": "ollama", "modelID": "old-A"},
map[string]any{"providerID": "ollama", "modelID": "old-B"},
},
}
data, _ := json.MarshalIndent(initial, "", " ")
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
o := &OpenCode{}
if err := o.Edit([]string{"new-X"}); err != nil {
t.Fatal(err)
}
stored, _ := os.ReadFile(filepath.Join(stateDir, "model.json"))
var state map[string]any
json.Unmarshal(stored, &state)
recent, _ := state["recent"].([]any)
if len(recent) != 3 {
t.Fatalf("expected 3 entries, got %d", len(recent))
}
first, _ := recent[0].(map[string]any)
if first["modelID"] != "new-X" {
t.Errorf("first entry = %v, want new-X", first["modelID"])
}
})
t.Run("prepends multiple new models in order", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
initial := map[string]any{
"recent": []any{
map[string]any{"providerID": "ollama", "modelID": "old-A"},
map[string]any{"providerID": "ollama", "modelID": "old-B"},
},
}
data, _ := json.MarshalIndent(initial, "", " ")
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
o := &OpenCode{}
if err := o.Edit([]string{"X", "Y", "Z"}); err != nil {
t.Fatal(err)
}
stored, _ := os.ReadFile(filepath.Join(stateDir, "model.json"))
var state map[string]any
json.Unmarshal(stored, &state)
recent, _ := state["recent"].([]any)
want := []string{"X", "Y", "Z", "old-A", "old-B"}
if len(recent) != len(want) {
t.Fatalf("expected %d entries, got %d", len(want), len(recent))
}
for i, w := range want {
e, _ := recent[i].(map[string]any)
if e["modelID"] != w {
t.Errorf("recent[%d] = %v, want %v", i, e["modelID"], w)
}
}
})
t.Run("preserves non-ollama entries", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
initial := map[string]any{
"recent": []any{
map[string]any{"providerID": "openai", "modelID": "gpt-4"},
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
},
}
data, _ := json.MarshalIndent(initial, "", " ")
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
o := &OpenCode{}
if err := o.Edit([]string{"qwen3:32b"}); err != nil {
t.Fatal(err)
}
stored, _ := os.ReadFile(filepath.Join(stateDir, "model.json"))
var state map[string]any
json.Unmarshal(stored, &state)
recent, _ := state["recent"].([]any)
// Should have: qwen3:32b (new), gpt-4 (preserved openai), llama3.2 (preserved ollama)
var foundOpenAI bool
for _, entry := range recent {
e, _ := entry.(map[string]any)
if e["providerID"] == "openai" && e["modelID"] == "gpt-4" {
foundOpenAI = true
}
}
if !foundOpenAI {
t.Errorf("non-ollama gpt-4 entry was not preserved, got %v", recent)
}
})
t.Run("deduplicates ollama models being re-added", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
initial := map[string]any{
"recent": []any{
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
},
}
data, _ := json.MarshalIndent(initial, "", " ")
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
o := &OpenCode{}
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
stored, _ := os.ReadFile(filepath.Join(stateDir, "model.json"))
var state map[string]any
json.Unmarshal(stored, &state)
recent, _ := state["recent"].([]any)
count := 0
for _, entry := range recent {
e, _ := entry.(map[string]any)
if e["modelID"] == "llama3.2" {
count++
}
}
if count != 1 {
t.Errorf("expected 1 llama3.2 entry, got %d", count)
}
})
t.Run("caps recent list at 10", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
// Pre-populate with 9 distinct ollama models
recentEntries := make([]any, 0, 9)
for i := range 9 {
recentEntries = append(recentEntries, map[string]any{
"providerID": "ollama",
"modelID": fmt.Sprintf("old-%d", i),
})
}
initial := map[string]any{"recent": recentEntries}
data, _ := json.MarshalIndent(initial, "", " ")
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
// Add 5 new models — should cap at 10 total
o := &OpenCode{}
if err := o.Edit([]string{"new-0", "new-1", "new-2", "new-3", "new-4"}); err != nil {
t.Fatal(err)
}
stored, _ := os.ReadFile(filepath.Join(stateDir, "model.json"))
var state map[string]any
json.Unmarshal(stored, &state)
recent, _ := state["recent"].([]any)
if len(recent) != 10 {
t.Errorf("expected 10 entries (capped), got %d", len(recent))
}
})
}
func TestOpenCodeEdit_BaseURL(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Default OLLAMA_HOST
o.Edit([]string{"llama3.2"})
var cfg map[string]any
json.Unmarshal([]byte(o.configContent), &cfg)
provider, _ := cfg["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
options, _ := ollama["options"].(map[string]any)
baseURL, _ := options["baseURL"].(string)
if baseURL == "" {
t.Error("baseURL should be set")
}
}

View File

@@ -1,396 +0,0 @@
package launch
import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"os/exec"
"path/filepath"
"slices"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
// Pi implements Runner and Editor for Pi (Pi Coding Agent) integration
type Pi struct{}
const (
piNpmPackage = "@mariozechner/pi-coding-agent"
piWebSearchSource = "npm:@ollama/pi-web-search"
piWebSearchPkg = "@ollama/pi-web-search"
)
func (p *Pi) String() string { return "Pi" }
func (p *Pi) Run(model string, args []string) error {
fmt.Fprintf(os.Stderr, "\n%sPreparing Pi...%s\n", ansiGray, ansiReset)
if err := ensureNpmInstalled(); err != nil {
return err
}
fmt.Fprintf(os.Stderr, "%sChecking Pi installation...%s\n", ansiGray, ansiReset)
bin, err := ensurePiInstalled()
if err != nil {
return err
}
ensurePiWebSearchPackage(bin)
fmt.Fprintf(os.Stderr, "\n%sLaunching Pi...%s\n\n", ansiGray, ansiReset)
cmd := exec.Command(bin, args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func ensureNpmInstalled() error {
if _, err := exec.LookPath("npm"); err != nil {
return fmt.Errorf("npm (Node.js) is required to launch pi\n\nInstall it first:\n https://nodejs.org/\n\nThen re-run:\n ollama launch pi")
}
return nil
}
func ensurePiInstalled() (string, error) {
if _, err := exec.LookPath("pi"); err == nil {
return "pi", nil
}
if _, err := exec.LookPath("npm"); err != nil {
return "", fmt.Errorf("pi is not installed and required dependencies are missing\n\nInstall the following first:\n npm (Node.js): https://nodejs.org/\n\nThen re-run:\n ollama launch pi")
}
ok, err := ConfirmPrompt("Pi is not installed. Install with npm?")
if err != nil {
return "", err
}
if !ok {
return "", fmt.Errorf("pi installation cancelled")
}
fmt.Fprintf(os.Stderr, "\nInstalling Pi...\n")
cmd := exec.Command("npm", "install", "-g", piNpmPackage+"@latest")
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("failed to install pi: %w", err)
}
if _, err := exec.LookPath("pi"); err != nil {
return "", fmt.Errorf("pi was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
}
fmt.Fprintf(os.Stderr, "%sPi installed successfully%s\n\n", ansiGreen, ansiReset)
return "pi", nil
}
func ensurePiWebSearchPackage(bin string) {
if !shouldManagePiWebSearch() {
fmt.Fprintf(os.Stderr, "%sCloud is disabled; skipping %s setup.%s\n", ansiGray, piWebSearchPkg, ansiReset)
return
}
fmt.Fprintf(os.Stderr, "%sChecking Pi web search package...%s\n", ansiGray, ansiReset)
installed, err := piPackageInstalled(bin, piWebSearchSource)
if err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not check %s installation: %v%s\n", ansiYellow, piWebSearchPkg, err, ansiReset)
return
}
if !installed {
fmt.Fprintf(os.Stderr, "%sInstalling %s...%s\n", ansiGray, piWebSearchPkg, ansiReset)
cmd := exec.Command(bin, "install", piWebSearchSource)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not install %s: %v%s\n", ansiYellow, piWebSearchPkg, err, ansiReset)
return
}
fmt.Fprintf(os.Stderr, "%s ✓ Installed %s%s\n", ansiGreen, piWebSearchPkg, ansiReset)
return
}
fmt.Fprintf(os.Stderr, "%sUpdating %s...%s\n", ansiGray, piWebSearchPkg, ansiReset)
cmd := exec.Command(bin, "update", piWebSearchSource)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not update %s: %v%s\n", ansiYellow, piWebSearchPkg, err, ansiReset)
return
}
fmt.Fprintf(os.Stderr, "%s ✓ Updated %s%s\n", ansiGreen, piWebSearchPkg, ansiReset)
}
func shouldManagePiWebSearch() bool {
client, err := api.ClientFromEnvironment()
if err != nil {
return true
}
disabled, known := cloudStatusDisabled(context.Background(), client)
if known && disabled {
return false
}
return true
}
func piPackageInstalled(bin, source string) (bool, error) {
cmd := exec.Command(bin, "list")
out, err := cmd.CombinedOutput()
if err != nil {
msg := strings.TrimSpace(string(out))
if msg == "" {
return false, err
}
return false, fmt.Errorf("%w: %s", err, msg)
}
for _, line := range strings.Split(string(out), "\n") {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, source) {
return true, nil
}
}
return false, nil
}
func (p *Pi) Paths() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
var paths []string
modelsPath := filepath.Join(home, ".pi", "agent", "models.json")
if _, err := os.Stat(modelsPath); err == nil {
paths = append(paths, modelsPath)
}
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
if _, err := os.Stat(settingsPath); err == nil {
paths = append(paths, settingsPath)
}
return paths
}
func (p *Pi) Edit(models []string) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".pi", "agent", "models.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config)
}
providers, ok := config["providers"].(map[string]any)
if !ok {
providers = make(map[string]any)
}
ollama, ok := providers["ollama"].(map[string]any)
if !ok {
ollama = map[string]any{
"baseUrl": envconfig.Host().String() + "/v1",
"api": "openai-completions",
"apiKey": "ollama",
}
}
existingModels, ok := ollama["models"].([]any)
if !ok {
existingModels = make([]any, 0)
}
// Build set of selected models to track which need to be added
selectedSet := make(map[string]bool, len(models))
for _, m := range models {
selectedSet[m] = true
}
// Build new models list:
// 1. Keep user-managed models (no _launch marker) - untouched
// 2. Keep ollama-managed models (_launch marker) that are still selected,
// except stale cloud entries that should be rebuilt below
// 3. Add new ollama-managed models
var newModels []any
for _, m := range existingModels {
if modelObj, ok := m.(map[string]any); ok {
if id, ok := modelObj["id"].(string); ok {
// User-managed model (no _launch marker) - always preserve
if !isPiOllamaModel(modelObj) {
newModels = append(newModels, m)
} else if selectedSet[id] {
// Rebuild stale managed cloud entries so createConfig refreshes
// the whole entry instead of patching it in place.
if !hasContextWindow(modelObj) {
if _, ok := lookupCloudModelLimit(id); ok {
continue
}
}
newModels = append(newModels, m)
selectedSet[id] = false
}
}
}
}
// Add newly selected models that weren't already in the list
client := api.NewClient(envconfig.Host(), http.DefaultClient)
ctx := context.Background()
for _, model := range models {
if selectedSet[model] {
newModels = append(newModels, createConfig(ctx, client, model))
}
}
ollama["models"] = newModels
providers["ollama"] = ollama
config["providers"] = providers
configData, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
if err := fileutil.WriteWithBackup(configPath, configData); err != nil {
return err
}
// Update settings.json with default provider and model
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
settings := make(map[string]any)
if data, err := os.ReadFile(settingsPath); err == nil {
_ = json.Unmarshal(data, &settings)
}
settings["defaultProvider"] = "ollama"
settings["defaultModel"] = models[0]
settingsData, err := json.MarshalIndent(settings, "", " ")
if err != nil {
return err
}
return fileutil.WriteWithBackup(settingsPath, settingsData)
}
func (p *Pi) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
configPath := filepath.Join(home, ".pi", "agent", "models.json")
config, err := fileutil.ReadJSON(configPath)
if err != nil {
return nil
}
providers, _ := config["providers"].(map[string]any)
ollama, _ := providers["ollama"].(map[string]any)
models, _ := ollama["models"].([]any)
var result []string
for _, m := range models {
if modelObj, ok := m.(map[string]any); ok {
if id, ok := modelObj["id"].(string); ok {
result = append(result, id)
}
}
}
slices.Sort(result)
return result
}
// isPiOllamaModel reports whether a model config entry is managed by ollama launch
func isPiOllamaModel(cfg map[string]any) bool {
if v, ok := cfg["_launch"].(bool); ok && v {
return true
}
return false
}
func hasContextWindow(cfg map[string]any) bool {
switch v := cfg["contextWindow"].(type) {
case float64:
return v > 0
case int:
return v > 0
case int64:
return v > 0
default:
return false
}
}
// createConfig builds Pi model config with capability detection
func createConfig(ctx context.Context, client *api.Client, modelID string) map[string]any {
cfg := map[string]any{
"id": modelID,
"_launch": true,
}
if l, ok := lookupCloudModelLimit(modelID); ok {
cfg["contextWindow"] = l.Context
}
applyCloudContextFallback := func() {
if l, ok := lookupCloudModelLimit(modelID); ok {
cfg["contextWindow"] = l.Context
}
}
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID})
if err != nil {
applyCloudContextFallback()
return cfg
}
// Set input types based on vision capability
if slices.Contains(resp.Capabilities, model.CapabilityVision) {
cfg["input"] = []string{"text", "image"}
} else {
cfg["input"] = []string{"text"}
}
// Set reasoning based on thinking capability
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
cfg["reasoning"] = true
}
// Extract context window from ModelInfo. For known cloud models, the
// pre-filled shared limit remains unless the server provides a positive value.
hasContextWindow := false
for key, val := range resp.ModelInfo {
if strings.HasSuffix(key, ".context_length") {
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
cfg["contextWindow"] = int(ctxLen)
hasContextWindow = true
}
break
}
}
if !hasContextWindow {
applyCloudContextFallback()
}
return cfg
}

View File

@@ -1,417 +0,0 @@
package launch
import (
"fmt"
"os"
"os/exec"
"slices"
"strings"
)
// IntegrationInstallSpec describes how launcher should detect and guide installation.
type IntegrationInstallSpec struct {
CheckInstalled func() bool
EnsureInstalled func() error
URL string
Command []string
}
// IntegrationSpec is the canonical registry entry for one integration.
type IntegrationSpec struct {
Name string
Runner Runner
Aliases []string
Hidden bool
Description string
Install IntegrationInstallSpec
}
// IntegrationInfo contains display information about a registered integration.
type IntegrationInfo struct {
Name string
DisplayName string
Description string
}
var launcherIntegrationOrder = []string{"openclaw", "claude", "opencode", "hermes", "codex", "copilot", "droid", "pi"}
var integrationSpecs = []*IntegrationSpec{
{
Name: "claude",
Runner: &Claude{},
Description: "Anthropic's coding tool with subagents",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := (&Claude{}).findPath()
return err == nil
},
URL: "https://code.claude.com/docs/en/quickstart",
},
},
{
Name: "cline",
Runner: &Cline{},
Description: "Autonomous coding agent with parallel execution",
Hidden: true,
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := exec.LookPath("cline")
return err == nil
},
Command: []string{"npm", "install", "-g", "cline"},
},
},
{
Name: "codex",
Runner: &Codex{},
Description: "OpenAI's open-source coding agent",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := exec.LookPath("codex")
return err == nil
},
URL: "https://developers.openai.com/codex/cli/",
Command: []string{"npm", "install", "-g", "@openai/codex"},
},
},
{
Name: "kimi",
Runner: &Kimi{},
Description: "Moonshot's coding agent for terminal and IDEs",
Hidden: true,
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := exec.LookPath("kimi")
return err == nil
},
EnsureInstalled: func() error {
_, err := ensureKimiInstalled()
return err
},
URL: "https://moonshotai.github.io/kimi-cli/en/guides/getting-started.html",
},
},
{
Name: "copilot",
Runner: &Copilot{},
Aliases: []string{"copilot-cli"},
Description: "GitHub's AI coding agent for the terminal",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := (&Copilot{}).findPath()
return err == nil
},
URL: "https://github.com/features/copilot/cli/",
},
},
{
Name: "droid",
Runner: &Droid{},
Description: "Factory's coding agent across terminal and IDEs",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := exec.LookPath("droid")
return err == nil
},
URL: "https://docs.factory.ai/cli/getting-started/quickstart",
},
},
{
Name: "opencode",
Runner: &OpenCode{},
Description: "Anomaly's open-source coding agent",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, ok := findOpenCode()
return ok
},
URL: "https://opencode.ai",
},
},
{
Name: "openclaw",
Runner: &Openclaw{},
Aliases: []string{"clawdbot", "moltbot"},
Description: "Personal AI with 100+ skills",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
if _, err := exec.LookPath("openclaw"); err == nil {
return true
}
if _, err := exec.LookPath("clawdbot"); err == nil {
return true
}
return false
},
EnsureInstalled: func() error {
_, err := ensureOpenclawInstalled()
return err
},
URL: "https://docs.openclaw.ai",
},
},
{
Name: "pi",
Runner: &Pi{},
Description: "Minimal AI agent toolkit with plugin support",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := exec.LookPath("pi")
return err == nil
},
EnsureInstalled: func() error {
_, err := ensurePiInstalled()
return err
},
Command: []string{"npm", "install", "-g", "@mariozechner/pi-coding-agent@latest"},
},
},
{
Name: "hermes",
Runner: &Hermes{},
Description: "Self-improving AI agent built by Nous Research",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
return (&Hermes{}).installed()
},
EnsureInstalled: func() error {
return (&Hermes{}).ensureInstalled()
},
URL: "https://hermes-agent.nousresearch.com/docs/getting-started/installation/",
},
},
{
Name: "vscode",
Runner: &VSCode{},
Aliases: []string{"code"},
Description: "Microsoft's open-source AI code editor",
Hidden: true,
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
return (&VSCode{}).findBinary() != ""
},
URL: "https://code.visualstudio.com",
},
},
}
var integrationSpecsByName map[string]*IntegrationSpec
func init() {
rebuildIntegrationSpecIndexes()
}
func hyperlink(url, text string) string {
return fmt.Sprintf("\033]8;;%s\033\\%s\033]8;;\033\\", url, text)
}
func rebuildIntegrationSpecIndexes() {
integrationSpecsByName = make(map[string]*IntegrationSpec, len(integrationSpecs))
canonical := make(map[string]bool, len(integrationSpecs))
for _, spec := range integrationSpecs {
key := strings.ToLower(spec.Name)
if key == "" {
panic("launch: integration spec missing name")
}
if canonical[key] {
panic(fmt.Sprintf("launch: duplicate integration name %q", key))
}
canonical[key] = true
integrationSpecsByName[key] = spec
}
seenAliases := make(map[string]string)
for _, spec := range integrationSpecs {
for _, alias := range spec.Aliases {
key := strings.ToLower(alias)
if key == "" {
panic(fmt.Sprintf("launch: integration %q has empty alias", spec.Name))
}
if canonical[key] {
panic(fmt.Sprintf("launch: alias %q collides with canonical integration name", key))
}
if owner, exists := seenAliases[key]; exists {
panic(fmt.Sprintf("launch: alias %q collides between %q and %q", key, owner, spec.Name))
}
seenAliases[key] = spec.Name
integrationSpecsByName[key] = spec
}
}
orderSeen := make(map[string]bool, len(launcherIntegrationOrder))
for _, name := range launcherIntegrationOrder {
key := strings.ToLower(name)
if orderSeen[key] {
panic(fmt.Sprintf("launch: duplicate launcher order entry %q", key))
}
orderSeen[key] = true
spec, ok := integrationSpecsByName[key]
if !ok {
panic(fmt.Sprintf("launch: unknown launcher order entry %q", key))
}
if spec.Name != key {
panic(fmt.Sprintf("launch: launcher order entry %q must use canonical name, not alias", key))
}
if spec.Hidden {
panic(fmt.Sprintf("launch: hidden integration %q cannot appear in launcher order", key))
}
}
}
// LookupIntegrationSpec resolves either a canonical integration name or alias to its spec.
func LookupIntegrationSpec(name string) (*IntegrationSpec, error) {
spec, ok := integrationSpecsByName[strings.ToLower(name)]
if !ok {
return nil, fmt.Errorf("unknown integration: %s", name)
}
return spec, nil
}
// LookupIntegration resolves a registry name to the canonical key and runner.
func LookupIntegration(name string) (string, Runner, error) {
spec, err := LookupIntegrationSpec(name)
if err != nil {
return "", nil, err
}
return spec.Name, spec.Runner, nil
}
// ListVisibleIntegrationSpecs returns the canonical integrations that should appear in interactive UIs.
func ListVisibleIntegrationSpecs() []IntegrationSpec {
visible := make([]IntegrationSpec, 0, len(integrationSpecs))
for _, spec := range integrationSpecs {
if spec.Hidden {
continue
}
visible = append(visible, *spec)
}
orderRank := make(map[string]int, len(launcherIntegrationOrder))
for i, name := range launcherIntegrationOrder {
orderRank[name] = i + 1
}
slices.SortFunc(visible, func(a, b IntegrationSpec) int {
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
if aRank > 0 && bRank > 0 {
return aRank - bRank
}
if aRank > 0 {
return -1
}
if bRank > 0 {
return 1
}
return strings.Compare(a.Name, b.Name)
})
return visible
}
// ListIntegrationInfos returns the registered integrations in launcher display order.
func ListIntegrationInfos() []IntegrationInfo {
visible := ListVisibleIntegrationSpecs()
infos := make([]IntegrationInfo, 0, len(visible))
for _, spec := range visible {
infos = append(infos, IntegrationInfo{
Name: spec.Name,
DisplayName: spec.Runner.String(),
Description: spec.Description,
})
}
return infos
}
// IntegrationSelectionItems returns the sorted integration items shown by launcher selection UIs.
func IntegrationSelectionItems() ([]ModelItem, error) {
visible := ListVisibleIntegrationSpecs()
if len(visible) == 0 {
return nil, fmt.Errorf("no integrations available")
}
items := make([]ModelItem, 0, len(visible))
for _, spec := range visible {
description := spec.Runner.String()
if conn, err := loadStoredIntegrationConfig(spec.Name); err == nil && len(conn.Models) > 0 {
description = fmt.Sprintf("%s (%s)", spec.Runner.String(), conn.Models[0])
}
items = append(items, ModelItem{Name: spec.Name, Description: description})
}
return items, nil
}
// IsIntegrationInstalled checks if an integration binary is installed.
func IsIntegrationInstalled(name string) bool {
integration, err := integrationFor(name)
if err != nil {
fmt.Fprintf(os.Stderr, "Ollama couldn't find integration %q, so it'll show up as not installed.\n", name)
return false
}
return integration.installed
}
// integration is resolved registry metadata used by launcher state and install checks.
// It combines immutable registry spec data with computed runtime traits.
type integration struct {
spec *IntegrationSpec
installed bool
autoInstallable bool
editor bool
installHint string
}
// integrationFor resolves an integration name into the canonical spec plus
// derived launcher/install traits used across registry and launch flows.
func integrationFor(name string) (integration, error) {
spec, err := LookupIntegrationSpec(name)
if err != nil {
return integration{}, err
}
installed := true
if spec.Install.CheckInstalled != nil {
installed = spec.Install.CheckInstalled()
}
_, editor := spec.Runner.(Editor)
hint := ""
if spec.Install.URL != "" {
hint = "Install from " + hyperlink(spec.Install.URL, spec.Install.URL)
} else if len(spec.Install.Command) > 0 {
hint = "Install with: " + strings.Join(spec.Install.Command, " ")
}
return integration{
spec: spec,
installed: installed,
autoInstallable: spec.Install.EnsureInstalled != nil,
editor: editor,
installHint: hint,
}, nil
}
// EnsureIntegrationInstalled installs auto-installable integrations when missing.
func EnsureIntegrationInstalled(name string, runner Runner) error {
integration, err := integrationFor(name)
if err != nil {
return fmt.Errorf("%s is not installed", runner)
}
if integration.installed {
return nil
}
if integration.autoInstallable {
return integration.spec.Install.EnsureInstalled()
}
switch {
case integration.spec.Install.URL != "":
return fmt.Errorf("%s is not installed, install from %s", integration.spec.Name, integration.spec.Install.URL)
case len(integration.spec.Install.Command) > 0:
return fmt.Errorf("%s is not installed, install with: %s", integration.spec.Name, strings.Join(integration.spec.Install.Command, " "))
default:
return fmt.Errorf("%s is not installed", runner)
}
}

View File

@@ -1,21 +0,0 @@
package launch
import "strings"
// OverrideIntegration replaces one registry entry's runner for tests and returns a restore function.
func OverrideIntegration(name string, runner Runner) func() {
spec, err := LookupIntegrationSpec(name)
if err != nil {
key := strings.ToLower(name)
integrationSpecsByName[key] = &IntegrationSpec{Name: key, Runner: runner}
return func() {
delete(integrationSpecsByName, key)
}
}
original := spec.Runner
spec.Runner = runner
return func() {
spec.Runner = original
}
}

View File

@@ -1,83 +0,0 @@
package launch
import (
"os"
"path/filepath"
"testing"
)
func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
tests := []struct {
name string
binary string
runner Runner
checkPath func(home string) string
}{
{
name: "droid",
binary: "droid",
runner: &Droid{},
checkPath: func(home string) string {
return filepath.Join(home, ".factory", "settings.json")
},
},
{
name: "opencode",
binary: "opencode",
runner: &OpenCode{},
checkPath: func(home string) string {
return filepath.Join(home, ".local", "state", "opencode", "model.json")
},
},
{
name: "cline",
binary: "cline",
runner: &Cline{},
checkPath: func(home string) string {
return filepath.Join(home, ".cline", "data", "globalState.json")
},
},
{
name: "pi",
binary: "pi",
runner: &Pi{},
checkPath: func(home string) string {
return filepath.Join(home, ".pi", "agent", "models.json")
},
},
{
name: "kimi",
binary: "kimi",
runner: &Kimi{},
checkPath: func(home string) string {
return filepath.Join(home, ".kimi", "config.toml")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
home := t.TempDir()
setTestHome(t, home)
binDir := t.TempDir()
writeFakeBinary(t, binDir, tt.binary)
if tt.name == "pi" {
writeFakeBinary(t, binDir, "npm")
}
if tt.name == "kimi" {
writeFakeBinary(t, binDir, "curl")
writeFakeBinary(t, binDir, "bash")
}
t.Setenv("PATH", binDir)
configPath := tt.checkPath(home)
if err := tt.runner.Run("llama3.2", nil); err != nil {
t.Fatalf("Run returned error: %v", err)
}
if _, err := os.Stat(configPath); !os.IsNotExist(err) {
t.Fatalf("expected Run to leave %s untouched, got err=%v", configPath, err)
}
})
}
}

View File

@@ -1,115 +0,0 @@
package launch
import (
"errors"
"fmt"
"os"
"golang.org/x/term"
)
// ANSI escape sequences for terminal formatting.
const (
ansiBold = "\033[1m"
ansiReset = "\033[0m"
ansiGray = "\033[37m"
ansiGreen = "\033[32m"
ansiYellow = "\033[33m"
)
// ErrCancelled is returned when the user cancels a selection.
var ErrCancelled = errors.New("cancelled")
// errCancelled is kept as an internal alias for existing call sites.
var errCancelled = ErrCancelled
// DefaultConfirmPrompt provides a TUI-based confirmation prompt.
// When set, ConfirmPrompt delegates to it instead of using raw terminal I/O.
var DefaultConfirmPrompt func(prompt string, options ConfirmOptions) (bool, error)
// ConfirmOptions customizes labels for confirmation prompts.
type ConfirmOptions struct {
YesLabel string
NoLabel string
}
// SingleSelector is a function type for single item selection.
// current is the name of the previously selected item to highlight; empty means no pre-selection.
type SingleSelector func(title string, items []ModelItem, current string) (string, error)
// MultiSelector is a function type for multi item selection.
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
// DefaultSingleSelector is the default single-select implementation.
var DefaultSingleSelector SingleSelector
// DefaultMultiSelector is the default multi-select implementation.
var DefaultMultiSelector MultiSelector
// DefaultSignIn provides a TUI-based sign-in flow.
// When set, ensureAuth uses it instead of plain text prompts.
// Returns the signed-in username or an error.
var DefaultSignIn func(modelName, signInURL string) (string, error)
type launchConfirmPolicy struct {
yes bool
requireYesMessage bool
}
var currentLaunchConfirmPolicy launchConfirmPolicy
func withLaunchConfirmPolicy(policy launchConfirmPolicy) func() {
old := currentLaunchConfirmPolicy
currentLaunchConfirmPolicy = policy
return func() {
currentLaunchConfirmPolicy = old
}
}
// ConfirmPrompt is the shared confirmation gate for launch flows (integration
// edits, missing-model pulls, sign-in prompts, OpenClaw install/security, etc).
// Behavior is controlled by currentLaunchConfirmPolicy, typically scoped by
// withLaunchConfirmPolicy in LaunchCmd (e.g. auto-approve with --yes).
func ConfirmPrompt(prompt string) (bool, error) {
return ConfirmPromptWithOptions(prompt, ConfirmOptions{})
}
// ConfirmPromptWithOptions is the shared confirmation gate for launch flows
// that need custom yes/no labels in interactive UIs.
func ConfirmPromptWithOptions(prompt string, options ConfirmOptions) (bool, error) {
if currentLaunchConfirmPolicy.yes {
return true, nil
}
if currentLaunchConfirmPolicy.requireYesMessage {
return false, fmt.Errorf("%s requires confirmation; re-run with --yes to continue", prompt)
}
if DefaultConfirmPrompt != nil {
return DefaultConfirmPrompt(prompt, options)
}
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
return false, err
}
defer term.Restore(fd, oldState)
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
buf := make([]byte, 1)
for {
if _, err := os.Stdin.Read(buf); err != nil {
return false, err
}
switch buf[0] {
case 'Y', 'y', 13:
fmt.Fprintf(os.Stderr, "yes\r\n")
return true, nil
case 'N', 'n', 27, 3:
fmt.Fprintf(os.Stderr, "no\r\n")
return false, nil
}
}
}

Some files were not shown because too many files have changed in this diff Show More