Compare commits
155 Commits
pdevine/sa
...
jmorganca/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d846fdbc0 | ||
|
|
f3536a356e | ||
|
|
c89280fb0c | ||
|
|
eb5434d7fb | ||
|
|
2b949a11d9 | ||
|
|
6b013002fc | ||
|
|
5e622289c5 | ||
|
|
9c8bcecdb2 | ||
|
|
1cbe7950d6 | ||
|
|
95073400fc | ||
|
|
c29932c631 | ||
|
|
1ce101c9a0 | ||
|
|
5a7928ed38 | ||
|
|
7fdc051091 | ||
|
|
5bad871241 | ||
|
|
82437d620a | ||
|
|
570c53859d | ||
|
|
ebd70f73b7 | ||
|
|
eb5df80733 | ||
|
|
356c0b8e34 | ||
|
|
ea3c6a3cbe | ||
|
|
f6b69f3f28 | ||
|
|
e38b606e8b | ||
|
|
cb0033598e | ||
|
|
4d14b0ff92 | ||
|
|
d9cb70c270 | ||
|
|
31f968fe1f | ||
|
|
b7bda92d52 | ||
|
|
8e54823fd3 | ||
|
|
7c8da5679e | ||
|
|
6214103e66 | ||
|
|
9e7cb9697e | ||
|
|
3824e380a8 | ||
|
|
c9b2dcfc52 | ||
|
|
b00bd1dfd4 | ||
|
|
ac83ac20c4 | ||
|
|
e7ccc129ea | ||
|
|
69ed0c2729 | ||
|
|
1cefa749aa | ||
|
|
aec2fef95d | ||
|
|
366625a831 | ||
|
|
516ebd8548 | ||
|
|
f567abc63f | ||
|
|
1adfc27f04 | ||
|
|
4a2b9f9dbc | ||
|
|
e46b67a6cc | ||
|
|
c000afe76c | ||
|
|
9d7b18f81e | ||
|
|
4f5999fd3f | ||
|
|
ac5f0dbb6a | ||
|
|
d1151e18a1 | ||
|
|
ebbce136c7 | ||
|
|
26b9f53f8e | ||
|
|
7575438366 | ||
|
|
7d7c90d702 | ||
|
|
4fda69809a | ||
|
|
c9b5da6b0c | ||
|
|
de5cb7311f | ||
|
|
95ee7fbd29 | ||
|
|
ec55536734 | ||
|
|
77491439c2 | ||
|
|
b166b36cd2 | ||
|
|
c2b0bb7a52 | ||
|
|
22c2bdbd8a | ||
|
|
6df6d097d9 | ||
|
|
d7c176ab91 | ||
|
|
0ff7d724ff | ||
|
|
46cb7795e1 | ||
|
|
126d8db7f3 | ||
|
|
3f3a24b418 | ||
|
|
96e36c0d90 | ||
|
|
6f8ddbb26b | ||
|
|
b5e7888414 | ||
|
|
eab4d22269 | ||
|
|
5759c2d2d2 | ||
|
|
42b1c2642b | ||
|
|
727d69ddf3 | ||
|
|
f622b0c5fc | ||
|
|
5d0000634c | ||
|
|
676d9845ba | ||
|
|
e37a9b4c01 | ||
|
|
d727aacd04 | ||
|
|
fa69b833cd | ||
|
|
bbbad97686 | ||
|
|
bcf6d55b54 | ||
|
|
810d4f9c22 | ||
|
|
856c047a6c | ||
|
|
79c1e93c00 | ||
|
|
f8b657c967 | ||
|
|
10fefe0d57 | ||
|
|
2f9a68f9e9 | ||
|
|
3980c0217d | ||
|
|
870599f5da | ||
|
|
abf8e8e9c8 | ||
|
|
f3f31a8192 | ||
|
|
9e7ba835da | ||
|
|
347f17b8d1 | ||
|
|
081b9eb423 | ||
|
|
bb867c6fdb | ||
|
|
81f4506a61 | ||
|
|
76925f1284 | ||
|
|
f676231de9 | ||
|
|
af5f7c0a9e | ||
|
|
a6b27d776b | ||
|
|
539741199e | ||
|
|
8f45236d09 | ||
|
|
97013a190c | ||
|
|
c222735c02 | ||
|
|
87d21c7fc0 | ||
|
|
54e05172a0 | ||
|
|
464186e995 | ||
|
|
8c4d5d6c2f | ||
|
|
bc72b14016 | ||
|
|
61086083eb | ||
|
|
62d1f01ab4 | ||
|
|
10e51c5177 | ||
|
|
3e06bde643 | ||
|
|
6be2de8214 | ||
|
|
ebb1b9ec14 | ||
|
|
d126467d5d | ||
|
|
afb4c62fbf | ||
|
|
e790dc435b | ||
|
|
288077c3a3 | ||
|
|
4425c54eda | ||
|
|
778899a5d2 | ||
|
|
4eab60c1e2 | ||
|
|
1af850e6e3 | ||
|
|
9b0c7cc7b9 | ||
|
|
6928630601 | ||
|
|
9896e3627f | ||
|
|
15732f0ea7 | ||
|
|
562c76d7cc | ||
|
|
122c68c151 | ||
|
|
82848a7806 | ||
|
|
39982a954e | ||
|
|
e9f6ea232f | ||
|
|
110eff01a9 | ||
|
|
799e51d419 | ||
|
|
e8fcb29586 | ||
|
|
97d2f05a6d | ||
|
|
8207e55ec7 | ||
|
|
ad16bffc7d | ||
|
|
c1e3ef4bcc | ||
|
|
a3093cd5e5 | ||
|
|
23d4cad1a2 | ||
|
|
86513cb697 | ||
|
|
3490e9590b | ||
|
|
8da09b1e7e | ||
|
|
a60b9adcce | ||
|
|
a16f96658b | ||
|
|
18ab09b431 | ||
|
|
638faeac54 | ||
|
|
dd5eb6337d | ||
|
|
79917cf80b | ||
|
|
cc90a035a0 |
67
.github/workflows/release.yaml
vendored
@@ -117,6 +117,25 @@ jobs:
|
||||
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
||||
flags: ''
|
||||
runner_dir: 'vulkan'
|
||||
- os: windows
|
||||
arch: amd64
|
||||
preset: 'MLX CUDA 13'
|
||||
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
|
||||
cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip
|
||||
cuda-components:
|
||||
- '"cudart"'
|
||||
- '"nvcc"'
|
||||
- '"cublas"'
|
||||
- '"cublas_dev"'
|
||||
- '"cufft"'
|
||||
- '"cufft_dev"'
|
||||
- '"nvrtc"'
|
||||
- '"nvrtc_dev"'
|
||||
- '"crt"'
|
||||
- '"nvvm"'
|
||||
- '"nvptxcompiler"'
|
||||
cuda-version: '13.0'
|
||||
flags: ''
|
||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||
environment: release
|
||||
env:
|
||||
@@ -125,8 +144,10 @@ jobs:
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
choco install -y --no-progress ccache ninja
|
||||
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
||||
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan')
|
||||
if (Get-Command ccache -ErrorAction SilentlyContinue) {
|
||||
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
||||
}
|
||||
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan') || startsWith(matrix.preset, 'MLX ')
|
||||
id: cache-install
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
@@ -134,8 +155,9 @@ jobs:
|
||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||
C:\Program Files\AMD\ROCm
|
||||
C:\VulkanSDK
|
||||
key: ${{ matrix.install }}
|
||||
- if: startsWith(matrix.preset, 'CUDA ')
|
||||
C:\Program Files\NVIDIA\CUDNN
|
||||
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'MLX ')
|
||||
name: Install CUDA ${{ matrix.cuda-version }}
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
@@ -179,6 +201,23 @@ jobs:
|
||||
run: |
|
||||
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
- if: startsWith(matrix.preset, 'MLX ')
|
||||
name: Install cuDNN for MLX
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
|
||||
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||
Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip"
|
||||
Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted
|
||||
$cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName
|
||||
New-Item -ItemType Directory -Force -Path $cudnnRoot
|
||||
Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse
|
||||
}
|
||||
|
||||
echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
@@ -186,7 +225,8 @@ jobs:
|
||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||
C:\Program Files\AMD\ROCm
|
||||
C:\VulkanSDK
|
||||
key: ${{ matrix.install }}
|
||||
C:\Program Files\NVIDIA\CUDNN
|
||||
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/cache@v4
|
||||
with:
|
||||
@@ -198,7 +238,7 @@ jobs:
|
||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}"
|
||||
cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}"
|
||||
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
|
||||
cmake --install build --component "${{ startsWith(matrix.preset, 'MLX ') && 'MLX' || startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
|
||||
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
|
||||
env:
|
||||
CMAKE_GENERATOR: Ninja
|
||||
@@ -384,6 +424,7 @@ jobs:
|
||||
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/include*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||
@@ -543,11 +584,19 @@ jobs:
|
||||
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg dist/*.ps1 dist/*.sh ; do
|
||||
echo "Uploading $payload"
|
||||
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
|
||||
pids[$!]=$!
|
||||
pids+=($!)
|
||||
sleep 1
|
||||
done
|
||||
echo "Waiting for uploads to complete"
|
||||
for pid in "${pids[*]}"; do
|
||||
wait $pid
|
||||
failed=0
|
||||
for pid in "${pids[@]}"; do
|
||||
if ! wait $pid; then
|
||||
echo "::error::Upload failed (pid $pid)"
|
||||
failed=1
|
||||
fi
|
||||
done
|
||||
if [ $failed -ne 0 ]; then
|
||||
echo "One or more uploads failed"
|
||||
exit 1
|
||||
fi
|
||||
echo "done"
|
||||
|
||||
71
.github/workflows/test.yaml
vendored
@@ -37,7 +37,7 @@ jobs:
|
||||
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
|
||||
}
|
||||
|
||||
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
|
||||
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*' '.github/**/*') | tee -a $GITHUB_OUTPUT
|
||||
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
|
||||
|
||||
linux:
|
||||
@@ -51,7 +51,7 @@ jobs:
|
||||
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
|
||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
||||
- preset: ROCm
|
||||
container: rocm/dev-ubuntu-22.04:6.1.2
|
||||
container: rocm/dev-ubuntu-22.04:7.2
|
||||
extra-packages: rocm-libs
|
||||
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm'
|
||||
- preset: Vulkan
|
||||
@@ -60,6 +60,11 @@ jobs:
|
||||
mesa-vulkan-drivers vulkan-tools
|
||||
libvulkan1 libvulkan-dev
|
||||
vulkan-sdk cmake ccache g++ make
|
||||
- preset: 'MLX CUDA 13'
|
||||
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
|
||||
extra-packages: libcudnn9-dev-cuda-13 libopenblas-dev liblapack-dev liblapacke-dev git curl
|
||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=87 -DBLAS_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu -DLAPACK_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu'
|
||||
install-go: true
|
||||
runs-on: linux
|
||||
container: ${{ matrix.container }}
|
||||
steps:
|
||||
@@ -76,19 +81,29 @@ jobs:
|
||||
$sudo apt-get update
|
||||
fi
|
||||
$sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }}
|
||||
# MLX requires CMake 3.25+, install from official releases
|
||||
if [ "${{ matrix.preset }}" = "MLX CUDA 13" ]; then
|
||||
curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.31.2/cmake-3.31.2-linux-$(uname -m).tar.gz | $sudo tar xz -C /usr/local --strip-components 1
|
||||
fi
|
||||
# Export VULKAN_SDK if provided by LunarG package (defensive)
|
||||
if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then
|
||||
echo "VULKAN_SDK=/usr" >> $GITHUB_ENV
|
||||
fi
|
||||
env:
|
||||
DEBIAN_FRONTEND: noninteractive
|
||||
- if: matrix.install-go
|
||||
name: Install Go
|
||||
run: |
|
||||
GO_VERSION=$(awk '/^go / { print $2 }' go.mod)
|
||||
curl -fsSL "https://golang.org/dl/go${GO_VERSION}.linux-$(dpkg --print-architecture).tar.gz" | tar xz -C /usr/local
|
||||
echo "/usr/local/go/bin" >> $GITHUB_PATH
|
||||
- uses: actions/cache@v4
|
||||
with:
|
||||
path: /github/home/.cache/ccache
|
||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
|
||||
- run: |
|
||||
cmake --preset ${{ matrix.preset }} ${{ matrix.flags }}
|
||||
cmake --build --preset ${{ matrix.preset }} --parallel
|
||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
|
||||
cmake --build --preset "${{ matrix.preset }}" --parallel
|
||||
|
||||
windows:
|
||||
needs: [changes]
|
||||
@@ -114,12 +129,31 @@ jobs:
|
||||
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||
- preset: Vulkan
|
||||
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
||||
- preset: 'MLX CUDA 13'
|
||||
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
|
||||
cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip
|
||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
|
||||
cuda-components:
|
||||
- '"cudart"'
|
||||
- '"nvcc"'
|
||||
- '"cublas"'
|
||||
- '"cublas_dev"'
|
||||
- '"cufft"'
|
||||
- '"cufft_dev"'
|
||||
- '"nvrtc"'
|
||||
- '"nvrtc_dev"'
|
||||
- '"crt"'
|
||||
- '"nvvm"'
|
||||
- '"nvptxcompiler"'
|
||||
cuda-version: '13.0'
|
||||
runs-on: windows
|
||||
steps:
|
||||
- run: |
|
||||
choco install -y --no-progress ccache ninja
|
||||
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
||||
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan'
|
||||
if (Get-Command ccache -ErrorAction SilentlyContinue) {
|
||||
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
||||
}
|
||||
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan' || matrix.preset == 'MLX CUDA 13'
|
||||
id: cache-install
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
@@ -127,8 +161,9 @@ jobs:
|
||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||
C:\Program Files\AMD\ROCm
|
||||
C:\VulkanSDK
|
||||
key: ${{ matrix.install }}
|
||||
- if: matrix.preset == 'CUDA'
|
||||
C:\Program Files\NVIDIA\CUDNN
|
||||
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||
- if: matrix.preset == 'CUDA' || matrix.preset == 'MLX CUDA 13'
|
||||
name: Install CUDA ${{ matrix.cuda-version }}
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
@@ -168,6 +203,23 @@ jobs:
|
||||
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
|
||||
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
|
||||
- if: matrix.preset == 'MLX CUDA 13'
|
||||
name: Install cuDNN for MLX
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
|
||||
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||
Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip"
|
||||
Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted
|
||||
$cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName
|
||||
New-Item -ItemType Directory -Force -Path $cudnnRoot
|
||||
Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse
|
||||
}
|
||||
|
||||
echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
@@ -175,7 +227,8 @@ jobs:
|
||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||
C:\Program Files\AMD\ROCm
|
||||
C:\VulkanSDK
|
||||
key: ${{ matrix.install }}
|
||||
C:\Program Files\NVIDIA\CUDNN
|
||||
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/cache@v4
|
||||
with:
|
||||
|
||||
171
CMakeLists.txt
@@ -64,10 +64,15 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR})
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${OLLAMA_BUILD_DIR})
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
|
||||
# Store ggml include paths for use with target_include_directories later.
|
||||
# We avoid global include_directories() to prevent polluting the include path
|
||||
# for other projects like MLX (whose openblas dependency has its own common.h).
|
||||
set(GGML_INCLUDE_DIRS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx
|
||||
)
|
||||
|
||||
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
|
||||
|
||||
@@ -87,6 +92,14 @@ if(NOT CPU_VARIANTS)
|
||||
set(CPU_VARIANTS "ggml-cpu")
|
||||
endif()
|
||||
|
||||
# Apply ggml include directories to ggml targets only (not globally)
|
||||
target_include_directories(ggml-base PRIVATE ${GGML_INCLUDE_DIRS})
|
||||
foreach(variant ${CPU_VARIANTS})
|
||||
if(TARGET ${variant})
|
||||
target_include_directories(${variant} PRIVATE ${GGML_INCLUDE_DIRS})
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
install(TARGETS ggml-base ${CPU_VARIANTS}
|
||||
RUNTIME_DEPENDENCIES
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
@@ -103,6 +116,7 @@ if(CMAKE_CUDA_COMPILER)
|
||||
|
||||
find_package(CUDAToolkit)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
||||
target_include_directories(ggml-cuda PRIVATE ${GGML_INCLUDE_DIRS})
|
||||
install(TARGETS ggml-cuda
|
||||
RUNTIME_DEPENDENCIES
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||
@@ -134,6 +148,7 @@ if(CMAKE_HIP_COMPILER)
|
||||
if(AMDGPU_TARGETS)
|
||||
find_package(hip REQUIRED)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
||||
target_include_directories(ggml-hip PRIVATE ${GGML_INCLUDE_DIRS})
|
||||
|
||||
if (WIN32)
|
||||
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
|
||||
@@ -148,7 +163,7 @@ if(CMAKE_HIP_COMPILER)
|
||||
)
|
||||
install(RUNTIME_DEPENDENCY_SET rocm
|
||||
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
|
||||
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf
|
||||
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register roctx64 rocroller drm drm_amdgpu numa elf
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
POST_EXCLUDE_REGEXES "system32"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
|
||||
@@ -168,6 +183,7 @@ if(NOT APPLE)
|
||||
find_package(Vulkan)
|
||||
if(Vulkan_FOUND)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
||||
target_include_directories(ggml-vulkan PRIVATE ${GGML_INCLUDE_DIRS})
|
||||
install(TARGETS ggml-vulkan
|
||||
RUNTIME_DEPENDENCIES
|
||||
PRE_INCLUDE_REGEXES vulkan
|
||||
@@ -179,7 +195,6 @@ if(NOT APPLE)
|
||||
endif()
|
||||
|
||||
option(MLX_ENGINE "Enable MLX backend" OFF)
|
||||
|
||||
if(MLX_ENGINE)
|
||||
message(STATUS "Setting up MLX (this takes a while...)")
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx)
|
||||
@@ -187,10 +202,36 @@ if(MLX_ENGINE)
|
||||
# Find CUDA toolkit if MLX is built with CUDA support
|
||||
find_package(CUDAToolkit)
|
||||
|
||||
# Build list of directories for runtime dependency resolution
|
||||
set(MLX_RUNTIME_DIRS ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR})
|
||||
# Add cuDNN bin paths for DLLs (Windows MLX CUDA builds)
|
||||
# CUDNN_ROOT_DIR is the standard CMake variable for cuDNN location
|
||||
if(DEFINED ENV{CUDNN_ROOT_DIR})
|
||||
# cuDNN 9.x has versioned subdirectories under bin/ (e.g., bin/13.0/)
|
||||
file(GLOB CUDNN_BIN_SUBDIRS "$ENV{CUDNN_ROOT_DIR}/bin/*")
|
||||
list(APPEND MLX_RUNTIME_DIRS ${CUDNN_BIN_SUBDIRS})
|
||||
endif()
|
||||
# Add build output directory and MLX dependency build directories
|
||||
list(APPEND MLX_RUNTIME_DIRS ${OLLAMA_BUILD_DIR})
|
||||
# OpenBLAS DLL location (pre-built zip extracts into openblas-src/bin/)
|
||||
list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/openblas-src/bin)
|
||||
# NCCL: on Linux, if real NCCL is found, cmake bundles libnccl.so via the
|
||||
# regex below. If NCCL is not found, MLX links a static stub (OBJECT lib)
|
||||
# so there is no runtime dependency. This path covers the stub build dir
|
||||
# for windows so we include the DLL in our dependencies.
|
||||
list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/distributed/nccl/nccl_stub-prefix/src/nccl_stub-build/Release)
|
||||
|
||||
# Base regexes for runtime dependencies (cross-platform)
|
||||
set(MLX_INCLUDE_REGEXES cublas cublasLt cudart cufft nvrtc nvrtc-builtins cudnn nccl openblas gfortran)
|
||||
# On Windows, also include dl.dll (dlfcn-win32 POSIX emulation layer)
|
||||
if(WIN32)
|
||||
list(APPEND MLX_INCLUDE_REGEXES "^dl\\.dll$")
|
||||
endif()
|
||||
|
||||
install(TARGETS mlx mlxc
|
||||
RUNTIME_DEPENDENCIES
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
|
||||
DIRECTORIES ${MLX_RUNTIME_DIRS}
|
||||
PRE_INCLUDE_REGEXES ${MLX_INCLUDE_REGEXES}
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
@@ -205,13 +246,117 @@ if(MLX_ENGINE)
|
||||
COMPONENT MLX)
|
||||
endif()
|
||||
|
||||
# Manually install cudart and cublas since they might not be picked up as direct dependencies
|
||||
# Install headers for NVRTC JIT compilation at runtime.
|
||||
# MLX's own install rules use the default component so they get skipped by
|
||||
# --component MLX. Headers are installed alongside libmlx in OLLAMA_INSTALL_DIR.
|
||||
#
|
||||
# Layout:
|
||||
# ${OLLAMA_INSTALL_DIR}/include/cccl/{cuda,nv}/ — CCCL headers
|
||||
# ${OLLAMA_INSTALL_DIR}/include/*.h — CUDA toolkit headers
|
||||
#
|
||||
# MLX's jit_module.cpp resolves CCCL via
|
||||
# current_binary_dir()[.parent_path()] / "include" / "cccl"
|
||||
# On Linux, MLX's jit_module.cpp resolves CCCL via
|
||||
# current_binary_dir().parent_path() / "include" / "cccl", so we create a
|
||||
# symlink from lib/ollama/include -> ${OLLAMA_RUNNER_DIR}/include
|
||||
# This will need refinement if we add multiple CUDA versions for MLX in the future.
|
||||
# CUDA runtime headers are found via CUDA_PATH env var (set by mlxrunner).
|
||||
if(EXISTS ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda)
|
||||
install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda
|
||||
DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl
|
||||
COMPONENT MLX)
|
||||
install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/nv
|
||||
DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl
|
||||
COMPONENT MLX)
|
||||
if(NOT WIN32 AND NOT APPLE)
|
||||
install(CODE "
|
||||
set(_link \"${CMAKE_INSTALL_PREFIX}/lib/ollama/include\")
|
||||
set(_target \"${OLLAMA_RUNNER_DIR}/include\")
|
||||
if(NOT EXISTS \${_link})
|
||||
execute_process(COMMAND \${CMAKE_COMMAND} -E create_symlink \${_target} \${_link})
|
||||
endif()
|
||||
" COMPONENT MLX)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Install minimal CUDA toolkit headers needed by MLX JIT kernels.
|
||||
# These are the transitive closure of includes from mlx/backend/cuda/device/*.cuh.
|
||||
# The Go mlxrunner sets CUDA_PATH to OLLAMA_INSTALL_DIR so MLX finds them at
|
||||
# $CUDA_PATH/include/*.h via NVRTC --include-path.
|
||||
if(CUDAToolkit_FOUND)
|
||||
file(GLOB CUDART_LIBS
|
||||
# CUDAToolkit_INCLUDE_DIRS may be a semicolon-separated list
|
||||
# (e.g. ".../include;.../include/cccl"). Find the entry that
|
||||
# contains the CUDA runtime headers we need.
|
||||
set(_cuda_inc "")
|
||||
foreach(_dir ${CUDAToolkit_INCLUDE_DIRS})
|
||||
if(EXISTS "${_dir}/cuda_runtime_api.h")
|
||||
set(_cuda_inc "${_dir}")
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
if(NOT _cuda_inc)
|
||||
message(WARNING "Could not find cuda_runtime_api.h in CUDAToolkit_INCLUDE_DIRS: ${CUDAToolkit_INCLUDE_DIRS}")
|
||||
else()
|
||||
set(_dst "${OLLAMA_INSTALL_DIR}/include")
|
||||
set(_MLX_JIT_CUDA_HEADERS
|
||||
builtin_types.h
|
||||
cooperative_groups.h
|
||||
cuda_bf16.h
|
||||
cuda_bf16.hpp
|
||||
cuda_device_runtime_api.h
|
||||
cuda_fp16.h
|
||||
cuda_fp16.hpp
|
||||
cuda_fp8.h
|
||||
cuda_fp8.hpp
|
||||
cuda_runtime_api.h
|
||||
device_types.h
|
||||
driver_types.h
|
||||
math_constants.h
|
||||
surface_types.h
|
||||
texture_types.h
|
||||
vector_functions.h
|
||||
vector_functions.hpp
|
||||
vector_types.h
|
||||
)
|
||||
foreach(_hdr ${_MLX_JIT_CUDA_HEADERS})
|
||||
install(FILES "${_cuda_inc}/${_hdr}"
|
||||
DESTINATION ${_dst}
|
||||
COMPONENT MLX)
|
||||
endforeach()
|
||||
# Subdirectory headers
|
||||
install(DIRECTORY "${_cuda_inc}/cooperative_groups"
|
||||
DESTINATION ${_dst}
|
||||
COMPONENT MLX
|
||||
FILES_MATCHING PATTERN "*.h")
|
||||
install(FILES "${_cuda_inc}/crt/host_defines.h"
|
||||
DESTINATION "${_dst}/crt"
|
||||
COMPONENT MLX)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# On Windows, explicitly install dl.dll (dlfcn-win32 POSIX dlopen emulation)
|
||||
# RUNTIME_DEPENDENCIES auto-excludes it via POST_EXCLUDE_FILES_STRICT because
|
||||
# dlfcn-win32 is a known CMake target with its own install rules (which install
|
||||
# to the wrong destination). We must install it explicitly here.
|
||||
if(WIN32)
|
||||
install(FILES ${OLLAMA_BUILD_DIR}/dl.dll
|
||||
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||
COMPONENT MLX)
|
||||
endif()
|
||||
|
||||
# Manually install CUDA runtime libraries that MLX loads via dlopen
|
||||
# (not detected by RUNTIME_DEPENDENCIES since they aren't link-time deps)
|
||||
if(CUDAToolkit_FOUND)
|
||||
file(GLOB MLX_CUDA_LIBS
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
|
||||
if(CUDART_LIBS)
|
||||
install(FILES ${CUDART_LIBS}
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*"
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcublasLt.so*"
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libnvrtc.so*"
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libnvrtc-builtins.so*"
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcufft.so*"
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcudnn.so*")
|
||||
if(MLX_CUDA_LIBS)
|
||||
install(FILES ${MLX_CUDA_LIBS}
|
||||
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||
COMPONENT MLX)
|
||||
endif()
|
||||
|
||||
@@ -77,6 +77,15 @@
|
||||
"OLLAMA_RUNNER_DIR": "rocm"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "ROCm 7",
|
||||
"inherits": [ "ROCm" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_HIP_FLAGS": "-parallel-jobs=4",
|
||||
"AMDGPU_TARGETS": "gfx942;gfx950;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1151;gfx1200;gfx1201;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-",
|
||||
"OLLAMA_RUNNER_DIR": "rocm"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Vulkan",
|
||||
"inherits": [ "Default" ],
|
||||
@@ -103,6 +112,7 @@
|
||||
"name": "MLX CUDA 13",
|
||||
"inherits": [ "MLX", "CUDA 13" ],
|
||||
"cacheVariables": {
|
||||
"MLX_CUDA_ARCHITECTURES": "86;89;90;90a;100;103;75-virtual;80-virtual;110-virtual;120-virtual;121-virtual",
|
||||
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
|
||||
}
|
||||
}
|
||||
@@ -158,6 +168,11 @@
|
||||
"inherits": [ "ROCm" ],
|
||||
"configurePreset": "ROCm 6"
|
||||
},
|
||||
{
|
||||
"name": "ROCm 7",
|
||||
"inherits": [ "ROCm" ],
|
||||
"configurePreset": "ROCm 7"
|
||||
},
|
||||
{
|
||||
"name": "Vulkan",
|
||||
"targets": [ "ggml-vulkan" ],
|
||||
|
||||
122
Dockerfile
@@ -1,28 +1,23 @@
|
||||
# vim: filetype=dockerfile
|
||||
|
||||
ARG FLAVOR=${TARGETARCH}
|
||||
ARG PARALLEL=8
|
||||
|
||||
ARG ROCMVERSION=6.3.3
|
||||
ARG ROCMVERSION=7.2
|
||||
ARG JETPACK5VERSION=r35.4.1
|
||||
ARG JETPACK6VERSION=r36.4.0
|
||||
ARG CMAKEVERSION=3.31.2
|
||||
ARG NINJAVERSION=1.12.1
|
||||
ARG VULKANVERSION=1.4.321.1
|
||||
|
||||
# Default empty stages for local MLX source overrides.
|
||||
# Override with: docker build --build-context local-mlx=../mlx --build-context local-mlx-c=../mlx-c
|
||||
FROM scratch AS local-mlx
|
||||
FROM scratch AS local-mlx-c
|
||||
|
||||
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
|
||||
RUN dnf install -y yum-utils ccache gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ gcc-toolset-11-binutils \
|
||||
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
||||
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
||||
ARG VULKANVERSION
|
||||
RUN wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
|
||||
&& tar xvf /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
|
||||
&& dnf -y install ninja-build \
|
||||
&& ln -s /usr/bin/python3 /usr/bin/python \
|
||||
&& /${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \
|
||||
&& /${VULKANVERSION}/vulkansdk -j 8 shaderc
|
||||
RUN cp -r /${VULKANVERSION}/x86_64/include/* /usr/local/include/ \
|
||||
&& cp -r /${VULKANVERSION}/x86_64/lib/* /usr/local/lib
|
||||
ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH
|
||||
|
||||
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
|
||||
# install epel-release for ccache
|
||||
@@ -33,100 +28,119 @@ ENV CC=clang CXX=clang++
|
||||
|
||||
FROM base-${TARGETARCH} AS base
|
||||
ARG CMAKEVERSION
|
||||
ARG NINJAVERSION
|
||||
RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||
RUN dnf install -y unzip \
|
||||
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux$([ "$(uname -m)" = "aarch64" ] && echo "-aarch64").zip \
|
||||
&& unzip /tmp/ninja.zip -d /usr/local/bin \
|
||||
&& rm /tmp/ninja.zip
|
||||
ENV CMAKE_GENERATOR=Ninja
|
||||
ENV LDFLAGS=-s
|
||||
|
||||
FROM base AS cpu
|
||||
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
|
||||
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
||||
ARG PARALLEL
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CPU' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CPU' \
|
||||
&& cmake --install build --component CPU --strip --parallel ${PARALLEL}
|
||||
&& cmake --build --preset 'CPU' -- -l $(nproc) \
|
||||
&& cmake --install build --component CPU --strip
|
||||
|
||||
FROM base AS cuda-11
|
||||
ARG CUDA11VERSION=11.8
|
||||
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
||||
ARG PARALLEL
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 11' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
&& cmake --build --preset 'CUDA 11' -- -l $(nproc) \
|
||||
&& cmake --install build --component CUDA --strip
|
||||
|
||||
FROM base AS cuda-12
|
||||
ARG CUDA12VERSION=12.8
|
||||
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
||||
ARG PARALLEL
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 12' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
&& cmake --build --preset 'CUDA 12' -- -l $(nproc) \
|
||||
&& cmake --install build --component CUDA --strip
|
||||
|
||||
|
||||
FROM base AS cuda-13
|
||||
ARG CUDA13VERSION=13.0
|
||||
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
||||
ARG PARALLEL
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 13' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
&& cmake --build --preset 'CUDA 13' -- -l $(nproc) \
|
||||
&& cmake --install build --component CUDA --strip
|
||||
|
||||
|
||||
FROM base AS rocm-6
|
||||
FROM base AS rocm-7
|
||||
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
||||
ARG PARALLEL
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'ROCm 6' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
|
||||
&& cmake --install build --component HIP --strip --parallel ${PARALLEL}
|
||||
cmake --preset 'ROCm 7' \
|
||||
&& cmake --build --preset 'ROCm 7' -- -l $(nproc) \
|
||||
&& cmake --install build --component HIP --strip
|
||||
RUN rm -f dist/lib/ollama/rocm/rocblas/library/*gfx90[06]*
|
||||
|
||||
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
|
||||
ARG CMAKEVERSION
|
||||
RUN apt-get update && apt-get install -y curl ccache \
|
||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||
ARG NINJAVERSION
|
||||
RUN apt-get update && apt-get install -y curl ccache unzip \
|
||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 \
|
||||
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux-aarch64.zip \
|
||||
&& unzip /tmp/ninja.zip -d /usr/local/bin \
|
||||
&& rm /tmp/ninja.zip
|
||||
ENV CMAKE_GENERATOR=Ninja
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
ARG PARALLEL
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'JetPack 5' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 5' \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
&& cmake --build --preset 'JetPack 5' -- -l $(nproc) \
|
||||
&& cmake --install build --component CUDA --strip
|
||||
|
||||
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
|
||||
ARG CMAKEVERSION
|
||||
RUN apt-get update && apt-get install -y curl ccache \
|
||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||
ARG NINJAVERSION
|
||||
RUN apt-get update && apt-get install -y curl ccache unzip \
|
||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 \
|
||||
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux-aarch64.zip \
|
||||
&& unzip /tmp/ninja.zip -d /usr/local/bin \
|
||||
&& rm /tmp/ninja.zip
|
||||
ENV CMAKE_GENERATOR=Ninja
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
ARG PARALLEL
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'JetPack 6' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
&& cmake --build --preset 'JetPack 6' -- -l $(nproc) \
|
||||
&& cmake --install build --component CUDA --strip
|
||||
|
||||
FROM base AS vulkan
|
||||
ARG VULKANVERSION
|
||||
RUN ln -s /usr/bin/python3 /usr/bin/python \
|
||||
&& wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk.tar.xz \
|
||||
&& tar xvf /tmp/vulkansdk.tar.xz -C /tmp \
|
||||
&& /tmp/${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \
|
||||
&& /tmp/${VULKANVERSION}/vulkansdk -j 8 shaderc \
|
||||
&& cp -r /tmp/${VULKANVERSION}/x86_64/include/* /usr/local/include/ \
|
||||
&& cp -r /tmp/${VULKANVERSION}/x86_64/lib/* /usr/local/lib \
|
||||
&& cp -r /tmp/${VULKANVERSION}/x86_64/bin/* /usr/local/bin/ \
|
||||
&& rm -rf /tmp/${VULKANVERSION} /tmp/vulkansdk.tar.xz
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'Vulkan' \
|
||||
&& cmake --build --parallel --preset 'Vulkan' \
|
||||
&& cmake --install build --component Vulkan --strip --parallel 8
|
||||
&& cmake --build --preset 'Vulkan' -- -l $(nproc) \
|
||||
&& cmake --install build --component Vulkan --strip
|
||||
|
||||
FROM base AS mlx
|
||||
ARG CUDA13VERSION=13.0
|
||||
@@ -138,20 +152,27 @@ ENV PATH=/usr/local/cuda-13/bin:$PATH
|
||||
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
|
||||
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
|
||||
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
|
||||
ARG PARALLEL
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
COPY x/imagegen/mlx x/imagegen/mlx
|
||||
COPY go.mod go.sum .
|
||||
COPY MLX_VERSION .
|
||||
COPY MLX_VERSION MLX_C_VERSION .
|
||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
RUN go mod download
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
|
||||
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
|
||||
--mount=type=bind,from=local-mlx,target=/tmp/local-mlx \
|
||||
--mount=type=bind,from=local-mlx-c,target=/tmp/local-mlx-c \
|
||||
if [ -f /tmp/local-mlx/CMakeLists.txt ]; then \
|
||||
export OLLAMA_MLX_SOURCE=/tmp/local-mlx; \
|
||||
fi \
|
||||
&& if [ -f /tmp/local-mlx-c/CMakeLists.txt ]; then \
|
||||
export OLLAMA_MLX_C_SOURCE=/tmp/local-mlx-c; \
|
||||
fi \
|
||||
&& cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
||||
&& cmake --build --preset 'MLX CUDA 13' -- -l $(nproc) \
|
||||
&& cmake --install build --component MLX --strip
|
||||
|
||||
FROM base AS build
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
@@ -160,16 +181,14 @@ RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
# Clone mlx-c headers for CGO (version from MLX_VERSION file)
|
||||
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
ENV CGO_CFLAGS="${CGO_CFLAGS} -I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
|
||||
ENV CGO_CFLAGS="${CGO_CFLAGS}"
|
||||
ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}"
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .
|
||||
go build -trimpath -buildmode=pie -o /bin/ollama .
|
||||
|
||||
FROM --platform=linux/amd64 scratch AS amd64
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
@@ -186,10 +205,9 @@ COPY --from=jetpack-5 dist/lib/ollama/ /lib/ollama/
|
||||
COPY --from=jetpack-6 dist/lib/ollama/ /lib/ollama/
|
||||
|
||||
FROM scratch AS rocm
|
||||
COPY --from=rocm-6 dist/lib/ollama /lib/ollama
|
||||
COPY --from=rocm-7 dist/lib/ollama /lib/ollama
|
||||
|
||||
FROM ${FLAVOR} AS archive
|
||||
ARG VULKANVERSION
|
||||
COPY --from=cpu dist/lib/ollama /lib/ollama
|
||||
COPY --from=build /bin/ollama /bin/ollama
|
||||
|
||||
|
||||
1
MLX_C_VERSION
Normal file
@@ -0,0 +1 @@
|
||||
0726ca922fc902c4c61ef9c27d94132be418e945
|
||||
@@ -1 +1 @@
|
||||
v0.5.0
|
||||
38ad257088fb2193ad47e527cf6534a689f30943
|
||||
|
||||
@@ -68,7 +68,7 @@ type MessagesRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Messages []MessageParam `json:"messages"`
|
||||
System any `json:"system,omitempty"` // string or []ContentBlock
|
||||
System any `json:"system,omitempty"` // string or []map[string]any (JSON-decoded ContentBlock)
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
@@ -82,8 +82,27 @@ type MessagesRequest struct {
|
||||
|
||||
// MessageParam represents a message in the request
|
||||
type MessageParam struct {
|
||||
Role string `json:"role"` // "user" or "assistant"
|
||||
Content any `json:"content"` // string or []ContentBlock
|
||||
Role string `json:"role"` // "user" or "assistant"
|
||||
Content []ContentBlock `json:"content"` // always []ContentBlock; plain strings are normalized on unmarshal
|
||||
}
|
||||
|
||||
func (m *MessageParam) UnmarshalJSON(data []byte) error {
|
||||
var raw struct {
|
||||
Role string `json:"role"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return err
|
||||
}
|
||||
m.Role = raw.Role
|
||||
|
||||
var s string
|
||||
if err := json.Unmarshal(raw.Content, &s); err == nil {
|
||||
m.Content = []ContentBlock{{Type: "text", Text: &s}}
|
||||
return nil
|
||||
}
|
||||
|
||||
return json.Unmarshal(raw.Content, &m.Content)
|
||||
}
|
||||
|
||||
// ContentBlock represents a content block in a message.
|
||||
@@ -102,9 +121,9 @@ type ContentBlock struct {
|
||||
Source *ImageSource `json:"source,omitempty"`
|
||||
|
||||
// For tool_use and server_tool_use blocks
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input api.ToolCallFunctionArguments `json:"input,omitzero"`
|
||||
|
||||
// For tool_result and web_search_tool_result blocks
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
@@ -377,178 +396,145 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
var messages []api.Message
|
||||
role := strings.ToLower(msg.Role)
|
||||
|
||||
switch content := msg.Content.(type) {
|
||||
case string:
|
||||
messages = append(messages, api.Message{Role: role, Content: content})
|
||||
var textContent strings.Builder
|
||||
var images []api.ImageData
|
||||
var toolCalls []api.ToolCall
|
||||
var thinking string
|
||||
var toolResults []api.Message
|
||||
textBlocks := 0
|
||||
imageBlocks := 0
|
||||
toolUseBlocks := 0
|
||||
toolResultBlocks := 0
|
||||
serverToolUseBlocks := 0
|
||||
webSearchToolResultBlocks := 0
|
||||
thinkingBlocks := 0
|
||||
unknownBlocks := 0
|
||||
|
||||
case []any:
|
||||
var textContent strings.Builder
|
||||
var images []api.ImageData
|
||||
var toolCalls []api.ToolCall
|
||||
var thinking string
|
||||
var toolResults []api.Message
|
||||
textBlocks := 0
|
||||
imageBlocks := 0
|
||||
toolUseBlocks := 0
|
||||
toolResultBlocks := 0
|
||||
serverToolUseBlocks := 0
|
||||
webSearchToolResultBlocks := 0
|
||||
thinkingBlocks := 0
|
||||
unknownBlocks := 0
|
||||
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
logutil.Trace("anthropic: invalid content block format", "role", role)
|
||||
return nil, errors.New("invalid content block format")
|
||||
for _, block := range msg.Content {
|
||||
switch block.Type {
|
||||
case "text":
|
||||
textBlocks++
|
||||
if block.Text != nil {
|
||||
textContent.WriteString(*block.Text)
|
||||
}
|
||||
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
case "image":
|
||||
imageBlocks++
|
||||
if block.Source == nil {
|
||||
logutil.Trace("anthropic: invalid image source", "role", role)
|
||||
return nil, errors.New("invalid image source")
|
||||
}
|
||||
|
||||
switch blockType {
|
||||
case "text":
|
||||
textBlocks++
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
textContent.WriteString(text)
|
||||
if block.Source.Type == "base64" {
|
||||
decoded, err := base64.StdEncoding.DecodeString(block.Source.Data)
|
||||
if err != nil {
|
||||
logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err)
|
||||
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
||||
}
|
||||
images = append(images, decoded)
|
||||
} else {
|
||||
logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", block.Source.Type)
|
||||
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", block.Source.Type)
|
||||
}
|
||||
|
||||
case "image":
|
||||
imageBlocks++
|
||||
source, ok := blockMap["source"].(map[string]any)
|
||||
if !ok {
|
||||
logutil.Trace("anthropic: invalid image source", "role", role)
|
||||
return nil, errors.New("invalid image source")
|
||||
}
|
||||
case "tool_use":
|
||||
toolUseBlocks++
|
||||
if block.ID == "" {
|
||||
logutil.Trace("anthropic: tool_use block missing id", "role", role)
|
||||
return nil, errors.New("tool_use block missing required 'id' field")
|
||||
}
|
||||
if block.Name == "" {
|
||||
logutil.Trace("anthropic: tool_use block missing name", "role", role)
|
||||
return nil, errors.New("tool_use block missing required 'name' field")
|
||||
}
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
ID: block.ID,
|
||||
Function: api.ToolCallFunction{
|
||||
Name: block.Name,
|
||||
Arguments: block.Input,
|
||||
},
|
||||
})
|
||||
|
||||
sourceType, _ := source["type"].(string)
|
||||
if sourceType == "base64" {
|
||||
data, _ := source["data"].(string)
|
||||
decoded, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err)
|
||||
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
||||
}
|
||||
images = append(images, decoded)
|
||||
} else {
|
||||
logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", sourceType)
|
||||
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
|
||||
}
|
||||
// URL images would need to be fetched - skip for now
|
||||
case "tool_result":
|
||||
toolResultBlocks++
|
||||
var resultContent string
|
||||
|
||||
case "tool_use":
|
||||
toolUseBlocks++
|
||||
id, ok := blockMap["id"].(string)
|
||||
if !ok {
|
||||
logutil.Trace("anthropic: tool_use block missing id", "role", role)
|
||||
return nil, errors.New("tool_use block missing required 'id' field")
|
||||
}
|
||||
name, ok := blockMap["name"].(string)
|
||||
if !ok {
|
||||
logutil.Trace("anthropic: tool_use block missing name", "role", role)
|
||||
return nil, errors.New("tool_use block missing required 'name' field")
|
||||
}
|
||||
tc := api.ToolCall{
|
||||
ID: id,
|
||||
Function: api.ToolCallFunction{
|
||||
Name: name,
|
||||
},
|
||||
}
|
||||
if input, ok := blockMap["input"].(map[string]any); ok {
|
||||
tc.Function.Arguments = mapToArgs(input)
|
||||
}
|
||||
toolCalls = append(toolCalls, tc)
|
||||
|
||||
case "tool_result":
|
||||
toolResultBlocks++
|
||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||
var resultContent string
|
||||
|
||||
switch c := blockMap["content"].(type) {
|
||||
case string:
|
||||
resultContent = c
|
||||
case []any:
|
||||
for _, cb := range c {
|
||||
if cbMap, ok := cb.(map[string]any); ok {
|
||||
if cbMap["type"] == "text" {
|
||||
if text, ok := cbMap["text"].(string); ok {
|
||||
resultContent += text
|
||||
}
|
||||
switch c := block.Content.(type) {
|
||||
case string:
|
||||
resultContent = c
|
||||
case []any:
|
||||
for _, cb := range c {
|
||||
if cbMap, ok := cb.(map[string]any); ok {
|
||||
if cbMap["type"] == "text" {
|
||||
if text, ok := cbMap["text"].(string); ok {
|
||||
resultContent += text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: resultContent,
|
||||
ToolCallID: toolUseID,
|
||||
})
|
||||
|
||||
case "thinking":
|
||||
thinkingBlocks++
|
||||
if t, ok := blockMap["thinking"].(string); ok {
|
||||
thinking = t
|
||||
}
|
||||
|
||||
case "server_tool_use":
|
||||
serverToolUseBlocks++
|
||||
id, _ := blockMap["id"].(string)
|
||||
name, _ := blockMap["name"].(string)
|
||||
tc := api.ToolCall{
|
||||
ID: id,
|
||||
Function: api.ToolCallFunction{
|
||||
Name: name,
|
||||
},
|
||||
}
|
||||
if input, ok := blockMap["input"].(map[string]any); ok {
|
||||
tc.Function.Arguments = mapToArgs(input)
|
||||
}
|
||||
toolCalls = append(toolCalls, tc)
|
||||
|
||||
case "web_search_tool_result":
|
||||
webSearchToolResultBlocks++
|
||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: formatWebSearchToolResultContent(blockMap["content"]),
|
||||
ToolCallID: toolUseID,
|
||||
})
|
||||
default:
|
||||
unknownBlocks++
|
||||
}
|
||||
}
|
||||
|
||||
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
|
||||
m := api.Message{
|
||||
Role: role,
|
||||
Content: textContent.String(),
|
||||
Images: images,
|
||||
ToolCalls: toolCalls,
|
||||
Thinking: thinking,
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: resultContent,
|
||||
ToolCallID: block.ToolUseID,
|
||||
})
|
||||
|
||||
case "thinking":
|
||||
thinkingBlocks++
|
||||
if block.Thinking != nil {
|
||||
thinking = *block.Thinking
|
||||
}
|
||||
messages = append(messages, m)
|
||||
|
||||
case "server_tool_use":
|
||||
serverToolUseBlocks++
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
ID: block.ID,
|
||||
Function: api.ToolCallFunction{
|
||||
Name: block.Name,
|
||||
Arguments: block.Input,
|
||||
},
|
||||
})
|
||||
|
||||
case "web_search_tool_result":
|
||||
webSearchToolResultBlocks++
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: formatWebSearchToolResultContent(block.Content),
|
||||
ToolCallID: block.ToolUseID,
|
||||
})
|
||||
default:
|
||||
unknownBlocks++
|
||||
}
|
||||
|
||||
// Add tool results as separate messages
|
||||
messages = append(messages, toolResults...)
|
||||
logutil.Trace("anthropic: converted block message",
|
||||
"role", role,
|
||||
"blocks", len(content),
|
||||
"text", textBlocks,
|
||||
"image", imageBlocks,
|
||||
"tool_use", toolUseBlocks,
|
||||
"tool_result", toolResultBlocks,
|
||||
"server_tool_use", serverToolUseBlocks,
|
||||
"web_search_result", webSearchToolResultBlocks,
|
||||
"thinking", thinkingBlocks,
|
||||
"unknown", unknownBlocks,
|
||||
"messages", TraceAPIMessages(messages),
|
||||
)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid message content type: %T", content)
|
||||
}
|
||||
|
||||
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
|
||||
m := api.Message{
|
||||
Role: role,
|
||||
Content: textContent.String(),
|
||||
Images: images,
|
||||
ToolCalls: toolCalls,
|
||||
Thinking: thinking,
|
||||
}
|
||||
messages = append(messages, m)
|
||||
}
|
||||
|
||||
// Add tool results as separate messages
|
||||
messages = append(messages, toolResults...)
|
||||
logutil.Trace("anthropic: converted block message",
|
||||
"role", role,
|
||||
"blocks", len(msg.Content),
|
||||
"text", textBlocks,
|
||||
"image", imageBlocks,
|
||||
"tool_use", toolUseBlocks,
|
||||
"tool_result", toolResultBlocks,
|
||||
"server_tool_use", serverToolUseBlocks,
|
||||
"web_search_result", webSearchToolResultBlocks,
|
||||
"thinking", thinkingBlocks,
|
||||
"unknown", unknownBlocks,
|
||||
"messages", TraceAPIMessages(messages),
|
||||
)
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
@@ -852,6 +838,19 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
continue
|
||||
}
|
||||
|
||||
// Close thinking block if still open (thinking → tool_use without text in between)
|
||||
if c.thinkingStarted && !c.thinkingDone {
|
||||
c.thinkingDone = true
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
c.contentIndex++
|
||||
}
|
||||
|
||||
if c.textStarted {
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
@@ -869,7 +868,6 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_start",
|
||||
Data: ContentBlockStartEvent{
|
||||
@@ -879,7 +877,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Input: map[string]any{},
|
||||
Input: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -976,15 +974,6 @@ func ptr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
// mapToArgs converts a map to ToolCallFunctionArguments
|
||||
func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// CountTokensRequest represents an Anthropic count_tokens request
|
||||
type CountTokensRequest struct {
|
||||
Model string `json:"model"`
|
||||
@@ -1017,17 +1006,13 @@ func estimateTokens(req CountTokensRequest) int {
|
||||
var totalLen int
|
||||
|
||||
// Count system prompt
|
||||
if req.System != nil {
|
||||
totalLen += countAnyContent(req.System)
|
||||
}
|
||||
totalLen += countAnyContent(req.System)
|
||||
|
||||
// Count messages
|
||||
for _, msg := range req.Messages {
|
||||
// Count role (always present)
|
||||
totalLen += len(msg.Role)
|
||||
// Count content
|
||||
contentLen := countAnyContent(msg.Content)
|
||||
totalLen += contentLen
|
||||
totalLen += countAnyContent(msg.Content)
|
||||
}
|
||||
|
||||
for _, tool := range req.Tools {
|
||||
@@ -1050,12 +1035,25 @@ func countAnyContent(content any) int {
|
||||
switch c := content.(type) {
|
||||
case string:
|
||||
return len(c)
|
||||
case []any:
|
||||
case []ContentBlock:
|
||||
total := 0
|
||||
for _, block := range c {
|
||||
total += countContentBlock(block)
|
||||
}
|
||||
return total
|
||||
case []any:
|
||||
total := 0
|
||||
for _, item := range c {
|
||||
data, err := json.Marshal(item)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var block ContentBlock
|
||||
if err := json.Unmarshal(data, &block); err == nil {
|
||||
total += countContentBlock(block)
|
||||
}
|
||||
}
|
||||
return total
|
||||
default:
|
||||
if data, err := json.Marshal(content); err == nil {
|
||||
return len(data)
|
||||
@@ -1064,38 +1062,19 @@ func countAnyContent(content any) int {
|
||||
}
|
||||
}
|
||||
|
||||
func countContentBlock(block any) int {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
if s, ok := block.(string); ok {
|
||||
return len(s)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func countContentBlock(block ContentBlock) int {
|
||||
total := 0
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
total += len(text)
|
||||
if block.Text != nil {
|
||||
total += len(*block.Text)
|
||||
}
|
||||
|
||||
if thinking, ok := blockMap["thinking"].(string); ok {
|
||||
total += len(thinking)
|
||||
if block.Thinking != nil {
|
||||
total += len(*block.Thinking)
|
||||
}
|
||||
|
||||
if blockType == "tool_use" {
|
||||
if data, err := json.Marshal(blockMap); err == nil {
|
||||
if block.Type == "tool_use" || block.Type == "tool_result" {
|
||||
if data, err := json.Marshal(block); err == nil {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
if blockType == "tool_result" {
|
||||
if data, err := json.Marshal(blockMap); err == nil {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
return total
|
||||
}
|
||||
|
||||
|
||||
@@ -15,11 +15,16 @@ const (
|
||||
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||
)
|
||||
|
||||
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
|
||||
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
// textContent is a convenience for constructing []ContentBlock with a single text block in tests.
|
||||
func textContent(s string) []ContentBlock {
|
||||
return []ContentBlock{{Type: "text", Text: &s}}
|
||||
}
|
||||
|
||||
// makeArgs creates ToolCallFunctionArguments from key-value pairs (convenience function for tests)
|
||||
func makeArgs(kvs ...any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
for i := 0; i < len(kvs)-1; i += 2 {
|
||||
args.Set(kvs[i].(string), kvs[i+1])
|
||||
}
|
||||
return args
|
||||
}
|
||||
@@ -29,7 +34,7 @@ func TestFromMessagesRequest_Basic(t *testing.T) {
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "user", Content: textContent("Hello")},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -61,7 +66,7 @@ func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
|
||||
MaxTokens: 1024,
|
||||
System: "You are a helpful assistant.",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "user", Content: textContent("Hello")},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -88,7 +93,7 @@ func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
|
||||
map[string]any{"type": "text", "text": " Be concise."},
|
||||
},
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "user", Content: textContent("Hello")},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -113,7 +118,7 @@ func TestFromMessagesRequest_WithOptions(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 2048,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
|
||||
Temperature: &temp,
|
||||
TopP: &topP,
|
||||
TopK: &topK,
|
||||
@@ -148,14 +153,14 @@ func TestFromMessagesRequest_WithImage(t *testing.T) {
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{"type": "text", "text": "What's in this image?"},
|
||||
map[string]any{
|
||||
"type": "image",
|
||||
"source": map[string]any{
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": testImage,
|
||||
Content: []ContentBlock{
|
||||
{Type: "text", Text: ptr("What's in this image?")},
|
||||
{
|
||||
Type: "image",
|
||||
Source: &ImageSource{
|
||||
Type: "base64",
|
||||
MediaType: "image/png",
|
||||
Data: testImage,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -190,15 +195,15 @@ func TestFromMessagesRequest_WithToolUse(t *testing.T) {
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "What's the weather in Paris?"},
|
||||
{Role: "user", Content: textContent("What's the weather in Paris?")},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "call_123",
|
||||
"name": "get_weather",
|
||||
"input": map[string]any{"location": "Paris"},
|
||||
Content: []ContentBlock{
|
||||
{
|
||||
Type: "tool_use",
|
||||
ID: "call_123",
|
||||
Name: "get_weather",
|
||||
Input: makeArgs("location", "Paris"),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -234,11 +239,11 @@ func TestFromMessagesRequest_WithToolResult(t *testing.T) {
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "call_123",
|
||||
"content": "The weather in Paris is sunny, 22°C",
|
||||
Content: []ContentBlock{
|
||||
{
|
||||
Type: "tool_result",
|
||||
ToolUseID: "call_123",
|
||||
Content: "The weather in Paris is sunny, 22°C",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -270,7 +275,7 @@ func TestFromMessagesRequest_WithTools(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Name: "get_weather",
|
||||
@@ -305,7 +310,7 @@ func TestFromMessagesRequest_DropsCustomWebSearchWhenBuiltinPresent(t *testing.T
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Type: "web_search_20250305",
|
||||
@@ -346,7 +351,7 @@ func TestFromMessagesRequest_KeepsCustomWebSearchWhenBuiltinAbsent(t *testing.T)
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Type: "custom",
|
||||
@@ -377,7 +382,7 @@ func TestFromMessagesRequest_WithThinking(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
|
||||
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
|
||||
}
|
||||
|
||||
@@ -399,13 +404,13 @@ func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "user", Content: textContent("Hello")},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "Let me think about this...",
|
||||
Content: []ContentBlock{
|
||||
{
|
||||
Type: "thinking",
|
||||
Thinking: ptr("Let me think about this..."),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -434,10 +439,10 @@ func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"name": "get_weather",
|
||||
Content: []ContentBlock{
|
||||
{
|
||||
Type: "tool_use",
|
||||
Name: "get_weather",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -460,10 +465,10 @@ func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "call_123",
|
||||
Content: []ContentBlock{
|
||||
{
|
||||
Type: "tool_use",
|
||||
ID: "call_123",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -483,7 +488,7 @@ func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Name: "bad_tool",
|
||||
@@ -548,7 +553,7 @@ func TestToMessagesResponse_WithToolCalls(t *testing.T) {
|
||||
ID: "call_123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
Arguments: makeArgs("location", "Paris"),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -760,7 +765,7 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
ID: "call_123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
Arguments: makeArgs("location", "Paris"),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -799,6 +804,107 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamConverter_ThinkingDirectlyFollowedByToolCall verifies that when a
|
||||
// model emits a thinking block followed directly by a tool_use block (with no
|
||||
// text block in between), the streaming converter correctly closes the thinking
|
||||
// block and increments the content index before opening the tool_use block.
|
||||
// Previously, the converter reused contentIndex=0 for the tool_use block,
|
||||
// which caused "Content block not found" errors in clients. See #14816.
|
||||
func TestStreamConverter_ThinkingDirectlyFollowedByToolCall(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
// First chunk: thinking content (no text)
|
||||
resp1 := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Thinking: "I should call the tool.",
|
||||
},
|
||||
}
|
||||
events1 := conv.Process(resp1)
|
||||
|
||||
// Should have: message_start, content_block_start(thinking), content_block_delta(thinking)
|
||||
if len(events1) < 3 {
|
||||
t.Fatalf("expected at least 3 events for thinking chunk, got %d", len(events1))
|
||||
}
|
||||
if events1[0].Event != "message_start" {
|
||||
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
|
||||
}
|
||||
thinkingStart, ok := events1[1].Data.(ContentBlockStartEvent)
|
||||
if !ok || thinkingStart.ContentBlock.Type != "thinking" {
|
||||
t.Errorf("expected content_block_start(thinking) as second event, got %+v", events1[1])
|
||||
}
|
||||
if thinkingStart.Index != 0 {
|
||||
t.Errorf("expected thinking block at index 0, got %d", thinkingStart.Index)
|
||||
}
|
||||
|
||||
// Second chunk: tool call (no text between thinking and tool)
|
||||
resp2 := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_abc",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "ask_user",
|
||||
Arguments: makeArgs("question", "cats or dogs?"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
|
||||
}
|
||||
events2 := conv.Process(resp2)
|
||||
|
||||
// Expect: content_block_stop(index=0), content_block_start(tool_use, index=1),
|
||||
// content_block_delta(input_json_delta, index=1), content_block_stop(index=1),
|
||||
// message_delta, message_stop
|
||||
var thinkingStop, toolStart, toolDelta, toolStop *StreamEvent
|
||||
for i := range events2 {
|
||||
e := &events2[i]
|
||||
switch e.Event {
|
||||
case "content_block_stop":
|
||||
if stop, ok := e.Data.(ContentBlockStopEvent); ok {
|
||||
if stop.Index == 0 && thinkingStop == nil {
|
||||
thinkingStop = e
|
||||
} else if stop.Index == 1 {
|
||||
toolStop = e
|
||||
}
|
||||
}
|
||||
case "content_block_start":
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok && start.ContentBlock.Type == "tool_use" {
|
||||
toolStart = e
|
||||
}
|
||||
case "content_block_delta":
|
||||
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok && delta.Delta.Type == "input_json_delta" {
|
||||
toolDelta = e
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if thinkingStop == nil {
|
||||
t.Error("expected content_block_stop for thinking block (index 0)")
|
||||
}
|
||||
if toolStart == nil {
|
||||
t.Fatal("expected content_block_start for tool_use block")
|
||||
}
|
||||
if start, ok := toolStart.Data.(ContentBlockStartEvent); !ok || start.Index != 1 {
|
||||
t.Errorf("expected tool_use block at index 1, got %+v", toolStart.Data)
|
||||
}
|
||||
if toolDelta == nil {
|
||||
t.Fatal("expected input_json_delta event for tool call")
|
||||
}
|
||||
if delta, ok := toolDelta.Data.(ContentBlockDeltaEvent); !ok || delta.Index != 1 {
|
||||
t.Errorf("expected tool delta at index 1, got %+v", toolDelta.Data)
|
||||
}
|
||||
if toolStop == nil {
|
||||
t.Error("expected content_block_stop for tool_use block (index 1)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||
// Test that unmarshalable arguments (like channels) are handled gracefully
|
||||
// and don't cause a panic or corrupt stream
|
||||
@@ -864,7 +970,7 @@ func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
||||
ID: "call_good",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "good_function",
|
||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
||||
Arguments: makeArgs("location", "Paris"),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -966,6 +1072,57 @@ func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestContentBlockJSON_NonToolBlocksDoNotIncludeInput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
block ContentBlock
|
||||
}{
|
||||
{
|
||||
name: "text block",
|
||||
block: ContentBlock{
|
||||
Type: "text",
|
||||
Text: ptr("hello"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "thinking block",
|
||||
block: ContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: ptr("let me think"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image block",
|
||||
block: ContentBlock{
|
||||
Type: "image",
|
||||
Source: &ImageSource{
|
||||
Type: "base64",
|
||||
MediaType: "image/png",
|
||||
Data: testImage,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(tt.block)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := result["input"]; ok {
|
||||
t.Fatalf("unexpected input field in non-tool block JSON: %s", string(data))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
t.Run("text block start includes empty text", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
@@ -986,7 +1143,9 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
// Marshal and verify the text field is present
|
||||
data, _ := json.Marshal(start)
|
||||
var result map[string]any
|
||||
json.Unmarshal(data, &result)
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
t.Fatalf("failed to unmarshal content_block_start JSON: %v", err)
|
||||
}
|
||||
cb := result["content_block"].(map[string]any)
|
||||
if _, ok := cb["text"]; !ok {
|
||||
t.Error("content_block_start for text should include 'text' field")
|
||||
@@ -1033,13 +1192,71 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
t.Error("expected thinking content_block_start event")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tool_use block start includes empty input object", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: makeArgs("location", "Paris"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
events := conv.Process(resp)
|
||||
|
||||
var foundToolStart bool
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "tool_use" {
|
||||
foundToolStart = true
|
||||
if start.ContentBlock.Input.Len() != 0 {
|
||||
t.Errorf("expected empty input object, got len=%d", start.ContentBlock.Input.Len())
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(start)
|
||||
var result map[string]any
|
||||
json.Unmarshal(data, &result)
|
||||
cb := result["content_block"].(map[string]any)
|
||||
input, ok := cb["input"]
|
||||
if !ok {
|
||||
t.Error("content_block_start for tool_use should include 'input' field")
|
||||
continue
|
||||
}
|
||||
inputMap, ok := input.(map[string]any)
|
||||
if !ok {
|
||||
t.Errorf("input field should be an object, got %T", input)
|
||||
continue
|
||||
}
|
||||
if len(inputMap) != 0 {
|
||||
t.Errorf("expected empty input object in content_block_start, got %v", inputMap)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundToolStart {
|
||||
t.Error("expected tool_use content_block_start event")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEstimateTokens_SimpleMessage(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello, world!"},
|
||||
{Role: "user", Content: textContent("Hello, world!")},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1060,7 +1277,7 @@ func TestEstimateTokens_WithSystemPrompt(t *testing.T) {
|
||||
Model: "test-model",
|
||||
System: "You are a helpful assistant.",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "user", Content: textContent("Hello")},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1076,7 +1293,7 @@ func TestEstimateTokens_WithTools(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{Role: "user", Content: textContent("What's the weather?")},
|
||||
},
|
||||
Tools: []Tool{
|
||||
{
|
||||
@@ -1099,17 +1316,17 @@ func TestEstimateTokens_WithThinking(t *testing.T) {
|
||||
req := CountTokensRequest{
|
||||
Model: "test-model",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "user", Content: textContent("Hello")},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "Let me think about this carefully...",
|
||||
Content: []ContentBlock{
|
||||
{
|
||||
Type: "thinking",
|
||||
Thinking: ptr("Let me think about this carefully..."),
|
||||
},
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "Here is my response.",
|
||||
{
|
||||
Type: "text",
|
||||
Text: ptr("Here is my response."),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1207,12 +1424,12 @@ func TestConvertTool_RegularTool(t *testing.T) {
|
||||
func TestConvertMessage_ServerToolUse(t *testing.T) {
|
||||
msg := MessageParam{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "server_tool_use",
|
||||
"id": "srvtoolu_123",
|
||||
"name": "web_search",
|
||||
"input": map[string]any{"query": "test query"},
|
||||
Content: []ContentBlock{
|
||||
{
|
||||
Type: "server_tool_use",
|
||||
ID: "srvtoolu_123",
|
||||
Name: "web_search",
|
||||
Input: makeArgs("query", "test query"),
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -1243,11 +1460,11 @@ func TestConvertMessage_ServerToolUse(t *testing.T) {
|
||||
func TestConvertMessage_WebSearchToolResult(t *testing.T) {
|
||||
msg := MessageParam{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": "srvtoolu_123",
|
||||
"content": []any{
|
||||
Content: []ContentBlock{
|
||||
{
|
||||
Type: "web_search_tool_result",
|
||||
ToolUseID: "srvtoolu_123",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "web_search_result",
|
||||
"title": "Test Result",
|
||||
@@ -1284,11 +1501,11 @@ func TestConvertMessage_WebSearchToolResult(t *testing.T) {
|
||||
func TestConvertMessage_WebSearchToolResultEmptyStillCreatesToolMessage(t *testing.T) {
|
||||
msg := MessageParam{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": "srvtoolu_empty",
|
||||
"content": []any{},
|
||||
Content: []ContentBlock{
|
||||
{
|
||||
Type: "web_search_tool_result",
|
||||
ToolUseID: "srvtoolu_empty",
|
||||
Content: []any{},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -1315,11 +1532,11 @@ func TestConvertMessage_WebSearchToolResultEmptyStillCreatesToolMessage(t *testi
|
||||
func TestConvertMessage_WebSearchToolResultErrorStillCreatesToolMessage(t *testing.T) {
|
||||
msg := MessageParam{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": "srvtoolu_error",
|
||||
"content": map[string]any{
|
||||
Content: []ContentBlock{
|
||||
{
|
||||
Type: "web_search_tool_result",
|
||||
ToolUseID: "srvtoolu_error",
|
||||
Content: map[string]any{
|
||||
"type": "web_search_tool_result_error",
|
||||
"error_code": "max_uses_exceeded",
|
||||
},
|
||||
|
||||
@@ -476,25 +476,3 @@ func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// AliasRequest is the request body for creating or updating a model alias.
|
||||
type AliasRequest struct {
|
||||
Alias string `json:"alias"`
|
||||
Target string `json:"target"`
|
||||
PrefixMatching bool `json:"prefix_matching,omitempty"`
|
||||
}
|
||||
|
||||
// SetAliasExperimental creates or updates a model alias via the experimental aliases API.
|
||||
func (c *Client) SetAliasExperimental(ctx context.Context, req *AliasRequest) error {
|
||||
return c.do(ctx, http.MethodPost, "/api/experimental/aliases", req, nil)
|
||||
}
|
||||
|
||||
// AliasDeleteRequest is the request body for deleting a model alias.
|
||||
type AliasDeleteRequest struct {
|
||||
Alias string `json:"alias"`
|
||||
}
|
||||
|
||||
// DeleteAliasExperimental deletes a model alias via the experimental aliases API.
|
||||
func (c *Client) DeleteAliasExperimental(ctx context.Context, req *AliasDeleteRequest) error {
|
||||
return c.do(ctx, http.MethodDelete, "/api/experimental/aliases", req, nil)
|
||||
}
|
||||
|
||||
@@ -436,6 +436,7 @@ type ToolProperty struct {
|
||||
Description string `json:"description,omitempty"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
Properties *ToolPropertiesMap `json:"properties,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
// ToTypeScriptType converts a ToolProperty to a TypeScript type string
|
||||
|
||||
@@ -550,14 +550,12 @@ export class Error {
|
||||
}
|
||||
}
|
||||
export class ModelUpstreamResponse {
|
||||
digest?: string;
|
||||
pushTime: number;
|
||||
stale: boolean;
|
||||
error?: string;
|
||||
|
||||
constructor(source: any = {}) {
|
||||
if ('string' === typeof source) source = JSON.parse(source);
|
||||
this.digest = source["digest"];
|
||||
this.pushTime = source["pushTime"];
|
||||
this.stale = source["stale"];
|
||||
this.error = source["error"];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -161,7 +161,7 @@ export async function getModels(query?: string): Promise<Model[]> {
|
||||
// Add query if it's in the registry and not already in the list
|
||||
if (!exactMatch) {
|
||||
const result = await getModelUpstreamInfo(new Model({ model: query }));
|
||||
const existsUpstream = !!result.digest && !result.error;
|
||||
const existsUpstream = result.exists;
|
||||
if (existsUpstream) {
|
||||
filteredModels.push(new Model({ model: query }));
|
||||
}
|
||||
@@ -339,7 +339,7 @@ export async function deleteChat(chatId: string): Promise<void> {
|
||||
// Get upstream information for model staleness checking
|
||||
export async function getModelUpstreamInfo(
|
||||
model: Model,
|
||||
): Promise<{ digest?: string; pushTime: number; error?: string }> {
|
||||
): Promise<{ stale: boolean; exists: boolean; error?: string }> {
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/api/v1/model/upstream`, {
|
||||
method: "POST",
|
||||
@@ -353,22 +353,22 @@ export async function getModelUpstreamInfo(
|
||||
|
||||
if (!response.ok) {
|
||||
console.warn(
|
||||
`Failed to check upstream digest for ${model.model}: ${response.status}`,
|
||||
`Failed to check upstream for ${model.model}: ${response.status}`,
|
||||
);
|
||||
return { pushTime: 0 };
|
||||
return { stale: false, exists: false };
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (data.error) {
|
||||
console.warn(`Upstream digest check: ${data.error}`);
|
||||
return { error: data.error, pushTime: 0 };
|
||||
console.warn(`Upstream check: ${data.error}`);
|
||||
return { stale: false, exists: false, error: data.error };
|
||||
}
|
||||
|
||||
return { digest: data.digest, pushTime: data.pushTime || 0 };
|
||||
return { stale: !!data.stale, exists: true };
|
||||
} catch (error) {
|
||||
console.warn(`Error checking model staleness:`, error);
|
||||
return { pushTime: 0 };
|
||||
return { stale: false, exists: false };
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -61,24 +61,7 @@ export const ModelPicker = forwardRef<
|
||||
try {
|
||||
const upstreamInfo = await getModelUpstreamInfo(model);
|
||||
|
||||
// Compare local digest with upstream digest
|
||||
let isStale =
|
||||
model.digest &&
|
||||
upstreamInfo.digest &&
|
||||
model.digest !== upstreamInfo.digest;
|
||||
|
||||
// If the model has a modified time and upstream has a push time,
|
||||
// check if the model was modified after the push time - if so, it's not stale
|
||||
if (isStale && model.modified_at && upstreamInfo.pushTime > 0) {
|
||||
const modifiedAtTime =
|
||||
new Date(model.modified_at as string | number | Date).getTime() /
|
||||
1000;
|
||||
if (modifiedAtTime > upstreamInfo.pushTime) {
|
||||
isStale = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (isStale) {
|
||||
if (upstreamInfo.stale) {
|
||||
const currentStaleModels =
|
||||
queryClient.getQueryData<Map<string, boolean>>(["staleModels"]) ||
|
||||
new Map();
|
||||
|
||||
@@ -214,6 +214,7 @@ export default function Settings() {
|
||||
Agent: false,
|
||||
Tools: false,
|
||||
ContextLength: 0,
|
||||
AutoUpdateEnabled: true,
|
||||
});
|
||||
updateSettingsMutation.mutate(defaultSettings);
|
||||
}
|
||||
|
||||
@@ -133,9 +133,8 @@ type Error struct {
|
||||
}
|
||||
|
||||
type ModelUpstreamResponse struct {
|
||||
Digest string `json:"digest,omitempty"`
|
||||
PushTime int64 `json:"pushTime"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Stale bool `json:"stale"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// Serializable data for the browser state
|
||||
|
||||
20
app/ui/ui.go
@@ -32,6 +32,7 @@ import (
|
||||
"github.com/ollama/ollama/app/version"
|
||||
ollamaAuth "github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
_ "github.com/tkrajina/typescriptify-golang-structs/typescriptify"
|
||||
)
|
||||
@@ -155,7 +156,7 @@ func (s *Server) ollamaProxy() http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
target := envconfig.Host()
|
||||
target := envconfig.ConnectableHost()
|
||||
s.log().Info("configuring ollama proxy", "target", target.String())
|
||||
|
||||
newProxy := httputil.NewSingleHostReverseProxy(target)
|
||||
@@ -193,7 +194,7 @@ func (s *Server) Handler() http.Handler {
|
||||
if CORS() {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, User-Agent, Accept, X-Requested-With")
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
// Handle preflight requests
|
||||
@@ -318,7 +319,7 @@ func (s *Server) handleError(w http.ResponseWriter, e error) {
|
||||
if CORS() {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, User-Agent, Accept, X-Requested-With")
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
|
||||
@@ -1572,9 +1573,18 @@ func (s *Server) modelUpstream(w http.ResponseWriter, r *http.Request) error {
|
||||
return json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
n := model.ParseName(req.Model)
|
||||
stale := true
|
||||
if m, err := manifest.ParseNamedManifest(n); err == nil {
|
||||
if m.Digest() == digest {
|
||||
stale = false
|
||||
} else if pushTime > 0 && m.FileInfo().ModTime().Unix() >= pushTime {
|
||||
stale = false
|
||||
}
|
||||
}
|
||||
|
||||
response := responses.ModelUpstreamResponse{
|
||||
Digest: digest,
|
||||
PushTime: pushTime,
|
||||
Stale: stale,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
216
cmd/audio.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
audioSampleRate = 16000
|
||||
audioChannels = 1
|
||||
audioFrameSize = 1024 // samples per callback
|
||||
)
|
||||
|
||||
// AudioRecorder captures audio from the default microphone.
|
||||
// Platform-specific capture is provided by audioStream (audio_darwin.go, etc.).
|
||||
type AudioRecorder struct {
|
||||
stream audioStream
|
||||
mu sync.Mutex
|
||||
samples []float32
|
||||
started time.Time
|
||||
MaxChunkSeconds int // hard split limit in seconds; 0 means use default
|
||||
}
|
||||
|
||||
// audioStream is the platform-specific audio capture interface.
|
||||
type audioStream interface {
|
||||
// Start begins capturing. Samples are delivered via the callback.
|
||||
Start(callback func(samples []float32)) error
|
||||
// Stop ends capturing and releases resources.
|
||||
Stop() error
|
||||
}
|
||||
|
||||
// NewAudioRecorder creates a recorder ready to capture from the default mic.
|
||||
func NewAudioRecorder() (*AudioRecorder, error) {
|
||||
stream, err := newAudioStream(audioSampleRate, audioChannels, audioFrameSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &AudioRecorder{stream: stream}, nil
|
||||
}
|
||||
|
||||
// Start begins capturing audio from the microphone.
|
||||
func (r *AudioRecorder) Start() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.samples = make([]float32, 0, audioSampleRate*30) // preallocate ~30s
|
||||
r.started = time.Now()
|
||||
|
||||
return r.stream.Start(func(samples []float32) {
|
||||
r.mu.Lock()
|
||||
r.samples = append(r.samples, samples...)
|
||||
r.mu.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
// Stop ends the recording and returns the duration.
|
||||
func (r *AudioRecorder) Stop() (time.Duration, error) {
|
||||
r.mu.Lock()
|
||||
dur := time.Since(r.started)
|
||||
r.mu.Unlock()
|
||||
|
||||
if r.stream != nil {
|
||||
r.stream.Stop()
|
||||
}
|
||||
|
||||
return dur, nil
|
||||
}
|
||||
|
||||
// Duration returns how long the current recording has been running.
|
||||
func (r *AudioRecorder) Duration() time.Duration {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.started.IsZero() {
|
||||
return 0
|
||||
}
|
||||
return time.Since(r.started)
|
||||
}
|
||||
|
||||
// Chunking constants for live transcription.
|
||||
const (
|
||||
chunkTargetSamples = 8 * audioSampleRate // 8s — start yielding when silence found
|
||||
chunkMinSamples = 5 * audioSampleRate // start scanning for silence at 5s
|
||||
defaultMaxAudioSeconds = 28 // default hard split (just under typical 30s model cap)
|
||||
silenceWindow = 800 // 50ms RMS window
|
||||
)
|
||||
|
||||
func (r *AudioRecorder) maxChunk() int {
|
||||
if r.MaxChunkSeconds > 0 {
|
||||
return r.MaxChunkSeconds * audioSampleRate
|
||||
}
|
||||
return defaultMaxAudioSeconds * audioSampleRate
|
||||
}
|
||||
|
||||
// TakeChunk checks if there are enough accumulated samples to yield a chunk.
|
||||
// If so, it splits at the best silence boundary, removes the consumed samples
|
||||
// from the buffer, and returns the chunk as WAV bytes. Returns nil if not enough
|
||||
// audio has accumulated yet.
|
||||
func (r *AudioRecorder) TakeChunk() []byte {
|
||||
r.mu.Lock()
|
||||
n := len(r.samples)
|
||||
if n < chunkMinSamples {
|
||||
r.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
maxSamples := r.maxChunk()
|
||||
|
||||
if n < chunkTargetSamples && n < maxSamples {
|
||||
r.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
limit := n
|
||||
if limit > maxSamples {
|
||||
limit = maxSamples
|
||||
}
|
||||
|
||||
splitAt := limit
|
||||
bestEnergy := float64(1e30)
|
||||
|
||||
scanStart := limit - silenceWindow
|
||||
scanEnd := chunkMinSamples
|
||||
for pos := scanStart; pos >= scanEnd; pos -= silenceWindow / 2 {
|
||||
end := pos + silenceWindow
|
||||
if end > n {
|
||||
end = n
|
||||
}
|
||||
var sumSq float64
|
||||
for _, s := range r.samples[pos:end] {
|
||||
sumSq += float64(s) * float64(s)
|
||||
}
|
||||
rms := sumSq / float64(end-pos)
|
||||
if rms < bestEnergy {
|
||||
bestEnergy = rms
|
||||
splitAt = pos + silenceWindow/2
|
||||
}
|
||||
}
|
||||
|
||||
chunk := make([]float32, splitAt)
|
||||
copy(chunk, r.samples[:splitAt])
|
||||
remaining := make([]float32, n-splitAt)
|
||||
copy(remaining, r.samples[splitAt:])
|
||||
r.samples = remaining
|
||||
r.mu.Unlock()
|
||||
|
||||
return encodeWAV(chunk, audioSampleRate, audioChannels)
|
||||
}
|
||||
|
||||
// FlushWAV returns any remaining samples as WAV, clearing the buffer.
|
||||
func (r *AudioRecorder) FlushWAV() []byte {
|
||||
r.mu.Lock()
|
||||
samples := r.samples
|
||||
r.samples = nil
|
||||
r.mu.Unlock()
|
||||
|
||||
if len(samples) == 0 {
|
||||
return nil
|
||||
}
|
||||
return encodeWAV(samples, audioSampleRate, audioChannels)
|
||||
}
|
||||
|
||||
// WAV encodes the captured samples as a WAV file in memory.
|
||||
func (r *AudioRecorder) WAV() ([]byte, error) {
|
||||
r.mu.Lock()
|
||||
samples := make([]float32, len(r.samples))
|
||||
copy(samples, r.samples)
|
||||
r.mu.Unlock()
|
||||
|
||||
if len(samples) == 0 {
|
||||
return nil, errNoAudio
|
||||
}
|
||||
|
||||
return encodeWAV(samples, audioSampleRate, audioChannels), nil
|
||||
}
|
||||
|
||||
// encodeWAV produces a 16-bit PCM WAV file from float32 samples.
|
||||
func encodeWAV(samples []float32, sampleRate, channels int) []byte {
|
||||
numSamples := len(samples)
|
||||
bitsPerSample := 16
|
||||
byteRate := sampleRate * channels * bitsPerSample / 8
|
||||
blockAlign := channels * bitsPerSample / 8
|
||||
dataSize := numSamples * blockAlign
|
||||
|
||||
buf := make([]byte, 44+dataSize)
|
||||
|
||||
copy(buf[0:4], "RIFF")
|
||||
binary.LittleEndian.PutUint32(buf[4:8], uint32(36+dataSize))
|
||||
copy(buf[8:12], "WAVE")
|
||||
|
||||
copy(buf[12:16], "fmt ")
|
||||
binary.LittleEndian.PutUint32(buf[16:20], 16)
|
||||
binary.LittleEndian.PutUint16(buf[20:22], 1)
|
||||
binary.LittleEndian.PutUint16(buf[22:24], uint16(channels))
|
||||
binary.LittleEndian.PutUint32(buf[24:28], uint32(sampleRate))
|
||||
binary.LittleEndian.PutUint32(buf[28:32], uint32(byteRate))
|
||||
binary.LittleEndian.PutUint16(buf[32:34], uint16(blockAlign))
|
||||
binary.LittleEndian.PutUint16(buf[34:36], uint16(bitsPerSample))
|
||||
|
||||
copy(buf[36:40], "data")
|
||||
binary.LittleEndian.PutUint32(buf[40:44], uint32(dataSize))
|
||||
|
||||
offset := 44
|
||||
for _, s := range samples {
|
||||
if s > 1.0 {
|
||||
s = 1.0
|
||||
} else if s < -1.0 {
|
||||
s = -1.0
|
||||
}
|
||||
val := int16(s * 32767)
|
||||
binary.LittleEndian.PutUint16(buf[offset:offset+2], uint16(val))
|
||||
offset += 2
|
||||
}
|
||||
|
||||
return buf
|
||||
}
|
||||
180
cmd/audio_darwin.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package cmd
|
||||
|
||||
/*
|
||||
#cgo LDFLAGS: -framework CoreAudio -framework AudioToolbox
|
||||
#include <AudioToolbox/AudioQueue.h>
|
||||
#include <string.h>
|
||||
|
||||
// Callback context passed to AudioQueue.
|
||||
typedef struct {
|
||||
int ready; // set to 1 when a buffer is filled
|
||||
} AQContext;
|
||||
|
||||
// C callback — re-enqueues the buffer so recording continues.
|
||||
// Not static — must be visible to the linker for Go's function pointer.
|
||||
void aqInputCallback(
|
||||
void *inUserData,
|
||||
AudioQueueRef inAQ,
|
||||
AudioQueueBufferRef inBuffer,
|
||||
const AudioTimeStamp *inStartTime,
|
||||
UInt32 inNumberPacketDescriptions,
|
||||
const AudioStreamPacketDescription *inPacketDescs)
|
||||
{
|
||||
// Re-enqueue the buffer immediately so recording continues.
|
||||
AudioQueueEnqueueBuffer(inAQ, inBuffer, 0, NULL);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var errNoAudio = fmt.Errorf("no audio recorded")
|
||||
|
||||
const numAQBuffers = 3
|
||||
|
||||
type coreAudioStream struct {
|
||||
queue C.AudioQueueRef
|
||||
buffers [numAQBuffers]C.AudioQueueBufferRef
|
||||
mu sync.Mutex
|
||||
callback func(samples []float32)
|
||||
running bool
|
||||
pollDone chan struct{}
|
||||
|
||||
sampleRate int
|
||||
channels int
|
||||
frameSize int
|
||||
}
|
||||
|
||||
func newAudioStream(sampleRate, channels, frameSize int) (audioStream, error) {
|
||||
return &coreAudioStream{
|
||||
sampleRate: sampleRate,
|
||||
channels: channels,
|
||||
frameSize: frameSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *coreAudioStream) Start(callback func(samples []float32)) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.callback = callback
|
||||
|
||||
// Set up audio format: 16-bit signed integer PCM, mono, 16kHz.
|
||||
var format C.AudioStreamBasicDescription
|
||||
format.mSampleRate = C.Float64(s.sampleRate)
|
||||
format.mFormatID = C.kAudioFormatLinearPCM
|
||||
format.mFormatFlags = C.kLinearPCMFormatFlagIsSignedInteger | C.kLinearPCMFormatFlagIsPacked
|
||||
format.mBitsPerChannel = 16
|
||||
format.mChannelsPerFrame = C.UInt32(s.channels)
|
||||
format.mBytesPerFrame = 2 * C.UInt32(s.channels)
|
||||
format.mFramesPerPacket = 1
|
||||
format.mBytesPerPacket = format.mBytesPerFrame
|
||||
|
||||
// Create the audio queue.
|
||||
var status C.OSStatus
|
||||
status = C.AudioQueueNewInput(
|
||||
&format,
|
||||
C.AudioQueueInputCallback(C.aqInputCallback),
|
||||
nil, // user data
|
||||
C.CFRunLoopRef(0), // NULL run loop — use internal thread
|
||||
C.CFStringRef(0), // NULL run loop mode
|
||||
0, // flags
|
||||
&s.queue,
|
||||
)
|
||||
if status != 0 {
|
||||
return fmt.Errorf("AudioQueueNewInput failed: %d", status)
|
||||
}
|
||||
|
||||
// Allocate and enqueue buffers.
|
||||
bufferBytes := C.UInt32(s.frameSize * int(format.mBytesPerFrame))
|
||||
for i := range s.buffers {
|
||||
status = C.AudioQueueAllocateBuffer(s.queue, bufferBytes, &s.buffers[i])
|
||||
if status != 0 {
|
||||
C.AudioQueueDispose(s.queue, C.true)
|
||||
return fmt.Errorf("AudioQueueAllocateBuffer failed: %d", status)
|
||||
}
|
||||
status = C.AudioQueueEnqueueBuffer(s.queue, s.buffers[i], 0, nil)
|
||||
if status != 0 {
|
||||
C.AudioQueueDispose(s.queue, C.true)
|
||||
return fmt.Errorf("AudioQueueEnqueueBuffer failed: %d", status)
|
||||
}
|
||||
}
|
||||
|
||||
// Start recording.
|
||||
status = C.AudioQueueStart(s.queue, nil)
|
||||
if status != 0 {
|
||||
C.AudioQueueDispose(s.queue, C.true)
|
||||
return fmt.Errorf("AudioQueueStart failed: %d", status)
|
||||
}
|
||||
|
||||
s.running = true
|
||||
s.pollDone = make(chan struct{})
|
||||
|
||||
// Poll buffers for data. AudioQueue re-enqueues in the C callback,
|
||||
// so we read the data out periodically.
|
||||
go s.pollLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *coreAudioStream) pollLoop() {
|
||||
defer close(s.pollDone)
|
||||
|
||||
// Read at roughly frameSize intervals.
|
||||
interval := time.Duration(float64(s.frameSize) / float64(s.sampleRate) * float64(time.Second))
|
||||
if interval < 10*time.Millisecond {
|
||||
interval = 10 * time.Millisecond
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
s.mu.Lock()
|
||||
if !s.running {
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Read available data from each buffer.
|
||||
for i := range s.buffers {
|
||||
buf := s.buffers[i]
|
||||
if buf.mAudioDataByteSize > 0 {
|
||||
numSamples := int(buf.mAudioDataByteSize) / 2 // 16-bit samples
|
||||
if numSamples > 0 {
|
||||
raw := (*[1 << 28]int16)(buf.mAudioData)[:numSamples:numSamples]
|
||||
floats := make([]float32, numSamples)
|
||||
for j, v := range raw {
|
||||
floats[j] = float32(v) / float32(math.MaxInt16)
|
||||
}
|
||||
s.callback(floats)
|
||||
}
|
||||
buf.mAudioDataByteSize = 0
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *coreAudioStream) Stop() error {
|
||||
s.mu.Lock()
|
||||
s.running = false
|
||||
queue := s.queue
|
||||
s.mu.Unlock()
|
||||
|
||||
if queue != nil {
|
||||
C.AudioQueueStop(queue, C.true)
|
||||
C.AudioQueueDispose(queue, C.true)
|
||||
}
|
||||
|
||||
if s.pollDone != nil {
|
||||
<-s.pollDone
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
275
cmd/audio_linux.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package cmd
|
||||
|
||||
/*
|
||||
#cgo LDFLAGS: -ldl
|
||||
#include <dlfcn.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
// Function pointer types for ALSA functions loaded at runtime.
|
||||
typedef int (*pcm_open_fn)(void**, const char*, int, int);
|
||||
typedef int (*pcm_simple_fn)(void*);
|
||||
typedef long (*pcm_readi_fn)(void*, void*, unsigned long);
|
||||
typedef int (*hw_malloc_fn)(void**);
|
||||
typedef void (*hw_free_fn)(void*);
|
||||
typedef int (*hw_any_fn)(void*, void*);
|
||||
typedef int (*hw_set_int_fn)(void*, void*, int);
|
||||
typedef int (*hw_set_uint_fn)(void*, void*, unsigned int);
|
||||
typedef int (*hw_set_rate_fn)(void*, void*, unsigned int*, int*);
|
||||
typedef int (*hw_set_period_fn)(void*, void*, unsigned long*, int*);
|
||||
typedef int (*hw_apply_fn)(void*, void*);
|
||||
typedef const char* (*strerror_fn)(int);
|
||||
|
||||
// Trampoline functions — call dynamically loaded ALSA symbols.
|
||||
static int alsa_pcm_open(void* fn, void** h, const char* name, int stream, int mode) {
|
||||
return ((pcm_open_fn)fn)(h, name, stream, mode);
|
||||
}
|
||||
static int alsa_pcm_close(void* fn, void* h) { return ((pcm_simple_fn)fn)(h); }
|
||||
static int alsa_pcm_prepare(void* fn, void* h) { return ((pcm_simple_fn)fn)(h); }
|
||||
static int alsa_pcm_drop(void* fn, void* h) { return ((pcm_simple_fn)fn)(h); }
|
||||
static long alsa_pcm_readi(void* fn, void* h, void* buf, unsigned long frames) {
|
||||
return ((pcm_readi_fn)fn)(h, buf, frames);
|
||||
}
|
||||
static int alsa_hw_malloc(void* fn, void** p) { return ((hw_malloc_fn)fn)(p); }
|
||||
static void alsa_hw_free(void* fn, void* p) { ((hw_free_fn)fn)(p); }
|
||||
static int alsa_hw_any(void* fn, void* h, void* p) { return ((hw_any_fn)fn)(h, p); }
|
||||
static int alsa_hw_set_access(void* fn, void* h, void* p, int v) { return ((hw_set_int_fn)fn)(h, p, v); }
|
||||
static int alsa_hw_set_format(void* fn, void* h, void* p, int v) { return ((hw_set_int_fn)fn)(h, p, v); }
|
||||
static int alsa_hw_set_channels(void* fn, void* h, void* p, unsigned int v) { return ((hw_set_uint_fn)fn)(h, p, v); }
|
||||
static int alsa_hw_set_rate(void* fn, void* h, void* p, unsigned int* v, int* d) { return ((hw_set_rate_fn)fn)(h, p, v, d); }
|
||||
static int alsa_hw_set_period(void* fn, void* h, void* p, unsigned long* v, int* d) { return ((hw_set_period_fn)fn)(h, p, v, d); }
|
||||
static int alsa_hw_apply(void* fn, void* h, void* p) { return ((hw_apply_fn)fn)(h, p); }
|
||||
static const char* alsa_strerror(void* fn, int e) { return ((strerror_fn)fn)(e); }
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var errNoAudio = fmt.Errorf("no audio recorded")
|
||||
|
||||
const (
|
||||
sndPCMStreamCapture = 1
|
||||
sndPCMAccessRWInterleaved = 3
|
||||
sndPCMFormatS16LE = 2
|
||||
)
|
||||
|
||||
var (
|
||||
alsaLoadErr error
|
||||
alsaOnce sync.Once
|
||||
alsa alsaFuncs
|
||||
)
|
||||
|
||||
type alsaFuncs struct {
|
||||
pcmOpen, pcmClose, pcmPrepare, pcmDrop, pcmReadi unsafe.Pointer
|
||||
hwMalloc, hwFree, hwAny unsafe.Pointer
|
||||
hwSetAccess, hwSetFormat, hwSetChannels unsafe.Pointer
|
||||
hwSetRate, hwSetPeriod, hwApply unsafe.Pointer
|
||||
strerror unsafe.Pointer
|
||||
}
|
||||
|
||||
func loadALSA() {
|
||||
var lib unsafe.Pointer
|
||||
for _, name := range []string{"libasound.so.2", "libasound.so"} {
|
||||
cName := C.CString(name)
|
||||
lib = C.dlopen(cName, C.RTLD_NOW)
|
||||
C.free(unsafe.Pointer(cName))
|
||||
if lib != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if lib == nil {
|
||||
alsaLoadErr = fmt.Errorf("audio capture unavailable: libasound.so not found")
|
||||
return
|
||||
}
|
||||
|
||||
sym := func(name string) unsafe.Pointer {
|
||||
cName := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cName))
|
||||
return C.dlsym(lib, cName)
|
||||
}
|
||||
|
||||
syms := []struct {
|
||||
ptr *unsafe.Pointer
|
||||
name string
|
||||
}{
|
||||
{&alsa.pcmOpen, "snd_pcm_open"},
|
||||
{&alsa.pcmClose, "snd_pcm_close"},
|
||||
{&alsa.pcmPrepare, "snd_pcm_prepare"},
|
||||
{&alsa.pcmDrop, "snd_pcm_drop"},
|
||||
{&alsa.pcmReadi, "snd_pcm_readi"},
|
||||
{&alsa.hwMalloc, "snd_pcm_hw_params_malloc"},
|
||||
{&alsa.hwFree, "snd_pcm_hw_params_free"},
|
||||
{&alsa.hwAny, "snd_pcm_hw_params_any"},
|
||||
{&alsa.hwSetAccess, "snd_pcm_hw_params_set_access"},
|
||||
{&alsa.hwSetFormat, "snd_pcm_hw_params_set_format"},
|
||||
{&alsa.hwSetChannels, "snd_pcm_hw_params_set_channels"},
|
||||
{&alsa.hwSetRate, "snd_pcm_hw_params_set_rate_near"},
|
||||
{&alsa.hwSetPeriod, "snd_pcm_hw_params_set_period_size_near"},
|
||||
{&alsa.hwApply, "snd_pcm_hw_params"},
|
||||
{&alsa.strerror, "snd_strerror"},
|
||||
}
|
||||
|
||||
for _, s := range syms {
|
||||
*s.ptr = sym(s.name)
|
||||
if *s.ptr == nil {
|
||||
alsaLoadErr = fmt.Errorf("audio capture unavailable: missing %s in libasound", s.name)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func alsaError(code C.int) string {
|
||||
if alsa.strerror == nil {
|
||||
return fmt.Sprintf("error %d", code)
|
||||
}
|
||||
return C.GoString(C.alsa_strerror(alsa.strerror, code))
|
||||
}
|
||||
|
||||
type alsaStream struct {
|
||||
handle unsafe.Pointer
|
||||
mu sync.Mutex
|
||||
callback func(samples []float32)
|
||||
running bool
|
||||
done chan struct{}
|
||||
|
||||
sampleRate int
|
||||
channels int
|
||||
frameSize int
|
||||
}
|
||||
|
||||
func newAudioStream(sampleRate, channels, frameSize int) (audioStream, error) {
|
||||
alsaOnce.Do(loadALSA)
|
||||
if alsaLoadErr != nil {
|
||||
return nil, alsaLoadErr
|
||||
}
|
||||
return &alsaStream{
|
||||
sampleRate: sampleRate,
|
||||
channels: channels,
|
||||
frameSize: frameSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *alsaStream) Start(callback func(samples []float32)) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.callback = callback
|
||||
|
||||
cName := C.CString("default")
|
||||
defer C.free(unsafe.Pointer(cName))
|
||||
|
||||
rc := C.alsa_pcm_open(alsa.pcmOpen, (*unsafe.Pointer)(unsafe.Pointer(&s.handle)), cName, C.int(sndPCMStreamCapture), 0)
|
||||
if rc < 0 {
|
||||
return fmt.Errorf("snd_pcm_open: %s", alsaError(rc))
|
||||
}
|
||||
|
||||
var hwParams unsafe.Pointer
|
||||
C.alsa_hw_malloc(alsa.hwMalloc, (*unsafe.Pointer)(unsafe.Pointer(&hwParams)))
|
||||
defer C.alsa_hw_free(alsa.hwFree, hwParams)
|
||||
|
||||
C.alsa_hw_any(alsa.hwAny, s.handle, hwParams)
|
||||
|
||||
if rc = C.alsa_hw_set_access(alsa.hwSetAccess, s.handle, hwParams, C.int(sndPCMAccessRWInterleaved)); rc < 0 {
|
||||
C.alsa_pcm_close(alsa.pcmClose, s.handle)
|
||||
return fmt.Errorf("set access: %s", alsaError(rc))
|
||||
}
|
||||
if rc = C.alsa_hw_set_format(alsa.hwSetFormat, s.handle, hwParams, C.int(sndPCMFormatS16LE)); rc < 0 {
|
||||
C.alsa_pcm_close(alsa.pcmClose, s.handle)
|
||||
return fmt.Errorf("set format: %s", alsaError(rc))
|
||||
}
|
||||
if rc = C.alsa_hw_set_channels(alsa.hwSetChannels, s.handle, hwParams, C.uint(s.channels)); rc < 0 {
|
||||
C.alsa_pcm_close(alsa.pcmClose, s.handle)
|
||||
return fmt.Errorf("set channels: %s", alsaError(rc))
|
||||
}
|
||||
|
||||
rate := C.uint(s.sampleRate)
|
||||
if rc = C.alsa_hw_set_rate(alsa.hwSetRate, s.handle, hwParams, &rate, nil); rc < 0 {
|
||||
C.alsa_pcm_close(alsa.pcmClose, s.handle)
|
||||
return fmt.Errorf("set rate: %s", alsaError(rc))
|
||||
}
|
||||
|
||||
periodSize := C.ulong(s.frameSize)
|
||||
if rc = C.alsa_hw_set_period(alsa.hwSetPeriod, s.handle, hwParams, &periodSize, nil); rc < 0 {
|
||||
C.alsa_pcm_close(alsa.pcmClose, s.handle)
|
||||
return fmt.Errorf("set period: %s", alsaError(rc))
|
||||
}
|
||||
|
||||
if rc = C.alsa_hw_apply(alsa.hwApply, s.handle, hwParams); rc < 0 {
|
||||
C.alsa_pcm_close(alsa.pcmClose, s.handle)
|
||||
return fmt.Errorf("apply hw params: %s", alsaError(rc))
|
||||
}
|
||||
|
||||
if rc = C.alsa_pcm_prepare(alsa.pcmPrepare, s.handle); rc < 0 {
|
||||
C.alsa_pcm_close(alsa.pcmClose, s.handle)
|
||||
return fmt.Errorf("prepare: %s", alsaError(rc))
|
||||
}
|
||||
|
||||
s.running = true
|
||||
s.done = make(chan struct{})
|
||||
go s.captureLoop(int(periodSize))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *alsaStream) captureLoop(periodSize int) {
|
||||
defer close(s.done)
|
||||
|
||||
buf := make([]int16, periodSize*s.channels)
|
||||
|
||||
for {
|
||||
s.mu.Lock()
|
||||
if !s.running {
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
handle := s.handle
|
||||
s.mu.Unlock()
|
||||
|
||||
frames := C.alsa_pcm_readi(alsa.pcmReadi, handle, unsafe.Pointer(&buf[0]), C.ulong(periodSize))
|
||||
if frames < 0 {
|
||||
C.alsa_pcm_prepare(alsa.pcmPrepare, handle)
|
||||
continue
|
||||
}
|
||||
if frames == 0 {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
numSamples := int(frames) * s.channels
|
||||
floats := make([]float32, numSamples)
|
||||
for i := 0; i < numSamples; i++ {
|
||||
floats[i] = float32(buf[i]) / float32(math.MaxInt16)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
if s.callback != nil {
|
||||
s.callback(floats)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *alsaStream) Stop() error {
|
||||
s.mu.Lock()
|
||||
s.running = false
|
||||
handle := s.handle
|
||||
s.handle = nil
|
||||
s.mu.Unlock()
|
||||
|
||||
if s.done != nil {
|
||||
<-s.done
|
||||
}
|
||||
|
||||
if handle != nil {
|
||||
C.alsa_pcm_drop(alsa.pcmDrop, handle)
|
||||
C.alsa_pcm_close(alsa.pcmClose, handle)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
288
cmd/audio_windows.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var errNoAudio = fmt.Errorf("no audio recorded")
|
||||
|
||||
// WASAPI COM GUIDs
|
||||
var (
|
||||
iidIMMDeviceEnumerator = guid{0xA95664D2, 0x9614, 0x4F35, [8]byte{0xA7, 0x46, 0xDE, 0x8D, 0xB6, 0x36, 0x17, 0xE6}}
|
||||
clsidMMDeviceEnumerator = guid{0xBCDE0395, 0xE52F, 0x467C, [8]byte{0x8E, 0x3D, 0xC4, 0x57, 0x92, 0x91, 0x69, 0x2E}}
|
||||
iidIAudioClient = guid{0x1CB9AD4C, 0xDBFA, 0x4C32, [8]byte{0xB1, 0x78, 0xC2, 0xF5, 0x68, 0xA7, 0x03, 0xB2}}
|
||||
iidIAudioCaptureClient = guid{0xC8ADBD64, 0xE71E, 0x48A0, [8]byte{0xA4, 0xDE, 0x18, 0x5C, 0x39, 0x5C, 0xD3, 0x17}}
|
||||
)
|
||||
|
||||
type guid struct {
|
||||
Data1 uint32
|
||||
Data2 uint16
|
||||
Data3 uint16
|
||||
Data4 [8]byte
|
||||
}
|
||||
|
||||
// WAVEFORMATEX structure
|
||||
type waveFormatEx struct {
|
||||
FormatTag uint16
|
||||
Channels uint16
|
||||
SamplesPerSec uint32
|
||||
AvgBytesPerSec uint32
|
||||
BlockAlign uint16
|
||||
BitsPerSample uint16
|
||||
CbSize uint16
|
||||
}
|
||||
|
||||
const (
|
||||
wavePCM = 1
|
||||
eCapture = 1
|
||||
eConsole = 0
|
||||
audclntSharemode = 0 // AUDCLNT_SHAREMODE_SHARED
|
||||
audclntStreamflagsEventcallback = 0x00040000
|
||||
|
||||
coinitMultithreaded = 0x0
|
||||
clsctxAll = 0x17
|
||||
|
||||
reftimesPerSec = 10000000 // 100ns units per second
|
||||
reftimesPerMillis = 10000
|
||||
)
|
||||
|
||||
var (
|
||||
ole32 = syscall.NewLazyDLL("ole32.dll")
|
||||
coInit = ole32.NewProc("CoInitializeEx")
|
||||
coCreate = ole32.NewProc("CoCreateInstance")
|
||||
)
|
||||
|
||||
type wasapiStream struct {
|
||||
mu sync.Mutex
|
||||
callback func(samples []float32)
|
||||
running bool
|
||||
done chan struct{}
|
||||
|
||||
sampleRate int
|
||||
channels int
|
||||
frameSize int
|
||||
|
||||
// COM interfaces (stored as uintptr for syscall)
|
||||
enumerator uintptr
|
||||
device uintptr
|
||||
client uintptr
|
||||
capture uintptr
|
||||
}
|
||||
|
||||
func newAudioStream(sampleRate, channels, frameSize int) (audioStream, error) {
|
||||
return &wasapiStream{
|
||||
sampleRate: sampleRate,
|
||||
channels: channels,
|
||||
frameSize: frameSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *wasapiStream) Start(callback func(samples []float32)) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.callback = callback
|
||||
|
||||
// Initialize COM
|
||||
hr, _, _ := coInit.Call(0, uintptr(coinitMultithreaded))
|
||||
// S_OK or S_FALSE (already initialized) are both fine
|
||||
if hr != 0 && hr != 1 {
|
||||
return fmt.Errorf("CoInitializeEx failed: 0x%08x", hr)
|
||||
}
|
||||
|
||||
// Create device enumerator
|
||||
hr, _, _ = coCreate.Call(
|
||||
uintptr(unsafe.Pointer(&clsidMMDeviceEnumerator)),
|
||||
0,
|
||||
uintptr(clsctxAll),
|
||||
uintptr(unsafe.Pointer(&iidIMMDeviceEnumerator)),
|
||||
uintptr(unsafe.Pointer(&s.enumerator)),
|
||||
)
|
||||
if hr != 0 {
|
||||
return fmt.Errorf("CoCreateInstance(MMDeviceEnumerator) failed: 0x%08x", hr)
|
||||
}
|
||||
|
||||
// Get default capture device
|
||||
// IMMDeviceEnumerator::GetDefaultAudioEndpoint is vtable index 4
|
||||
hr = comCall(s.enumerator, 4, uintptr(eCapture), uintptr(eConsole), uintptr(unsafe.Pointer(&s.device)))
|
||||
if hr != 0 {
|
||||
return fmt.Errorf("GetDefaultAudioEndpoint failed: 0x%08x", hr)
|
||||
}
|
||||
|
||||
// Activate IAudioClient
|
||||
// IMMDevice::Activate is vtable index 3
|
||||
hr = comCall(s.device, 3,
|
||||
uintptr(unsafe.Pointer(&iidIAudioClient)),
|
||||
uintptr(clsctxAll),
|
||||
0,
|
||||
uintptr(unsafe.Pointer(&s.client)),
|
||||
)
|
||||
if hr != 0 {
|
||||
return fmt.Errorf("IMMDevice::Activate failed: 0x%08x", hr)
|
||||
}
|
||||
|
||||
// Set up format: 16-bit PCM mono 16kHz
|
||||
format := waveFormatEx{
|
||||
FormatTag: wavePCM,
|
||||
Channels: uint16(s.channels),
|
||||
SamplesPerSec: uint32(s.sampleRate),
|
||||
BitsPerSample: 16,
|
||||
BlockAlign: uint16(2 * s.channels),
|
||||
AvgBytesPerSec: uint32(s.sampleRate * 2 * s.channels),
|
||||
CbSize: 0,
|
||||
}
|
||||
|
||||
// Initialize audio client
|
||||
// IAudioClient::Initialize is vtable index 3
|
||||
bufferDuration := int64(reftimesPerSec) // 1 second buffer
|
||||
hr = comCall(s.client, 3,
|
||||
uintptr(audclntSharemode),
|
||||
0, // stream flags
|
||||
uintptr(bufferDuration),
|
||||
0, // periodicity (0 = use default)
|
||||
uintptr(unsafe.Pointer(&format)),
|
||||
0, // audio session GUID (NULL = default)
|
||||
)
|
||||
if hr != 0 {
|
||||
return fmt.Errorf("IAudioClient::Initialize failed: 0x%08x", hr)
|
||||
}
|
||||
|
||||
// Get capture client
|
||||
// IAudioClient::GetService is vtable index 8
|
||||
hr = comCall(s.client, 8,
|
||||
uintptr(unsafe.Pointer(&iidIAudioCaptureClient)),
|
||||
uintptr(unsafe.Pointer(&s.capture)),
|
||||
)
|
||||
if hr != 0 {
|
||||
return fmt.Errorf("IAudioClient::GetService failed: 0x%08x", hr)
|
||||
}
|
||||
|
||||
// Start capture
|
||||
// IAudioClient::Start is vtable index 6
|
||||
hr = comCall(s.client, 6)
|
||||
if hr != 0 {
|
||||
return fmt.Errorf("IAudioClient::Start failed: 0x%08x", hr)
|
||||
}
|
||||
|
||||
s.running = true
|
||||
s.done = make(chan struct{})
|
||||
go s.captureLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *wasapiStream) captureLoop() {
|
||||
defer close(s.done)
|
||||
|
||||
ticker := time.NewTicker(20 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
s.mu.Lock()
|
||||
if !s.running {
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Read available packets
|
||||
for {
|
||||
var data uintptr
|
||||
var numFrames uint32
|
||||
var flags uint32
|
||||
|
||||
// IAudioCaptureClient::GetBuffer is vtable index 3
|
||||
hr := comCall(s.capture, 3,
|
||||
uintptr(unsafe.Pointer(&data)),
|
||||
uintptr(unsafe.Pointer(&numFrames)),
|
||||
uintptr(unsafe.Pointer(&flags)),
|
||||
0, // device position (not needed)
|
||||
0, // QPC position (not needed)
|
||||
)
|
||||
if hr != 0 || numFrames == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// Convert int16 samples to float32
|
||||
samples := make([]float32, numFrames*uint32(s.channels))
|
||||
raw := (*[1 << 28]int16)(unsafe.Pointer(data))[:len(samples):len(samples)]
|
||||
for i, v := range raw {
|
||||
samples[i] = float32(v) / float32(math.MaxInt16)
|
||||
}
|
||||
|
||||
s.callback(samples)
|
||||
|
||||
// IAudioCaptureClient::ReleaseBuffer is vtable index 4
|
||||
comCall(s.capture, 4, uintptr(numFrames))
|
||||
}
|
||||
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *wasapiStream) Stop() error {
|
||||
s.mu.Lock()
|
||||
s.running = false
|
||||
s.mu.Unlock()
|
||||
|
||||
if s.done != nil {
|
||||
<-s.done
|
||||
}
|
||||
|
||||
// IAudioClient::Stop is vtable index 7
|
||||
if s.client != 0 {
|
||||
comCall(s.client, 7)
|
||||
}
|
||||
|
||||
// Release COM interfaces (IUnknown::Release is vtable index 2)
|
||||
if s.capture != 0 {
|
||||
comCall(s.capture, 2)
|
||||
}
|
||||
if s.client != 0 {
|
||||
comCall(s.client, 2)
|
||||
}
|
||||
if s.device != 0 {
|
||||
comCall(s.device, 2)
|
||||
}
|
||||
if s.enumerator != 0 {
|
||||
comCall(s.enumerator, 2)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// comCall invokes a COM method by vtable index.
|
||||
func comCall(obj uintptr, method uintptr, args ...uintptr) uintptr {
|
||||
vtable := *(*uintptr)(unsafe.Pointer(obj))
|
||||
fn := *(*uintptr)(unsafe.Pointer(vtable + method*unsafe.Sizeof(uintptr(0))))
|
||||
|
||||
// Build syscall args: first arg is always 'this' pointer
|
||||
callArgs := make([]uintptr, 1+len(args))
|
||||
callArgs[0] = obj
|
||||
copy(callArgs[1:], args)
|
||||
|
||||
var hr uintptr
|
||||
switch len(callArgs) {
|
||||
case 1:
|
||||
hr, _, _ = syscall.SyscallN(fn, callArgs[0])
|
||||
case 2:
|
||||
hr, _, _ = syscall.SyscallN(fn, callArgs[0], callArgs[1])
|
||||
case 3:
|
||||
hr, _, _ = syscall.SyscallN(fn, callArgs[0], callArgs[1], callArgs[2])
|
||||
case 4:
|
||||
hr, _, _ = syscall.SyscallN(fn, callArgs[0], callArgs[1], callArgs[2], callArgs[3])
|
||||
case 5:
|
||||
hr, _, _ = syscall.SyscallN(fn, callArgs[0], callArgs[1], callArgs[2], callArgs[3], callArgs[4])
|
||||
case 6:
|
||||
hr, _, _ = syscall.SyscallN(fn, callArgs[0], callArgs[1], callArgs[2], callArgs[3], callArgs[4], callArgs[5])
|
||||
case 7:
|
||||
hr, _, _ = syscall.SyscallN(fn, callArgs[0], callArgs[1], callArgs[2], callArgs[3], callArgs[4], callArgs[5], callArgs[6])
|
||||
default:
|
||||
hr, _, _ = syscall.SyscallN(fn, callArgs...)
|
||||
}
|
||||
return hr
|
||||
}
|
||||
@@ -1,27 +1,31 @@
|
||||
Ollama Benchmark Tool
|
||||
---------------------
|
||||
|
||||
A Go-based command-line tool for benchmarking Ollama models with configurable parameters and multiple output formats.
|
||||
A Go-based command-line tool for benchmarking Ollama models with configurable parameters, warmup phases, TTFT tracking, VRAM monitoring, and benchstat/CSV output.
|
||||
|
||||
## Features
|
||||
|
||||
* Benchmark multiple models in a single run
|
||||
* Support for both text and image prompts
|
||||
* Configurable generation parameters (temperature, max tokens, seed, etc.)
|
||||
* Supports benchstat and CSV output formats
|
||||
* Detailed performance metrics (prefill, generate, load, total durations)
|
||||
* Warmup phase before timed epochs to stabilize measurements
|
||||
* Time-to-first-token (TTFT) tracking per epoch
|
||||
* Model metadata display (parameter size, quantization level, family)
|
||||
* VRAM and CPU memory usage tracking via running process info
|
||||
* Controlled prompt token length for reproducible benchmarks
|
||||
* Benchstat and CSV output formats
|
||||
|
||||
## Building from Source
|
||||
|
||||
```
|
||||
go build -o ollama-bench bench.go
|
||||
./ollama-bench -model gpt-oss:20b -epochs 6 -format csv
|
||||
go build -o ollama-bench ./cmd/bench
|
||||
./ollama-bench -model gemma3 -epochs 6 -format csv
|
||||
```
|
||||
|
||||
Using Go Run (without building)
|
||||
|
||||
```
|
||||
go run bench.go -model gpt-oss:20b -epochs 3
|
||||
go run ./cmd/bench -model gemma3 -epochs 3
|
||||
```
|
||||
|
||||
## Usage
|
||||
@@ -45,10 +49,16 @@ benchstat -col /name gemma.bench
|
||||
./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
|
||||
```
|
||||
|
||||
### Controlled Prompt Length
|
||||
|
||||
```
|
||||
./ollama-bench -model gemma3 -epochs 6 -prompt-tokens 512
|
||||
```
|
||||
|
||||
### Advanced Example
|
||||
|
||||
```
|
||||
./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv
|
||||
./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -warmup 2 -format csv -output results.csv
|
||||
```
|
||||
|
||||
## Command Line Options
|
||||
@@ -56,41 +66,48 @@ benchstat -col /name gemma.bench
|
||||
| Option | Description | Default |
|
||||
|----------|-------------|---------|
|
||||
| -model | Comma-separated list of models to benchmark | (required) |
|
||||
| -epochs | Number of iterations per model | 1 |
|
||||
| -max-tokens | Maximum tokens for model response | 0 (unlimited) |
|
||||
| -epochs | Number of iterations per model | 6 |
|
||||
| -max-tokens | Maximum tokens for model response | 200 |
|
||||
| -temperature | Temperature parameter | 0.0 |
|
||||
| -seed | Random seed | 0 (random) |
|
||||
| -timeout | Timeout in seconds | 300 |
|
||||
| -p | Prompt text | "Write a long story." |
|
||||
| -p | Prompt text | (default story prompt) |
|
||||
| -image | Image file to include in prompt | |
|
||||
| -k | Keep-alive duration in seconds | 0 |
|
||||
| -format | Output format (benchstat, csv) | benchstat |
|
||||
| -output | Output file for results | "" (stdout) |
|
||||
| -warmup | Number of warmup requests before timing | 1 |
|
||||
| -prompt-tokens | Generate prompt targeting ~N tokens (0 = use -p) | 0 |
|
||||
| -v | Verbose mode | false |
|
||||
| -debug | Show debug information | false |
|
||||
|
||||
## Output Formats
|
||||
|
||||
### Markdown Format
|
||||
### Benchstat Format (default)
|
||||
|
||||
The default markdown format is suitable for copying and pasting into a GitHub issue and will look like:
|
||||
```
|
||||
Model | Step | Count | Duration | nsPerToken | tokensPerSec |
|
||||
|-------|------|-------|----------|------------|--------------|
|
||||
| gpt-oss:20b | prefill | 124 | 30.006458ms | 241987.56 | 4132.44 |
|
||||
| gpt-oss:20b | generate | 200 | 2.646843954s | 13234219.77 | 75.56 |
|
||||
| gpt-oss:20b | load | 1 | 121.674208ms | - | - |
|
||||
| gpt-oss:20b | total | 1 | 2.861047625s | - | - |
|
||||
```
|
||||
|
||||
### Benchstat Format
|
||||
|
||||
Compatible with Go's benchstat tool for statistical analysis:
|
||||
Compatible with Go's benchstat tool for statistical analysis. Uses one value/unit pair per line, standard `ns/op` for timing metrics, and `ns/token` for throughput. Each epoch produces one set of lines -- benchstat aggregates across repeated runs to compute statistics.
|
||||
|
||||
```
|
||||
BenchmarkModel/name=gpt-oss:20b/step=prefill 128 78125.00 ns/token 12800.00 token/sec
|
||||
BenchmarkModel/name=gpt-oss:20b/step=generate 512 19531.25 ns/token 51200.00 token/sec
|
||||
BenchmarkModel/name=gpt-oss:20b/step=load 1 1500000000 ns/request
|
||||
# Model: gemma3 | Params: 4.3B | Quant: Q4_K_M | Family: gemma3 | Size: 4080218931 | VRAM: 4080218931
|
||||
BenchmarkModel/name=gemma3/step=prefill 1 78125.00 ns/token 12800.00 token/sec
|
||||
BenchmarkModel/name=gemma3/step=generate 1 19531.25 ns/token 51200.00 token/sec
|
||||
BenchmarkModel/name=gemma3/step=ttft 1 45123000 ns/op
|
||||
BenchmarkModel/name=gemma3/step=load 1 1500000000 ns/op
|
||||
BenchmarkModel/name=gemma3/step=total 1 2861047625 ns/op
|
||||
```
|
||||
|
||||
Use with benchstat:
|
||||
```
|
||||
./ollama-bench -model gemma3 -epochs 6 > gemma3.bench
|
||||
benchstat -col /step gemma3.bench
|
||||
```
|
||||
|
||||
Compare two runs:
|
||||
```
|
||||
./ollama-bench -model gemma3 -epochs 6 > before.bench
|
||||
# ... make changes ...
|
||||
./ollama-bench -model gemma3 -epochs 6 > after.bench
|
||||
benchstat before.bench after.bench
|
||||
```
|
||||
|
||||
### CSV Format
|
||||
@@ -99,17 +116,28 @@ Machine-readable comma-separated values:
|
||||
|
||||
```
|
||||
NAME,STEP,COUNT,NS_PER_COUNT,TOKEN_PER_SEC
|
||||
gpt-oss:20b,prefill,128,78125.00,12800.00
|
||||
gpt-oss:20b,generate,512,19531.25,51200.00
|
||||
gpt-oss:20b,load,1,1500000000,0
|
||||
# Model: gemma3 | Params: 4.3B | Quant: Q4_K_M | Family: gemma3 | Size: 4080218931 | VRAM: 4080218931
|
||||
gemma3,prefill,128,78125.00,12800.00
|
||||
gemma3,generate,512,19531.25,51200.00
|
||||
gemma3,ttft,1,45123000,0
|
||||
gemma3,load,1,1500000000,0
|
||||
gemma3,total,1,2861047625,0
|
||||
```
|
||||
|
||||
## Metrics Explained
|
||||
|
||||
The tool reports four types of metrics for each model:
|
||||
The tool reports the following metrics for each epoch:
|
||||
|
||||
* prefill: Time spent processing the prompt
|
||||
* generate: Time spent generating the response
|
||||
* load: Model loading time (one-time cost)
|
||||
* total: Total request duration
|
||||
* **prefill**: Time spent processing the prompt (ns/token)
|
||||
* **generate**: Time spent generating the response (ns/token)
|
||||
* **ttft**: Time to first token -- latency from request start to first response content
|
||||
* **load**: Model loading time (one-time cost)
|
||||
* **total**: Total request duration
|
||||
|
||||
Additionally, the model info comment line (displayed once per model before epochs) includes:
|
||||
|
||||
* **Params**: Model parameter count (e.g., 4.3B)
|
||||
* **Quant**: Quantization level (e.g., Q4_K_M)
|
||||
* **Family**: Model family (e.g., gemma3)
|
||||
* **Size**: Total model memory in bytes
|
||||
* **VRAM**: GPU memory used by the loaded model (when Size > VRAM, the difference is CPU spill)
|
||||
|
||||
@@ -17,19 +17,22 @@ import (
|
||||
)
|
||||
|
||||
type flagOptions struct {
|
||||
models *string
|
||||
epochs *int
|
||||
maxTokens *int
|
||||
temperature *float64
|
||||
seed *int
|
||||
timeout *int
|
||||
prompt *string
|
||||
imageFile *string
|
||||
keepAlive *float64
|
||||
format *string
|
||||
outputFile *string
|
||||
debug *bool
|
||||
verbose *bool
|
||||
models *string
|
||||
epochs *int
|
||||
maxTokens *int
|
||||
temperature *float64
|
||||
seed *int
|
||||
timeout *int
|
||||
prompt *string
|
||||
imageFile *string
|
||||
keepAlive *float64
|
||||
format *string
|
||||
outputFile *string
|
||||
debug *bool
|
||||
verbose *bool
|
||||
warmup *int
|
||||
promptTokens *int
|
||||
numCtx *int
|
||||
}
|
||||
|
||||
type Metrics struct {
|
||||
@@ -39,48 +42,203 @@ type Metrics struct {
|
||||
Duration time.Duration
|
||||
}
|
||||
|
||||
var once sync.Once
|
||||
type ModelInfo struct {
|
||||
Name string
|
||||
ParameterSize string
|
||||
QuantizationLevel string
|
||||
Family string
|
||||
SizeBytes int64
|
||||
VRAMBytes int64
|
||||
NumCtx int64
|
||||
}
|
||||
|
||||
const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.`
|
||||
|
||||
// Word list for generating prompts targeting a specific token count.
|
||||
var promptWordList = []string{
|
||||
"the", "quick", "brown", "fox", "jumps", "over", "lazy", "dog",
|
||||
"a", "bright", "sunny", "day", "in", "the", "meadow", "where",
|
||||
"flowers", "bloom", "and", "birds", "sing", "their", "morning",
|
||||
"songs", "while", "gentle", "breeze", "carries", "sweet", "scent",
|
||||
"of", "pine", "trees", "across", "rolling", "hills", "toward",
|
||||
"distant", "mountains", "covered", "with", "fresh", "snow",
|
||||
"beneath", "clear", "blue", "sky", "children", "play", "near",
|
||||
"old", "stone", "bridge", "that", "crosses", "winding", "river",
|
||||
}
|
||||
|
||||
// tokensPerWord is the calibrated ratio of tokens to words for the current model.
|
||||
// Initialized with a heuristic, then updated during warmup based on actual tokenization.
|
||||
var tokensPerWord = 1.3
|
||||
|
||||
func generatePromptForTokenCount(targetTokens int, epoch int) string {
|
||||
targetWords := int(float64(targetTokens) / tokensPerWord)
|
||||
if targetWords < 1 {
|
||||
targetWords = 1
|
||||
}
|
||||
|
||||
// Vary the starting offset by epoch to defeat KV cache prefix matching
|
||||
offset := epoch * 7 // stride by a prime to get good distribution
|
||||
n := len(promptWordList)
|
||||
words := make([]string, targetWords)
|
||||
for i := range words {
|
||||
words[i] = promptWordList[((i+offset)%n+n)%n]
|
||||
}
|
||||
return strings.Join(words, " ")
|
||||
}
|
||||
|
||||
// calibratePromptTokens adjusts tokensPerWord based on actual tokenization from a warmup run.
|
||||
func calibratePromptTokens(targetTokens, actualTokens, wordCount int) {
|
||||
if actualTokens <= 0 || wordCount <= 0 {
|
||||
return
|
||||
}
|
||||
tokensPerWord = float64(actualTokens) / float64(wordCount)
|
||||
newWords := int(float64(targetTokens) / tokensPerWord)
|
||||
fmt.Fprintf(os.Stderr, "bench: calibrated %.2f tokens/word (target=%d, got=%d, words=%d → %d)\n",
|
||||
tokensPerWord, targetTokens, actualTokens, wordCount, newWords)
|
||||
}
|
||||
|
||||
func buildGenerateRequest(model string, fOpt flagOptions, imgData api.ImageData, epoch int) *api.GenerateRequest {
|
||||
options := make(map[string]interface{})
|
||||
if *fOpt.maxTokens > 0 {
|
||||
options["num_predict"] = *fOpt.maxTokens
|
||||
}
|
||||
options["temperature"] = *fOpt.temperature
|
||||
if fOpt.seed != nil && *fOpt.seed > 0 {
|
||||
options["seed"] = *fOpt.seed
|
||||
}
|
||||
if fOpt.numCtx != nil && *fOpt.numCtx > 0 {
|
||||
options["num_ctx"] = *fOpt.numCtx
|
||||
}
|
||||
|
||||
var keepAliveDuration *api.Duration
|
||||
if *fOpt.keepAlive > 0 {
|
||||
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
|
||||
keepAliveDuration = &duration
|
||||
}
|
||||
|
||||
prompt := *fOpt.prompt
|
||||
if *fOpt.promptTokens > 0 {
|
||||
prompt = generatePromptForTokenCount(*fOpt.promptTokens, epoch)
|
||||
} else {
|
||||
// Vary the prompt per epoch to defeat KV cache prefix matching
|
||||
prompt = fmt.Sprintf("[%d] %s", epoch, prompt)
|
||||
}
|
||||
|
||||
req := &api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: prompt,
|
||||
Raw: true,
|
||||
Options: options,
|
||||
KeepAlive: keepAliveDuration,
|
||||
}
|
||||
|
||||
if imgData != nil {
|
||||
req.Images = []api.ImageData{imgData}
|
||||
}
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func fetchModelInfo(ctx context.Context, client *api.Client, model string) ModelInfo {
|
||||
info := ModelInfo{Name: model}
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "WARNING: Could not fetch model info for '%s': %v\n", model, err)
|
||||
return info
|
||||
}
|
||||
info.ParameterSize = resp.Details.ParameterSize
|
||||
info.QuantizationLevel = resp.Details.QuantizationLevel
|
||||
info.Family = resp.Details.Family
|
||||
return info
|
||||
}
|
||||
|
||||
func fetchMemoryUsage(ctx context.Context, client *api.Client, model string) (size, vram int64) {
|
||||
resp, err := client.ListRunning(ctx)
|
||||
if err != nil {
|
||||
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
||||
fmt.Fprintf(os.Stderr, "WARNING: Could not fetch memory usage: %v\n", err)
|
||||
}
|
||||
return 0, 0
|
||||
}
|
||||
for _, m := range resp.Models {
|
||||
if m.Name == model || m.Model == model {
|
||||
return m.Size, m.SizeVRAM
|
||||
}
|
||||
}
|
||||
for _, m := range resp.Models {
|
||||
if strings.HasPrefix(m.Name, model) || strings.HasPrefix(m.Model, model) {
|
||||
return m.Size, m.SizeVRAM
|
||||
}
|
||||
}
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
func fetchContextLength(ctx context.Context, client *api.Client, model string) int64 {
|
||||
resp, err := client.ListRunning(ctx)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
for _, m := range resp.Models {
|
||||
if m.Name == model || m.Model == model || strings.HasPrefix(m.Name, model) || strings.HasPrefix(m.Model, model) {
|
||||
return int64(m.ContextLength)
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func outputFormatHeader(w io.Writer, format string, verbose bool) {
|
||||
switch format {
|
||||
case "benchstat":
|
||||
if verbose {
|
||||
fmt.Fprintf(w, "goos: %s\n", runtime.GOOS)
|
||||
fmt.Fprintf(w, "goarch: %s\n", runtime.GOARCH)
|
||||
}
|
||||
case "csv":
|
||||
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
|
||||
fmt.Fprintln(w, strings.Join(headings, ","))
|
||||
}
|
||||
}
|
||||
|
||||
func outputModelInfo(w io.Writer, format string, info ModelInfo) {
|
||||
params := cmp.Or(info.ParameterSize, "unknown")
|
||||
quant := cmp.Or(info.QuantizationLevel, "unknown")
|
||||
family := cmp.Or(info.Family, "unknown")
|
||||
|
||||
memStr := ""
|
||||
if info.SizeBytes > 0 {
|
||||
memStr = fmt.Sprintf(" | Size: %d | VRAM: %d", info.SizeBytes, info.VRAMBytes)
|
||||
}
|
||||
ctxStr := ""
|
||||
if info.NumCtx > 0 {
|
||||
ctxStr = fmt.Sprintf(" | NumCtx: %d", info.NumCtx)
|
||||
}
|
||||
fmt.Fprintf(w, "# Model: %s | Params: %s | Quant: %s | Family: %s%s%s\n",
|
||||
info.Name, params, quant, family, memStr, ctxStr)
|
||||
}
|
||||
|
||||
func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) {
|
||||
switch format {
|
||||
case "benchstat":
|
||||
if verbose {
|
||||
printHeader := func() {
|
||||
fmt.Fprintf(w, "sysname: %s\n", runtime.GOOS)
|
||||
fmt.Fprintf(w, "machine: %s\n", runtime.GOARCH)
|
||||
}
|
||||
once.Do(printHeader)
|
||||
}
|
||||
for _, m := range metrics {
|
||||
if m.Step == "generate" || m.Step == "prefill" {
|
||||
if m.Count > 0 {
|
||||
nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
||||
tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
||||
|
||||
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d %.2f ns/token %.2f token/sec\n",
|
||||
m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
|
||||
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 %.2f ns/token %.2f token/sec\n",
|
||||
m.Model, m.Step, nsPerToken, tokensPerSec)
|
||||
} else {
|
||||
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d 0 ns/token 0 token/sec\n",
|
||||
m.Model, m.Step, m.Count)
|
||||
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 0 ns/token 0 token/sec\n",
|
||||
m.Model, m.Step)
|
||||
}
|
||||
} else if m.Step == "ttft" {
|
||||
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=ttft 1 %d ns/op\n",
|
||||
m.Model, m.Duration.Nanoseconds())
|
||||
} else {
|
||||
var suffix string
|
||||
if m.Step == "load" {
|
||||
suffix = "/step=load"
|
||||
}
|
||||
fmt.Fprintf(w, "BenchmarkModel/name=%s%s 1 %d ns/request\n",
|
||||
m.Model, suffix, m.Duration.Nanoseconds())
|
||||
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 %d ns/op\n",
|
||||
m.Model, m.Step, m.Duration.Nanoseconds())
|
||||
}
|
||||
}
|
||||
case "csv":
|
||||
printHeader := func() {
|
||||
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
|
||||
fmt.Fprintln(w, strings.Join(headings, ","))
|
||||
}
|
||||
once.Do(printHeader)
|
||||
|
||||
for _, m := range metrics {
|
||||
if m.Step == "generate" || m.Step == "prefill" {
|
||||
var nsPerToken float64
|
||||
@@ -94,39 +252,14 @@ func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool)
|
||||
fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds())
|
||||
}
|
||||
}
|
||||
case "markdown":
|
||||
printHeader := func() {
|
||||
fmt.Fprintln(w, "| Model | Step | Count | Duration | nsPerToken | tokensPerSec |")
|
||||
fmt.Fprintln(w, "|-------|------|-------|----------|------------|--------------|")
|
||||
}
|
||||
once.Do(printHeader)
|
||||
|
||||
for _, m := range metrics {
|
||||
var nsPerToken, tokensPerSec float64
|
||||
var nsPerTokenStr, tokensPerSecStr string
|
||||
|
||||
if m.Step == "generate" || m.Step == "prefill" {
|
||||
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
||||
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
||||
nsPerTokenStr = fmt.Sprintf("%.2f", nsPerToken)
|
||||
tokensPerSecStr = fmt.Sprintf("%.2f", tokensPerSec)
|
||||
} else {
|
||||
nsPerTokenStr = "-"
|
||||
tokensPerSecStr = "-"
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "| %s | %s | %d | %v | %s | %s |\n",
|
||||
m.Model, m.Step, m.Count, m.Duration, nsPerTokenStr, tokensPerSecStr)
|
||||
}
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkChat(fOpt flagOptions) error {
|
||||
func BenchmarkModel(fOpt flagOptions) error {
|
||||
models := strings.Split(*fOpt.models, ",")
|
||||
|
||||
// todo - add multi-image support
|
||||
var imgData api.ImageData
|
||||
var err error
|
||||
if *fOpt.imageFile != "" {
|
||||
@@ -158,71 +291,141 @@ func BenchmarkChat(fOpt flagOptions) error {
|
||||
out = f
|
||||
}
|
||||
|
||||
outputFormatHeader(out, *fOpt.format, *fOpt.verbose)
|
||||
|
||||
// Log prompt-tokens info in debug mode
|
||||
if *fOpt.debug && *fOpt.promptTokens > 0 {
|
||||
prompt := generatePromptForTokenCount(*fOpt.promptTokens, 0)
|
||||
wordCount := len(strings.Fields(prompt))
|
||||
fmt.Fprintf(os.Stderr, "Generated prompt targeting ~%d tokens (%d words, varied per epoch)\n", *fOpt.promptTokens, wordCount)
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
for range *fOpt.epochs {
|
||||
options := make(map[string]interface{})
|
||||
if *fOpt.maxTokens > 0 {
|
||||
options["num_predict"] = *fOpt.maxTokens
|
||||
}
|
||||
options["temperature"] = *fOpt.temperature
|
||||
if fOpt.seed != nil && *fOpt.seed > 0 {
|
||||
options["seed"] = *fOpt.seed
|
||||
}
|
||||
|
||||
var keepAliveDuration *api.Duration
|
||||
if *fOpt.keepAlive > 0 {
|
||||
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
|
||||
keepAliveDuration = &duration
|
||||
}
|
||||
|
||||
req := &api.ChatRequest{
|
||||
Model: model,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: *fOpt.prompt,
|
||||
},
|
||||
},
|
||||
Options: options,
|
||||
KeepAlive: keepAliveDuration,
|
||||
}
|
||||
|
||||
if imgData != nil {
|
||||
req.Messages[0].Images = []api.ImageData{imgData}
|
||||
}
|
||||
|
||||
var responseMetrics *api.Metrics
|
||||
// Fetch model info
|
||||
infoCtx, infoCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
info := fetchModelInfo(infoCtx, client, model)
|
||||
infoCancel()
|
||||
|
||||
// Warmup phase (uses negative epoch numbers to avoid colliding with timed epochs)
|
||||
for i := range *fOpt.warmup {
|
||||
req := buildGenerateRequest(model, fOpt, imgData, -(i + 1))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
||||
if *fOpt.debug {
|
||||
fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Message.Thinking, resp.Message.Content))
|
||||
}
|
||||
|
||||
var warmupMetrics *api.Metrics
|
||||
err = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||
if resp.Done {
|
||||
responseMetrics = &resp.Metrics
|
||||
warmupMetrics = &resp.Metrics
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if *fOpt.debug {
|
||||
fmt.Fprintln(os.Stderr)
|
||||
}
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Chat request timed out with model '%s' after %vs\n", model, 1)
|
||||
continue
|
||||
fmt.Fprintf(os.Stderr, "WARNING: Warmup %d/%d for %s failed: %v\n", i+1, *fOpt.warmup, model, err)
|
||||
} else {
|
||||
if *fOpt.debug {
|
||||
fmt.Fprintf(os.Stderr, "Warmup %d/%d for %s complete\n", i+1, *fOpt.warmup, model)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Couldn't chat with model '%s': %v\n", model, err)
|
||||
// Calibrate prompt token count on last warmup run
|
||||
if i == *fOpt.warmup-1 && *fOpt.promptTokens > 0 && warmupMetrics != nil {
|
||||
prompt := generatePromptForTokenCount(*fOpt.promptTokens, -(i + 1))
|
||||
wordCount := len(strings.Fields(prompt))
|
||||
calibratePromptTokens(*fOpt.promptTokens, warmupMetrics.PromptEvalCount, wordCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch memory/context info once after warmup (model is loaded and stable)
|
||||
memCtx, memCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
info.SizeBytes, info.VRAMBytes = fetchMemoryUsage(memCtx, client, model)
|
||||
if fOpt.numCtx != nil && *fOpt.numCtx > 0 {
|
||||
info.NumCtx = int64(*fOpt.numCtx)
|
||||
} else {
|
||||
info.NumCtx = fetchContextLength(memCtx, client, model)
|
||||
}
|
||||
memCancel()
|
||||
|
||||
outputModelInfo(out, *fOpt.format, info)
|
||||
|
||||
// Timed epoch loop
|
||||
shortCount := 0
|
||||
for epoch := range *fOpt.epochs {
|
||||
var responseMetrics *api.Metrics
|
||||
var ttft time.Duration
|
||||
short := false
|
||||
|
||||
// Retry loop: if the model hits a stop token before max-tokens,
|
||||
// retry with a different prompt (up to maxRetries times).
|
||||
const maxRetries = 3
|
||||
for attempt := range maxRetries + 1 {
|
||||
responseMetrics = nil
|
||||
ttft = 0
|
||||
var ttftOnce sync.Once
|
||||
|
||||
req := buildGenerateRequest(model, fOpt, imgData, epoch+attempt*1000)
|
||||
requestStart := time.Now()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
|
||||
|
||||
err = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||
if *fOpt.debug {
|
||||
fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Thinking, resp.Response))
|
||||
}
|
||||
|
||||
// Capture TTFT on first content
|
||||
ttftOnce.Do(func() {
|
||||
if resp.Response != "" || resp.Thinking != "" {
|
||||
ttft = time.Since(requestStart)
|
||||
}
|
||||
})
|
||||
|
||||
if resp.Done {
|
||||
responseMetrics = &resp.Metrics
|
||||
}
|
||||
return nil
|
||||
})
|
||||
cancel()
|
||||
|
||||
if *fOpt.debug {
|
||||
fmt.Fprintln(os.Stderr)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Request timed out with model '%s' after %vs\n", model, *fOpt.timeout)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Couldn't generate with model '%s': %v\n", model, err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if responseMetrics == nil {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model)
|
||||
break
|
||||
}
|
||||
|
||||
// Check if the response was shorter than requested
|
||||
short = *fOpt.maxTokens > 0 && responseMetrics.EvalCount < *fOpt.maxTokens
|
||||
if !short || attempt == maxRetries {
|
||||
break
|
||||
}
|
||||
|
||||
if *fOpt.debug {
|
||||
fmt.Fprintf(os.Stderr, "Short response (%d/%d tokens), retrying with different prompt (attempt %d/%d)\n",
|
||||
responseMetrics.EvalCount, *fOpt.maxTokens, attempt+1, maxRetries)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil || responseMetrics == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if responseMetrics == nil {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model)
|
||||
continue
|
||||
if short {
|
||||
shortCount++
|
||||
if *fOpt.debug {
|
||||
fmt.Fprintf(os.Stderr, "WARNING: Short response (%d/%d tokens) after %d retries for epoch %d\n",
|
||||
responseMetrics.EvalCount, *fOpt.maxTokens, maxRetries, epoch+1)
|
||||
}
|
||||
}
|
||||
|
||||
metrics := []Metrics{
|
||||
@@ -238,6 +441,12 @@ func BenchmarkChat(fOpt flagOptions) error {
|
||||
Count: responseMetrics.EvalCount,
|
||||
Duration: responseMetrics.EvalDuration,
|
||||
},
|
||||
{
|
||||
Model: model,
|
||||
Step: "ttft",
|
||||
Count: 1,
|
||||
Duration: ttft,
|
||||
},
|
||||
{
|
||||
Model: model,
|
||||
Step: "load",
|
||||
@@ -254,15 +463,42 @@ func BenchmarkChat(fOpt flagOptions) error {
|
||||
|
||||
OutputMetrics(out, *fOpt.format, metrics, *fOpt.verbose)
|
||||
|
||||
if *fOpt.debug && *fOpt.promptTokens > 0 {
|
||||
fmt.Fprintf(os.Stderr, "Generated prompt targeting ~%d tokens (actual: %d)\n",
|
||||
*fOpt.promptTokens, responseMetrics.PromptEvalCount)
|
||||
}
|
||||
|
||||
if *fOpt.keepAlive > 0 {
|
||||
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
if shortCount > 0 {
|
||||
fmt.Fprintf(os.Stderr, "WARNING: %d/%d epochs for '%s' had short responses (<%d tokens). Generation metrics may be unreliable.\n",
|
||||
shortCount, *fOpt.epochs, model, *fOpt.maxTokens)
|
||||
}
|
||||
|
||||
// Unload model before moving to the next one
|
||||
unloadModel(client, model, *fOpt.timeout)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func unloadModel(client *api.Client, model string, timeout int) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
zero := api.Duration{Duration: 0}
|
||||
req := &api.GenerateRequest{
|
||||
Model: model,
|
||||
KeepAlive: &zero,
|
||||
}
|
||||
_ = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func readImage(filePath string) (api.ImageData, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
@@ -280,19 +516,22 @@ func readImage(filePath string) (api.ImageData, error) {
|
||||
|
||||
func main() {
|
||||
fOpt := flagOptions{
|
||||
models: flag.String("model", "", "Model to benchmark"),
|
||||
epochs: flag.Int("epochs", 6, "Number of epochs (iterations) per model"),
|
||||
maxTokens: flag.Int("max-tokens", 200, "Maximum tokens for model response"),
|
||||
temperature: flag.Float64("temperature", 0, "Temperature parameter"),
|
||||
seed: flag.Int("seed", 0, "Random seed"),
|
||||
timeout: flag.Int("timeout", 60*5, "Timeout in seconds (default 300s)"),
|
||||
prompt: flag.String("p", DefaultPrompt, "Prompt to use"),
|
||||
imageFile: flag.String("image", "", "Filename for an image to include"),
|
||||
keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"),
|
||||
format: flag.String("format", "markdown", "Output format [benchstat|csv] (default benchstat)"),
|
||||
outputFile: flag.String("output", "", "Output file for results (stdout if empty)"),
|
||||
verbose: flag.Bool("v", false, "Show system information"),
|
||||
debug: flag.Bool("debug", false, "Show debug information"),
|
||||
models: flag.String("model", "", "Model to benchmark"),
|
||||
epochs: flag.Int("epochs", 6, "Number of epochs (iterations) per model"),
|
||||
maxTokens: flag.Int("max-tokens", 200, "Maximum tokens for model response"),
|
||||
temperature: flag.Float64("temperature", 0, "Temperature parameter"),
|
||||
seed: flag.Int("seed", 0, "Random seed"),
|
||||
timeout: flag.Int("timeout", 60*5, "Timeout in seconds (default 300s)"),
|
||||
prompt: flag.String("p", DefaultPrompt, "Prompt to use"),
|
||||
imageFile: flag.String("image", "", "Filename for an image to include"),
|
||||
keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"),
|
||||
format: flag.String("format", "benchstat", "Output format [benchstat|csv]"),
|
||||
outputFile: flag.String("output", "", "Output file for results (stdout if empty)"),
|
||||
verbose: flag.Bool("v", false, "Show system information"),
|
||||
debug: flag.Bool("debug", false, "Show debug information"),
|
||||
warmup: flag.Int("warmup", 1, "Number of warmup requests before timing"),
|
||||
promptTokens: flag.Int("prompt-tokens", 0, "Generate prompt targeting ~N tokens (0 = use -p prompt)"),
|
||||
numCtx: flag.Int("num-ctx", 0, "Context size (0 = server default)"),
|
||||
}
|
||||
|
||||
flag.Usage = func() {
|
||||
@@ -302,11 +541,12 @@ func main() {
|
||||
fmt.Fprintf(os.Stderr, "Options:\n")
|
||||
flag.PrintDefaults()
|
||||
fmt.Fprintf(os.Stderr, "\nExamples:\n")
|
||||
fmt.Fprintf(os.Stderr, " bench -model gpt-oss:20b -epochs 3 -temperature 0.7\n")
|
||||
fmt.Fprintf(os.Stderr, " bench -model gemma3,llama3 -epochs 6\n")
|
||||
fmt.Fprintf(os.Stderr, " bench -model gemma3 -epochs 6 -prompt-tokens 512 -format csv\n")
|
||||
}
|
||||
flag.Parse()
|
||||
|
||||
if !slices.Contains([]string{"markdown", "benchstat", "csv"}, *fOpt.format) {
|
||||
if !slices.Contains([]string{"benchstat", "csv"}, *fOpt.format) {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format)
|
||||
os.Exit(1)
|
||||
}
|
||||
@@ -317,5 +557,5 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
BenchmarkChat(fOpt)
|
||||
BenchmarkModel(fOpt)
|
||||
}
|
||||
|
||||
421
cmd/cmd.go
@@ -11,6 +11,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -38,9 +39,12 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
"github.com/ollama/ollama/cmd/tui"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/readline"
|
||||
@@ -57,36 +61,42 @@ import (
|
||||
|
||||
func init() {
|
||||
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
|
||||
config.DefaultSingleSelector = func(title string, items []config.ModelItem, current string) (string, error) {
|
||||
launch.DefaultSingleSelector = func(title string, items []launch.ModelItem, current string) (string, error) {
|
||||
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
|
||||
return "", fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode")
|
||||
}
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
result, err := tui.SelectSingle(title, tuiItems, current)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", config.ErrCancelled
|
||||
return "", launch.ErrCancelled
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
config.DefaultMultiSelector = func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
|
||||
launch.DefaultMultiSelector = func(title string, items []launch.ModelItem, preChecked []string) ([]string, error) {
|
||||
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
|
||||
return nil, fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode")
|
||||
}
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return nil, config.ErrCancelled
|
||||
return nil, launch.ErrCancelled
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
config.DefaultSignIn = func(modelName, signInURL string) (string, error) {
|
||||
launch.DefaultSignIn = func(modelName, signInURL string) (string, error) {
|
||||
userName, err := tui.RunSignIn(modelName, signInURL)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", config.ErrCancelled
|
||||
return "", launch.ErrCancelled
|
||||
}
|
||||
return userName, err
|
||||
}
|
||||
|
||||
config.DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
launch.DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
ok, err := tui.RunConfirm(prompt)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return false, config.ErrCancelled
|
||||
return false, launch.ErrCancelled
|
||||
}
|
||||
return ok, err
|
||||
}
|
||||
@@ -131,6 +141,17 @@ func getModelfileName(cmd *cobra.Command) (string, error) {
|
||||
return absName, nil
|
||||
}
|
||||
|
||||
// isLocalhost returns true if the configured Ollama host is a loopback or unspecified address.
|
||||
func isLocalhost() bool {
|
||||
host := envconfig.Host()
|
||||
h, _, _ := net.SplitHostPort(host.Host)
|
||||
if h == "localhost" {
|
||||
return true
|
||||
}
|
||||
ip := net.ParseIP(h)
|
||||
return ip != nil && (ip.IsLoopback() || ip.IsUnspecified())
|
||||
}
|
||||
|
||||
func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
@@ -145,6 +166,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
// Check for --experimental flag for safetensors model creation
|
||||
experimental, _ := cmd.Flags().GetBool("experimental")
|
||||
if experimental {
|
||||
if !isLocalhost() {
|
||||
return errors.New("remote safetensor model creation not yet supported")
|
||||
}
|
||||
// Get Modelfile content - either from -f flag or default to "FROM ."
|
||||
var reader io.Reader
|
||||
filename, err := getModelfileName(cmd)
|
||||
@@ -168,29 +192,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("failed to parse Modelfile: %w", err)
|
||||
}
|
||||
|
||||
// Extract FROM path and configuration
|
||||
var modelDir string
|
||||
mfConfig := &xcreateclient.ModelfileConfig{}
|
||||
|
||||
for _, cmd := range modelfile.Commands {
|
||||
switch cmd.Name {
|
||||
case "model":
|
||||
modelDir = cmd.Args
|
||||
case "template":
|
||||
mfConfig.Template = cmd.Args
|
||||
case "system":
|
||||
mfConfig.System = cmd.Args
|
||||
case "license":
|
||||
mfConfig.License = cmd.Args
|
||||
case "parser":
|
||||
mfConfig.Parser = cmd.Args
|
||||
case "renderer":
|
||||
mfConfig.Renderer = cmd.Args
|
||||
}
|
||||
}
|
||||
|
||||
if modelDir == "" {
|
||||
modelDir = "."
|
||||
modelDir, mfConfig, err := xcreateclient.ConfigFromModelfile(modelfile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Resolve relative paths based on Modelfile location
|
||||
@@ -214,6 +218,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
if filename == "" {
|
||||
// No Modelfile found - check if current directory is an image gen model
|
||||
if create.IsTensorModelDir(".") {
|
||||
if !isLocalhost() {
|
||||
return errors.New("remote safetensor model creation not yet supported")
|
||||
}
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
||||
ModelName: modelName,
|
||||
@@ -406,12 +413,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||
return err
|
||||
}
|
||||
|
||||
requestedCloud := modelref.HasExplicitCloudSource(opts.Model)
|
||||
|
||||
if info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model}); err != nil {
|
||||
return err
|
||||
} else if info.RemoteHost != "" {
|
||||
} else if info.RemoteHost != "" || requestedCloud {
|
||||
// Cloud model, no need to load/unload
|
||||
|
||||
isCloud := strings.HasPrefix(info.RemoteHost, "https://ollama.com")
|
||||
isCloud := requestedCloud || strings.HasPrefix(info.RemoteHost, "https://ollama.com")
|
||||
|
||||
// Check if user is signed in for ollama.com cloud models
|
||||
if isCloud {
|
||||
@@ -422,10 +431,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||
|
||||
if opts.ShowConnect {
|
||||
p.StopAndClear()
|
||||
remoteModel := info.RemoteModel
|
||||
if remoteModel == "" {
|
||||
remoteModel = opts.Model
|
||||
}
|
||||
if isCloud {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", remoteModel)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", remoteModel, info.RemoteHost)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -497,6 +510,64 @@ func generateEmbedding(cmd *cobra.Command, modelName, input string, keepAlive *a
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(parthsareen): consolidate with TUI signin flow
|
||||
func handleCloudAuthorizationError(err error) bool {
|
||||
var authErr api.AuthorizationError
|
||||
if errors.As(err, &authErr) && authErr.StatusCode == http.StatusUnauthorized {
|
||||
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
|
||||
if authErr.SigninURL != "" {
|
||||
fmt.Printf(ConnectInstructions, authErr.SigninURL)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// TEMP(drifkin): To match legacy `ollama run some-model:cloud` behavior, we
|
||||
// best-effort pull cloud stub files for any explicit cloud source models.
|
||||
// Remove this once `/api/tags` is cloud-aware.
|
||||
func ensureCloudStub(ctx context.Context, client *api.Client, modelName string) {
|
||||
if !modelref.HasExplicitCloudSource(modelName) {
|
||||
return
|
||||
}
|
||||
|
||||
normalizedName, _, err := modelref.NormalizePullName(modelName)
|
||||
if err != nil {
|
||||
slog.Warn("failed to normalize pull name", "model", modelName, "error", err, "normalizedName", normalizedName)
|
||||
return
|
||||
}
|
||||
|
||||
listResp, err := client.List(ctx)
|
||||
if err != nil {
|
||||
slog.Warn("failed to list models", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if hasListedModelName(listResp.Models, modelName) || hasListedModelName(listResp.Models, normalizedName) {
|
||||
return
|
||||
}
|
||||
|
||||
logutil.Trace("pulling cloud stub", "model", modelName, "normalizedName", normalizedName)
|
||||
err = client.Pull(ctx, &api.PullRequest{
|
||||
Model: normalizedName,
|
||||
}, func(api.ProgressResponse) error {
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
slog.Warn("failed to pull cloud stub", "model", modelName, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func hasListedModelName(models []api.ListModelResponse, name string) bool {
|
||||
for _, m := range models {
|
||||
if strings.EqualFold(m.Name, name) || strings.EqualFold(m.Model, name) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
interactive := true
|
||||
|
||||
@@ -585,17 +656,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
opts.WordWrap = !nowrap
|
||||
|
||||
useImagegen := false
|
||||
if cmd.Flags().Lookup("imagegen") != nil {
|
||||
useImagegen, err = cmd.Flags().GetBool("imagegen")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if useImagegen {
|
||||
opts.Options["use_imagegen_runner"] = true
|
||||
}
|
||||
|
||||
// Fill out the rest of the options based on information about the
|
||||
// model.
|
||||
client, err := api.ClientFromEnvironment()
|
||||
@@ -604,12 +664,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
requestedCloud := modelref.HasExplicitCloudSource(name)
|
||||
|
||||
info, err := func() (*api.ShowResponse, error) {
|
||||
showReq := &api.ShowRequest{Name: name}
|
||||
info, err := client.Show(cmd.Context(), showReq)
|
||||
var se api.StatusError
|
||||
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
|
||||
if requestedCloud {
|
||||
return nil, err
|
||||
}
|
||||
if err := PullHandler(cmd, []string{name}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -618,15 +682,21 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return info, err
|
||||
}()
|
||||
if err != nil {
|
||||
if handleCloudAuthorizationError(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
ensureCloudStub(cmd.Context(), client, name)
|
||||
|
||||
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkFlag.Changed)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
|
||||
audioCapable := slices.Contains(info.Capabilities, model.CapabilityAudio)
|
||||
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision) || audioCapable
|
||||
|
||||
// TODO: remove the projector info and vision info checks below,
|
||||
// these are left in for backwards compatibility with older servers
|
||||
@@ -712,7 +782,13 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
return generateInteractive(cmd, opts)
|
||||
}
|
||||
return generate(cmd, opts)
|
||||
if err := generate(cmd, opts); err != nil {
|
||||
if handleCloudAuthorizationError(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func SigninHandler(cmd *cobra.Command, args []string) error {
|
||||
@@ -1419,6 +1495,9 @@ type displayResponseState struct {
|
||||
|
||||
func displayResponse(content string, wordWrap bool, state *displayResponseState) {
|
||||
termWidth, _, _ := term.GetSize(int(os.Stdout.Fd()))
|
||||
if termWidth == 0 {
|
||||
termWidth = 80
|
||||
}
|
||||
if wordWrap && termWidth >= 10 {
|
||||
for _, ch := range content {
|
||||
if state.lineLength+1 > termWidth-5 {
|
||||
@@ -1892,6 +1971,24 @@ func ensureServerRunning(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
func launchInteractiveModel(cmd *cobra.Command, modelName string) error {
|
||||
opts := runOptions{
|
||||
Model: modelName,
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]any{},
|
||||
ShowConnect: true,
|
||||
}
|
||||
// loadOrUnloadModel is cloud-safe here: remote/cloud models skip local preload
|
||||
// and only validate auth/connectivity before interactive chat starts.
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
return fmt.Errorf("error loading model: %w", err)
|
||||
}
|
||||
if err := generateInteractive(cmd, opts); err != nil {
|
||||
return fmt.Errorf("error running model: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// runInteractiveTUI runs the main interactive TUI menu.
|
||||
func runInteractiveTUI(cmd *cobra.Command) {
|
||||
// Ensure the server is running before showing the TUI
|
||||
@@ -1900,175 +1997,85 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
||||
return
|
||||
}
|
||||
|
||||
// Selector adapters for tui
|
||||
singleSelector := func(title string, items []config.ModelItem, current string) (string, error) {
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
result, err := tui.SelectSingle(title, tuiItems, current)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return "", config.ErrCancelled
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
multiSelector := func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
|
||||
if errors.Is(err, tui.ErrCancelled) {
|
||||
return nil, config.ErrCancelled
|
||||
}
|
||||
return result, err
|
||||
deps := launcherDeps{
|
||||
buildState: launch.BuildLauncherState,
|
||||
runMenu: tui.RunMenu,
|
||||
resolveRunModel: launch.ResolveRunModel,
|
||||
launchIntegration: launch.LaunchIntegration,
|
||||
runModel: launchInteractiveModel,
|
||||
}
|
||||
|
||||
for {
|
||||
result, err := tui.Run()
|
||||
continueLoop, err := runInteractiveTUIStep(cmd, deps)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
}
|
||||
if !continueLoop {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
runModel := func(modelName string) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
return
|
||||
}
|
||||
if err := config.ShowOrPull(cmd.Context(), client, modelName); err != nil {
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
return
|
||||
}
|
||||
_ = config.SetLastModel(modelName)
|
||||
opts := runOptions{
|
||||
Model: modelName,
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]any{},
|
||||
ShowConnect: true,
|
||||
}
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error loading model: %v\n", err)
|
||||
return
|
||||
}
|
||||
if err := generateInteractive(cmd, opts); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error running model: %v\n", err)
|
||||
}
|
||||
}
|
||||
type launcherDeps struct {
|
||||
buildState func(context.Context) (*launch.LauncherState, error)
|
||||
runMenu func(*launch.LauncherState) (tui.TUIAction, error)
|
||||
resolveRunModel func(context.Context, launch.RunModelRequest) (string, error)
|
||||
launchIntegration func(context.Context, launch.IntegrationLaunchRequest) error
|
||||
runModel func(*cobra.Command, string) error
|
||||
}
|
||||
|
||||
launchIntegration := func(name string) bool {
|
||||
if err := config.EnsureInstalled(name); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
return true
|
||||
}
|
||||
// If not configured or model no longer exists, prompt for model selection
|
||||
configuredModel := config.IntegrationModel(name)
|
||||
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) || config.IsCloudModelDisabled(cmd.Context(), configuredModel) {
|
||||
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), name, singleSelector, multiSelector)
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
return false // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", name, err)
|
||||
return true
|
||||
}
|
||||
}
|
||||
if err := config.LaunchIntegration(name); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", name, err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
func runInteractiveTUIStep(cmd *cobra.Command, deps launcherDeps) (bool, error) {
|
||||
state, err := deps.buildState(cmd.Context())
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("build launcher state: %w", err)
|
||||
}
|
||||
|
||||
switch result.Selection {
|
||||
case tui.SelectionNone:
|
||||
// User quit
|
||||
return
|
||||
case tui.SelectionRunModel:
|
||||
_ = config.SetLastSelection("run")
|
||||
if modelName := config.LastModel(); modelName != "" && !config.IsCloudModelDisabled(cmd.Context(), modelName) {
|
||||
runModel(modelName)
|
||||
} else {
|
||||
modelName, err := config.SelectModelWithSelector(cmd.Context(), singleSelector)
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
|
||||
continue
|
||||
}
|
||||
runModel(modelName)
|
||||
}
|
||||
case tui.SelectionChangeRunModel:
|
||||
_ = config.SetLastSelection("run")
|
||||
// Use model from modal if selected, otherwise show picker
|
||||
modelName := result.Model
|
||||
if modelName == "" {
|
||||
var err error
|
||||
modelName, err = config.SelectModelWithSelector(cmd.Context(), singleSelector)
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if config.IsCloudModelDisabled(cmd.Context(), modelName) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
runModel(modelName)
|
||||
case tui.SelectionIntegration:
|
||||
_ = config.SetLastSelection(result.Integration)
|
||||
if !launchIntegration(result.Integration) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
case tui.SelectionChangeIntegration:
|
||||
_ = config.SetLastSelection(result.Integration)
|
||||
if len(result.Models) > 0 {
|
||||
// Filter out cloud-disabled models
|
||||
var filtered []string
|
||||
for _, m := range result.Models {
|
||||
if !config.IsCloudModelDisabled(cmd.Context(), m) {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
continue
|
||||
}
|
||||
result.Models = filtered
|
||||
// Multi-select from modal (Editor integrations)
|
||||
if err := config.SaveAndEditIntegration(result.Integration, result.Models); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
|
||||
continue
|
||||
}
|
||||
if err := config.LaunchIntegrationWithModel(result.Integration, result.Models[0]); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
||||
}
|
||||
} else if result.Model != "" {
|
||||
if config.IsCloudModelDisabled(cmd.Context(), result.Model) {
|
||||
continue
|
||||
}
|
||||
// Single-select from modal - save and launch
|
||||
if err := config.SaveIntegration(result.Integration, []string{result.Model}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error saving config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
if err := config.LaunchIntegrationWithModel(result.Integration, result.Model); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
||||
}
|
||||
} else {
|
||||
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), result.Integration, singleSelector, multiSelector)
|
||||
if errors.Is(err, config.ErrCancelled) {
|
||||
continue // Return to main menu
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
|
||||
continue
|
||||
}
|
||||
if err := config.LaunchIntegration(result.Integration); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
||||
}
|
||||
}
|
||||
action, err := deps.runMenu(state)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("run launcher menu: %w", err)
|
||||
}
|
||||
|
||||
return runLauncherAction(cmd, action, deps)
|
||||
}
|
||||
|
||||
func saveLauncherSelection(action tui.TUIAction) {
|
||||
// Best effort only: this affects menu recall, not launch correctness.
|
||||
_ = config.SetLastSelection(action.LastSelection())
|
||||
}
|
||||
|
||||
func runLauncherAction(cmd *cobra.Command, action tui.TUIAction, deps launcherDeps) (bool, error) {
|
||||
switch action.Kind {
|
||||
case tui.TUIActionNone:
|
||||
return false, nil
|
||||
case tui.TUIActionRunModel:
|
||||
saveLauncherSelection(action)
|
||||
modelName, err := deps.resolveRunModel(cmd.Context(), action.RunModelRequest())
|
||||
if errors.Is(err, launch.ErrCancelled) {
|
||||
return true, nil
|
||||
}
|
||||
if err != nil {
|
||||
return true, fmt.Errorf("selecting model: %w", err)
|
||||
}
|
||||
if err := deps.runModel(cmd, modelName); err != nil {
|
||||
return true, err
|
||||
}
|
||||
return true, nil
|
||||
case tui.TUIActionLaunchIntegration:
|
||||
saveLauncherSelection(action)
|
||||
err := deps.launchIntegration(cmd.Context(), action.IntegrationLaunchRequest())
|
||||
if errors.Is(err, launch.ErrCancelled) {
|
||||
return true, nil
|
||||
}
|
||||
if err != nil {
|
||||
return true, fmt.Errorf("launching %s: %w", action.Integration, err)
|
||||
}
|
||||
// VS Code is a GUI app — exit the TUI loop after launching
|
||||
if action.Integration == "vscode" {
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
default:
|
||||
return false, fmt.Errorf("unknown launcher action: %d", action.Kind)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2338,7 +2345,7 @@ func NewCLI() *cobra.Command {
|
||||
copyCmd,
|
||||
deleteCmd,
|
||||
runnerCmd,
|
||||
config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
|
||||
launch.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
|
||||
)
|
||||
|
||||
return rootCmd
|
||||
|
||||
270
cmd/cmd_launcher_test.go
Normal file
@@ -0,0 +1,270 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
"github.com/ollama/ollama/cmd/tui"
|
||||
)
|
||||
|
||||
func setCmdTestHome(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
t.Setenv("HOME", dir)
|
||||
t.Setenv("USERPROFILE", dir)
|
||||
}
|
||||
|
||||
func unexpectedRunModelResolution(t *testing.T) func(context.Context, launch.RunModelRequest) (string, error) {
|
||||
t.Helper()
|
||||
return func(ctx context.Context, req launch.RunModelRequest) (string, error) {
|
||||
t.Fatalf("did not expect run-model resolution: %+v", req)
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
|
||||
func unexpectedIntegrationLaunch(t *testing.T) func(context.Context, launch.IntegrationLaunchRequest) error {
|
||||
t.Helper()
|
||||
return func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||
t.Fatalf("did not expect integration launch: %+v", req)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func unexpectedModelLaunch(t *testing.T) func(*cobra.Command, string) error {
|
||||
t.Helper()
|
||||
return func(cmd *cobra.Command, model string) error {
|
||||
t.Fatalf("did not expect chat launch: %s", model)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunInteractiveTUI_RunModelActionsUseResolveRunModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
action tui.TUIAction
|
||||
wantForce bool
|
||||
wantModel string
|
||||
}{
|
||||
{
|
||||
name: "enter uses saved model flow",
|
||||
action: tui.TUIAction{Kind: tui.TUIActionRunModel},
|
||||
wantModel: "qwen3:8b",
|
||||
},
|
||||
{
|
||||
name: "right forces picker",
|
||||
action: tui.TUIAction{Kind: tui.TUIActionRunModel, ForceConfigure: true},
|
||||
wantForce: true,
|
||||
wantModel: "glm-5:cloud",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setCmdTestHome(t, t.TempDir())
|
||||
|
||||
var menuCalls int
|
||||
runMenu := func(state *launch.LauncherState) (tui.TUIAction, error) {
|
||||
menuCalls++
|
||||
if menuCalls == 1 {
|
||||
return tt.action, nil
|
||||
}
|
||||
return tui.TUIAction{Kind: tui.TUIActionNone}, nil
|
||||
}
|
||||
|
||||
var gotReq launch.RunModelRequest
|
||||
var launched string
|
||||
deps := launcherDeps{
|
||||
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
|
||||
return &launch.LauncherState{}, nil
|
||||
},
|
||||
runMenu: runMenu,
|
||||
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
|
||||
gotReq = req
|
||||
return tt.wantModel, nil
|
||||
},
|
||||
launchIntegration: unexpectedIntegrationLaunch(t),
|
||||
runModel: func(cmd *cobra.Command, model string) error {
|
||||
launched = model
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(context.Background())
|
||||
for {
|
||||
continueLoop, err := runInteractiveTUIStep(cmd, deps)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected step error: %v", err)
|
||||
}
|
||||
if !continueLoop {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if gotReq.ForcePicker != tt.wantForce {
|
||||
t.Fatalf("expected ForcePicker=%v, got %v", tt.wantForce, gotReq.ForcePicker)
|
||||
}
|
||||
if launched != tt.wantModel {
|
||||
t.Fatalf("expected interactive launcher to run %q, got %q", tt.wantModel, launched)
|
||||
}
|
||||
if got := config.LastSelection(); got != "run" {
|
||||
t.Fatalf("expected last selection to be run, got %q", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunInteractiveTUI_IntegrationActionsUseLaunchIntegration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
action tui.TUIAction
|
||||
wantForce bool
|
||||
}{
|
||||
{
|
||||
name: "enter launches integration",
|
||||
action: tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"},
|
||||
},
|
||||
{
|
||||
name: "right forces configure",
|
||||
action: tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude", ForceConfigure: true},
|
||||
wantForce: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setCmdTestHome(t, t.TempDir())
|
||||
|
||||
var menuCalls int
|
||||
runMenu := func(state *launch.LauncherState) (tui.TUIAction, error) {
|
||||
menuCalls++
|
||||
if menuCalls == 1 {
|
||||
return tt.action, nil
|
||||
}
|
||||
return tui.TUIAction{Kind: tui.TUIActionNone}, nil
|
||||
}
|
||||
|
||||
var gotReq launch.IntegrationLaunchRequest
|
||||
deps := launcherDeps{
|
||||
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
|
||||
return &launch.LauncherState{}, nil
|
||||
},
|
||||
runMenu: runMenu,
|
||||
resolveRunModel: unexpectedRunModelResolution(t),
|
||||
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||
gotReq = req
|
||||
return nil
|
||||
},
|
||||
runModel: unexpectedModelLaunch(t),
|
||||
}
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(context.Background())
|
||||
for {
|
||||
continueLoop, err := runInteractiveTUIStep(cmd, deps)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected step error: %v", err)
|
||||
}
|
||||
if !continueLoop {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if gotReq.Name != "claude" {
|
||||
t.Fatalf("expected integration name to be passed through, got %q", gotReq.Name)
|
||||
}
|
||||
if gotReq.ForceConfigure != tt.wantForce {
|
||||
t.Fatalf("expected ForceConfigure=%v, got %v", tt.wantForce, gotReq.ForceConfigure)
|
||||
}
|
||||
if got := config.LastSelection(); got != "claude" {
|
||||
t.Fatalf("expected last selection to be claude, got %q", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunLauncherAction_RunModelContinuesAfterCancellation(t *testing.T) {
|
||||
setCmdTestHome(t, t.TempDir())
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(context.Background())
|
||||
|
||||
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionRunModel}, launcherDeps{
|
||||
buildState: nil,
|
||||
runMenu: nil,
|
||||
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
|
||||
return "", launch.ErrCancelled
|
||||
},
|
||||
launchIntegration: unexpectedIntegrationLaunch(t),
|
||||
runModel: unexpectedModelLaunch(t),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error on cancellation, got %v", err)
|
||||
}
|
||||
if !continueLoop {
|
||||
t.Fatal("expected cancellation to continue the menu loop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunLauncherAction_VSCodeExitsTUILoop(t *testing.T) {
|
||||
setCmdTestHome(t, t.TempDir())
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(context.Background())
|
||||
|
||||
// VS Code should exit the TUI loop (return false) after a successful launch.
|
||||
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "vscode"}, launcherDeps{
|
||||
resolveRunModel: unexpectedRunModelResolution(t),
|
||||
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||
return nil
|
||||
},
|
||||
runModel: unexpectedModelLaunch(t),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
if continueLoop {
|
||||
t.Fatal("expected vscode launch to exit the TUI loop (return false)")
|
||||
}
|
||||
|
||||
// Other integrations should continue the TUI loop (return true).
|
||||
continueLoop, err = runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"}, launcherDeps{
|
||||
resolveRunModel: unexpectedRunModelResolution(t),
|
||||
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||
return nil
|
||||
},
|
||||
runModel: unexpectedModelLaunch(t),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
if !continueLoop {
|
||||
t.Fatal("expected non-vscode integration to continue the TUI loop (return true)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunLauncherAction_IntegrationContinuesAfterCancellation(t *testing.T) {
|
||||
setCmdTestHome(t, t.TempDir())
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(context.Background())
|
||||
|
||||
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"}, launcherDeps{
|
||||
buildState: nil,
|
||||
runMenu: nil,
|
||||
resolveRunModel: unexpectedRunModelResolution(t),
|
||||
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||
return launch.ErrCancelled
|
||||
},
|
||||
runModel: unexpectedModelLaunch(t),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error on cancellation, got %v", err)
|
||||
}
|
||||
if !continueLoop {
|
||||
t.Fatal("expected cancellation to continue the menu loop")
|
||||
}
|
||||
}
|
||||
501
cmd/cmd_test.go
@@ -301,7 +301,7 @@ Weigh anchor!
|
||||
ParameterSize: "7B",
|
||||
QuantizationLevel: "FP16",
|
||||
},
|
||||
Requires: "0.14.0",
|
||||
Requires: "0.19.0",
|
||||
}, false, &b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -310,10 +310,17 @@ Weigh anchor!
|
||||
architecture test
|
||||
parameters 7B
|
||||
quantization FP16
|
||||
requires 0.14.0
|
||||
requires 0.19.0
|
||||
|
||||
`
|
||||
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||
trimLinePadding := func(s string) string {
|
||||
lines := strings.Split(s, "\n")
|
||||
for i, line := range lines {
|
||||
lines[i] = strings.TrimRight(line, " \t\r")
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
if diff := cmp.Diff(trimLinePadding(expect), trimLinePadding(b.String())); diff != "" {
|
||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
@@ -705,6 +712,347 @@ func TestRunEmbeddingModelNoInput(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunHandler_CloudAuthErrorOnShow_PrintsSigninMessage(t *testing.T) {
|
||||
var generateCalled bool
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
if err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "unauthorized",
|
||||
"signin_url": "https://ollama.com/signin",
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||
generateCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
t.Cleanup(mockServer.Close)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(t.Context())
|
||||
cmd.Flags().String("keepalive", "", "")
|
||||
cmd.Flags().Bool("truncate", false, "")
|
||||
cmd.Flags().Int("dimensions", 0, "")
|
||||
cmd.Flags().Bool("verbose", false, "")
|
||||
cmd.Flags().Bool("insecure", false, "")
|
||||
cmd.Flags().Bool("nowordwrap", false, "")
|
||||
cmd.Flags().String("format", "", "")
|
||||
cmd.Flags().String("think", "", "")
|
||||
cmd.Flags().Bool("hidethinking", false, "")
|
||||
|
||||
oldStdout := os.Stdout
|
||||
readOut, writeOut, _ := os.Pipe()
|
||||
os.Stdout = writeOut
|
||||
t.Cleanup(func() { os.Stdout = oldStdout })
|
||||
|
||||
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||
|
||||
_ = writeOut.Close()
|
||||
var out bytes.Buffer
|
||||
_, _ = io.Copy(&out, readOut)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("RunHandler returned error: %v", err)
|
||||
}
|
||||
|
||||
if generateCalled {
|
||||
t.Fatal("expected run to stop before /api/generate after unauthorized /api/show")
|
||||
}
|
||||
|
||||
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
|
||||
t.Fatalf("expected sign-in guidance message, got %q", out.String())
|
||||
}
|
||||
|
||||
if !strings.Contains(out.String(), "https://ollama.com/signin") {
|
||||
t.Fatalf("expected signin_url in output, got %q", out.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunHandler_CloudAuthErrorOnGenerate_PrintsSigninMessage(t *testing.T) {
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||
Capabilities: []model.Capability{model.CapabilityCompletion},
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
if err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "unauthorized",
|
||||
"signin_url": "https://ollama.com/signin",
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
t.Cleanup(mockServer.Close)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(t.Context())
|
||||
cmd.Flags().String("keepalive", "", "")
|
||||
cmd.Flags().Bool("truncate", false, "")
|
||||
cmd.Flags().Int("dimensions", 0, "")
|
||||
cmd.Flags().Bool("verbose", false, "")
|
||||
cmd.Flags().Bool("insecure", false, "")
|
||||
cmd.Flags().Bool("nowordwrap", false, "")
|
||||
cmd.Flags().String("format", "", "")
|
||||
cmd.Flags().String("think", "", "")
|
||||
cmd.Flags().Bool("hidethinking", false, "")
|
||||
|
||||
oldStdout := os.Stdout
|
||||
readOut, writeOut, _ := os.Pipe()
|
||||
os.Stdout = writeOut
|
||||
t.Cleanup(func() { os.Stdout = oldStdout })
|
||||
|
||||
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||
|
||||
_ = writeOut.Close()
|
||||
var out bytes.Buffer
|
||||
_, _ = io.Copy(&out, readOut)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("RunHandler returned error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
|
||||
t.Fatalf("expected sign-in guidance message, got %q", out.String())
|
||||
}
|
||||
|
||||
if !strings.Contains(out.String(), "https://ollama.com/signin") {
|
||||
t.Fatalf("expected signin_url in output, got %q", out.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunHandler_ExplicitCloudStubMissing_PullsNormalizedNameTEMP(t *testing.T) {
|
||||
var pulledModel string
|
||||
var generateCalled bool
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||
Capabilities: []model.Capability{model.CapabilityCompletion},
|
||||
RemoteModel: "gpt-oss:20b",
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
case r.URL.Path == "/api/tags" && r.Method == http.MethodGet:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(api.ListResponse{Models: nil}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
case r.URL.Path == "/api/pull" && r.Method == http.MethodPost:
|
||||
var req api.PullRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
pulledModel = req.Model
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(api.ProgressResponse{Status: "success"}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||
generateCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
t.Cleanup(mockServer.Close)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(t.Context())
|
||||
cmd.Flags().String("keepalive", "", "")
|
||||
cmd.Flags().Bool("truncate", false, "")
|
||||
cmd.Flags().Int("dimensions", 0, "")
|
||||
cmd.Flags().Bool("verbose", false, "")
|
||||
cmd.Flags().Bool("insecure", false, "")
|
||||
cmd.Flags().Bool("nowordwrap", false, "")
|
||||
cmd.Flags().String("format", "", "")
|
||||
cmd.Flags().String("think", "", "")
|
||||
cmd.Flags().Bool("hidethinking", false, "")
|
||||
|
||||
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||
if err != nil {
|
||||
t.Fatalf("RunHandler returned error: %v", err)
|
||||
}
|
||||
|
||||
if pulledModel != "gpt-oss:20b-cloud" {
|
||||
t.Fatalf("expected normalized pull model %q, got %q", "gpt-oss:20b-cloud", pulledModel)
|
||||
}
|
||||
|
||||
if !generateCalled {
|
||||
t.Fatal("expected /api/generate to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunHandler_ExplicitCloudStubPresent_SkipsPullTEMP(t *testing.T) {
|
||||
var pullCalled bool
|
||||
var generateCalled bool
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||
Capabilities: []model.Capability{model.CapabilityCompletion},
|
||||
RemoteModel: "gpt-oss:20b",
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
case r.URL.Path == "/api/tags" && r.Method == http.MethodGet:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(api.ListResponse{
|
||||
Models: []api.ListModelResponse{{Name: "gpt-oss:20b-cloud"}},
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
case r.URL.Path == "/api/pull" && r.Method == http.MethodPost:
|
||||
pullCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(api.ProgressResponse{Status: "success"}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||
generateCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
t.Cleanup(mockServer.Close)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(t.Context())
|
||||
cmd.Flags().String("keepalive", "", "")
|
||||
cmd.Flags().Bool("truncate", false, "")
|
||||
cmd.Flags().Int("dimensions", 0, "")
|
||||
cmd.Flags().Bool("verbose", false, "")
|
||||
cmd.Flags().Bool("insecure", false, "")
|
||||
cmd.Flags().Bool("nowordwrap", false, "")
|
||||
cmd.Flags().String("format", "", "")
|
||||
cmd.Flags().String("think", "", "")
|
||||
cmd.Flags().Bool("hidethinking", false, "")
|
||||
|
||||
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||
if err != nil {
|
||||
t.Fatalf("RunHandler returned error: %v", err)
|
||||
}
|
||||
|
||||
if pullCalled {
|
||||
t.Fatal("expected /api/pull not to be called when cloud stub already exists")
|
||||
}
|
||||
|
||||
if !generateCalled {
|
||||
t.Fatal("expected /api/generate to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunHandler_ExplicitCloudStubPullFailure_IsBestEffortTEMP(t *testing.T) {
|
||||
var generateCalled bool
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||
Capabilities: []model.Capability{model.CapabilityCompletion},
|
||||
RemoteModel: "gpt-oss:20b",
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
case r.URL.Path == "/api/tags" && r.Method == http.MethodGet:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(api.ListResponse{Models: nil}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
case r.URL.Path == "/api/pull" && r.Method == http.MethodPost:
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
if err := json.NewEncoder(w).Encode(map[string]string{"error": "pull failed"}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||
generateCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
t.Cleanup(mockServer.Close)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(t.Context())
|
||||
cmd.Flags().String("keepalive", "", "")
|
||||
cmd.Flags().Bool("truncate", false, "")
|
||||
cmd.Flags().Int("dimensions", 0, "")
|
||||
cmd.Flags().Bool("verbose", false, "")
|
||||
cmd.Flags().Bool("insecure", false, "")
|
||||
cmd.Flags().Bool("nowordwrap", false, "")
|
||||
cmd.Flags().String("format", "", "")
|
||||
cmd.Flags().String("think", "", "")
|
||||
cmd.Flags().Bool("hidethinking", false, "")
|
||||
|
||||
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||
if err != nil {
|
||||
t.Fatalf("RunHandler returned error: %v", err)
|
||||
}
|
||||
|
||||
if !generateCalled {
|
||||
t.Fatal("expected /api/generate to be called despite pull failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelfileName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -1212,6 +1560,20 @@ func TestNewCreateRequest(t *testing.T) {
|
||||
Model: "newmodel",
|
||||
},
|
||||
},
|
||||
{
|
||||
"explicit cloud model preserves source when parent lacks it",
|
||||
"newmodel",
|
||||
runOptions{
|
||||
Model: "qwen3.5:cloud",
|
||||
ParentModel: "qwen3.5",
|
||||
Messages: []api.Message{},
|
||||
WordWrap: true,
|
||||
},
|
||||
&api.CreateRequest{
|
||||
From: "qwen3.5:cloud",
|
||||
Model: "newmodel",
|
||||
},
|
||||
},
|
||||
{
|
||||
"parent model as filepath test",
|
||||
"newmodel",
|
||||
@@ -1557,7 +1919,7 @@ func TestShowInfoImageGen(t *testing.T) {
|
||||
QuantizationLevel: "Q8",
|
||||
},
|
||||
Capabilities: []model.Capability{model.CapabilityImage},
|
||||
Requires: "0.14.0",
|
||||
Requires: "0.19.0",
|
||||
}, false, &b)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -1567,7 +1929,7 @@ func TestShowInfoImageGen(t *testing.T) {
|
||||
" architecture ZImagePipeline \n" +
|
||||
" parameters 10.3B \n" +
|
||||
" quantization Q8 \n" +
|
||||
" requires 0.14.0 \n" +
|
||||
" requires 0.19.0 \n" +
|
||||
"\n" +
|
||||
" Capabilities\n" +
|
||||
" image \n" +
|
||||
@@ -1663,31 +2025,81 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
|
||||
|
||||
func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteHost string
|
||||
whoamiStatus int
|
||||
whoamiResp any
|
||||
expectedError string
|
||||
name string
|
||||
model string
|
||||
showStatus int
|
||||
remoteHost string
|
||||
remoteModel string
|
||||
whoamiStatus int
|
||||
whoamiResp any
|
||||
expectWhoami bool
|
||||
expectedError string
|
||||
expectAuthError bool
|
||||
}{
|
||||
{
|
||||
name: "ollama.com cloud model - user signed in",
|
||||
model: "test-cloud-model",
|
||||
remoteHost: "https://ollama.com",
|
||||
remoteModel: "test-model",
|
||||
whoamiStatus: http.StatusOK,
|
||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||
expectWhoami: true,
|
||||
},
|
||||
{
|
||||
name: "ollama.com cloud model - user not signed in",
|
||||
model: "test-cloud-model",
|
||||
remoteHost: "https://ollama.com",
|
||||
remoteModel: "test-model",
|
||||
whoamiStatus: http.StatusUnauthorized,
|
||||
whoamiResp: map[string]string{
|
||||
"error": "unauthorized",
|
||||
"signin_url": "https://ollama.com/signin",
|
||||
},
|
||||
expectedError: "unauthorized",
|
||||
expectWhoami: true,
|
||||
expectedError: "unauthorized",
|
||||
expectAuthError: true,
|
||||
},
|
||||
{
|
||||
name: "non-ollama.com remote - no auth check",
|
||||
model: "test-cloud-model",
|
||||
remoteHost: "https://other-remote.com",
|
||||
remoteModel: "test-model",
|
||||
whoamiStatus: http.StatusUnauthorized, // should not be called
|
||||
whoamiResp: nil,
|
||||
},
|
||||
{
|
||||
name: "explicit :cloud model - auth check without remote metadata",
|
||||
model: "kimi-k2.5:cloud",
|
||||
remoteHost: "",
|
||||
remoteModel: "",
|
||||
whoamiStatus: http.StatusOK,
|
||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||
expectWhoami: true,
|
||||
},
|
||||
{
|
||||
name: "explicit :cloud model without local stub returns not found by default",
|
||||
model: "minimax-m2.7:cloud",
|
||||
showStatus: http.StatusNotFound,
|
||||
whoamiStatus: http.StatusOK,
|
||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||
expectedError: "not found",
|
||||
expectWhoami: false,
|
||||
expectAuthError: false,
|
||||
},
|
||||
{
|
||||
name: "explicit -cloud model - auth check without remote metadata",
|
||||
model: "kimi-k2.5:latest-cloud",
|
||||
remoteHost: "",
|
||||
remoteModel: "",
|
||||
whoamiStatus: http.StatusOK,
|
||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||
expectWhoami: true,
|
||||
},
|
||||
{
|
||||
name: "dash cloud-like name without explicit source does not require auth",
|
||||
model: "test-cloud-model",
|
||||
remoteHost: "",
|
||||
remoteModel: "",
|
||||
whoamiStatus: http.StatusUnauthorized, // should not be called
|
||||
whoamiResp: nil,
|
||||
},
|
||||
@@ -1699,10 +2111,15 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
if tt.showStatus != 0 && tt.showStatus != http.StatusOK {
|
||||
w.WriteHeader(tt.showStatus)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"error": "not found"})
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||
RemoteHost: tt.remoteHost,
|
||||
RemoteModel: "test-model",
|
||||
RemoteModel: tt.remoteModel,
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
@@ -1715,6 +2132,8 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
case "/api/generate":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
@@ -1727,29 +2146,28 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
cmd.SetContext(t.Context())
|
||||
|
||||
opts := &runOptions{
|
||||
Model: "test-cloud-model",
|
||||
Model: tt.model,
|
||||
ShowConnect: false,
|
||||
}
|
||||
|
||||
err := loadOrUnloadModel(cmd, opts)
|
||||
|
||||
if strings.HasPrefix(tt.remoteHost, "https://ollama.com") {
|
||||
if !whoamiCalled {
|
||||
t.Error("expected whoami to be called for ollama.com cloud model")
|
||||
}
|
||||
} else {
|
||||
if whoamiCalled {
|
||||
t.Error("whoami should not be called for non-ollama.com remote")
|
||||
}
|
||||
if whoamiCalled != tt.expectWhoami {
|
||||
t.Errorf("whoami called = %v, want %v", whoamiCalled, tt.expectWhoami)
|
||||
}
|
||||
|
||||
if tt.expectedError != "" {
|
||||
if err == nil {
|
||||
t.Errorf("expected error containing %q, got nil", tt.expectedError)
|
||||
} else {
|
||||
var authErr api.AuthorizationError
|
||||
if !errors.As(err, &authErr) {
|
||||
t.Errorf("expected AuthorizationError, got %T: %v", err, err)
|
||||
if !tt.expectAuthError && !strings.Contains(strings.ToLower(err.Error()), strings.ToLower(tt.expectedError)) {
|
||||
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
|
||||
}
|
||||
if tt.expectAuthError {
|
||||
var authErr api.AuthorizationError
|
||||
if !errors.As(err, &authErr) {
|
||||
t.Errorf("expected AuthorizationError, got %T: %v", err, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -1760,3 +2178,38 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsLocalhost(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
expected bool
|
||||
}{
|
||||
{"default empty", "", true},
|
||||
{"localhost no port", "localhost", true},
|
||||
{"localhost with port", "localhost:11435", true},
|
||||
{"127.0.0.1 no port", "127.0.0.1", true},
|
||||
{"127.0.0.1 with port", "127.0.0.1:11434", true},
|
||||
{"0.0.0.0 no port", "0.0.0.0", true},
|
||||
{"0.0.0.0 with port", "0.0.0.0:11434", true},
|
||||
{"::1 no port", "::1", true},
|
||||
{"[::1] with port", "[::1]:11434", true},
|
||||
{"loopback with scheme", "http://localhost:11434", true},
|
||||
{"remote hostname", "example.com", false},
|
||||
{"remote hostname with port", "example.com:11434", false},
|
||||
{"remote IP", "192.168.1.1", false},
|
||||
{"remote IP with port", "192.168.1.1:11434", false},
|
||||
{"remote with scheme", "http://example.com:11434", false},
|
||||
{"https remote", "https://example.com:443", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", tt.host)
|
||||
got := isLocalhost()
|
||||
if got != tt.expected {
|
||||
t.Errorf("isLocalhost() with OLLAMA_HOST=%q = %v, want %v", tt.host, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,192 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Claude implements Runner and AliasConfigurer for Claude Code integration
|
||||
type Claude struct{}
|
||||
|
||||
// Compile-time check that Claude implements AliasConfigurer
|
||||
var _ AliasConfigurer = (*Claude)(nil)
|
||||
|
||||
func (c *Claude) String() string { return "Claude Code" }
|
||||
|
||||
func (c *Claude) args(model string, extra []string) []string {
|
||||
var args []string
|
||||
if model != "" {
|
||||
args = append(args, "--model", model)
|
||||
}
|
||||
args = append(args, extra...)
|
||||
return args
|
||||
}
|
||||
|
||||
func (c *Claude) findPath() (string, error) {
|
||||
if p, err := exec.LookPath("claude"); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
name := "claude"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "claude.exe"
|
||||
}
|
||||
fallback := filepath.Join(home, ".claude", "local", name)
|
||||
if _, err := os.Stat(fallback); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
func (c *Claude) Run(model string, args []string) error {
|
||||
claudePath, err := c.findPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
|
||||
}
|
||||
|
||||
cmd := exec.Command(claudePath, c.args(model, args)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
env := append(os.Environ(),
|
||||
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
|
||||
"ANTHROPIC_API_KEY=",
|
||||
"ANTHROPIC_AUTH_TOKEN=ollama",
|
||||
)
|
||||
|
||||
env = append(env, c.modelEnvVars(model)...)
|
||||
|
||||
cmd.Env = env
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
|
||||
func (c *Claude) modelEnvVars(model string) []string {
|
||||
primary := model
|
||||
fast := model
|
||||
if cfg, err := loadIntegration("claude"); err == nil && cfg.Aliases != nil {
|
||||
if p := cfg.Aliases["primary"]; p != "" {
|
||||
primary = p
|
||||
}
|
||||
if f := cfg.Aliases["fast"]; f != "" {
|
||||
fast = f
|
||||
}
|
||||
}
|
||||
return []string{
|
||||
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + primary,
|
||||
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + primary,
|
||||
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + fast,
|
||||
"CLAUDE_CODE_SUBAGENT_MODEL=" + primary,
|
||||
}
|
||||
}
|
||||
|
||||
// ConfigureAliases sets up model aliases for Claude Code.
|
||||
// model: the model to use (if empty, user will be prompted to select)
|
||||
// aliases: existing alias configuration to preserve/update
|
||||
// Cloud-only: subagent routing (fast model) is gated to cloud models only until
|
||||
// there is a better strategy for prompt caching on local models.
|
||||
func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAliases map[string]string, force bool) (map[string]string, bool, error) {
|
||||
aliases := make(map[string]string)
|
||||
for k, v := range existingAliases {
|
||||
aliases[k] = v
|
||||
}
|
||||
|
||||
if model != "" {
|
||||
aliases["primary"] = model
|
||||
}
|
||||
|
||||
if !force && aliases["primary"] != "" {
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
||||
if isCloudModel(ctx, client, aliases["fast"]) {
|
||||
return aliases, false, nil
|
||||
}
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
return aliases, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
items, existingModels, cloudModels, client, err := listModels(ctx)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
|
||||
|
||||
if aliases["primary"] == "" || force {
|
||||
primary, err := DefaultSingleSelector("Select model:", items, aliases["primary"])
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
aliases["primary"] = primary
|
||||
}
|
||||
|
||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
||||
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
return aliases, true, nil
|
||||
}
|
||||
|
||||
// SetAliases syncs the configured aliases to the Ollama server using prefix matching.
|
||||
// Cloud-only: for local models (fast is empty), we delete any existing aliases to
|
||||
// prevent stale routing to a previous cloud model.
|
||||
func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prefixes := []string{"claude-sonnet-", "claude-haiku-"}
|
||||
|
||||
if aliases["fast"] == "" {
|
||||
for _, prefix := range prefixes {
|
||||
_ = client.DeleteAliasExperimental(ctx, &api.AliasDeleteRequest{Alias: prefix})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
prefixAliases := map[string]string{
|
||||
"claude-sonnet-": aliases["primary"],
|
||||
"claude-haiku-": aliases["fast"],
|
||||
}
|
||||
|
||||
var errs []string
|
||||
for prefix, target := range prefixAliases {
|
||||
req := &api.AliasRequest{
|
||||
Alias: prefix,
|
||||
Target: target,
|
||||
PrefixMatching: true,
|
||||
}
|
||||
if err := client.SetAliasExperimental(ctx, req); err != nil {
|
||||
errs = append(errs, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("failed to set aliases: %v", errs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -3,7 +3,6 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -11,7 +10,7 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
)
|
||||
|
||||
type integration struct {
|
||||
@@ -20,6 +19,9 @@ type integration struct {
|
||||
Onboarded bool `json:"onboarded,omitempty"`
|
||||
}
|
||||
|
||||
// IntegrationConfig is the persisted config for one integration.
|
||||
type IntegrationConfig = integration
|
||||
|
||||
type config struct {
|
||||
Integrations map[string]*integration `json:"integrations"`
|
||||
LastModel string `json:"last_model,omitempty"`
|
||||
@@ -124,7 +126,7 @@ func save(cfg *config) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return writeWithBackup(path, data)
|
||||
return fileutil.WriteWithBackup(path, data)
|
||||
}
|
||||
|
||||
func SaveIntegration(appName string, models []string) error {
|
||||
@@ -155,8 +157,8 @@ func SaveIntegration(appName string, models []string) error {
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// integrationOnboarded marks an integration as onboarded in ollama's config.
|
||||
func integrationOnboarded(appName string) error {
|
||||
// MarkIntegrationOnboarded marks an integration as onboarded in Ollama's config.
|
||||
func MarkIntegrationOnboarded(appName string) error {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -174,7 +176,7 @@ func integrationOnboarded(appName string) error {
|
||||
|
||||
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
|
||||
func IntegrationModel(appName string) string {
|
||||
integrationConfig, err := loadIntegration(appName)
|
||||
integrationConfig, err := LoadIntegration(appName)
|
||||
if err != nil || len(integrationConfig.Models) == 0 {
|
||||
return ""
|
||||
}
|
||||
@@ -183,7 +185,7 @@ func IntegrationModel(appName string) string {
|
||||
|
||||
// IntegrationModels returns all configured models for an integration, or nil.
|
||||
func IntegrationModels(appName string) []string {
|
||||
integrationConfig, err := loadIntegration(appName)
|
||||
integrationConfig, err := LoadIntegration(appName)
|
||||
if err != nil || len(integrationConfig.Models) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -228,28 +230,8 @@ func SetLastSelection(selection string) error {
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// ModelExists checks if a model exists on the Ollama server.
|
||||
func ModelExists(ctx context.Context, name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, m := range models.Models {
|
||||
if m.Name == name || strings.HasPrefix(m.Name, name+":") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func loadIntegration(appName string) (*integration, error) {
|
||||
// LoadIntegration returns the saved config for one integration.
|
||||
func LoadIntegration(appName string) (*integration, error) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -263,7 +245,8 @@ func loadIntegration(appName string) (*integration, error) {
|
||||
return integrationConfig, nil
|
||||
}
|
||||
|
||||
func saveAliases(appName string, aliases map[string]string) error {
|
||||
// SaveAliases replaces the saved aliases for one integration.
|
||||
func SaveAliases(appName string, aliases map[string]string) error {
|
||||
if appName == "" {
|
||||
return errors.New("app name cannot be empty")
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -45,12 +44,12 @@ func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}
|
||||
if err := saveAliases("claude", initial); err != nil {
|
||||
if err := SaveAliases("claude", initial); err != nil {
|
||||
t.Fatalf("failed to save initial aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify both are saved
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -63,12 +62,12 @@ func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
|
||||
"primary": "local-model",
|
||||
// fast intentionally missing
|
||||
}
|
||||
if err := saveAliases("claude", updated); err != nil {
|
||||
if err := SaveAliases("claude", updated); err != nil {
|
||||
t.Fatalf("failed to save updated aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify fast is GONE (not merged/preserved)
|
||||
loaded, err = loadIntegration("claude")
|
||||
loaded, err = LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load after update: %v", err)
|
||||
}
|
||||
@@ -91,12 +90,12 @@ func TestSaveAliases_PreservesModels(t *testing.T) {
|
||||
|
||||
// Then update aliases
|
||||
aliases := map[string]string{"primary": "new-model"}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
if err := SaveAliases("claude", aliases); err != nil {
|
||||
t.Fatalf("failed to save aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify models are preserved
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -111,16 +110,16 @@ func TestSaveAliases_EmptyMap(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save with aliases
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Save empty map
|
||||
if err := saveAliases("claude", map[string]string{}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{}); err != nil {
|
||||
t.Fatalf("failed to save empty: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -135,16 +134,16 @@ func TestSaveAliases_NilMap(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save with aliases first
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Save nil map - should clear aliases
|
||||
if err := saveAliases("claude", nil); err != nil {
|
||||
if err := SaveAliases("claude", nil); err != nil {
|
||||
t.Fatalf("failed to save nil: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -155,7 +154,7 @@ func TestSaveAliases_NilMap(t *testing.T) {
|
||||
|
||||
// TestSaveAliases_EmptyAppName returns error
|
||||
func TestSaveAliases_EmptyAppName(t *testing.T) {
|
||||
err := saveAliases("", map[string]string{"primary": "model"})
|
||||
err := SaveAliases("", map[string]string{"primary": "model"})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty app name")
|
||||
}
|
||||
@@ -165,12 +164,12 @@ func TestSaveAliases_CaseInsensitive(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
if err := saveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
|
||||
if err := SaveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Load with different case
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -179,11 +178,11 @@ func TestSaveAliases_CaseInsensitive(t *testing.T) {
|
||||
}
|
||||
|
||||
// Update with different case
|
||||
if err := saveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
|
||||
if err := SaveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
|
||||
t.Fatalf("failed to update: %v", err)
|
||||
}
|
||||
|
||||
loaded, err = loadIntegration("claude")
|
||||
loaded, err = LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load after update: %v", err)
|
||||
}
|
||||
@@ -198,11 +197,11 @@ func TestSaveAliases_CreatesIntegration(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save aliases for non-existent integration
|
||||
if err := saveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
|
||||
if err := SaveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("newintegration")
|
||||
loaded, err := LoadIntegration("newintegration")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -371,12 +370,12 @@ func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) {
|
||||
t.Fatal("server should succeed")
|
||||
}
|
||||
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("saveAliases failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify it was actually saved
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -408,7 +407,7 @@ func TestConfigFile_PreservesUnknownFields(t *testing.T) {
|
||||
os.WriteFile(configPath, []byte(initialConfig), 0o644)
|
||||
|
||||
// Update aliases
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
@@ -440,11 +439,6 @@ func containsHelper(s, substr string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func TestClaudeImplementsAliasConfigurer(t *testing.T) {
|
||||
c := &Claude{}
|
||||
var _ AliasConfigurer = c // Compile-time check
|
||||
}
|
||||
|
||||
func TestModelNameEdgeCases(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -464,11 +458,11 @@ func TestModelNameEdgeCases(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
aliases := map[string]string{"primary": tc.model}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
if err := SaveAliases("claude", aliases); err != nil {
|
||||
t.Fatalf("failed to save model %q: %v", tc.model, err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
@@ -485,7 +479,7 @@ func TestSwitchingScenarios(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial cloud config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}); err != nil {
|
||||
@@ -493,13 +487,13 @@ func TestSwitchingScenarios(t *testing.T) {
|
||||
}
|
||||
|
||||
// Switch to local (no fast)
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "local-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
if loaded.Aliases["fast"] != "" {
|
||||
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
@@ -513,21 +507,21 @@ func TestSwitchingScenarios(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial local config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "local-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to cloud (with fast)
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
if loaded.Aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
@@ -538,7 +532,7 @@ func TestSwitchingScenarios(t *testing.T) {
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial cloud config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model-1",
|
||||
"fast": "cloud-model-1",
|
||||
}); err != nil {
|
||||
@@ -546,14 +540,14 @@ func TestSwitchingScenarios(t *testing.T) {
|
||||
}
|
||||
|
||||
// Switch to different cloud
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model-2",
|
||||
"fast": "cloud-model-2",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
if loaded.Aliases["primary"] != "cloud-model-2" {
|
||||
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
@@ -563,43 +557,13 @@ func TestSwitchingScenarios(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestToolCapabilityFiltering(t *testing.T) {
|
||||
t.Run("all models checked for tool capability", func(t *testing.T) {
|
||||
// Both cloud and local models are checked for tool capability via Show API
|
||||
// Only models with "tools" in capabilities are included
|
||||
m := modelInfo{Name: "tool-model", Remote: false, ToolCapable: true}
|
||||
if !m.ToolCapable {
|
||||
t.Error("tool capable model should be marked as such")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modelInfo includes ToolCapable field", func(t *testing.T) {
|
||||
m := modelInfo{Name: "test", Remote: true, ToolCapable: true}
|
||||
if !m.ToolCapable {
|
||||
t.Error("ToolCapable field should be accessible")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsCloudModel_RequiresClient(t *testing.T) {
|
||||
t.Run("nil client always returns false", func(t *testing.T) {
|
||||
// isCloudModel now only uses Show API, no suffix detection
|
||||
if isCloudModel(context.Background(), nil, "model:cloud") {
|
||||
t.Error("nil client should return false regardless of suffix")
|
||||
}
|
||||
if isCloudModel(context.Background(), nil, "local-model") {
|
||||
t.Error("nil client should return false")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save aliases with one model
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -608,7 +572,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
if loaded.Aliases["primary"] != loaded.Models[0] {
|
||||
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
|
||||
}
|
||||
@@ -622,11 +586,11 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
if err := SaveIntegration("claude", []string{"old-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
|
||||
// They should be different (this is the bug state)
|
||||
if loaded.Models[0] == loaded.Aliases["primary"] {
|
||||
@@ -638,7 +602,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ = loadIntegration("claude")
|
||||
loaded, _ = LoadIntegration("claude")
|
||||
if loaded.Models[0] != loaded.Aliases["primary"] {
|
||||
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
|
||||
loaded.Models[0], loaded.Aliases["primary"])
|
||||
@@ -653,20 +617,20 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
if err := SaveIntegration("claude", []string{"initial-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Update aliases AND models together
|
||||
newAliases := map[string]string{"primary": "updated-model"}
|
||||
if err := saveAliases("claude", newAliases); err != nil {
|
||||
if err := SaveAliases("claude", newAliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := SaveIntegration("claude", []string{newAliases["primary"]}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
if loaded.Models[0] != "updated-model" {
|
||||
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
|
||||
}
|
||||
|
||||
@@ -10,17 +10,10 @@ import (
|
||||
// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests
|
||||
func setTestHome(t *testing.T, dir string) {
|
||||
t.Setenv("HOME", dir)
|
||||
t.Setenv("TMPDIR", dir)
|
||||
t.Setenv("USERPROFILE", dir)
|
||||
}
|
||||
|
||||
// editorPaths is a test helper that safely calls Paths if the runner implements Editor
|
||||
func editorPaths(r Runner) []string {
|
||||
if editor, ok := r.(Editor); ok {
|
||||
return editor.Paths()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestIntegrationConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
@@ -31,7 +24,7 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
config, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -55,11 +48,11 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
"primary": "llama3.2:70b",
|
||||
"fast": "llama3.2:8b",
|
||||
}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
if err := SaveAliases("claude", aliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
config, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -77,14 +70,14 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := SaveIntegration("claude", []string{"model-b"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config, err := loadIntegration("claude")
|
||||
config, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -96,7 +89,7 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
t.Run("defaultModel returns first model", func(t *testing.T) {
|
||||
SaveIntegration("codex", []string{"model-a", "model-b"})
|
||||
|
||||
config, _ := loadIntegration("codex")
|
||||
config, _ := LoadIntegration("codex")
|
||||
defaultModel := ""
|
||||
if len(config.Models) > 0 {
|
||||
defaultModel = config.Models[0]
|
||||
@@ -120,7 +113,7 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
t.Run("app name is case-insensitive", func(t *testing.T) {
|
||||
SaveIntegration("Claude", []string{"model-x"})
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
config, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -137,8 +130,8 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
SaveIntegration("app1", []string{"model-1"})
|
||||
SaveIntegration("app2", []string{"model-2"})
|
||||
|
||||
config1, _ := loadIntegration("app1")
|
||||
config2, _ := loadIntegration("app2")
|
||||
config1, _ := LoadIntegration("app1")
|
||||
config2, _ := LoadIntegration("app2")
|
||||
|
||||
defaultModel1 := ""
|
||||
if len(config1.Models) > 0 {
|
||||
@@ -185,64 +178,6 @@ func TestListIntegrations(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestEditorPaths(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("returns empty for claude (no Editor)", func(t *testing.T) {
|
||||
r := integrations["claude"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("expected no paths for claude, got %v", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns empty for codex (no Editor)", func(t *testing.T) {
|
||||
r := integrations["codex"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("expected no paths for codex, got %v", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns empty for droid when no config exists", func(t *testing.T) {
|
||||
r := integrations["droid"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 0 {
|
||||
t.Errorf("expected no paths, got %v", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns path for droid when config exists", func(t *testing.T) {
|
||||
settingsDir, _ := os.UserHomeDir()
|
||||
settingsDir = filepath.Join(settingsDir, ".factory")
|
||||
os.MkdirAll(settingsDir, 0o755)
|
||||
os.WriteFile(filepath.Join(settingsDir, "settings.json"), []byte(`{}`), 0o644)
|
||||
|
||||
r := integrations["droid"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 1 {
|
||||
t.Errorf("expected 1 path, got %d", len(paths))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns paths for opencode when configs exist", func(t *testing.T) {
|
||||
home, _ := os.UserHomeDir()
|
||||
configDir := filepath.Join(home, ".config", "opencode")
|
||||
stateDir := filepath.Join(home, ".local", "state", "opencode")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "opencode.json"), []byte(`{}`), 0o644)
|
||||
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644)
|
||||
|
||||
r := integrations["opencode"]
|
||||
paths := editorPaths(r)
|
||||
if len(paths) != 2 {
|
||||
t.Errorf("expected 2 paths, got %d: %v", len(paths), paths)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
@@ -251,7 +186,7 @@ func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
||||
os.MkdirAll(dir, 0o755)
|
||||
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
|
||||
|
||||
_, err := loadIntegration("test")
|
||||
_, err := LoadIntegration("test")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent integration in corrupted file")
|
||||
}
|
||||
@@ -265,7 +200,7 @@ func TestSaveIntegration_NilModels(t *testing.T) {
|
||||
t.Fatalf("saveIntegration with nil models failed: %v", err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("test")
|
||||
config, err := LoadIntegration("test")
|
||||
if err != nil {
|
||||
t.Fatalf("loadIntegration failed: %v", err)
|
||||
}
|
||||
@@ -294,7 +229,7 @@ func TestLoadIntegration_NonexistentIntegration(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
_, err := loadIntegration("nonexistent")
|
||||
_, err := LoadIntegration("nonexistent")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent integration, got nil")
|
||||
}
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// ANSI escape sequences for terminal formatting.
|
||||
const (
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiGreen = "\033[32m"
|
||||
ansiYellow = "\033[33m"
|
||||
)
|
||||
|
||||
// ErrCancelled is returned when the user cancels a selection.
|
||||
var ErrCancelled = errors.New("cancelled")
|
||||
|
||||
// errCancelled is kept as an alias for backward compatibility within the package.
|
||||
var errCancelled = ErrCancelled
|
||||
|
||||
// DefaultConfirmPrompt provides a TUI-based confirmation prompt.
|
||||
// When set, confirmPrompt delegates to it instead of using raw terminal I/O.
|
||||
var DefaultConfirmPrompt func(prompt string) (bool, error)
|
||||
|
||||
func confirmPrompt(prompt string) (bool, error) {
|
||||
if DefaultConfirmPrompt != nil {
|
||||
return DefaultConfirmPrompt(prompt)
|
||||
}
|
||||
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer term.Restore(fd, oldState)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
|
||||
|
||||
buf := make([]byte, 1)
|
||||
for {
|
||||
if _, err := os.Stdin.Read(buf); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
switch buf[0] {
|
||||
case 'Y', 'y', 13:
|
||||
fmt.Fprintf(os.Stderr, "yes\r\n")
|
||||
return true, nil
|
||||
case 'N', 'n', 27, 3:
|
||||
fmt.Fprintf(os.Stderr, "no\r\n")
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestErrCancelled(t *testing.T) {
|
||||
t.Run("NotNil", func(t *testing.T) {
|
||||
if errCancelled == nil {
|
||||
t.Error("errCancelled should not be nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Message", func(t *testing.T) {
|
||||
if errCancelled.Error() != "cancelled" {
|
||||
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/readline"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@@ -46,7 +47,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
|
||||
|
||||
if opts.MultiModal {
|
||||
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, or .webp images.\n", filepath.FromSlash("/path/to/file"))
|
||||
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, .webp images, or .wav audio files.\n", filepath.FromSlash("/path/to/file"))
|
||||
}
|
||||
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
@@ -540,6 +541,13 @@ func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
|
||||
parentModel = ""
|
||||
}
|
||||
|
||||
// Preserve explicit cloud intent for sessions started with `:cloud`.
|
||||
// Cloud model metadata can return a source-less parent_model (for example
|
||||
// "qwen3.5"), which would otherwise make `/save` create a local derivative.
|
||||
if modelref.HasExplicitCloudSource(opts.Model) && !modelref.HasExplicitCloudSource(parentModel) {
|
||||
parentModel = ""
|
||||
}
|
||||
|
||||
req := &api.CreateRequest{
|
||||
Model: name,
|
||||
From: cmp.Or(parentModel, opts.Model),
|
||||
@@ -584,7 +592,7 @@ func extractFileNames(input string) []string {
|
||||
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
|
||||
// and followed by more characters and a file extension
|
||||
// This will capture non filename strings, but we'll check for file existence to remove mismatches
|
||||
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp)\b`
|
||||
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp|wav|mp4|webm|mov|avi|mkv|m4v)\b`
|
||||
re := regexp.MustCompile(regexPattern)
|
||||
|
||||
return re.FindAllString(input, -1)
|
||||
@@ -600,10 +608,16 @@ func extractFileData(input string) (string, []api.ImageData, error) {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
continue
|
||||
} else if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Couldn't process image: %q\n", err)
|
||||
fmt.Fprintf(os.Stderr, "Couldn't process file: %q\n", err)
|
||||
return "", imgs, err
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
|
||||
ext := strings.ToLower(filepath.Ext(nfp))
|
||||
switch ext {
|
||||
case ".wav":
|
||||
fmt.Fprintf(os.Stderr, "Added audio '%s'\n", nfp)
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
|
||||
}
|
||||
input = strings.ReplaceAll(input, "'"+nfp+"'", "")
|
||||
input = strings.ReplaceAll(input, "'"+fp+"'", "")
|
||||
input = strings.ReplaceAll(input, fp, "")
|
||||
@@ -677,9 +691,9 @@ func getImageData(filePath string) ([]byte, error) {
|
||||
}
|
||||
|
||||
contentType := http.DetectContentType(buf)
|
||||
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp"}
|
||||
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp", "audio/wave"}
|
||||
if !slices.Contains(allowedTypes, contentType) {
|
||||
return nil, fmt.Errorf("invalid image type: %s", contentType)
|
||||
return nil, fmt.Errorf("invalid file type: %s", contentType)
|
||||
}
|
||||
|
||||
info, err := file.Stat()
|
||||
@@ -687,8 +701,7 @@ func getImageData(filePath string) ([]byte, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if the file size exceeds 100MB
|
||||
var maxSize int64 = 100 * 1024 * 1024 // 100MB in bytes
|
||||
var maxSize int64 = 100 * 1024 * 1024 // 100MB
|
||||
if info.Size() > maxSize {
|
||||
return nil, errors.New("file size exceeds maximum limit (100MB)")
|
||||
}
|
||||
|
||||
@@ -84,3 +84,33 @@ func TestExtractFileDataRemovesQuotedFilepath(t *testing.T) {
|
||||
assert.Len(t, imgs, 1)
|
||||
assert.Equal(t, cleaned, "before after")
|
||||
}
|
||||
|
||||
func TestExtractFileDataWAV(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
fp := filepath.Join(dir, "sample.wav")
|
||||
data := make([]byte, 600)
|
||||
copy(data[:44], []byte{
|
||||
'R', 'I', 'F', 'F',
|
||||
0x58, 0x02, 0x00, 0x00, // file size - 8
|
||||
'W', 'A', 'V', 'E',
|
||||
'f', 'm', 't', ' ',
|
||||
0x10, 0x00, 0x00, 0x00, // fmt chunk size
|
||||
0x01, 0x00, // PCM
|
||||
0x01, 0x00, // mono
|
||||
0x80, 0x3e, 0x00, 0x00, // 16000 Hz
|
||||
0x00, 0x7d, 0x00, 0x00, // byte rate
|
||||
0x02, 0x00, // block align
|
||||
0x10, 0x00, // 16-bit
|
||||
'd', 'a', 't', 'a',
|
||||
0x34, 0x02, 0x00, 0x00, // data size
|
||||
})
|
||||
if err := os.WriteFile(fp, data, 0o600); err != nil {
|
||||
t.Fatalf("failed to write test audio: %v", err)
|
||||
}
|
||||
|
||||
input := "before " + fp + " after"
|
||||
cleaned, imgs, err := extractFileData(input)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, imgs, 1)
|
||||
assert.Equal(t, "before after", cleaned)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
package config
|
||||
// Package fileutil provides small shared helpers for reading JSON files
|
||||
// and writing config files with backup-on-overwrite semantics.
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -9,7 +11,8 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func readJSONFile(path string) (map[string]any, error) {
|
||||
// ReadJSON reads a JSON object file into a generic map.
|
||||
func ReadJSON(path string) (map[string]any, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -33,12 +36,13 @@ func copyFile(src, dst string) error {
|
||||
return os.WriteFile(dst, data, info.Mode().Perm())
|
||||
}
|
||||
|
||||
func backupDir() string {
|
||||
// BackupDir returns the shared backup directory used before overwriting files.
|
||||
func BackupDir() string {
|
||||
return filepath.Join(os.TempDir(), "ollama-backups")
|
||||
}
|
||||
|
||||
func backupToTmp(srcPath string) (string, error) {
|
||||
dir := backupDir()
|
||||
dir := BackupDir()
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -50,8 +54,8 @@ func backupToTmp(srcPath string) (string, error) {
|
||||
return backupPath, nil
|
||||
}
|
||||
|
||||
// writeWithBackup writes data to path via temp file + rename, backing up any existing file first
|
||||
func writeWithBackup(path string, data []byte) error {
|
||||
// WriteWithBackup writes data to path via temp file + rename, backing up any existing file first.
|
||||
func WriteWithBackup(path string, data []byte) error {
|
||||
var backupPath string
|
||||
// backup must be created before any writes to the target file
|
||||
if existingContent, err := os.ReadFile(path); err == nil {
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -9,6 +9,21 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
tmpRoot, err := os.MkdirTemp("", "fileutil-test-*")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err := os.Setenv("TMPDIR", tmpRoot); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
code := m.Run()
|
||||
_ = os.RemoveAll(tmpRoot)
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func mustMarshal(t *testing.T, v any) []byte {
|
||||
t.Helper()
|
||||
data, err := json.MarshalIndent(v, "", " ")
|
||||
@@ -18,14 +33,19 @@ func mustMarshal(t *testing.T, v any) []byte {
|
||||
return data
|
||||
}
|
||||
|
||||
func isolatedTempDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
return t.TempDir()
|
||||
}
|
||||
|
||||
func TestWriteWithBackup(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tmpDir := isolatedTempDir(t)
|
||||
|
||||
t.Run("creates file", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "new.json")
|
||||
data := mustMarshal(t, map[string]string{"key": "value"})
|
||||
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
if err := WriteWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -43,17 +63,17 @@ func TestWriteWithBackup(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("creates backup in /tmp/ollama-backups", func(t *testing.T) {
|
||||
t.Run("creates backup in the temp backup directory", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "backup.json")
|
||||
|
||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||
|
||||
data := mustMarshal(t, map[string]bool{"updated": true})
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
if err := WriteWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(backupDir())
|
||||
entries, err := os.ReadDir(BackupDir())
|
||||
if err != nil {
|
||||
t.Fatal("backup directory not created")
|
||||
}
|
||||
@@ -63,7 +83,7 @@ func TestWriteWithBackup(t *testing.T) {
|
||||
if filepath.Ext(entry.Name()) != ".json" {
|
||||
name := entry.Name()
|
||||
if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." {
|
||||
backupPath := filepath.Join(backupDir(), name)
|
||||
backupPath := filepath.Join(BackupDir(), name)
|
||||
backup, err := os.ReadFile(backupPath)
|
||||
if err == nil {
|
||||
var backupData map[string]bool
|
||||
@@ -79,7 +99,7 @@ func TestWriteWithBackup(t *testing.T) {
|
||||
}
|
||||
|
||||
if !foundBackup {
|
||||
t.Error("backup file not created in /tmp/ollama-backups")
|
||||
t.Error("backup file not created in backup directory")
|
||||
}
|
||||
|
||||
current, _ := os.ReadFile(path)
|
||||
@@ -94,11 +114,11 @@ func TestWriteWithBackup(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "nobak.json")
|
||||
|
||||
data := mustMarshal(t, map[string]string{"new": "file"})
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
if err := WriteWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries, _ := os.ReadDir(backupDir())
|
||||
entries, _ := os.ReadDir(BackupDir())
|
||||
for _, entry := range entries {
|
||||
if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." {
|
||||
t.Error("backup should not exist for new file")
|
||||
@@ -111,11 +131,11 @@ func TestWriteWithBackup(t *testing.T) {
|
||||
|
||||
data := mustMarshal(t, map[string]string{"key": "value"})
|
||||
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
if err := WriteWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries1, _ := os.ReadDir(backupDir())
|
||||
entries1, _ := os.ReadDir(BackupDir())
|
||||
countBefore := 0
|
||||
for _, e := range entries1 {
|
||||
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
||||
@@ -123,11 +143,11 @@ func TestWriteWithBackup(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
if err := WriteWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries2, _ := os.ReadDir(backupDir())
|
||||
entries2, _ := os.ReadDir(BackupDir())
|
||||
countAfter := 0
|
||||
for _, e := range entries2 {
|
||||
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
||||
@@ -145,11 +165,11 @@ func TestWriteWithBackup(t *testing.T) {
|
||||
|
||||
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
|
||||
data := mustMarshal(t, map[string]int{"v": 2})
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
if err := WriteWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries, _ := os.ReadDir(backupDir())
|
||||
entries, _ := os.ReadDir(BackupDir())
|
||||
var found bool
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
@@ -161,7 +181,7 @@ func TestWriteWithBackup(t *testing.T) {
|
||||
}
|
||||
}
|
||||
found = true
|
||||
os.Remove(filepath.Join(backupDir(), name))
|
||||
os.Remove(filepath.Join(BackupDir(), name))
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -180,7 +200,7 @@ func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
tmpDir := isolatedTempDir(t)
|
||||
path := filepath.Join(tmpDir, "config.json")
|
||||
|
||||
// Create original file
|
||||
@@ -188,13 +208,13 @@ func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
|
||||
os.WriteFile(path, originalContent, 0o644)
|
||||
|
||||
// Make backup directory read-only to force backup failure
|
||||
backupDir := backupDir()
|
||||
backupDir := BackupDir()
|
||||
os.MkdirAll(backupDir, 0o755)
|
||||
os.Chmod(backupDir, 0o444) // Read-only
|
||||
defer os.Chmod(backupDir, 0o755)
|
||||
|
||||
newContent := []byte(`{"updated": true}`)
|
||||
err := writeWithBackup(path, newContent)
|
||||
err := WriteWithBackup(path, newContent)
|
||||
|
||||
// Should fail because backup couldn't be created
|
||||
if err == nil {
|
||||
@@ -215,7 +235,7 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
tmpDir := isolatedTempDir(t)
|
||||
|
||||
// Create a read-only directory
|
||||
readOnlyDir := filepath.Join(tmpDir, "readonly")
|
||||
@@ -224,7 +244,7 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
|
||||
defer os.Chmod(readOnlyDir, 0o755)
|
||||
|
||||
path := filepath.Join(readOnlyDir, "config.json")
|
||||
err := writeWithBackup(path, []byte(`{"test": true}`))
|
||||
err := WriteWithBackup(path, []byte(`{"test": true}`))
|
||||
|
||||
if err == nil {
|
||||
t.Error("expected permission error, got nil")
|
||||
@@ -234,10 +254,10 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
|
||||
// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist.
|
||||
// writeWithBackup doesn't create directories - caller is responsible.
|
||||
func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tmpDir := isolatedTempDir(t)
|
||||
path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json")
|
||||
|
||||
err := writeWithBackup(path, []byte(`{"test": true}`))
|
||||
err := WriteWithBackup(path, []byte(`{"test": true}`))
|
||||
|
||||
// Should fail because directory doesn't exist
|
||||
if err == nil {
|
||||
@@ -252,7 +272,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
|
||||
t.Skip("symlink tests may require admin on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
tmpDir := isolatedTempDir(t)
|
||||
realFile := filepath.Join(tmpDir, "real.json")
|
||||
symlink := filepath.Join(tmpDir, "link.json")
|
||||
|
||||
@@ -261,7 +281,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
|
||||
os.Symlink(realFile, symlink)
|
||||
|
||||
// Write through symlink
|
||||
err := writeWithBackup(symlink, []byte(`{"v": 2}`))
|
||||
err := WriteWithBackup(symlink, []byte(`{"v": 2}`))
|
||||
if err != nil {
|
||||
t.Fatalf("writeWithBackup through symlink failed: %v", err)
|
||||
}
|
||||
@@ -276,7 +296,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
|
||||
// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters.
|
||||
// User may have config files with unusual names.
|
||||
func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tmpDir := isolatedTempDir(t)
|
||||
|
||||
// File with spaces and special chars
|
||||
path := filepath.Join(tmpDir, "my config (backup).json")
|
||||
@@ -305,7 +325,7 @@ func TestCopyFile_PreservesPermissions(t *testing.T) {
|
||||
t.Skip("permission preservation tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
tmpDir := isolatedTempDir(t)
|
||||
src := filepath.Join(tmpDir, "src.json")
|
||||
dst := filepath.Join(tmpDir, "dst.json")
|
||||
|
||||
@@ -327,7 +347,7 @@ func TestCopyFile_PreservesPermissions(t *testing.T) {
|
||||
|
||||
// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist.
|
||||
func TestCopyFile_SourceNotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tmpDir := isolatedTempDir(t)
|
||||
src := filepath.Join(tmpDir, "nonexistent.json")
|
||||
dst := filepath.Join(tmpDir, "dst.json")
|
||||
|
||||
@@ -339,11 +359,11 @@ func TestCopyFile_SourceNotFound(t *testing.T) {
|
||||
|
||||
// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory.
|
||||
func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tmpDir := isolatedTempDir(t)
|
||||
dirPath := filepath.Join(tmpDir, "actualdir")
|
||||
os.MkdirAll(dirPath, 0o755)
|
||||
|
||||
err := writeWithBackup(dirPath, []byte(`{"test": true}`))
|
||||
err := WriteWithBackup(dirPath, []byte(`{"test": true}`))
|
||||
if err == nil {
|
||||
t.Error("expected error when target is a directory, got nil")
|
||||
}
|
||||
@@ -351,10 +371,10 @@ func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
|
||||
|
||||
// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly.
|
||||
func TestWriteWithBackup_EmptyData(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tmpDir := isolatedTempDir(t)
|
||||
path := filepath.Join(tmpDir, "empty.json")
|
||||
|
||||
err := writeWithBackup(path, []byte{})
|
||||
err := WriteWithBackup(path, []byte{})
|
||||
if err != nil {
|
||||
t.Fatalf("writeWithBackup with empty data failed: %v", err)
|
||||
}
|
||||
@@ -375,7 +395,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
tmpDir := isolatedTempDir(t)
|
||||
path := filepath.Join(tmpDir, "unreadable.json")
|
||||
|
||||
// Create file and make it unreadable
|
||||
@@ -384,7 +404,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
|
||||
defer os.Chmod(path, 0o644)
|
||||
|
||||
// Should fail because we can't read the file to compare/backup
|
||||
err := writeWithBackup(path, []byte(`{"updated": true}`))
|
||||
err := WriteWithBackup(path, []byte(`{"updated": true}`))
|
||||
if err == nil {
|
||||
t.Error("expected error when file is unreadable, got nil")
|
||||
}
|
||||
@@ -393,7 +413,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
|
||||
// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes
|
||||
// within the same second (timestamp collision scenario).
|
||||
func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tmpDir := isolatedTempDir(t)
|
||||
path := filepath.Join(tmpDir, "rapid.json")
|
||||
|
||||
// Create initial file
|
||||
@@ -402,7 +422,7 @@ func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
||||
// Rapid successive writes
|
||||
for i := 1; i <= 3; i++ {
|
||||
data := []byte(fmt.Sprintf(`{"v": %d}`, i))
|
||||
if err := writeWithBackup(path, data); err != nil {
|
||||
if err := WriteWithBackup(path, data); err != nil {
|
||||
t.Fatalf("write %d failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
@@ -414,7 +434,7 @@ func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify at least one backup exists
|
||||
entries, _ := os.ReadDir(backupDir())
|
||||
entries, _ := os.ReadDir(BackupDir())
|
||||
var backupCount int
|
||||
for _, e := range entries {
|
||||
if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." {
|
||||
@@ -432,8 +452,9 @@ func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
|
||||
t.Skip("test modifies system temp directory")
|
||||
}
|
||||
|
||||
tmpDir := isolatedTempDir(t)
|
||||
// Create a file at the backup directory path
|
||||
backupPath := backupDir()
|
||||
backupPath := BackupDir()
|
||||
// Clean up any existing directory first
|
||||
os.RemoveAll(backupPath)
|
||||
// Create a file instead of directory
|
||||
@@ -443,11 +464,10 @@ func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
|
||||
os.MkdirAll(backupPath, 0o755)
|
||||
}()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "test.json")
|
||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||
|
||||
err := writeWithBackup(path, []byte(`{"updated": true}`))
|
||||
err := WriteWithBackup(path, []byte(`{"updated": true}`))
|
||||
if err == nil {
|
||||
t.Error("expected error when backup dir is a file, got nil")
|
||||
}
|
||||
@@ -459,7 +479,7 @@ func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
tmpDir := isolatedTempDir(t)
|
||||
|
||||
// Count existing temp files
|
||||
countTempFiles := func() int {
|
||||
@@ -493,7 +513,7 @@ func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
|
||||
badPath := filepath.Join(tmpDir, "isdir")
|
||||
os.MkdirAll(badPath, 0o755)
|
||||
|
||||
_ = writeWithBackup(badPath, []byte(`{"test": true}`))
|
||||
_ = WriteWithBackup(badPath, []byte(`{"test": true}`))
|
||||
|
||||
after := countTempFiles()
|
||||
if after > before {
|
||||
87
cmd/launch/claude.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Claude implements Runner for Claude Code integration.
|
||||
type Claude struct{}
|
||||
|
||||
func (c *Claude) String() string { return "Claude Code" }
|
||||
|
||||
func (c *Claude) args(model string, extra []string) []string {
|
||||
var args []string
|
||||
if model != "" {
|
||||
args = append(args, "--model", model)
|
||||
}
|
||||
args = append(args, extra...)
|
||||
return args
|
||||
}
|
||||
|
||||
func (c *Claude) findPath() (string, error) {
|
||||
if p, err := exec.LookPath("claude"); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
name := "claude"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "claude.exe"
|
||||
}
|
||||
fallback := filepath.Join(home, ".claude", "local", name)
|
||||
if _, err := os.Stat(fallback); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
func (c *Claude) Run(model string, args []string) error {
|
||||
claudePath, err := c.findPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
|
||||
}
|
||||
|
||||
cmd := exec.Command(claudePath, c.args(model, args)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
env := append(os.Environ(),
|
||||
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
|
||||
"ANTHROPIC_API_KEY=",
|
||||
"ANTHROPIC_AUTH_TOKEN=ollama",
|
||||
"CLAUDE_CODE_ATTRIBUTION_HEADER=0",
|
||||
)
|
||||
|
||||
env = append(env, c.modelEnvVars(model)...)
|
||||
|
||||
cmd.Env = env
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
|
||||
func (c *Claude) modelEnvVars(model string) []string {
|
||||
env := []string{
|
||||
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + model,
|
||||
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + model,
|
||||
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + model,
|
||||
"CLAUDE_CODE_SUBAGENT_MODEL=" + model,
|
||||
}
|
||||
|
||||
if isCloudModelName(model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
env = append(env, "CLAUDE_CODE_AUTO_COMPACT_WINDOW="+strconv.Itoa(l.Context))
|
||||
}
|
||||
}
|
||||
|
||||
return env
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"os"
|
||||
@@ -117,10 +117,7 @@ func TestClaudeModelEnvVars(t *testing.T) {
|
||||
return m
|
||||
}
|
||||
|
||||
t.Run("falls back to model param when no aliases saved", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("maps all Claude model env vars to the provided model", func(t *testing.T) {
|
||||
got := envMap(c.modelEnvVars("llama3.2"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2" {
|
||||
t.Errorf("OPUS = %q, want llama3.2", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
@@ -134,65 +131,41 @@ func TestClaudeModelEnvVars(t *testing.T) {
|
||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2" {
|
||||
t.Errorf("SUBAGENT = %q, want llama3.2", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses primary alias for opus sonnet and subagent", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
SaveIntegration("claude", []string{"qwen3:8b"})
|
||||
saveAliases("claude", map[string]string{"primary": "qwen3:8b"})
|
||||
|
||||
got := envMap(c.modelEnvVars("qwen3:8b"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("OPUS = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("SONNET = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("HAIKU = %q, want qwen3:8b (no fast alias)", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||
}
|
||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "qwen3:8b" {
|
||||
t.Errorf("SUBAGENT = %q, want qwen3:8b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
|
||||
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty for local models", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses fast alias for haiku", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
SaveIntegration("claude", []string{"llama3.2:70b"})
|
||||
saveAliases("claude", map[string]string{
|
||||
"primary": "llama3.2:70b",
|
||||
"fast": "llama3.2:8b",
|
||||
})
|
||||
|
||||
got := envMap(c.modelEnvVars("llama3.2:70b"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2:70b" {
|
||||
t.Errorf("OPUS = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
t.Run("supports empty model", func(t *testing.T) {
|
||||
got := envMap(c.modelEnvVars(""))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "" {
|
||||
t.Errorf("OPUS = %q, want empty", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2:70b" {
|
||||
t.Errorf("SONNET = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "" {
|
||||
t.Errorf("SONNET = %q, want empty", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2:8b" {
|
||||
t.Errorf("HAIKU = %q, want llama3.2:8b", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "" {
|
||||
t.Errorf("HAIKU = %q, want empty", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||
}
|
||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2:70b" {
|
||||
t.Errorf("SUBAGENT = %q, want llama3.2:70b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "" {
|
||||
t.Errorf("SUBAGENT = %q, want empty", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||
}
|
||||
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
|
||||
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("alias primary overrides model param", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Run("sets auto compact window for known cloud models", func(t *testing.T) {
|
||||
got := envMap(c.modelEnvVars("glm-5:cloud"))
|
||||
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "202752" {
|
||||
t.Errorf("AUTO_COMPACT_WINDOW = %q, want 202752", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
|
||||
}
|
||||
})
|
||||
|
||||
SaveIntegration("claude", []string{"saved-model"})
|
||||
saveAliases("claude", map[string]string{"primary": "saved-model"})
|
||||
|
||||
got := envMap(c.modelEnvVars("different-model"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "saved-model" {
|
||||
t.Errorf("OPUS = %q, want saved-model", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
t.Run("does not set auto compact window for unknown cloud models", func(t *testing.T) {
|
||||
got := envMap(c.modelEnvVars("unknown-model:cloud"))
|
||||
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
|
||||
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,14 +1,13 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
@@ -22,24 +21,6 @@ func (c *Cline) Run(model string, args []string) error {
|
||||
return fmt.Errorf("cline is not installed, install with: npm install -g cline")
|
||||
}
|
||||
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("cline"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
var err error
|
||||
models, err = resolveEditorModels("cline", models, func() ([]string, error) {
|
||||
return selectModels(context.Background(), "cline", "")
|
||||
})
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("cline", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
@@ -97,7 +78,7 @@ func (c *Cline) Edit(models []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(configPath, data)
|
||||
return fileutil.WriteWithBackup(configPath, data)
|
||||
}
|
||||
|
||||
func (c *Cline) Models() []string {
|
||||
@@ -106,7 +87,7 @@ func (c *Cline) Models() []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
config, err := readJSONFile(filepath.Join(home, ".cline", "data", "globalState.json"))
|
||||
config, err := fileutil.ReadJSON(filepath.Join(home, ".cline", "data", "globalState.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"slices"
|
||||
@@ -16,7 +16,7 @@ func TestCodexArgs(t *testing.T) {
|
||||
}{
|
||||
{"with model", "llama3.2", nil, []string{"--oss", "-m", "llama3.2"}},
|
||||
{"empty model", "", nil, []string{"--oss"}},
|
||||
{"with model and profile", "qwen3-coder", []string{"-p", "myprofile"}, []string{"--oss", "-m", "qwen3-coder", "-p", "myprofile"}},
|
||||
{"with model and profile", "qwen3.5", []string{"-p", "myprofile"}, []string{"--oss", "-m", "qwen3.5", "-p", "myprofile"}},
|
||||
{"with sandbox flag", "llama3.2", []string{"--sandbox", "workspace-write"}, []string{"--oss", "-m", "llama3.2", "--sandbox", "workspace-write"}},
|
||||
}
|
||||
|
||||
598
cmd/launch/command_test.go
Normal file
@@ -0,0 +1,598 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func captureStderr(t *testing.T, fn func()) string {
|
||||
t.Helper()
|
||||
|
||||
oldStderr := os.Stderr
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create stderr pipe: %v", err)
|
||||
}
|
||||
os.Stderr = w
|
||||
defer func() {
|
||||
os.Stderr = oldStderr
|
||||
}()
|
||||
|
||||
done := make(chan string, 1)
|
||||
go func() {
|
||||
var buf bytes.Buffer
|
||||
_, _ = io.Copy(&buf, r)
|
||||
done <- buf.String()
|
||||
}()
|
||||
|
||||
fn()
|
||||
|
||||
_ = w.Close()
|
||||
return <-done
|
||||
}
|
||||
|
||||
func TestLaunchCmd(t *testing.T) {
|
||||
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
mockTUI := func(cmd *cobra.Command) {}
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
|
||||
t.Run("command structure", func(t *testing.T) {
|
||||
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
|
||||
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]")
|
||||
}
|
||||
if cmd.Short == "" {
|
||||
t.Error("Short description should not be empty")
|
||||
}
|
||||
if cmd.Long == "" {
|
||||
t.Error("Long description should not be empty")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("flags exist", func(t *testing.T) {
|
||||
if cmd.Flags().Lookup("model") == nil {
|
||||
t.Error("--model flag should exist")
|
||||
}
|
||||
if cmd.Flags().Lookup("config") == nil {
|
||||
t.Error("--config flag should exist")
|
||||
}
|
||||
if cmd.Flags().Lookup("yes") == nil {
|
||||
t.Error("--yes flag should exist")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PreRunE is set", func(t *testing.T) {
|
||||
if cmd.PreRunE == nil {
|
||||
t.Error("PreRunE should be set to checkServerHeartbeat")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLaunchCmdTUICallback(t *testing.T) {
|
||||
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
t.Run("no args calls TUI", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{})
|
||||
_ = cmd.Execute()
|
||||
|
||||
if !tuiCalled {
|
||||
t.Error("TUI callback should be called when no args provided")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("integration arg bypasses TUI", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.NotFoundHandler())
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"claude"})
|
||||
_ = cmd.Execute()
|
||||
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when integration arg provided")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("--model flag without integration returns error", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"--model", "test-model"})
|
||||
err := cmd.Execute()
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected --model without an integration to fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "require an integration name") {
|
||||
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||
}
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when --model is provided without an integration")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("--config flag without integration returns error", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"--config"})
|
||||
err := cmd.Execute()
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected --config without an integration to fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "require an integration name") {
|
||||
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||
}
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when --config is provided without an integration")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("--yes flag without integration returns error", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"--yes"})
|
||||
err := cmd.Execute()
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected --yes without an integration to fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "require an integration name") {
|
||||
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||
}
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when --yes is provided without an integration")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("extra args without integration return error", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"--model", "test-model", "--", "--sandbox", "workspace-write"})
|
||||
err := cmd.Execute()
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected flags and extra args without an integration to fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "require an integration name") {
|
||||
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||
}
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when flags or extra args are provided without an integration")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLaunchCmdNilHeartbeat(t *testing.T) {
|
||||
cmd := LaunchCmd(nil, nil)
|
||||
if cmd == nil {
|
||||
t.Fatal("LaunchCmd returned nil")
|
||||
}
|
||||
if cmd.PreRunE != nil {
|
||||
t.Log("Note: PreRunE is set even when nil is passed (acceptable)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmdModelFlagFiltersDisabledCloudFromSavedConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
|
||||
if err := config.SaveIntegration("stubeditor", []string{"glm-5:cloud"}); err != nil {
|
||||
t.Fatalf("failed to seed saved config: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||
case "/api/show":
|
||||
fmt.Fprintf(w, `{"model":"llama3.2"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &launcherEditorRunner{}
|
||||
restore := OverrideIntegration("stubeditor", stub)
|
||||
defer restore()
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("launch command failed: %v", err)
|
||||
}
|
||||
|
||||
saved, err := config.LoadIntegration("stubeditor")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload integration config: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
|
||||
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
|
||||
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
if stub.ranModel != "llama3.2" {
|
||||
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmdModelFlagClearsDisabledCloudOverride(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &launcherSingleRunner{}
|
||||
restore := OverrideIntegration("stubapp", stub)
|
||||
defer restore()
|
||||
|
||||
oldSelector := DefaultSingleSelector
|
||||
defer func() { DefaultSingleSelector = oldSelector }()
|
||||
|
||||
var selectorCalls int
|
||||
var gotCurrent string
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
selectorCalls++
|
||||
gotCurrent = current
|
||||
return "llama3.2", nil
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||
cmd.SetArgs([]string{"stubapp", "--model", "glm-5:cloud"})
|
||||
stderr := captureStderr(t, func() {
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("launch command failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
if selectorCalls != 1 {
|
||||
t.Fatalf("expected disabled cloud override to fall back to selector, got %d calls", selectorCalls)
|
||||
}
|
||||
if gotCurrent != "" {
|
||||
t.Fatalf("expected disabled override to be cleared before selection, got current %q", gotCurrent)
|
||||
}
|
||||
if stub.ranModel != "llama3.2" {
|
||||
t.Fatalf("expected launch to run with replacement local model, got %q", stub.ranModel)
|
||||
}
|
||||
if !strings.Contains(stderr, "Warning: ignoring --model glm-5:cloud because cloud is disabled") {
|
||||
t.Fatalf("expected disabled-cloud warning, got stderr: %q", stderr)
|
||||
}
|
||||
|
||||
saved, err := config.LoadIntegration("stubapp")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload integration config: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
|
||||
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmdYes_AutoConfirmsLaunchPromptPath(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withLauncherHooks(t)
|
||||
withInteractiveSession(t, false)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||
case "/api/status":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprint(w, `{"error":"not found"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &launcherEditorRunner{paths: []string{"/tmp/stubeditor.json"}}
|
||||
restore := OverrideIntegration("stubeditor", stub)
|
||||
defer restore()
|
||||
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
t.Fatalf("unexpected prompt with --yes: %q", prompt)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2", "--yes"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("launch command with --yes failed: %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
|
||||
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
if stub.ranModel != "llama3.2" {
|
||||
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmdHeadlessWithYes_AutoPullsMissingLocalModel(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withLauncherHooks(t)
|
||||
withInteractiveSession(t, false)
|
||||
|
||||
var pullCalled bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprint(w, `{"error":"model not found"}`)
|
||||
case "/api/pull":
|
||||
pullCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, `{"status":"success"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &launcherSingleRunner{}
|
||||
restore := OverrideIntegration("stubapp", stub)
|
||||
defer restore()
|
||||
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
t.Fatalf("unexpected prompt with --yes in headless autopull path: %q", prompt)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||
cmd.SetArgs([]string{"stubapp", "--model", "missing-model", "--yes"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("launch command with --yes failed: %v", err)
|
||||
}
|
||||
|
||||
if !pullCalled {
|
||||
t.Fatal("expected missing local model to be auto-pulled with --yes in headless mode")
|
||||
}
|
||||
if stub.ranModel != "missing-model" {
|
||||
t.Fatalf("expected launch to run with pulled model, got %q", stub.ranModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmdHeadlessWithoutYes_ReturnsActionableConfirmError(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withLauncherHooks(t)
|
||||
withInteractiveSession(t, false)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||
case "/api/status":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprint(w, `{"error":"not found"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &launcherEditorRunner{paths: []string{"/tmp/stubeditor.json"}}
|
||||
restore := OverrideIntegration("stubeditor", stub)
|
||||
defer restore()
|
||||
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
t.Fatalf("unexpected prompt in headless non-yes mode: %q", prompt)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
|
||||
err := cmd.Execute()
|
||||
if err == nil {
|
||||
t.Fatal("expected launch command to fail without --yes in headless mode")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "re-run with --yes") {
|
||||
t.Fatalf("expected actionable --yes guidance, got %v", err)
|
||||
}
|
||||
if len(stub.edited) != 0 {
|
||||
t.Fatalf("expected no editor writes when confirmation is blocked, got %v", stub.edited)
|
||||
}
|
||||
if stub.ranModel != "" {
|
||||
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmdIntegrationArgPromptsForModelWithSavedSelection(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
|
||||
if err := config.SaveIntegration("stubapp", []string{"llama3.2"}); err != nil {
|
||||
t.Fatalf("failed to seed saved config: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"llama3.2"},{"name":"qwen3:8b"}]}`)
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model":"qwen3:8b"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &launcherSingleRunner{}
|
||||
restore := OverrideIntegration("stubapp", stub)
|
||||
defer restore()
|
||||
|
||||
oldSelector := DefaultSingleSelector
|
||||
defer func() { DefaultSingleSelector = oldSelector }()
|
||||
|
||||
var gotCurrent string
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
gotCurrent = current
|
||||
return "qwen3:8b", nil
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||
cmd.SetArgs([]string{"stubapp"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("launch command failed: %v", err)
|
||||
}
|
||||
|
||||
if gotCurrent != "llama3.2" {
|
||||
t.Fatalf("expected selector current model to be saved model llama3.2, got %q", gotCurrent)
|
||||
}
|
||||
if stub.ranModel != "qwen3:8b" {
|
||||
t.Fatalf("expected launch to run selected model qwen3:8b, got %q", stub.ranModel)
|
||||
}
|
||||
|
||||
saved, err := config.LoadIntegration("stubapp")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload integration config: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff([]string{"qwen3:8b"}, saved.Models); diff != "" {
|
||||
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmdHeadlessYes_IntegrationRequiresModelEvenWhenSaved(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withLauncherHooks(t)
|
||||
withInteractiveSession(t, false)
|
||||
|
||||
if err := config.SaveIntegration("stubapp", []string{"llama3.2"}); err != nil {
|
||||
t.Fatalf("failed to seed saved config: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &launcherSingleRunner{}
|
||||
restore := OverrideIntegration("stubapp", stub)
|
||||
defer restore()
|
||||
|
||||
oldSelector := DefaultSingleSelector
|
||||
defer func() { DefaultSingleSelector = oldSelector }()
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
t.Fatal("selector should not be called for headless --yes saved-model launch")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||
cmd.SetArgs([]string{"stubapp", "--yes"})
|
||||
err := cmd.Execute()
|
||||
if err == nil {
|
||||
t.Fatal("expected launch command to fail when --yes is used headlessly without --model")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "requires --model <model>") {
|
||||
t.Fatalf("expected actionable --model guidance, got %v", err)
|
||||
}
|
||||
if stub.ranModel != "" {
|
||||
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmdHeadlessYes_IntegrationWithoutSavedModelReturnsError(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withLauncherHooks(t)
|
||||
withInteractiveSession(t, false)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &launcherSingleRunner{}
|
||||
restore := OverrideIntegration("stubapp", stub)
|
||||
defer restore()
|
||||
|
||||
oldSelector := DefaultSingleSelector
|
||||
defer func() { DefaultSingleSelector = oldSelector }()
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
t.Fatal("selector should not be called for headless --yes without saved model")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||
cmd.SetArgs([]string{"stubapp", "--yes"})
|
||||
err := cmd.Execute()
|
||||
if err == nil {
|
||||
t.Fatal("expected launch command to fail when --yes is used headlessly without --model")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "requires --model <model>") {
|
||||
t.Fatalf("expected actionable --model guidance, got %v", err)
|
||||
}
|
||||
if stub.ranModel != "" {
|
||||
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
|
||||
}
|
||||
}
|
||||
@@ -1,16 +1,14 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
@@ -47,25 +45,6 @@ func (d *Droid) Run(model string, args []string) error {
|
||||
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
|
||||
}
|
||||
|
||||
// Call Edit() to ensure config is up-to-date before launch
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("droid"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
var err error
|
||||
models, err = resolveEditorModels("droid", models, func() ([]string, error) {
|
||||
return selectModels(context.Background(), "droid", "")
|
||||
})
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("droid", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
@@ -111,6 +90,16 @@ func (d *Droid) Edit(models []string) error {
|
||||
json.Unmarshal(data, &settings) // ignore error, zero values are fine
|
||||
}
|
||||
|
||||
settingsMap = updateDroidSettings(settingsMap, settings, models)
|
||||
|
||||
data, err := json.MarshalIndent(settingsMap, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fileutil.WriteWithBackup(settingsPath, data)
|
||||
}
|
||||
|
||||
func updateDroidSettings(settingsMap map[string]any, settings droidSettings, models []string) map[string]any {
|
||||
// Keep only non-Ollama models from the raw map (preserves extra fields)
|
||||
// Rebuild Ollama models
|
||||
var nonOllamaModels []any
|
||||
@@ -125,13 +114,12 @@ func (d *Droid) Edit(models []string) error {
|
||||
}
|
||||
|
||||
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
|
||||
var newModels []any
|
||||
var defaultModelID string
|
||||
for i, model := range models {
|
||||
maxOutput := 64000
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if isCloudModelName(model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
maxOutput = l.Output
|
||||
}
|
||||
@@ -167,12 +155,7 @@ func (d *Droid) Edit(models []string) error {
|
||||
}
|
||||
|
||||
settingsMap["sessionDefaultSettings"] = sessionSettings
|
||||
|
||||
data, err := json.MarshalIndent(settingsMap, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(settingsPath, data)
|
||||
return settingsMap
|
||||
}
|
||||
|
||||
func (d *Droid) Models() []string {
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
)
|
||||
|
||||
func TestDroidIntegration(t *testing.T) {
|
||||
@@ -362,7 +364,7 @@ func TestDroidEdit_DuplicateModels(t *testing.T) {
|
||||
t.Fatalf("Edit with duplicates failed: %v", err)
|
||||
}
|
||||
|
||||
settings, err := readJSONFile(settingsPath)
|
||||
settings, err := fileutil.ReadJSON(settingsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("readJSONFile failed: %v", err)
|
||||
}
|
||||
@@ -392,7 +394,7 @@ func TestDroidEdit_MalformedModelEntry(t *testing.T) {
|
||||
}
|
||||
|
||||
// Malformed entries (non-object) are dropped - only valid model objects are preserved
|
||||
settings, _ := readJSONFile(settingsPath)
|
||||
settings, _ := fileutil.ReadJSON(settingsPath)
|
||||
customModels, _ := settings["customModels"].([]any)
|
||||
|
||||
// Should have: 1 new Ollama model only (malformed entries dropped)
|
||||
@@ -419,7 +421,7 @@ func TestDroidEdit_WrongTypeSessionSettings(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should create proper sessionDefaultSettings
|
||||
settings, _ := readJSONFile(settingsPath)
|
||||
settings, _ := fileutil.ReadJSON(settingsPath)
|
||||
session, ok := settings["sessionDefaultSettings"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("sessionDefaultSettings should be map after setup, got %T", settings["sessionDefaultSettings"])
|
||||
@@ -1008,34 +1010,34 @@ func TestDroidEdit_ModelNamesWithSpecialCharacters(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDroidEdit_MissingCustomModelsKey(t *testing.T) {
|
||||
d := &Droid{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
settingsDir := filepath.Join(tmpDir, ".factory")
|
||||
settingsPath := filepath.Join(settingsDir, "settings.json")
|
||||
|
||||
os.MkdirAll(settingsDir, 0o755)
|
||||
|
||||
// No customModels key at all
|
||||
original := `{
|
||||
"diffMode": "github",
|
||||
"sessionDefaultSettings": {"autonomyMode": "auto-high"}
|
||||
}`
|
||||
os.WriteFile(settingsPath, []byte(original), 0o644)
|
||||
|
||||
if err := d.Edit([]string{"model-a"}); err != nil {
|
||||
var settingsStruct droidSettings
|
||||
var settings map[string]any
|
||||
if err := json.Unmarshal([]byte(original), &settings); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := json.Unmarshal([]byte(original), &settingsStruct); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(settingsPath)
|
||||
var settings map[string]any
|
||||
json.Unmarshal(data, &settings)
|
||||
settings = updateDroidSettings(settings, settingsStruct, []string{"model-a"})
|
||||
|
||||
// Original fields preserved
|
||||
if settings["diffMode"] != "github" {
|
||||
t.Error("diffMode not preserved")
|
||||
}
|
||||
session, ok := settings["sessionDefaultSettings"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("sessionDefaultSettings not preserved")
|
||||
}
|
||||
if session["autonomyMode"] != "auto-high" {
|
||||
t.Error("sessionDefaultSettings.autonomyMode not preserved")
|
||||
}
|
||||
|
||||
// customModels created
|
||||
models, ok := settings["customModels"].([]any)
|
||||
@@ -1276,25 +1278,17 @@ func TestDroidEdit_LocalModelDefaultMaxOutput(t *testing.T) {
|
||||
|
||||
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
|
||||
// Verify that every cloud model in cloudModelLimits has a valid output
|
||||
// value that would be used for maxOutputTokens when isCloudModel returns true.
|
||||
// :cloud suffix stripping must also work since that's how users specify them.
|
||||
// value that would be used for maxOutputTokens when the selected model uses
|
||||
// the explicit :cloud source tag.
|
||||
for name, expected := range cloudModelLimits {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
l, ok := lookupCloudModelLimit(name)
|
||||
if !ok {
|
||||
t.Fatalf("lookupCloudModelLimit(%q) returned false", name)
|
||||
}
|
||||
if l.Output != expected.Output {
|
||||
t.Errorf("output = %d, want %d", l.Output, expected.Output)
|
||||
}
|
||||
// Also verify :cloud suffix lookup
|
||||
cloudName := name + ":cloud"
|
||||
l2, ok := lookupCloudModelLimit(cloudName)
|
||||
l, ok := lookupCloudModelLimit(cloudName)
|
||||
if !ok {
|
||||
t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName)
|
||||
}
|
||||
if l2.Output != expected.Output {
|
||||
t.Errorf(":cloud output = %d, want %d", l2.Output, expected.Output)
|
||||
if l.Output != expected.Output {
|
||||
t.Errorf("output = %d, want %d", l.Output, expected.Output)
|
||||
}
|
||||
})
|
||||
}
|
||||
881
cmd/launch/launch.go
Normal file
@@ -0,0 +1,881 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// LauncherState is the launch-owned snapshot used to render the root launcher menu.
|
||||
type LauncherState struct {
|
||||
LastSelection string
|
||||
RunModel string
|
||||
RunModelUsable bool
|
||||
Integrations map[string]LauncherIntegrationState
|
||||
}
|
||||
|
||||
// LauncherIntegrationState is the launch-owned status for one launcher integration.
|
||||
type LauncherIntegrationState struct {
|
||||
Name string
|
||||
DisplayName string
|
||||
Description string
|
||||
Installed bool
|
||||
AutoInstallable bool
|
||||
Selectable bool
|
||||
Changeable bool
|
||||
CurrentModel string
|
||||
ModelUsable bool
|
||||
InstallHint string
|
||||
Editor bool
|
||||
}
|
||||
|
||||
// RunModelRequest controls how the root launcher resolves the chat model.
|
||||
type RunModelRequest struct {
|
||||
ForcePicker bool
|
||||
Policy *LaunchPolicy
|
||||
}
|
||||
|
||||
// LaunchConfirmMode controls confirmation behavior across launch flows.
|
||||
type LaunchConfirmMode int
|
||||
|
||||
const (
|
||||
// LaunchConfirmPrompt prompts the user for confirmation.
|
||||
LaunchConfirmPrompt LaunchConfirmMode = iota
|
||||
// LaunchConfirmAutoApprove skips prompts and treats confirmation as accepted.
|
||||
LaunchConfirmAutoApprove
|
||||
// LaunchConfirmRequireYes rejects confirmation requests with a --yes hint.
|
||||
LaunchConfirmRequireYes
|
||||
)
|
||||
|
||||
// LaunchMissingModelMode controls local missing-model handling in launch flows.
|
||||
type LaunchMissingModelMode int
|
||||
|
||||
const (
|
||||
// LaunchMissingModelPromptToPull prompts to pull a missing local model.
|
||||
LaunchMissingModelPromptToPull LaunchMissingModelMode = iota
|
||||
// LaunchMissingModelAutoPull pulls a missing local model without prompting.
|
||||
LaunchMissingModelAutoPull
|
||||
// LaunchMissingModelFail fails immediately when a local model is missing.
|
||||
LaunchMissingModelFail
|
||||
)
|
||||
|
||||
// LaunchPolicy controls launch behavior that may vary by caller context.
|
||||
type LaunchPolicy struct {
|
||||
Confirm LaunchConfirmMode
|
||||
MissingModel LaunchMissingModelMode
|
||||
}
|
||||
|
||||
func defaultLaunchPolicy(interactive bool, yes bool) LaunchPolicy {
|
||||
policy := LaunchPolicy{
|
||||
Confirm: LaunchConfirmPrompt,
|
||||
MissingModel: LaunchMissingModelPromptToPull,
|
||||
}
|
||||
switch {
|
||||
case yes:
|
||||
// if yes flag is set, auto approve and auto pull
|
||||
policy.Confirm = LaunchConfirmAutoApprove
|
||||
policy.MissingModel = LaunchMissingModelAutoPull
|
||||
case !interactive:
|
||||
// otherwise make sure to stop when needed
|
||||
policy.Confirm = LaunchConfirmRequireYes
|
||||
policy.MissingModel = LaunchMissingModelFail
|
||||
}
|
||||
return policy
|
||||
}
|
||||
|
||||
func (p LaunchPolicy) confirmPolicy() launchConfirmPolicy {
|
||||
switch p.Confirm {
|
||||
case LaunchConfirmAutoApprove:
|
||||
return launchConfirmPolicy{yes: true}
|
||||
case LaunchConfirmRequireYes:
|
||||
return launchConfirmPolicy{requireYesMessage: true}
|
||||
default:
|
||||
return launchConfirmPolicy{}
|
||||
}
|
||||
}
|
||||
|
||||
func (p LaunchPolicy) missingModelPolicy() missingModelPolicy {
|
||||
switch p.MissingModel {
|
||||
case LaunchMissingModelAutoPull:
|
||||
return missingModelAutoPull
|
||||
case LaunchMissingModelFail:
|
||||
return missingModelFail
|
||||
default:
|
||||
return missingModelPromptPull
|
||||
}
|
||||
}
|
||||
|
||||
// IntegrationLaunchRequest controls the canonical integration launcher flow.
|
||||
type IntegrationLaunchRequest struct {
|
||||
Name string
|
||||
ModelOverride string
|
||||
ForceConfigure bool
|
||||
ConfigureOnly bool
|
||||
ExtraArgs []string
|
||||
Policy *LaunchPolicy
|
||||
}
|
||||
|
||||
var isInteractiveSession = func() bool {
|
||||
return term.IsTerminal(int(os.Stdin.Fd())) && term.IsTerminal(int(os.Stdout.Fd()))
|
||||
}
|
||||
|
||||
// Runner executes a model with an integration.
|
||||
type Runner interface {
|
||||
Run(model string, args []string) error
|
||||
String() string
|
||||
}
|
||||
|
||||
// Editor can edit config files for integrations that support model configuration.
|
||||
type Editor interface {
|
||||
Paths() []string
|
||||
Edit(models []string) error
|
||||
Models() []string
|
||||
}
|
||||
|
||||
type modelInfo struct {
|
||||
Name string
|
||||
Remote bool
|
||||
ToolCapable bool
|
||||
}
|
||||
|
||||
// ModelInfo re-exports launcher model inventory details for callers.
|
||||
type ModelInfo = modelInfo
|
||||
|
||||
// ModelItem represents a model for selection UIs.
|
||||
type ModelItem struct {
|
||||
Name string
|
||||
Description string
|
||||
Recommended bool
|
||||
}
|
||||
|
||||
// LaunchCmd returns the cobra command for launching integrations.
|
||||
// The runTUI callback is called when the root launcher UI should be shown.
|
||||
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error, runTUI func(cmd *cobra.Command)) *cobra.Command {
|
||||
var modelFlag string
|
||||
var configFlag bool
|
||||
var yesFlag bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
|
||||
Short: "Launch the Ollama menu or an integration",
|
||||
Long: `Launch the Ollama interactive menu, or directly launch a specific integration.
|
||||
|
||||
Without arguments, this is equivalent to running 'ollama' directly.
|
||||
Flags and extra arguments require an integration name.
|
||||
|
||||
Supported integrations:
|
||||
claude Claude Code
|
||||
cline Cline
|
||||
codex Codex
|
||||
droid Droid
|
||||
opencode OpenCode
|
||||
openclaw OpenClaw (aliases: clawdbot, moltbot)
|
||||
pi Pi
|
||||
vscode VS Code (aliases: code)
|
||||
|
||||
Examples:
|
||||
ollama launch
|
||||
ollama launch claude
|
||||
ollama launch claude --model <model>
|
||||
ollama launch droid --config (does not auto-launch)
|
||||
ollama launch codex -- -p myprofile (pass extra args to integration)
|
||||
ollama launch codex -- --sandbox workspace-write`,
|
||||
Args: cobra.ArbitraryArgs,
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
policy := defaultLaunchPolicy(isInteractiveSession(), yesFlag)
|
||||
// reset when done to make sure state doens't leak between launches
|
||||
restoreConfirmPolicy := withLaunchConfirmPolicy(policy.confirmPolicy())
|
||||
defer restoreConfirmPolicy()
|
||||
|
||||
var name string
|
||||
var passArgs []string
|
||||
dashIdx := cmd.ArgsLenAtDash()
|
||||
|
||||
if dashIdx == -1 {
|
||||
if len(args) > 1 {
|
||||
return fmt.Errorf("unexpected arguments: %v\nUse '--' to pass extra arguments to the integration", args[1:])
|
||||
}
|
||||
if len(args) == 1 {
|
||||
name = args[0]
|
||||
}
|
||||
} else {
|
||||
if dashIdx > 1 {
|
||||
return fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx)
|
||||
}
|
||||
if dashIdx == 1 {
|
||||
name = args[0]
|
||||
}
|
||||
passArgs = args[dashIdx:]
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
if cmd.Flags().Changed("model") || cmd.Flags().Changed("config") || cmd.Flags().Changed("yes") || len(passArgs) > 0 {
|
||||
return fmt.Errorf("flags and extra args require an integration name, for example: 'ollama launch claude --model qwen3.5'")
|
||||
}
|
||||
runTUI(cmd)
|
||||
return nil
|
||||
}
|
||||
|
||||
if modelFlag != "" && isCloudModelName(modelFlag) {
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
if disabled, _ := cloudStatusDisabled(cmd.Context(), client); disabled {
|
||||
fmt.Fprintf(os.Stderr, "Warning: ignoring --model %s because cloud is disabled\n", modelFlag)
|
||||
modelFlag = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
headlessYes := yesFlag && !isInteractiveSession()
|
||||
err := LaunchIntegration(cmd.Context(), IntegrationLaunchRequest{
|
||||
Name: name,
|
||||
ModelOverride: modelFlag,
|
||||
ForceConfigure: configFlag || (modelFlag == "" && !headlessYes),
|
||||
ConfigureOnly: configFlag,
|
||||
ExtraArgs: passArgs,
|
||||
Policy: &policy,
|
||||
})
|
||||
if errors.Is(err, ErrCancelled) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&modelFlag, "model", "", "Model to use")
|
||||
cmd.Flags().BoolVar(&configFlag, "config", false, "Configure without launching")
|
||||
cmd.Flags().BoolVarP(&yesFlag, "yes", "y", false, "Automatically answer yes to confirmation prompts")
|
||||
return cmd
|
||||
}
|
||||
|
||||
type launcherClient struct {
|
||||
apiClient *api.Client
|
||||
modelInventory []ModelInfo
|
||||
inventoryLoaded bool
|
||||
policy LaunchPolicy
|
||||
}
|
||||
|
||||
func newLauncherClient(policy LaunchPolicy) (*launcherClient, error) {
|
||||
apiClient, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &launcherClient{
|
||||
apiClient: apiClient,
|
||||
policy: policy,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// BuildLauncherState returns the launch-owned root launcher menu snapshot.
|
||||
func BuildLauncherState(ctx context.Context) (*LauncherState, error) {
|
||||
launchClient, err := newLauncherClient(defaultLaunchPolicy(isInteractiveSession(), false))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return launchClient.buildLauncherState(ctx)
|
||||
}
|
||||
|
||||
// ResolveRunModel returns the model that should be used for interactive chat.
|
||||
func ResolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
|
||||
// Called by the launcher TUI "Run a model" action (cmd/runLauncherAction),
|
||||
// which resolves models separately from LaunchIntegration. Callers can pass
|
||||
// Policy directly; otherwise we fall back to ambient --yes/session defaults.
|
||||
policy := defaultLaunchPolicy(isInteractiveSession(), currentLaunchConfirmPolicy.yes)
|
||||
if req.Policy != nil {
|
||||
policy = *req.Policy
|
||||
}
|
||||
|
||||
launchClient, err := newLauncherClient(policy)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return launchClient.resolveRunModel(ctx, req)
|
||||
}
|
||||
|
||||
// LaunchIntegration runs the canonical launcher flow for one integration.
|
||||
func LaunchIntegration(ctx context.Context, req IntegrationLaunchRequest) error {
|
||||
name, runner, err := LookupIntegration(req.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !req.ConfigureOnly {
|
||||
if err := EnsureIntegrationInstalled(name, runner); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var policy LaunchPolicy
|
||||
// TUI does not set a policy, whereas ollama launch <app> does as it can have flags which change the behavior
|
||||
if req.Policy == nil {
|
||||
policy = defaultLaunchPolicy(isInteractiveSession(), false)
|
||||
} else {
|
||||
policy = *req.Policy
|
||||
}
|
||||
|
||||
launchClient, err := newLauncherClient(policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
saved, _ := loadStoredIntegrationConfig(name)
|
||||
// In headless --yes mode we cannot prompt, so require an explicit --model.
|
||||
if policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() && req.ModelOverride == "" {
|
||||
return fmt.Errorf("headless --yes launch for %s requires --model <model>", name)
|
||||
}
|
||||
|
||||
if editor, ok := runner.(Editor); ok {
|
||||
return launchClient.launchEditorIntegration(ctx, name, runner, editor, saved, req)
|
||||
}
|
||||
return launchClient.launchSingleIntegration(ctx, name, runner, saved, req)
|
||||
}
|
||||
|
||||
func (c *launcherClient) buildLauncherState(ctx context.Context) (*LauncherState, error) {
|
||||
_ = c.loadModelInventoryOnce(ctx)
|
||||
|
||||
state := &LauncherState{
|
||||
LastSelection: config.LastSelection(),
|
||||
RunModel: config.LastModel(),
|
||||
Integrations: make(map[string]LauncherIntegrationState),
|
||||
}
|
||||
runModelUsable, err := c.savedModelUsable(ctx, state.RunModel)
|
||||
if err != nil {
|
||||
runModelUsable = false
|
||||
}
|
||||
state.RunModelUsable = runModelUsable
|
||||
|
||||
for _, info := range ListIntegrationInfos() {
|
||||
integrationState, err := c.buildLauncherIntegrationState(ctx, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
state.Integrations[info.Name] = integrationState
|
||||
}
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) buildLauncherIntegrationState(ctx context.Context, info IntegrationInfo) (LauncherIntegrationState, error) {
|
||||
integration, err := integrationFor(info.Name)
|
||||
if err != nil {
|
||||
return LauncherIntegrationState{}, err
|
||||
}
|
||||
currentModel, usable, err := c.launcherModelState(ctx, info.Name, integration.editor)
|
||||
if err != nil {
|
||||
return LauncherIntegrationState{}, err
|
||||
}
|
||||
|
||||
return LauncherIntegrationState{
|
||||
Name: info.Name,
|
||||
DisplayName: info.DisplayName,
|
||||
Description: info.Description,
|
||||
Installed: integration.installed,
|
||||
AutoInstallable: integration.autoInstallable,
|
||||
Selectable: integration.installed || integration.autoInstallable,
|
||||
Changeable: integration.installed || integration.autoInstallable,
|
||||
CurrentModel: currentModel,
|
||||
ModelUsable: usable,
|
||||
InstallHint: integration.installHint,
|
||||
Editor: integration.editor,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) launcherModelState(ctx context.Context, name string, isEditor bool) (string, bool, error) {
|
||||
cfg, loadErr := loadStoredIntegrationConfig(name)
|
||||
hasModels := loadErr == nil && len(cfg.Models) > 0
|
||||
if !hasModels {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
if isEditor {
|
||||
filtered := c.filterDisabledCloudModels(ctx, cfg.Models)
|
||||
if len(filtered) > 0 {
|
||||
return filtered[0], true, nil
|
||||
}
|
||||
return cfg.Models[0], false, nil
|
||||
}
|
||||
|
||||
model := cfg.Models[0]
|
||||
usable, usableErr := c.savedModelUsable(ctx, model)
|
||||
return model, usableErr == nil && usable, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
|
||||
current := config.LastModel()
|
||||
if !req.ForcePicker && current != "" && c.policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() {
|
||||
if err := c.ensureModelsReady(ctx, []string{current}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Headless mode: auto-selected last used model %q\n", current)
|
||||
return current, nil
|
||||
}
|
||||
|
||||
if !req.ForcePicker {
|
||||
usable, err := c.savedModelUsable(ctx, current)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if usable {
|
||||
if err := c.ensureModelsReady(ctx, []string{current}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return current, nil
|
||||
}
|
||||
}
|
||||
|
||||
model, err := c.selectSingleModelWithSelector(ctx, "Select model to run:", current, DefaultSingleSelector)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if model != current {
|
||||
if err := config.SetLastModel(model); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) launchSingleIntegration(ctx context.Context, name string, runner Runner, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
|
||||
current := primaryModelFromConfig(saved)
|
||||
target := req.ModelOverride
|
||||
needsConfigure := req.ForceConfigure
|
||||
|
||||
if target == "" {
|
||||
target = current
|
||||
usable, err := c.savedModelUsable(ctx, target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !usable {
|
||||
needsConfigure = true
|
||||
}
|
||||
}
|
||||
|
||||
if needsConfigure {
|
||||
selected, err := c.selectSingleModelWithSelector(ctx, fmt.Sprintf("Select model for %s:", runner), target, DefaultSingleSelector)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
target = selected
|
||||
} else if err := c.ensureModelsReady(ctx, []string{target}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if target == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if target != current {
|
||||
if err := config.SaveIntegration(name, []string{target}); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return launchAfterConfiguration(name, runner, target, req)
|
||||
}
|
||||
|
||||
func (c *launcherClient) launchEditorIntegration(ctx context.Context, name string, runner Runner, editor Editor, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
|
||||
models, needsConfigure := c.resolveEditorLaunchModels(ctx, saved, req)
|
||||
|
||||
if needsConfigure {
|
||||
selected, err := c.selectMultiModelsForIntegration(ctx, runner, models)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
models = selected
|
||||
} else if len(models) > 0 {
|
||||
if err := c.ensureModelsReady(ctx, models[:1]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if needsConfigure || req.ModelOverride != "" {
|
||||
if err := prepareEditorIntegration(name, runner, editor, models); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return launchAfterConfiguration(name, runner, models[0], req)
|
||||
}
|
||||
|
||||
func (c *launcherClient) selectSingleModelWithSelector(ctx context.Context, title, current string, selector SingleSelector) (string, error) {
|
||||
if selector == nil {
|
||||
return "", fmt.Errorf("no selector configured")
|
||||
}
|
||||
|
||||
items, _, err := c.loadSelectableModels(ctx, nil, current, "no models available, run 'ollama pull <model>' first")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
selected, err := selector(title, items, current)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := c.ensureModelsReady(ctx, []string{selected}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) selectMultiModelsForIntegration(ctx context.Context, runner Runner, preChecked []string) ([]string, error) {
|
||||
if DefaultMultiSelector == nil {
|
||||
return nil, fmt.Errorf("no selector configured")
|
||||
}
|
||||
|
||||
current := firstModel(preChecked)
|
||||
|
||||
items, orderedChecked, err := c.loadSelectableModels(ctx, preChecked, current, "no models available")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(preChecked) > 0 {
|
||||
// Keep list order stable in multi-select even when there are existing checks.
|
||||
// checked/default state still comes from orderedChecked.
|
||||
stableItems, _, stableErr := c.loadSelectableModels(ctx, nil, current, "no models available")
|
||||
if stableErr != nil {
|
||||
return nil, stableErr
|
||||
}
|
||||
items = stableItems
|
||||
}
|
||||
|
||||
selected, err := DefaultMultiSelector(fmt.Sprintf("Select models for %s:", runner), items, orderedChecked)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
accepted, skipped, err := c.selectReadyModelsForSave(ctx, selected)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, skip := range skipped {
|
||||
fmt.Fprintf(os.Stderr, "Skipped %s: %s\n", skip.model, skip.reason)
|
||||
}
|
||||
return accepted, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) loadSelectableModels(ctx context.Context, preChecked []string, current, emptyMessage string) ([]ModelItem, []string, error) {
|
||||
if err := c.loadModelInventoryOnce(ctx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||
items, orderedChecked, _, _ := buildModelList(c.modelInventory, preChecked, current)
|
||||
if cloudDisabled {
|
||||
items = filterCloudItems(items)
|
||||
orderedChecked = c.filterDisabledCloudModels(ctx, orderedChecked)
|
||||
}
|
||||
if len(items) == 0 {
|
||||
return nil, nil, errors.New(emptyMessage)
|
||||
}
|
||||
return items, orderedChecked, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) ensureModelsReady(ctx context.Context, models []string) error {
|
||||
models = dedupeModelList(models)
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloudModels := make(map[string]bool, len(models))
|
||||
for _, model := range models {
|
||||
isCloudModel := isCloudModelName(model)
|
||||
if isCloudModel {
|
||||
cloudModels[model] = true
|
||||
}
|
||||
if err := showOrPullWithPolicy(ctx, c.apiClient, model, c.policy.missingModelPolicy(), isCloudModel); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return ensureAuth(ctx, c.apiClient, cloudModels, models)
|
||||
}
|
||||
|
||||
func dedupeModelList(models []string) []string {
|
||||
deduped := make([]string, 0, len(models))
|
||||
seen := make(map[string]bool, len(models))
|
||||
for _, model := range models {
|
||||
if model == "" || seen[model] {
|
||||
continue
|
||||
}
|
||||
seen[model] = true
|
||||
deduped = append(deduped, model)
|
||||
}
|
||||
return deduped
|
||||
}
|
||||
|
||||
type skippedModel struct {
|
||||
model string
|
||||
reason string
|
||||
}
|
||||
|
||||
func (c *launcherClient) selectReadyModelsForSave(ctx context.Context, selected []string) ([]string, []skippedModel, error) {
|
||||
selected = dedupeModelList(selected)
|
||||
accepted := make([]string, 0, len(selected))
|
||||
skipped := make([]skippedModel, 0, len(selected))
|
||||
|
||||
for _, model := range selected {
|
||||
if err := c.ensureModelsReady(ctx, []string{model}); err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil, nil, err
|
||||
}
|
||||
skipped = append(skipped, skippedModel{
|
||||
model: model,
|
||||
reason: skippedModelReason(model, err),
|
||||
})
|
||||
continue
|
||||
}
|
||||
accepted = append(accepted, model)
|
||||
}
|
||||
|
||||
return accepted, skipped, nil
|
||||
}
|
||||
|
||||
func skippedModelReason(model string, err error) string {
|
||||
if errors.Is(err, ErrCancelled) {
|
||||
if isCloudModelName(model) {
|
||||
return "sign in was cancelled"
|
||||
}
|
||||
return "download was cancelled"
|
||||
}
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
func (c *launcherClient) resolveEditorLaunchModels(ctx context.Context, saved *config.IntegrationConfig, req IntegrationLaunchRequest) ([]string, bool) {
|
||||
if req.ForceConfigure {
|
||||
return editorPreCheckedModels(saved, req.ModelOverride), true
|
||||
}
|
||||
|
||||
if req.ModelOverride != "" {
|
||||
models := append([]string{req.ModelOverride}, additionalSavedModels(saved, req.ModelOverride)...)
|
||||
models = c.filterDisabledCloudModels(ctx, models)
|
||||
return models, len(models) == 0
|
||||
}
|
||||
|
||||
if saved == nil || len(saved.Models) == 0 {
|
||||
return nil, true
|
||||
}
|
||||
|
||||
models := c.filterDisabledCloudModels(ctx, saved.Models)
|
||||
return models, len(models) == 0
|
||||
}
|
||||
|
||||
func (c *launcherClient) filterDisabledCloudModels(ctx context.Context, models []string) []string {
|
||||
// if connection cannot be established or there is a 404, cloud models will continue to be displayed
|
||||
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||
if !cloudDisabled {
|
||||
return append([]string(nil), models...)
|
||||
}
|
||||
|
||||
filtered := make([]string, 0, len(models))
|
||||
for _, model := range models {
|
||||
if !isCloudModelName(model) {
|
||||
filtered = append(filtered, model)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func (c *launcherClient) savedModelUsable(ctx context.Context, name string) (bool, error) {
|
||||
if err := c.loadModelInventoryOnce(ctx); err != nil {
|
||||
return c.showBasedModelUsable(ctx, name)
|
||||
}
|
||||
return c.singleModelUsable(ctx, name), nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) showBasedModelUsable(ctx context.Context, name string) (bool, error) {
|
||||
if name == "" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
info, err := c.apiClient.Show(ctx, &api.ShowRequest{Model: name})
|
||||
if err != nil {
|
||||
var statusErr api.StatusError
|
||||
if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
if isCloudModelName(name) || info.RemoteModel != "" {
|
||||
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||
|
||||
return !cloudDisabled, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) singleModelUsable(ctx context.Context, name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
if isCloudModelName(name) {
|
||||
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||
return !cloudDisabled
|
||||
}
|
||||
return c.hasLocalModel(name)
|
||||
}
|
||||
|
||||
func (c *launcherClient) hasLocalModel(name string) bool {
|
||||
for _, model := range c.modelInventory {
|
||||
if model.Remote {
|
||||
continue
|
||||
}
|
||||
if model.Name == name || strings.HasPrefix(model.Name, name+":") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *launcherClient) loadModelInventoryOnce(ctx context.Context) error {
|
||||
if c.inventoryLoaded {
|
||||
return nil
|
||||
}
|
||||
|
||||
resp, err := c.apiClient.List(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.modelInventory = c.modelInventory[:0]
|
||||
for _, model := range resp.Models {
|
||||
c.modelInventory = append(c.modelInventory, ModelInfo{
|
||||
Name: model.Name,
|
||||
Remote: model.RemoteModel != "",
|
||||
})
|
||||
}
|
||||
|
||||
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||
if cloudDisabled {
|
||||
c.modelInventory = filterCloudModels(c.modelInventory)
|
||||
}
|
||||
c.inventoryLoaded = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func runIntegration(runner Runner, modelName string, args []string) error {
|
||||
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", runner, modelName)
|
||||
return runner.Run(modelName, args)
|
||||
}
|
||||
|
||||
func launchAfterConfiguration(name string, runner Runner, model string, req IntegrationLaunchRequest) error {
|
||||
if req.ConfigureOnly {
|
||||
launch, err := ConfirmPrompt(fmt.Sprintf("Launch %s now?", runner))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !launch {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if err := EnsureIntegrationInstalled(name, runner); err != nil {
|
||||
return err
|
||||
}
|
||||
return runIntegration(runner, model, req.ExtraArgs)
|
||||
}
|
||||
|
||||
func loadStoredIntegrationConfig(name string) (*config.IntegrationConfig, error) {
|
||||
cfg, err := config.LoadIntegration(name)
|
||||
if err == nil {
|
||||
return cfg, nil
|
||||
}
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
spec, specErr := LookupIntegrationSpec(name)
|
||||
if specErr != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, alias := range spec.Aliases {
|
||||
legacy, legacyErr := config.LoadIntegration(alias)
|
||||
if legacyErr == nil {
|
||||
migrateLegacyIntegrationConfig(spec.Name, legacy)
|
||||
if migrated, migratedErr := config.LoadIntegration(spec.Name); migratedErr == nil {
|
||||
return migrated, nil
|
||||
}
|
||||
return legacy, nil
|
||||
}
|
||||
if legacyErr != nil && !errors.Is(legacyErr, os.ErrNotExist) {
|
||||
return nil, legacyErr
|
||||
}
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func migrateLegacyIntegrationConfig(canonical string, legacy *config.IntegrationConfig) {
|
||||
if legacy == nil {
|
||||
return
|
||||
}
|
||||
|
||||
_ = config.SaveIntegration(canonical, append([]string(nil), legacy.Models...))
|
||||
if len(legacy.Aliases) > 0 {
|
||||
_ = config.SaveAliases(canonical, cloneAliases(legacy.Aliases))
|
||||
}
|
||||
if legacy.Onboarded {
|
||||
_ = config.MarkIntegrationOnboarded(canonical)
|
||||
}
|
||||
}
|
||||
|
||||
func primaryModelFromConfig(cfg *config.IntegrationConfig) string {
|
||||
if cfg == nil || len(cfg.Models) == 0 {
|
||||
return ""
|
||||
}
|
||||
return cfg.Models[0]
|
||||
}
|
||||
|
||||
func cloneAliases(aliases map[string]string) map[string]string {
|
||||
if len(aliases) == 0 {
|
||||
return make(map[string]string)
|
||||
}
|
||||
|
||||
cloned := make(map[string]string, len(aliases))
|
||||
for key, value := range aliases {
|
||||
cloned[key] = value
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func firstModel(models []string) string {
|
||||
if len(models) == 0 {
|
||||
return ""
|
||||
}
|
||||
return models[0]
|
||||
}
|
||||
|
||||
func editorPreCheckedModels(saved *config.IntegrationConfig, override string) []string {
|
||||
if override == "" {
|
||||
if saved == nil {
|
||||
return nil
|
||||
}
|
||||
return append([]string(nil), saved.Models...)
|
||||
}
|
||||
return append([]string{override}, additionalSavedModels(saved, override)...)
|
||||
}
|
||||
|
||||
func additionalSavedModels(saved *config.IntegrationConfig, exclude string) []string {
|
||||
if saved == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var models []string
|
||||
for _, model := range saved.Models {
|
||||
if model != exclude {
|
||||
models = append(models, model)
|
||||
}
|
||||
}
|
||||
return models
|
||||
}
|
||||
1990
cmd/launch/launch_test.go
Normal file
494
cmd/launch/models.go
Normal file
@@ -0,0 +1,494 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/progress"
|
||||
)
|
||||
|
||||
var recommendedModels = []ModelItem{
|
||||
{Name: "kimi-k2.5:cloud", Description: "Multimodal reasoning with subagents", Recommended: true},
|
||||
{Name: "qwen3.5:cloud", Description: "Reasoning, coding, and agentic tool use with vision", Recommended: true},
|
||||
{Name: "glm-5:cloud", Description: "Reasoning and code generation", Recommended: true},
|
||||
{Name: "minimax-m2.7:cloud", Description: "Fast, efficient coding and real-world productivity", Recommended: true},
|
||||
{Name: "glm-4.7-flash", Description: "Reasoning and code generation locally", Recommended: true},
|
||||
{Name: "qwen3.5", Description: "Reasoning, coding, and visual understanding locally", Recommended: true},
|
||||
}
|
||||
|
||||
var recommendedVRAM = map[string]string{
|
||||
"glm-4.7-flash": "~25GB",
|
||||
"qwen3.5": "~11GB",
|
||||
}
|
||||
|
||||
// cloudModelLimit holds context and output token limits for a cloud model.
|
||||
type cloudModelLimit struct {
|
||||
Context int
|
||||
Output int
|
||||
}
|
||||
|
||||
// cloudModelLimits maps cloud model base names to their token limits.
|
||||
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
|
||||
var cloudModelLimits = map[string]cloudModelLimit{
|
||||
"minimax-m2.7": {Context: 204_800, Output: 128_000},
|
||||
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
|
||||
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
|
||||
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
|
||||
"glm-4.6": {Context: 202_752, Output: 131_072},
|
||||
"glm-4.7": {Context: 202_752, Output: 131_072},
|
||||
"glm-5": {Context: 202_752, Output: 131_072},
|
||||
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
|
||||
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
|
||||
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
|
||||
"kimi-k2.5": {Context: 262_144, Output: 262_144},
|
||||
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
|
||||
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
|
||||
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
|
||||
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
|
||||
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
|
||||
"qwen3.5": {Context: 262_144, Output: 32_768},
|
||||
}
|
||||
|
||||
// lookupCloudModelLimit returns the token limits for a cloud model.
|
||||
// It normalizes explicit cloud source suffixes before checking the shared limit map.
|
||||
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
||||
base, stripped := modelref.StripCloudSourceTag(name)
|
||||
if stripped {
|
||||
if l, ok := cloudModelLimits[base]; ok {
|
||||
return l, true
|
||||
}
|
||||
}
|
||||
return cloudModelLimit{}, false
|
||||
}
|
||||
|
||||
// missingModelPolicy controls how model-not-found errors should be handled.
|
||||
type missingModelPolicy int
|
||||
|
||||
const (
|
||||
// missingModelPromptPull prompts the user to download missing local models.
|
||||
missingModelPromptPull missingModelPolicy = iota
|
||||
// missingModelAutoPull downloads missing local models without prompting.
|
||||
missingModelAutoPull
|
||||
// missingModelFail returns an error for missing local models without prompting.
|
||||
missingModelFail
|
||||
)
|
||||
|
||||
// OpenBrowser opens the URL in the user's browser.
|
||||
func OpenBrowser(url string) {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", url).Start()
|
||||
case "linux":
|
||||
// Skip on headless systems where no display server is available
|
||||
if os.Getenv("DISPLAY") == "" && os.Getenv("WAYLAND_DISPLAY") == "" {
|
||||
return
|
||||
}
|
||||
_ = exec.Command("xdg-open", url).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
|
||||
}
|
||||
}
|
||||
|
||||
// ensureAuth ensures the user is signed in before cloud-backed models run.
|
||||
func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error {
|
||||
var selectedCloudModels []string
|
||||
for _, m := range selected {
|
||||
if cloudModels[m] {
|
||||
selectedCloudModels = append(selectedCloudModels, m)
|
||||
}
|
||||
}
|
||||
if len(selectedCloudModels) == 0 {
|
||||
return nil
|
||||
}
|
||||
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
|
||||
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
|
||||
}
|
||||
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var aErr api.AuthorizationError
|
||||
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||
return err
|
||||
}
|
||||
|
||||
modelList := strings.Join(selectedCloudModels, ", ")
|
||||
|
||||
if DefaultSignIn != nil {
|
||||
_, err := DefaultSignIn(modelList, aErr.SigninURL)
|
||||
if errors.Is(err, ErrCancelled) {
|
||||
return ErrCancelled
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s requires sign in", modelList)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
yes, err := ConfirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||
if errors.Is(err, ErrCancelled) {
|
||||
return ErrCancelled
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !yes {
|
||||
return ErrCancelled
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||
OpenBrowser(aErr.SigninURL)
|
||||
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
|
||||
if frame%10 == 0 {
|
||||
u, err := client.Whoami(ctx)
|
||||
if err == nil && u != nil && u.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// showOrPullWithPolicy checks if a model exists and applies the provided missing-model policy.
|
||||
func showOrPullWithPolicy(ctx context.Context, client *api.Client, model string, policy missingModelPolicy, isCloudModel bool) error {
|
||||
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
|
||||
return nil
|
||||
} else {
|
||||
var statusErr api.StatusError
|
||||
if !errors.As(err, &statusErr) || statusErr.StatusCode != http.StatusNotFound {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if isCloudModel {
|
||||
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
|
||||
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
|
||||
}
|
||||
return fmt.Errorf("model %q not found", model)
|
||||
}
|
||||
|
||||
switch policy {
|
||||
case missingModelAutoPull:
|
||||
return pullMissingModel(ctx, client, model)
|
||||
case missingModelFail:
|
||||
return fmt.Errorf("model %q not found; run 'ollama pull %s' first, or use --yes to auto-pull", model, model)
|
||||
default:
|
||||
return confirmAndPull(ctx, client, model)
|
||||
}
|
||||
}
|
||||
|
||||
func confirmAndPull(ctx context.Context, client *api.Client, model string) error {
|
||||
if ok, err := ConfirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
|
||||
return err
|
||||
} else if !ok {
|
||||
return errCancelled
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
return pullMissingModel(ctx, client, model)
|
||||
}
|
||||
|
||||
func pullMissingModel(ctx context.Context, client *api.Client, model string) error {
|
||||
if err := pullModel(ctx, client, model, false); err != nil {
|
||||
return fmt.Errorf("failed to pull %s: %w", model, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareEditorIntegration persists models and applies editor-managed config files.
|
||||
func prepareEditorIntegration(name string, runner Runner, editor Editor, models []string) error {
|
||||
if ok, err := confirmEditorEdit(runner, editor); err != nil {
|
||||
return err
|
||||
} else if !ok {
|
||||
return errCancelled
|
||||
}
|
||||
if err := editor.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
if err := config.SaveIntegration(name, models); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func confirmEditorEdit(runner Runner, editor Editor) (bool, error) {
|
||||
paths := editor.Paths()
|
||||
if len(paths) == 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", runner)
|
||||
for _, path := range paths {
|
||||
fmt.Fprintf(os.Stderr, " %s\n", path)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", fileutil.BackupDir())
|
||||
|
||||
return ConfirmPrompt("Proceed?")
|
||||
}
|
||||
|
||||
// buildModelList merges existing models with recommendations for selection UIs.
|
||||
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
|
||||
existingModels = make(map[string]bool)
|
||||
cloudModels = make(map[string]bool)
|
||||
recommended := make(map[string]bool)
|
||||
var hasLocalModel, hasCloudModel bool
|
||||
|
||||
recDesc := make(map[string]string)
|
||||
for _, rec := range recommendedModels {
|
||||
recommended[rec.Name] = true
|
||||
recDesc[rec.Name] = rec.Description
|
||||
}
|
||||
|
||||
for _, m := range existing {
|
||||
existingModels[m.Name] = true
|
||||
if m.Remote {
|
||||
cloudModels[m.Name] = true
|
||||
hasCloudModel = true
|
||||
} else {
|
||||
hasLocalModel = true
|
||||
}
|
||||
displayName := strings.TrimSuffix(m.Name, ":latest")
|
||||
existingModels[displayName] = true
|
||||
item := ModelItem{Name: displayName, Recommended: recommended[displayName], Description: recDesc[displayName]}
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
for _, rec := range recommendedModels {
|
||||
if existingModels[rec.Name] || existingModels[rec.Name+":latest"] {
|
||||
continue
|
||||
}
|
||||
items = append(items, rec)
|
||||
if isCloudModelName(rec.Name) {
|
||||
cloudModels[rec.Name] = true
|
||||
}
|
||||
}
|
||||
|
||||
checked := make(map[string]bool, len(preChecked))
|
||||
for _, n := range preChecked {
|
||||
checked[n] = true
|
||||
}
|
||||
|
||||
if current != "" {
|
||||
matchedCurrent := false
|
||||
for _, item := range items {
|
||||
if item.Name == current {
|
||||
current = item.Name
|
||||
matchedCurrent = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !matchedCurrent {
|
||||
for _, item := range items {
|
||||
if strings.HasPrefix(item.Name, current+":") {
|
||||
current = item.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if checked[current] {
|
||||
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
|
||||
}
|
||||
|
||||
notInstalled := make(map[string]bool)
|
||||
for i := range items {
|
||||
if !existingModels[items[i].Name] && !cloudModels[items[i].Name] {
|
||||
notInstalled[items[i].Name] = true
|
||||
var parts []string
|
||||
if items[i].Description != "" {
|
||||
parts = append(parts, items[i].Description)
|
||||
}
|
||||
if vram := recommendedVRAM[items[i].Name]; vram != "" {
|
||||
parts = append(parts, vram)
|
||||
}
|
||||
parts = append(parts, "(not downloaded)")
|
||||
items[i].Description = strings.Join(parts, ", ")
|
||||
}
|
||||
}
|
||||
|
||||
recRank := make(map[string]int)
|
||||
for i, rec := range recommendedModels {
|
||||
recRank[rec.Name] = i + 1
|
||||
}
|
||||
|
||||
onlyLocal := hasLocalModel && !hasCloudModel
|
||||
|
||||
if hasLocalModel || hasCloudModel {
|
||||
slices.SortStableFunc(items, func(a, b ModelItem) int {
|
||||
ac, bc := checked[a.Name], checked[b.Name]
|
||||
aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
|
||||
aRec, bRec := recRank[a.Name] > 0, recRank[b.Name] > 0
|
||||
aCloud, bCloud := cloudModels[a.Name], cloudModels[b.Name]
|
||||
|
||||
if ac != bc {
|
||||
if ac {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
if aRec != bRec {
|
||||
if aRec {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
if aRec && bRec {
|
||||
if aCloud != bCloud {
|
||||
if onlyLocal {
|
||||
if aCloud {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
}
|
||||
if aCloud {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
return recRank[a.Name] - recRank[b.Name]
|
||||
}
|
||||
if aNew != bNew {
|
||||
if aNew {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
}
|
||||
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
||||
})
|
||||
}
|
||||
|
||||
return items, preChecked, existingModels, cloudModels
|
||||
}
|
||||
|
||||
// isCloudModelName reports whether the model name has an explicit cloud source.
|
||||
func isCloudModelName(name string) bool {
|
||||
return modelref.HasExplicitCloudSource(name)
|
||||
}
|
||||
|
||||
// filterCloudModels drops remote-only models from the given inventory.
|
||||
func filterCloudModels(existing []modelInfo) []modelInfo {
|
||||
filtered := existing[:0]
|
||||
for _, m := range existing {
|
||||
if !m.Remote {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// filterCloudItems removes cloud models from selection items.
|
||||
func filterCloudItems(items []ModelItem) []ModelItem {
|
||||
filtered := items[:0]
|
||||
for _, item := range items {
|
||||
if !isCloudModelName(item.Name) {
|
||||
filtered = append(filtered, item)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
|
||||
if client == nil {
|
||||
return false
|
||||
}
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return resp.RemoteModel != ""
|
||||
}
|
||||
|
||||
// cloudStatusDisabled returns whether cloud usage is currently disabled.
|
||||
func cloudStatusDisabled(ctx context.Context, client *api.Client) (disabled bool, known bool) {
|
||||
status, err := client.CloudStatusExperimental(ctx)
|
||||
if err != nil {
|
||||
var statusErr api.StatusError
|
||||
if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
|
||||
return false, false
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
return status.Cloud.Disabled, true
|
||||
}
|
||||
|
||||
// TODO(parthsareen): this duplicates the pull progress UI in cmd.PullHandler.
|
||||
// Move the shared pull rendering to a small utility once the package boundary settles.
|
||||
func pullModel(ctx context.Context, client *api.Client, model string, insecure bool) error {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
bars := make(map[string]*progress.Bar)
|
||||
var status string
|
||||
var spinner *progress.Spinner
|
||||
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
if resp.Completed == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
|
||||
bar, ok := bars[resp.Digest]
|
||||
if !ok {
|
||||
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
|
||||
name = strings.TrimSpace(name)
|
||||
if isDigest {
|
||||
name = name[:min(12, len(name))]
|
||||
}
|
||||
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
|
||||
bars[resp.Digest] = bar
|
||||
p.Add(resp.Digest, bar)
|
||||
}
|
||||
|
||||
bar.Set(resp.Completed)
|
||||
} else if status != resp.Status {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
|
||||
status = resp.Status
|
||||
spinner = progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
request := api.PullRequest{Name: model, Insecure: insecure}
|
||||
return client.Pull(ctx, &request, fn)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -14,7 +14,10 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/mod/semver"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
@@ -24,6 +27,9 @@ const defaultGatewayPort = 18789
|
||||
// Bound model capability probing so launch/config cannot hang on slow/unreachable API calls.
|
||||
var openclawModelShowTimeout = 5 * time.Second
|
||||
|
||||
// openclawFreshInstall is set to true when ensureOpenclawInstalled performs an install
|
||||
var openclawFreshInstall bool
|
||||
|
||||
type Openclaw struct{}
|
||||
|
||||
func (c *Openclaw) String() string { return "OpenClaw" }
|
||||
@@ -34,10 +40,7 @@ func (c *Openclaw) Run(model string, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
firstLaunch := true
|
||||
if integrationConfig, err := loadIntegration("openclaw"); err == nil {
|
||||
firstLaunch = !integrationConfig.Onboarded
|
||||
}
|
||||
firstLaunch := !c.onboarded()
|
||||
|
||||
if firstLaunch {
|
||||
fmt.Fprintf(os.Stderr, "\n%sSecurity%s\n\n", ansiBold, ansiReset)
|
||||
@@ -45,28 +48,46 @@ func (c *Openclaw) Run(model string, args []string) error {
|
||||
fmt.Fprintf(os.Stderr, " A bad prompt can trick it into doing unsafe things.\n\n")
|
||||
fmt.Fprintf(os.Stderr, "%s Learn more: https://docs.openclaw.ai/gateway/security%s\n\n", ansiGray, ansiReset)
|
||||
|
||||
ok, err := confirmPrompt("I understand the risks. Continue?")
|
||||
ok, err := ConfirmPrompt("I understand the risks. Continue?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if !c.onboarded() {
|
||||
// Ensure the latest version is installed before onboarding so we get
|
||||
// the newest wizard flags (e.g. --auth-choice ollama).
|
||||
if !openclawFreshInstall {
|
||||
update := exec.Command(bin, "update")
|
||||
update.Stdout = os.Stdout
|
||||
update.Stderr = os.Stderr
|
||||
_ = update.Run() // best-effort; continue even if update fails
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n%sSetting up OpenClaw with Ollama...%s\n", ansiGreen, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s Model: %s%s\n\n", ansiGray, model, ansiReset)
|
||||
|
||||
cmd := exec.Command(bin, "onboard",
|
||||
onboardArgs := []string{
|
||||
"onboard",
|
||||
"--non-interactive",
|
||||
"--accept-risk",
|
||||
"--auth-choice", "skip",
|
||||
"--gateway-token", "ollama",
|
||||
"--install-daemon",
|
||||
"--auth-choice", "ollama",
|
||||
"--custom-base-url", envconfig.Host().String(),
|
||||
"--custom-model-id", model,
|
||||
"--skip-channels",
|
||||
"--skip-skills",
|
||||
)
|
||||
}
|
||||
if canInstallDaemon() {
|
||||
onboardArgs = append(onboardArgs, "--install-daemon")
|
||||
} else {
|
||||
// When we can't install a daemon (e.g. no systemd, sudo dropped
|
||||
// XDG_RUNTIME_DIR, or container environment), skip the gateway
|
||||
// health check so non-interactive onboarding completes. The
|
||||
// gateway is started as a foreground child process after onboarding.
|
||||
onboardArgs = append(onboardArgs, "--skip-health")
|
||||
}
|
||||
cmd := exec.Command(bin, onboardArgs...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
@@ -75,25 +96,13 @@ func (c *Openclaw) Run(model string, args []string) error {
|
||||
}
|
||||
|
||||
patchDeviceScopes()
|
||||
|
||||
// Onboarding overwrites openclaw.json, so re-apply the model config
|
||||
// that Edit() wrote before Run() was called.
|
||||
if err := c.Edit([]string{model}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: could not re-apply model config: %v%s\n", ansiYellow, err, ansiReset)
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasSuffix(model, ":cloud") || strings.HasSuffix(model, "-cloud") {
|
||||
if ensureWebSearchPlugin() {
|
||||
registerWebSearchPlugin()
|
||||
}
|
||||
if ensureWebSearchPlugin() {
|
||||
registerWebSearchPlugin()
|
||||
}
|
||||
|
||||
if firstLaunch {
|
||||
fmt.Fprintf(os.Stderr, "\n%sPreparing your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\n%sStarting your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\n%sStarting your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
|
||||
|
||||
// When extra args are passed through, run exactly what the user asked for
|
||||
// after setup and skip the built-in gateway+TUI convenience flow.
|
||||
@@ -106,11 +115,6 @@ func (c *Openclaw) Run(model string, args []string) error {
|
||||
if err := cmd.Run(); err != nil {
|
||||
return windowsHint(err)
|
||||
}
|
||||
if firstLaunch {
|
||||
if err := integrationOnboarded("openclaw"); err != nil {
|
||||
return fmt.Errorf("failed to save onboarding state: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -118,7 +122,7 @@ func (c *Openclaw) Run(model string, args []string) error {
|
||||
addr := fmt.Sprintf("localhost:%d", port)
|
||||
|
||||
// If the gateway is already running (e.g. via the daemon), restart it
|
||||
// so it picks up any config changes from Edit() above (model, provider, etc.).
|
||||
// so it picks up any config changes (model, provider, etc.).
|
||||
if portOpen(addr) {
|
||||
restart := exec.Command(bin, "daemon", "restart")
|
||||
restart.Env = openclawEnv()
|
||||
@@ -165,11 +169,6 @@ func (c *Openclaw) Run(model string, args []string) error {
|
||||
return windowsHint(err)
|
||||
}
|
||||
|
||||
if firstLaunch {
|
||||
if err := integrationOnboarded("openclaw"); err != nil {
|
||||
return fmt.Errorf("failed to save onboarding state: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -409,6 +408,25 @@ func patchScopes(obj map[string]any, key string, required []string) bool {
|
||||
return added
|
||||
}
|
||||
|
||||
// canInstallDaemon reports whether the openclaw daemon can be installed as a
|
||||
// background service. Returns false on Linux when systemd is absent (e.g.
|
||||
// containers) so that --install-daemon is omitted and the gateway is started
|
||||
// as a foreground child process instead. Returns true in all other cases.
|
||||
func canInstallDaemon() bool {
|
||||
if runtime.GOOS != "linux" {
|
||||
return true
|
||||
}
|
||||
// /run/systemd/system exists as a directory when systemd is the init system.
|
||||
// This is absent in most containers.
|
||||
fi, err := os.Stat("/run/systemd/system")
|
||||
if err != nil || !fi.IsDir() {
|
||||
return false
|
||||
}
|
||||
// Even when systemd is the init system, user services require a user
|
||||
// manager instance. XDG_RUNTIME_DIR being set is a prerequisite.
|
||||
return os.Getenv("XDG_RUNTIME_DIR") != ""
|
||||
}
|
||||
|
||||
func ensureOpenclawInstalled() (string, error) {
|
||||
if _, err := exec.LookPath("openclaw"); err == nil {
|
||||
return "openclaw", nil
|
||||
@@ -417,16 +435,20 @@ func ensureOpenclawInstalled() (string, error) {
|
||||
return "clawdbot", nil
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("npm"); err != nil {
|
||||
return "", fmt.Errorf("openclaw is not installed and npm was not found\n\n" +
|
||||
"Install Node.js first:\n" +
|
||||
" https://nodejs.org/\n\n" +
|
||||
"Then rerun:\n" +
|
||||
" ollama launch\n" +
|
||||
"and select OpenClaw")
|
||||
_, npmErr := exec.LookPath("npm")
|
||||
_, gitErr := exec.LookPath("git")
|
||||
if npmErr != nil || gitErr != nil {
|
||||
var missing []string
|
||||
if npmErr != nil {
|
||||
missing = append(missing, "npm (Node.js): https://nodejs.org/")
|
||||
}
|
||||
if gitErr != nil {
|
||||
missing = append(missing, "git: https://git-scm.com/")
|
||||
}
|
||||
return "", fmt.Errorf("openclaw is not installed and required dependencies are missing\n\nInstall the following first:\n %s", strings.Join(missing, "\n "))
|
||||
}
|
||||
|
||||
ok, err := confirmPrompt("OpenClaw is not installed. Install with npm?")
|
||||
ok, err := ConfirmPrompt("OpenClaw is not installed. Install with npm?")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -448,6 +470,7 @@ func ensureOpenclawInstalled() (string, error) {
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sOpenClaw installed successfully%s\n\n", ansiGreen, ansiReset)
|
||||
openclawFreshInstall = true
|
||||
return "openclaw", nil
|
||||
}
|
||||
|
||||
@@ -502,7 +525,7 @@ func (c *Openclaw) Edit(models []string) error {
|
||||
ollama = make(map[string]any)
|
||||
}
|
||||
|
||||
ollama["baseUrl"] = envconfig.Host().String() + "/v1"
|
||||
ollama["baseUrl"] = envconfig.Host().String()
|
||||
// needed to register provider
|
||||
ollama["apiKey"] = "ollama-local"
|
||||
ollama["api"] = "ollama"
|
||||
@@ -561,7 +584,7 @@ func (c *Openclaw) Edit(models []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeWithBackup(configPath, data); err != nil {
|
||||
if err := fileutil.WriteWithBackup(configPath, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -592,6 +615,8 @@ func clearSessionModelOverride(primary string) {
|
||||
if override, _ := sess["modelOverride"].(string); override != "" && override != primary {
|
||||
delete(sess, "modelOverride")
|
||||
delete(sess, "providerOverride")
|
||||
}
|
||||
if model, _ := sess["model"].(string); model != "" && model != primary {
|
||||
sess["model"] = primary
|
||||
changed = true
|
||||
}
|
||||
@@ -606,11 +631,15 @@ func clearSessionModelOverride(primary string) {
|
||||
_ = os.WriteFile(path, out, 0o600)
|
||||
}
|
||||
|
||||
const webSearchNpmPackage = "@ollama/openclaw-web-search"
|
||||
const (
|
||||
webSearchNpmPackage = "@ollama/openclaw-web-search"
|
||||
webSearchMinVersion = "0.2.1"
|
||||
)
|
||||
|
||||
// ensureWebSearchPlugin installs the openclaw-web-search extension into the
|
||||
// user-level extensions directory (~/.openclaw/extensions/) if it isn't already
|
||||
// present. Returns true if the extension is available.
|
||||
// present, or re-installs if the installed version is older than webSearchMinVersion.
|
||||
// Returns true if the extension is available.
|
||||
func ensureWebSearchPlugin() bool {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
@@ -618,8 +647,8 @@ func ensureWebSearchPlugin() bool {
|
||||
}
|
||||
|
||||
pluginDir := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
|
||||
if _, err := os.Stat(filepath.Join(pluginDir, "index.ts")); err == nil {
|
||||
return true // already installed
|
||||
if webSearchPluginUpToDate(pluginDir) {
|
||||
return true
|
||||
}
|
||||
|
||||
npmBin, err := exec.LookPath("npm")
|
||||
@@ -653,6 +682,34 @@ func ensureWebSearchPlugin() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// webSearchPluginUpToDate returns true if the plugin is installed and its
|
||||
// package.json version is >= webSearchMinVersion.
|
||||
func webSearchPluginUpToDate(pluginDir string) bool {
|
||||
data, err := os.ReadFile(filepath.Join(pluginDir, "package.json"))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
var pkg struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
if json.Unmarshal(data, &pkg) != nil || pkg.Version == "" {
|
||||
return false
|
||||
}
|
||||
return !versionLessThan(pkg.Version, webSearchMinVersion)
|
||||
}
|
||||
|
||||
// versionLessThan compares two semver version strings (major.minor.patch).
|
||||
// Inputs may omit the "v" prefix; it is added automatically for semver.Compare.
|
||||
func versionLessThan(a, b string) bool {
|
||||
if !strings.HasPrefix(a, "v") {
|
||||
a = "v" + a
|
||||
}
|
||||
if !strings.HasPrefix(b, "v") {
|
||||
b = "v" + b
|
||||
}
|
||||
return semver.Compare(a, b) < 0
|
||||
}
|
||||
|
||||
// registerWebSearchPlugin adds plugins.entries.openclaw-web-search to the OpenClaw
|
||||
// config so the gateway activates it on next start. Best-effort; silently returns
|
||||
// on any error.
|
||||
@@ -679,23 +736,67 @@ func registerWebSearchPlugin() {
|
||||
if entries == nil {
|
||||
entries = make(map[string]any)
|
||||
}
|
||||
if _, ok := entries["openclaw-web-search"]; ok {
|
||||
return // already registered
|
||||
}
|
||||
entries["openclaw-web-search"] = map[string]any{"enabled": true}
|
||||
plugins["entries"] = entries
|
||||
|
||||
// Pin trust so the gateway doesn't warn about untracked plugins.
|
||||
allow, _ := plugins["allow"].([]any)
|
||||
hasAllow := false
|
||||
for _, v := range allow {
|
||||
if s, ok := v.(string); ok && s == "openclaw-web-search" {
|
||||
hasAllow = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasAllow {
|
||||
allow = append(allow, "openclaw-web-search")
|
||||
}
|
||||
plugins["allow"] = allow
|
||||
|
||||
// Record install provenance so the loader can verify the plugin origin.
|
||||
installs, _ := plugins["installs"].(map[string]any)
|
||||
if installs == nil {
|
||||
installs = make(map[string]any)
|
||||
}
|
||||
pluginDir := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
|
||||
installs["openclaw-web-search"] = map[string]any{
|
||||
"source": "npm",
|
||||
"spec": webSearchNpmPackage,
|
||||
"installPath": pluginDir,
|
||||
}
|
||||
plugins["installs"] = installs
|
||||
|
||||
config["plugins"] = plugins
|
||||
|
||||
// Disable the built-in web search since our plugin replaces it.
|
||||
// Add plugin tools to tools.alsoAllow so they survive the coding profile's
|
||||
// policy pipeline (which has an explicit allow list of core tools only).
|
||||
tools, _ := config["tools"].(map[string]any)
|
||||
if tools == nil {
|
||||
tools = make(map[string]any)
|
||||
}
|
||||
|
||||
alsoAllow, _ := tools["alsoAllow"].([]any)
|
||||
needed := []string{"ollama_web_search", "ollama_web_fetch"}
|
||||
have := make(map[string]bool, len(alsoAllow))
|
||||
for _, v := range alsoAllow {
|
||||
if s, ok := v.(string); ok {
|
||||
have[s] = true
|
||||
}
|
||||
}
|
||||
for _, name := range needed {
|
||||
if !have[name] {
|
||||
alsoAllow = append(alsoAllow, name)
|
||||
}
|
||||
}
|
||||
tools["alsoAllow"] = alsoAllow
|
||||
|
||||
// Disable built-in web search/fetch since our plugin replaces them.
|
||||
web, _ := tools["web"].(map[string]any)
|
||||
if web == nil {
|
||||
web = make(map[string]any)
|
||||
}
|
||||
web["search"] = map[string]any{"enabled": false}
|
||||
web["fetch"] = map[string]any{"enabled": false}
|
||||
tools["web"] = web
|
||||
config["tools"] = tools
|
||||
|
||||
@@ -776,9 +877,9 @@ func (c *Openclaw) Models() []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
config, err := readJSONFile(filepath.Join(home, ".openclaw", "openclaw.json"))
|
||||
config, err := fileutil.ReadJSON(filepath.Join(home, ".openclaw", "openclaw.json"))
|
||||
if err != nil {
|
||||
config, err = readJSONFile(filepath.Join(home, ".clawdbot", "clawdbot.json"))
|
||||
config, err = fileutil.ReadJSON(filepath.Join(home, ".clawdbot", "clawdbot.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -82,78 +82,6 @@ func TestOpenclawRunPassthroughArgs(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenclawRunFirstLaunchPersistence(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("uses a POSIX shell test binary")
|
||||
}
|
||||
|
||||
oldHook := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldHook }()
|
||||
|
||||
t.Run("success persists onboarding flag", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Mark OpenClaw onboarding complete so Run takes passthrough path directly.
|
||||
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
|
||||
"wizard": {"lastRunAt": "2026-01-01T00:00:00Z"}
|
||||
}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "openclaw"), []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c := &Openclaw{}
|
||||
if err := c.Run("llama3.2", []string{"gateway", "--status"}); err != nil {
|
||||
t.Fatalf("Run() error = %v", err)
|
||||
}
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
if err != nil {
|
||||
t.Fatalf("loadIntegration() error = %v", err)
|
||||
}
|
||||
if !integrationConfig.Onboarded {
|
||||
t.Fatal("expected onboarding flag to be persisted after successful run")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("failure does not persist onboarding flag", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
|
||||
"wizard": {"lastRunAt": "2026-01-01T00:00:00Z"}
|
||||
}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "openclaw"), []byte("#!/bin/sh\nexit 1\n"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c := &Openclaw{}
|
||||
if err := c.Run("llama3.2", []string{"gateway", "--status"}); err == nil {
|
||||
t.Fatal("expected run failure")
|
||||
}
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
if err == nil && integrationConfig.Onboarded {
|
||||
t.Fatal("expected onboarding flag to remain unset after failed run")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenclawEdit(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
@@ -589,7 +517,7 @@ const testOpenclawFixture = `{
|
||||
"providers": {
|
||||
"anthropic": {"apiKey": "xxx"},
|
||||
"ollama": {
|
||||
"baseUrl": "http://127.0.0.1:11434/v1",
|
||||
"baseUrl": "http://127.0.0.1:11434",
|
||||
"models": [{"id": "old-model", "customField": "preserved"}]
|
||||
}
|
||||
}
|
||||
@@ -1448,7 +1376,7 @@ func TestOpenclawModelConfig(t *testing.T) {
|
||||
// report it as a remote/cloud model
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"minimax-m2.5"}`)
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"minimax-m2.7"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
@@ -1458,7 +1386,7 @@ func TestOpenclawModelConfig(t *testing.T) {
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, isCloud := openclawModelConfig(context.Background(), client, "minimax-m2.5:cloud")
|
||||
cfg, isCloud := openclawModelConfig(context.Background(), client, "minimax-m2.7:cloud")
|
||||
|
||||
if !isCloud {
|
||||
t.Error("expected isCloud = true for cloud model")
|
||||
@@ -1528,7 +1456,7 @@ func TestIntegrationOnboarded(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
integrationConfig, err := LoadIntegration("openclaw")
|
||||
if err == nil && integrationConfig.Onboarded {
|
||||
t.Error("expected false for fresh config")
|
||||
}
|
||||
@@ -1542,7 +1470,7 @@ func TestIntegrationOnboarded(t *testing.T) {
|
||||
if err := integrationOnboarded("openclaw"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
integrationConfig, err := LoadIntegration("openclaw")
|
||||
if err != nil || !integrationConfig.Onboarded {
|
||||
t.Error("expected true after integrationOnboarded")
|
||||
}
|
||||
@@ -1556,7 +1484,7 @@ func TestIntegrationOnboarded(t *testing.T) {
|
||||
if err := integrationOnboarded("OpenClaw"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
integrationConfig, err := LoadIntegration("openclaw")
|
||||
if err != nil || !integrationConfig.Onboarded {
|
||||
t.Error("expected true when set with different case")
|
||||
}
|
||||
@@ -1575,7 +1503,7 @@ func TestIntegrationOnboarded(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify onboarded is set
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
integrationConfig, err := LoadIntegration("openclaw")
|
||||
if err != nil || !integrationConfig.Onboarded {
|
||||
t.Error("expected true after integrationOnboarded")
|
||||
}
|
||||
@@ -1587,3 +1515,377 @@ func TestIntegrationOnboarded(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestVersionLessThan(t *testing.T) {
|
||||
tests := []struct {
|
||||
a, b string
|
||||
want bool
|
||||
}{
|
||||
{"0.1.7", "0.2.1", true},
|
||||
{"0.2.0", "0.2.1", true},
|
||||
{"0.2.1", "0.2.1", false},
|
||||
{"0.2.2", "0.2.1", false},
|
||||
{"1.0.0", "0.2.1", false},
|
||||
{"0.2.1", "1.0.0", true},
|
||||
{"v0.1.7", "0.2.1", true},
|
||||
{"0.2.1", "v0.2.1", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.a+"_vs_"+tt.b, func(t *testing.T) {
|
||||
if got := versionLessThan(tt.a, tt.b); got != tt.want {
|
||||
t.Errorf("versionLessThan(%q, %q) = %v, want %v", tt.a, tt.b, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSearchPluginUpToDate(t *testing.T) {
|
||||
t.Run("missing directory", func(t *testing.T) {
|
||||
if webSearchPluginUpToDate(filepath.Join(t.TempDir(), "nonexistent")) {
|
||||
t.Error("expected false for missing directory")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing package.json", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if webSearchPluginUpToDate(dir) {
|
||||
t.Error("expected false for missing package.json")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("old version", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "package.json"), []byte(`{"version":"0.1.7"}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if webSearchPluginUpToDate(dir) {
|
||||
t.Error("expected false for old version 0.1.7")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("exact minimum version", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "package.json"), []byte(`{"version":"0.2.1"}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !webSearchPluginUpToDate(dir) {
|
||||
t.Error("expected true for exact minimum version 0.2.1")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("newer version", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "package.json"), []byte(`{"version":"1.0.0"}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !webSearchPluginUpToDate(dir) {
|
||||
t.Error("expected true for newer version 1.0.0")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid json", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "package.json"), []byte(`not json`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if webSearchPluginUpToDate(dir) {
|
||||
t.Error("expected false for invalid json")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty version", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "package.json"), []byte(`{"version":""}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if webSearchPluginUpToDate(dir) {
|
||||
t.Error("expected false for empty version")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRegisterWebSearchPlugin(t *testing.T) {
|
||||
home := t.TempDir()
|
||||
setTestHome(t, home)
|
||||
|
||||
configDir := filepath.Join(home, ".openclaw")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
configPath := filepath.Join(configDir, "openclaw.json")
|
||||
|
||||
t.Run("fresh config", func(t *testing.T) {
|
||||
if err := os.WriteFile(configPath, []byte(`{}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
registerWebSearchPlugin()
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var config map[string]any
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
plugins, _ := config["plugins"].(map[string]any)
|
||||
if plugins == nil {
|
||||
t.Fatal("plugins section missing")
|
||||
}
|
||||
|
||||
// Check entries
|
||||
entries, _ := plugins["entries"].(map[string]any)
|
||||
entry, _ := entries["openclaw-web-search"].(map[string]any)
|
||||
if enabled, _ := entry["enabled"].(bool); !enabled {
|
||||
t.Error("expected entries.openclaw-web-search.enabled = true")
|
||||
}
|
||||
|
||||
// Check allow list
|
||||
allow, _ := plugins["allow"].([]any)
|
||||
found := false
|
||||
for _, v := range allow {
|
||||
if s, ok := v.(string); ok && s == "openclaw-web-search" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected plugins.allow to contain openclaw-web-search")
|
||||
}
|
||||
|
||||
// Check install provenance
|
||||
installs, _ := plugins["installs"].(map[string]any)
|
||||
record, _ := installs["openclaw-web-search"].(map[string]any)
|
||||
if record == nil {
|
||||
t.Fatal("expected plugins.installs.openclaw-web-search")
|
||||
}
|
||||
if source, _ := record["source"].(string); source != "npm" {
|
||||
t.Errorf("install source = %q, want %q", source, "npm")
|
||||
}
|
||||
if spec, _ := record["spec"].(string); spec != webSearchNpmPackage {
|
||||
t.Errorf("install spec = %q, want %q", spec, webSearchNpmPackage)
|
||||
}
|
||||
expectedPath := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
|
||||
if installPath, _ := record["installPath"].(string); installPath != expectedPath {
|
||||
t.Errorf("installPath = %q, want %q", installPath, expectedPath)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("idempotent", func(t *testing.T) {
|
||||
if err := os.WriteFile(configPath, []byte(`{}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
registerWebSearchPlugin()
|
||||
registerWebSearchPlugin()
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var config map[string]any
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
plugins, _ := config["plugins"].(map[string]any)
|
||||
allow, _ := plugins["allow"].([]any)
|
||||
count := 0
|
||||
for _, v := range allow {
|
||||
if s, ok := v.(string); ok && s == "openclaw-web-search" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("expected exactly 1 openclaw-web-search in allow, got %d", count)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserves existing config", func(t *testing.T) {
|
||||
initial := map[string]any{
|
||||
"plugins": map[string]any{
|
||||
"allow": []any{"some-other-plugin"},
|
||||
"entries": map[string]any{
|
||||
"some-other-plugin": map[string]any{"enabled": true},
|
||||
},
|
||||
"installs": map[string]any{
|
||||
"some-other-plugin": map[string]any{
|
||||
"source": "npm",
|
||||
"installPath": "/some/path",
|
||||
},
|
||||
},
|
||||
},
|
||||
"customField": "preserved",
|
||||
}
|
||||
data, _ := json.Marshal(initial)
|
||||
if err := os.WriteFile(configPath, data, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
registerWebSearchPlugin()
|
||||
|
||||
out, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var config map[string]any
|
||||
if err := json.Unmarshal(out, &config); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if config["customField"] != "preserved" {
|
||||
t.Error("customField was not preserved")
|
||||
}
|
||||
|
||||
plugins, _ := config["plugins"].(map[string]any)
|
||||
entries, _ := plugins["entries"].(map[string]any)
|
||||
if entries["some-other-plugin"] == nil {
|
||||
t.Error("existing plugin entry was lost")
|
||||
}
|
||||
|
||||
installs, _ := plugins["installs"].(map[string]any)
|
||||
if installs["some-other-plugin"] == nil {
|
||||
t.Error("existing install record was lost")
|
||||
}
|
||||
|
||||
allow, _ := plugins["allow"].([]any)
|
||||
hasOther, hasWebSearch := false, false
|
||||
for _, v := range allow {
|
||||
s, _ := v.(string)
|
||||
if s == "some-other-plugin" {
|
||||
hasOther = true
|
||||
}
|
||||
if s == "openclaw-web-search" {
|
||||
hasWebSearch = true
|
||||
}
|
||||
}
|
||||
if !hasOther {
|
||||
t.Error("existing allow entry was lost")
|
||||
}
|
||||
if !hasWebSearch {
|
||||
t.Error("openclaw-web-search not added to allow")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClearSessionModelOverride(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
sessionsDir := filepath.Join(tmpDir, ".openclaw", "agents", "main", "sessions")
|
||||
sessionsPath := filepath.Join(sessionsDir, "sessions.json")
|
||||
|
||||
writeSessionsFile := func(t *testing.T, sessions map[string]map[string]any) {
|
||||
t.Helper()
|
||||
if err := os.MkdirAll(sessionsDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, err := json.Marshal(sessions)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(sessionsPath, data, 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
readSessionsFile := func(t *testing.T) map[string]map[string]any {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(sessionsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("reading sessions file: %v", err)
|
||||
}
|
||||
var sessions map[string]map[string]any
|
||||
if err := json.Unmarshal(data, &sessions); err != nil {
|
||||
t.Fatalf("parsing sessions file: %v", err)
|
||||
}
|
||||
return sessions
|
||||
}
|
||||
|
||||
t.Run("clears modelOverride and updates model", func(t *testing.T) {
|
||||
writeSessionsFile(t, map[string]map[string]any{
|
||||
"sess1": {"model": "ollama/old-model", "modelOverride": "old-model", "providerOverride": "ollama"},
|
||||
})
|
||||
clearSessionModelOverride("new-model")
|
||||
sessions := readSessionsFile(t)
|
||||
sess := sessions["sess1"]
|
||||
if _, ok := sess["modelOverride"]; ok {
|
||||
t.Error("modelOverride should have been deleted")
|
||||
}
|
||||
if _, ok := sess["providerOverride"]; ok {
|
||||
t.Error("providerOverride should have been deleted")
|
||||
}
|
||||
if sess["model"] != "new-model" {
|
||||
t.Errorf("model = %q, want %q", sess["model"], "new-model")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updates model field in sessions without modelOverride", func(t *testing.T) {
|
||||
// This is the bug case: session has model pointing to old primary,
|
||||
// but no explicit modelOverride. After changing primary, the session
|
||||
// model field must also be updated.
|
||||
writeSessionsFile(t, map[string]map[string]any{
|
||||
"sess1": {"model": "ollama/old-model"},
|
||||
})
|
||||
clearSessionModelOverride("new-model")
|
||||
sessions := readSessionsFile(t)
|
||||
if sessions["sess1"]["model"] != "new-model" {
|
||||
t.Errorf("model = %q, want %q", sessions["sess1"]["model"], "new-model")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("does not update session already using primary", func(t *testing.T) {
|
||||
writeSessionsFile(t, map[string]map[string]any{
|
||||
"sess1": {"model": "current-model"},
|
||||
})
|
||||
clearSessionModelOverride("current-model")
|
||||
sessions := readSessionsFile(t)
|
||||
if sessions["sess1"]["model"] != "current-model" {
|
||||
t.Errorf("model = %q, want %q", sessions["sess1"]["model"], "current-model")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("does not update session with empty model field", func(t *testing.T) {
|
||||
writeSessionsFile(t, map[string]map[string]any{
|
||||
"sess1": {"other": "data"},
|
||||
})
|
||||
clearSessionModelOverride("new-model")
|
||||
sessions := readSessionsFile(t)
|
||||
if _, ok := sessions["sess1"]["model"]; ok {
|
||||
t.Error("model field should not have been added to session with no model")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles multiple sessions mixed", func(t *testing.T) {
|
||||
writeSessionsFile(t, map[string]map[string]any{
|
||||
"with-override": {"model": "old", "modelOverride": "old", "providerOverride": "ollama"},
|
||||
"without-override": {"model": "old"},
|
||||
"already-current": {"model": "new-model"},
|
||||
"no-model": {"other": "data"},
|
||||
})
|
||||
clearSessionModelOverride("new-model")
|
||||
sessions := readSessionsFile(t)
|
||||
|
||||
if sessions["with-override"]["model"] != "new-model" {
|
||||
t.Errorf("with-override model = %q, want %q", sessions["with-override"]["model"], "new-model")
|
||||
}
|
||||
if _, ok := sessions["with-override"]["modelOverride"]; ok {
|
||||
t.Error("with-override: modelOverride should be deleted")
|
||||
}
|
||||
if sessions["without-override"]["model"] != "new-model" {
|
||||
t.Errorf("without-override model = %q, want %q", sessions["without-override"]["model"], "new-model")
|
||||
}
|
||||
if sessions["already-current"]["model"] != "new-model" {
|
||||
t.Errorf("already-current model = %q, want %q", sessions["already-current"]["model"], "new-model")
|
||||
}
|
||||
if _, ok := sessions["no-model"]["model"]; ok {
|
||||
t.Error("no-model: model should not have been added")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no-op when sessions file missing", func(t *testing.T) {
|
||||
os.RemoveAll(sessionsDir)
|
||||
clearSessionModelOverride("new-model") // should not panic or error
|
||||
})
|
||||
}
|
||||
@@ -1,9 +1,7 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
@@ -12,34 +10,13 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// OpenCode implements Runner and Editor for OpenCode integration
|
||||
type OpenCode struct{}
|
||||
|
||||
// cloudModelLimit holds context and output token limits for a cloud model.
|
||||
type cloudModelLimit struct {
|
||||
Context int
|
||||
Output int
|
||||
}
|
||||
|
||||
// lookupCloudModelLimit returns the token limits for a cloud model.
|
||||
// It tries the exact name first, then strips the ":cloud" suffix.
|
||||
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
||||
if l, ok := cloudModelLimits[name]; ok {
|
||||
return l, true
|
||||
}
|
||||
base := strings.TrimSuffix(name, ":cloud")
|
||||
if base != name {
|
||||
if l, ok := cloudModelLimits[base]; ok {
|
||||
return l, true
|
||||
}
|
||||
}
|
||||
return cloudModelLimit{}, false
|
||||
}
|
||||
|
||||
func (o *OpenCode) String() string { return "OpenCode" }
|
||||
|
||||
func (o *OpenCode) Run(model string, args []string) error {
|
||||
@@ -47,25 +24,6 @@ func (o *OpenCode) Run(model string, args []string) error {
|
||||
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
|
||||
}
|
||||
|
||||
// Call Edit() to ensure config is up-to-date before launch
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("opencode"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
var err error
|
||||
models, err = resolveEditorModels("opencode", models, func() ([]string, error) {
|
||||
return selectModels(context.Background(), "opencode", "")
|
||||
})
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := o.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("opencode", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
@@ -122,13 +80,18 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
if !ok {
|
||||
ollama = map[string]any{
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama (local)",
|
||||
"name": "Ollama",
|
||||
"options": map[string]any{
|
||||
"baseURL": envconfig.Host().String() + "/v1",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Migrate legacy provider name
|
||||
if name, _ := ollama["name"].(string); name == "Ollama (local)" {
|
||||
ollama["name"] = "Ollama"
|
||||
}
|
||||
|
||||
models, ok := ollama["models"].(map[string]any)
|
||||
if !ok {
|
||||
models = make(map[string]any)
|
||||
@@ -147,8 +110,6 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
|
||||
for _, model := range modelList {
|
||||
if existing, ok := models[model].(map[string]any); ok {
|
||||
// migrate existing models without _launch marker
|
||||
@@ -158,7 +119,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
||||
}
|
||||
}
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if isCloudModelName(model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
existing["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
@@ -172,7 +133,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
"name": model,
|
||||
"_launch": true,
|
||||
}
|
||||
if isCloudModel(context.Background(), client, model) {
|
||||
if isCloudModelName(model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
entry["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
@@ -186,12 +147,13 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
ollama["models"] = models
|
||||
provider["ollama"] = ollama
|
||||
config["provider"] = provider
|
||||
config["model"] = "ollama/" + modelList[0]
|
||||
|
||||
configData, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeWithBackup(configPath, configData); err != nil {
|
||||
if err := fileutil.WriteWithBackup(configPath, configData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -243,7 +205,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(statePath, stateData)
|
||||
return fileutil.WriteWithBackup(statePath, stateData)
|
||||
}
|
||||
|
||||
func (o *OpenCode) Models() []string {
|
||||
@@ -251,7 +213,7 @@ func (o *OpenCode) Models() []string {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
config, err := readJSONFile(filepath.Join(home, ".config", "opencode", "opencode.json"))
|
||||
config, err := fileutil.ReadJSON(filepath.Join(home, ".config", "opencode", "opencode.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
@@ -1,8 +1,10 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -47,6 +49,7 @@ func TestOpenCodeEdit(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
assertOpenCodeDefaultModel(t, configPath, "ollama/llama3.2")
|
||||
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
|
||||
})
|
||||
|
||||
@@ -155,11 +158,13 @@ func TestOpenCodeEdit(t *testing.T) {
|
||||
o.Edit([]string{"llama3.2", "mistral"})
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
assertOpenCodeModelExists(t, configPath, "mistral")
|
||||
assertOpenCodeDefaultModel(t, configPath, "ollama/llama3.2")
|
||||
|
||||
// Then remove one by only selecting the other
|
||||
o.Edit([]string{"llama3.2"})
|
||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
||||
assertOpenCodeModelNotExists(t, configPath, "mistral")
|
||||
assertOpenCodeDefaultModel(t, configPath, "ollama/llama3.2")
|
||||
})
|
||||
|
||||
t.Run("preserve user customizations on managed models", func(t *testing.T) {
|
||||
@@ -232,6 +237,44 @@ func TestOpenCodeEdit(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("migrate Ollama (local) provider name", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"name":"Ollama (local)","npm":"@ai-sdk/openai-compatible","options":{"baseURL":"http://localhost:11434/v1"}}}}`), 0o644)
|
||||
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
if ollama["name"] != "Ollama" {
|
||||
t.Errorf("provider name not migrated: got %q, want %q", ollama["name"], "Ollama")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserve custom provider name", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"name":"My Custom Ollama","npm":"@ai-sdk/openai-compatible","options":{"baseURL":"http://localhost:11434/v1"}}}}`), 0o644)
|
||||
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
if ollama["name"] != "My Custom Ollama" {
|
||||
t.Errorf("custom provider name was changed: got %q, want %q", ollama["name"], "My Custom Ollama")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
@@ -298,6 +341,22 @@ func assertOpenCodeModelNotExists(t *testing.T, path, model string) {
|
||||
}
|
||||
}
|
||||
|
||||
func assertOpenCodeDefaultModel(t *testing.T, path, want string) {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got, _ := cfg["model"].(string)
|
||||
if got != want {
|
||||
t.Fatalf("default model = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func assertOpenCodeRecentModel(t *testing.T, path string, index int, providerID, modelID string) {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(path)
|
||||
@@ -619,6 +678,54 @@ func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_BackfillsCloudModelLimitOnExistingEntry(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"glm-5"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
configPath := filepath.Join(configDir, "opencode.json")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"models": {
|
||||
"glm-5:cloud": {
|
||||
"name": "glm-5:cloud",
|
||||
"_launch": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`), 0o644)
|
||||
|
||||
if err := o.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entry := readOpenCodeModel(t, configPath, "glm-5:cloud")
|
||||
limit, ok := entry["limit"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("cloud model limit was not added on re-edit")
|
||||
}
|
||||
if limit["context"] != float64(202_752) {
|
||||
t.Errorf("context = %v, want 202752", limit["context"])
|
||||
}
|
||||
if limit["output"] != float64(131_072) {
|
||||
t.Errorf("output = %v, want 131072", limit["output"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupCloudModelLimit(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -626,13 +733,19 @@ func TestLookupCloudModelLimit(t *testing.T) {
|
||||
wantContext int
|
||||
wantOutput int
|
||||
}{
|
||||
{"glm-4.7", true, 202_752, 131_072},
|
||||
{"glm-4.7", false, 0, 0},
|
||||
{"glm-4.7:cloud", true, 202_752, 131_072},
|
||||
{"kimi-k2.5", true, 262_144, 262_144},
|
||||
{"glm-5:cloud", true, 202_752, 131_072},
|
||||
{"gpt-oss:120b-cloud", true, 131_072, 131_072},
|
||||
{"gpt-oss:20b-cloud", true, 131_072, 131_072},
|
||||
{"kimi-k2.5", false, 0, 0},
|
||||
{"kimi-k2.5:cloud", true, 262_144, 262_144},
|
||||
{"deepseek-v3.2", true, 163_840, 65_536},
|
||||
{"deepseek-v3.2", false, 0, 0},
|
||||
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
|
||||
{"qwen3-coder:480b", true, 262_144, 65_536},
|
||||
{"qwen3.5", false, 0, 0},
|
||||
{"qwen3.5:cloud", true, 262_144, 32_768},
|
||||
{"qwen3-coder:480b", false, 0, 0},
|
||||
{"qwen3-coder:480b:cloud", true, 262_144, 65_536},
|
||||
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
|
||||
{"llama3.2", false, 0, 0},
|
||||
{"unknown-model:cloud", false, 0, 0},
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
@@ -19,29 +20,151 @@ import (
|
||||
// Pi implements Runner and Editor for Pi (Pi Coding Agent) integration
|
||||
type Pi struct{}
|
||||
|
||||
const (
|
||||
piNpmPackage = "@mariozechner/pi-coding-agent"
|
||||
piWebSearchSource = "npm:@ollama/pi-web-search"
|
||||
piWebSearchPkg = "@ollama/pi-web-search"
|
||||
)
|
||||
|
||||
func (p *Pi) String() string { return "Pi" }
|
||||
|
||||
func (p *Pi) Run(model string, args []string) error {
|
||||
if _, err := exec.LookPath("pi"); err != nil {
|
||||
return fmt.Errorf("pi is not installed, install with: npm install -g @mariozechner/pi-coding-agent")
|
||||
fmt.Fprintf(os.Stderr, "\n%sPreparing Pi...%s\n", ansiGray, ansiReset)
|
||||
if err := ensureNpmInstalled(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Call Edit() to ensure config is up-to-date before launch
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("pi"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
if err := p.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
fmt.Fprintf(os.Stderr, "%sChecking Pi installation...%s\n", ansiGray, ansiReset)
|
||||
bin, err := ensurePiInstalled()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd := exec.Command("pi", args...)
|
||||
ensurePiWebSearchPackage(bin)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n%sLaunching Pi...%s\n\n", ansiGray, ansiReset)
|
||||
|
||||
cmd := exec.Command(bin, args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func ensureNpmInstalled() error {
|
||||
if _, err := exec.LookPath("npm"); err != nil {
|
||||
return fmt.Errorf("npm (Node.js) is required to launch pi\n\nInstall it first:\n https://nodejs.org/")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensurePiInstalled() (string, error) {
|
||||
if _, err := exec.LookPath("pi"); err == nil {
|
||||
return "pi", nil
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("npm"); err != nil {
|
||||
return "", fmt.Errorf("pi is not installed and required dependencies are missing\n\nInstall the following first:\n npm (Node.js): https://nodejs.org/")
|
||||
}
|
||||
|
||||
ok, err := ConfirmPrompt("Pi is not installed. Install with npm?")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !ok {
|
||||
return "", fmt.Errorf("pi installation cancelled")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nInstalling Pi...\n")
|
||||
cmd := exec.Command("npm", "install", "-g", piNpmPackage+"@latest")
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return "", fmt.Errorf("failed to install pi: %w", err)
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("pi"); err != nil {
|
||||
return "", fmt.Errorf("pi was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sPi installed successfully%s\n\n", ansiGreen, ansiReset)
|
||||
return "pi", nil
|
||||
}
|
||||
|
||||
func ensurePiWebSearchPackage(bin string) {
|
||||
if !shouldManagePiWebSearch() {
|
||||
fmt.Fprintf(os.Stderr, "%sCloud is disabled; skipping %s setup.%s\n", ansiGray, piWebSearchPkg, ansiReset)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sChecking Pi web search package...%s\n", ansiGray, ansiReset)
|
||||
|
||||
installed, err := piPackageInstalled(bin, piWebSearchSource)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: could not check %s installation: %v%s\n", ansiYellow, piWebSearchPkg, err, ansiReset)
|
||||
return
|
||||
}
|
||||
|
||||
if !installed {
|
||||
fmt.Fprintf(os.Stderr, "%sInstalling %s...%s\n", ansiGray, piWebSearchPkg, ansiReset)
|
||||
cmd := exec.Command(bin, "install", piWebSearchSource)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: could not install %s: %v%s\n", ansiYellow, piWebSearchPkg, err, ansiReset)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%s ✓ Installed %s%s\n", ansiGreen, piWebSearchPkg, ansiReset)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sUpdating %s...%s\n", ansiGray, piWebSearchPkg, ansiReset)
|
||||
cmd := exec.Command(bin, "update", piWebSearchSource)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: could not update %s: %v%s\n", ansiYellow, piWebSearchPkg, err, ansiReset)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%s ✓ Updated %s%s\n", ansiGreen, piWebSearchPkg, ansiReset)
|
||||
}
|
||||
|
||||
func shouldManagePiWebSearch() bool {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
disabled, known := cloudStatusDisabled(context.Background(), client)
|
||||
if known && disabled {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func piPackageInstalled(bin, source string) (bool, error) {
|
||||
cmd := exec.Command(bin, "list")
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
msg := strings.TrimSpace(string(out))
|
||||
if msg == "" {
|
||||
return false, err
|
||||
}
|
||||
return false, fmt.Errorf("%w: %s", err, msg)
|
||||
}
|
||||
|
||||
for _, line := range strings.Split(string(out), "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if strings.HasPrefix(trimmed, source) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (p *Pi) Paths() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
@@ -107,7 +230,8 @@ func (p *Pi) Edit(models []string) error {
|
||||
|
||||
// Build new models list:
|
||||
// 1. Keep user-managed models (no _launch marker) - untouched
|
||||
// 2. Keep ollama-managed models (_launch marker) that are still selected
|
||||
// 2. Keep ollama-managed models (_launch marker) that are still selected,
|
||||
// except stale cloud entries that should be rebuilt below
|
||||
// 3. Add new ollama-managed models
|
||||
var newModels []any
|
||||
for _, m := range existingModels {
|
||||
@@ -117,7 +241,13 @@ func (p *Pi) Edit(models []string) error {
|
||||
if !isPiOllamaModel(modelObj) {
|
||||
newModels = append(newModels, m)
|
||||
} else if selectedSet[id] {
|
||||
// Ollama-managed and still selected - keep it
|
||||
// Rebuild stale managed cloud entries so createConfig refreshes
|
||||
// the whole entry instead of patching it in place.
|
||||
if !hasContextWindow(modelObj) {
|
||||
if _, ok := lookupCloudModelLimit(id); ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
newModels = append(newModels, m)
|
||||
selectedSet[id] = false
|
||||
}
|
||||
@@ -142,7 +272,7 @@ func (p *Pi) Edit(models []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeWithBackup(configPath, configData); err != nil {
|
||||
if err := fileutil.WriteWithBackup(configPath, configData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -160,7 +290,7 @@ func (p *Pi) Edit(models []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(settingsPath, settingsData)
|
||||
return fileutil.WriteWithBackup(settingsPath, settingsData)
|
||||
}
|
||||
|
||||
func (p *Pi) Models() []string {
|
||||
@@ -170,7 +300,7 @@ func (p *Pi) Models() []string {
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".pi", "agent", "models.json")
|
||||
config, err := readJSONFile(configPath)
|
||||
config, err := fileutil.ReadJSON(configPath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
@@ -199,15 +329,38 @@ func isPiOllamaModel(cfg map[string]any) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func hasContextWindow(cfg map[string]any) bool {
|
||||
switch v := cfg["contextWindow"].(type) {
|
||||
case float64:
|
||||
return v > 0
|
||||
case int:
|
||||
return v > 0
|
||||
case int64:
|
||||
return v > 0
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// createConfig builds Pi model config with capability detection
|
||||
func createConfig(ctx context.Context, client *api.Client, modelID string) map[string]any {
|
||||
cfg := map[string]any{
|
||||
"id": modelID,
|
||||
"_launch": true,
|
||||
}
|
||||
if l, ok := lookupCloudModelLimit(modelID); ok {
|
||||
cfg["contextWindow"] = l.Context
|
||||
}
|
||||
|
||||
applyCloudContextFallback := func() {
|
||||
if l, ok := lookupCloudModelLimit(modelID); ok {
|
||||
cfg["contextWindow"] = l.Context
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID})
|
||||
if err != nil {
|
||||
applyCloudContextFallback()
|
||||
return cfg
|
||||
}
|
||||
|
||||
@@ -223,15 +376,21 @@ func createConfig(ctx context.Context, client *api.Client, modelID string) map[s
|
||||
cfg["reasoning"] = true
|
||||
}
|
||||
|
||||
// Extract context window from ModelInfo
|
||||
// Extract context window from ModelInfo. For known cloud models, the
|
||||
// pre-filled shared limit remains unless the server provides a positive value.
|
||||
hasContextWindow := false
|
||||
for key, val := range resp.ModelInfo {
|
||||
if strings.HasSuffix(key, ".context_length") {
|
||||
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
|
||||
cfg["contextWindow"] = int(ctxLen)
|
||||
hasContextWindow = true
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasContextWindow {
|
||||
applyCloudContextFallback()
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package config
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
@@ -33,6 +35,339 @@ func TestPiIntegration(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestPiRun_InstallAndWebSearchLifecycle(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("uses POSIX shell test binaries")
|
||||
}
|
||||
|
||||
writeScript := func(t *testing.T, path, content string) {
|
||||
t.Helper()
|
||||
if err := os.WriteFile(path, []byte(content), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
seedPiScript := func(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
piPath := filepath.Join(dir, "pi")
|
||||
listPath := filepath.Join(dir, "pi-list.txt")
|
||||
piScript := fmt.Sprintf(`#!/bin/sh
|
||||
echo "$@" >> %q
|
||||
if [ "$1" = "list" ]; then
|
||||
if [ -f %q ]; then
|
||||
/bin/cat %q
|
||||
fi
|
||||
exit 0
|
||||
fi
|
||||
if [ "$1" = "update" ] && [ "$PI_FAIL_UPDATE" = "1" ]; then
|
||||
echo "update failed" >&2
|
||||
exit 1
|
||||
fi
|
||||
if [ "$1" = "install" ] && [ "$PI_FAIL_INSTALL" = "1" ]; then
|
||||
echo "install failed" >&2
|
||||
exit 1
|
||||
fi
|
||||
exit 0
|
||||
`, filepath.Join(dir, "pi.log"), listPath, listPath)
|
||||
writeScript(t, piPath, piScript)
|
||||
}
|
||||
|
||||
seedNpmNoop := func(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
writeScript(t, filepath.Join(dir, "npm"), "#!/bin/sh\nexit 0\n")
|
||||
}
|
||||
|
||||
withConfirm := func(t *testing.T, fn func(prompt string) (bool, error)) {
|
||||
t.Helper()
|
||||
oldConfirm := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = fn
|
||||
t.Cleanup(func() { DefaultConfirmPrompt = oldConfirm })
|
||||
}
|
||||
|
||||
setCloudStatus := func(t *testing.T, disabled bool) {
|
||||
t.Helper()
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/status" {
|
||||
fmt.Fprintf(w, `{"cloud":{"disabled":%t,"source":"config"}}`, disabled)
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
}
|
||||
|
||||
t.Run("pi missing + user accepts install", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
setCloudStatus(t, false)
|
||||
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "pi-list.txt"), []byte("User packages:\n npm:@ollama/pi-web-search\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
npmScript := fmt.Sprintf(`#!/bin/sh
|
||||
echo "$@" >> %q
|
||||
if [ "$1" = "install" ] && [ "$2" = "-g" ] && [ "$3" = %q ]; then
|
||||
/bin/cat > %q <<'EOS'
|
||||
#!/bin/sh
|
||||
echo "$@" >> %q
|
||||
if [ "$1" = "list" ]; then
|
||||
if [ -f %q ]; then
|
||||
/bin/cat %q
|
||||
fi
|
||||
exit 0
|
||||
fi
|
||||
exit 0
|
||||
EOS
|
||||
/bin/chmod +x %q
|
||||
fi
|
||||
exit 0
|
||||
`, filepath.Join(tmpDir, "npm.log"), piNpmPackage+"@latest", filepath.Join(tmpDir, "pi"), filepath.Join(tmpDir, "pi.log"), filepath.Join(tmpDir, "pi-list.txt"), filepath.Join(tmpDir, "pi-list.txt"), filepath.Join(tmpDir, "pi"))
|
||||
writeScript(t, filepath.Join(tmpDir, "npm"), npmScript)
|
||||
|
||||
withConfirm(t, func(prompt string) (bool, error) {
|
||||
if strings.Contains(prompt, "Pi is not installed.") {
|
||||
return true, nil
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
|
||||
p := &Pi{}
|
||||
if err := p.Run("ignored", []string{"--version"}); err != nil {
|
||||
t.Fatalf("Run() error = %v", err)
|
||||
}
|
||||
|
||||
npmCalls, err := os.ReadFile(filepath.Join(tmpDir, "npm.log"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.Contains(string(npmCalls), "install -g "+piNpmPackage+"@latest") {
|
||||
t.Fatalf("expected npm install call, got:\n%s", npmCalls)
|
||||
}
|
||||
|
||||
piCalls, err := os.ReadFile(filepath.Join(tmpDir, "pi.log"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got := string(piCalls)
|
||||
if !strings.Contains(got, "list\n") {
|
||||
t.Fatalf("expected pi list call, got:\n%s", got)
|
||||
}
|
||||
if !strings.Contains(got, "update "+piWebSearchSource+"\n") {
|
||||
t.Fatalf("expected pi update call, got:\n%s", got)
|
||||
}
|
||||
if !strings.Contains(got, "--version\n") {
|
||||
t.Fatalf("expected final pi launch call, got:\n%s", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("pi missing + user declines install", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
setCloudStatus(t, false)
|
||||
writeScript(t, filepath.Join(tmpDir, "npm"), "#!/bin/sh\nexit 0\n")
|
||||
|
||||
withConfirm(t, func(prompt string) (bool, error) {
|
||||
if strings.Contains(prompt, "Pi is not installed.") {
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
|
||||
p := &Pi{}
|
||||
err := p.Run("ignored", nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "pi installation cancelled") {
|
||||
t.Fatalf("expected install cancellation error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("pi installed + web search missing auto-installs", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
setCloudStatus(t, false)
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "pi-list.txt"), []byte("User packages:\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
seedPiScript(t, tmpDir)
|
||||
seedNpmNoop(t, tmpDir)
|
||||
withConfirm(t, func(prompt string) (bool, error) {
|
||||
t.Fatalf("did not expect confirmation prompt, got %q", prompt)
|
||||
return false, nil
|
||||
})
|
||||
|
||||
p := &Pi{}
|
||||
if err := p.Run("ignored", []string{"session"}); err != nil {
|
||||
t.Fatalf("Run() error = %v", err)
|
||||
}
|
||||
|
||||
piCalls, err := os.ReadFile(filepath.Join(tmpDir, "pi.log"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got := string(piCalls)
|
||||
if !strings.Contains(got, "list\n") {
|
||||
t.Fatalf("expected pi list call, got:\n%s", got)
|
||||
}
|
||||
if !strings.Contains(got, "install "+piWebSearchSource+"\n") {
|
||||
t.Fatalf("expected pi install call, got:\n%s", got)
|
||||
}
|
||||
if strings.Contains(got, "update "+piWebSearchSource+"\n") {
|
||||
t.Fatalf("did not expect pi update call when package missing, got:\n%s", got)
|
||||
}
|
||||
if !strings.Contains(got, "session\n") {
|
||||
t.Fatalf("expected final pi launch call, got:\n%s", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("pi installed + web search present updates every launch", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
setCloudStatus(t, false)
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "pi-list.txt"), []byte("User packages:\n "+piWebSearchSource+"\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
seedPiScript(t, tmpDir)
|
||||
seedNpmNoop(t, tmpDir)
|
||||
|
||||
p := &Pi{}
|
||||
if err := p.Run("ignored", []string{"doctor"}); err != nil {
|
||||
t.Fatalf("Run() error = %v", err)
|
||||
}
|
||||
|
||||
piCalls, err := os.ReadFile(filepath.Join(tmpDir, "pi.log"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got := string(piCalls)
|
||||
if !strings.Contains(got, "update "+piWebSearchSource+"\n") {
|
||||
t.Fatalf("expected pi update call, got:\n%s", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("web search update failure warns and continues", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
setCloudStatus(t, false)
|
||||
t.Setenv("PI_FAIL_UPDATE", "1")
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "pi-list.txt"), []byte("User packages:\n "+piWebSearchSource+"\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
seedPiScript(t, tmpDir)
|
||||
seedNpmNoop(t, tmpDir)
|
||||
|
||||
p := &Pi{}
|
||||
stderr := captureStderr(t, func() {
|
||||
if err := p.Run("ignored", []string{"session"}); err != nil {
|
||||
t.Fatalf("Run() should continue after web search update failure, got %v", err)
|
||||
}
|
||||
})
|
||||
if !strings.Contains(stderr, "Warning: could not update "+piWebSearchPkg) {
|
||||
t.Fatalf("expected update warning, got:\n%s", stderr)
|
||||
}
|
||||
|
||||
piCalls, err := os.ReadFile(filepath.Join(tmpDir, "pi.log"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.Contains(string(piCalls), "session\n") {
|
||||
t.Fatalf("expected final pi launch call, got:\n%s", piCalls)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("web search install failure warns and continues", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
setCloudStatus(t, false)
|
||||
t.Setenv("PI_FAIL_INSTALL", "1")
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "pi-list.txt"), []byte("User packages:\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
seedPiScript(t, tmpDir)
|
||||
seedNpmNoop(t, tmpDir)
|
||||
withConfirm(t, func(prompt string) (bool, error) {
|
||||
t.Fatalf("did not expect confirmation prompt, got %q", prompt)
|
||||
return false, nil
|
||||
})
|
||||
|
||||
p := &Pi{}
|
||||
stderr := captureStderr(t, func() {
|
||||
if err := p.Run("ignored", []string{"session"}); err != nil {
|
||||
t.Fatalf("Run() should continue after web search install failure, got %v", err)
|
||||
}
|
||||
})
|
||||
if !strings.Contains(stderr, "Warning: could not install "+piWebSearchPkg) {
|
||||
t.Fatalf("expected install warning, got:\n%s", stderr)
|
||||
}
|
||||
|
||||
piCalls, err := os.ReadFile(filepath.Join(tmpDir, "pi.log"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.Contains(string(piCalls), "session\n") {
|
||||
t.Fatalf("expected final pi launch call, got:\n%s", piCalls)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud disabled skips web search package management", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
setCloudStatus(t, true)
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "pi-list.txt"), []byte("User packages:\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
seedPiScript(t, tmpDir)
|
||||
seedNpmNoop(t, tmpDir)
|
||||
|
||||
p := &Pi{}
|
||||
stderr := captureStderr(t, func() {
|
||||
if err := p.Run("ignored", []string{"session"}); err != nil {
|
||||
t.Fatalf("Run() error = %v", err)
|
||||
}
|
||||
})
|
||||
if !strings.Contains(stderr, "Cloud is disabled; skipping "+piWebSearchPkg+" setup.") {
|
||||
t.Fatalf("expected cloud-disabled skip message, got:\n%s", stderr)
|
||||
}
|
||||
|
||||
piCalls, err := os.ReadFile(filepath.Join(tmpDir, "pi.log"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got := string(piCalls)
|
||||
if strings.Contains(got, "list\n") || strings.Contains(got, "install "+piWebSearchSource+"\n") || strings.Contains(got, "update "+piWebSearchSource+"\n") {
|
||||
t.Fatalf("did not expect web search package management calls, got:\n%s", got)
|
||||
}
|
||||
if !strings.Contains(got, "session\n") {
|
||||
t.Fatalf("expected final pi launch call, got:\n%s", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing npm returns error before pi flow", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
setCloudStatus(t, false)
|
||||
seedPiScript(t, tmpDir)
|
||||
|
||||
p := &Pi{}
|
||||
err := p.Run("ignored", []string{"session"})
|
||||
if err == nil || !strings.Contains(err.Error(), "npm (Node.js) is required to launch pi") {
|
||||
t.Fatalf("expected missing npm error, got %v", err)
|
||||
}
|
||||
|
||||
if _, statErr := os.Stat(filepath.Join(tmpDir, "pi.log")); !os.IsNotExist(statErr) {
|
||||
t.Fatalf("expected pi not to run when npm is missing, stat err = %v", statErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPiPaths(t *testing.T) {
|
||||
pi := &Pi{}
|
||||
|
||||
@@ -192,6 +527,48 @@ func TestPiEdit(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rebuilds stale existing managed cloud model", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
existingConfig := `{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"baseUrl": "http://localhost:11434/v1",
|
||||
"api": "openai-completions",
|
||||
"apiKey": "ollama",
|
||||
"models": [
|
||||
{"id": "glm-5:cloud", "_launch": true, "legacyField": "stale"}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := pi.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||
t.Fatalf("Edit() error = %v", err)
|
||||
}
|
||||
|
||||
cfg := readConfig()
|
||||
providers := cfg["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelsArray := ollama["models"].([]any)
|
||||
modelEntry := modelsArray[0].(map[string]any)
|
||||
|
||||
if modelEntry["contextWindow"] != float64(202_752) {
|
||||
t.Errorf("contextWindow = %v, want 202752", modelEntry["contextWindow"])
|
||||
}
|
||||
input, ok := modelEntry["input"].([]any)
|
||||
if !ok || len(input) != 1 || input[0] != "text" {
|
||||
t.Errorf("input = %v, want [text]", modelEntry["input"])
|
||||
}
|
||||
if _, ok := modelEntry["legacyField"]; ok {
|
||||
t.Error("legacyField should be removed when stale managed cloud entry is rebuilt")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("replaces old models with new ones", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
@@ -798,6 +1175,59 @@ func TestCreateConfig(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud model falls back to hardcoded context when show fails", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "kimi-k2.5:cloud")
|
||||
|
||||
if cfg["contextWindow"] != 262_144 {
|
||||
t.Errorf("contextWindow = %v, want 262144", cfg["contextWindow"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud model falls back to hardcoded context when show omits model info", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "glm-5:cloud")
|
||||
|
||||
if cfg["contextWindow"] != 202_752 {
|
||||
t.Errorf("contextWindow = %v, want 202752", cfg["contextWindow"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud model with dash suffix falls back to hardcoded context", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg := createConfig(context.Background(), client, "gpt-oss:120b-cloud")
|
||||
|
||||
if cfg["contextWindow"] != 131_072 {
|
||||
t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
|
||||
}
|
||||
})
|
||||
t.Run("skips zero context length", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
373
cmd/launch/registry.go
Normal file
@@ -0,0 +1,373 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// IntegrationInstallSpec describes how launcher should detect and guide installation.
|
||||
type IntegrationInstallSpec struct {
|
||||
CheckInstalled func() bool
|
||||
EnsureInstalled func() error
|
||||
URL string
|
||||
Command []string
|
||||
}
|
||||
|
||||
// IntegrationSpec is the canonical registry entry for one integration.
|
||||
type IntegrationSpec struct {
|
||||
Name string
|
||||
Runner Runner
|
||||
Aliases []string
|
||||
Hidden bool
|
||||
Description string
|
||||
Install IntegrationInstallSpec
|
||||
}
|
||||
|
||||
// IntegrationInfo contains display information about a registered integration.
|
||||
type IntegrationInfo struct {
|
||||
Name string
|
||||
DisplayName string
|
||||
Description string
|
||||
}
|
||||
|
||||
var launcherIntegrationOrder = []string{"opencode", "droid", "pi"}
|
||||
|
||||
var integrationSpecs = []*IntegrationSpec{
|
||||
{
|
||||
Name: "claude",
|
||||
Runner: &Claude{},
|
||||
Description: "Anthropic's coding tool with subagents",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := (&Claude{}).findPath()
|
||||
return err == nil
|
||||
},
|
||||
URL: "https://code.claude.com/docs/en/quickstart",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "cline",
|
||||
Runner: &Cline{},
|
||||
Description: "Autonomous coding agent with parallel execution",
|
||||
Hidden: true,
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := exec.LookPath("cline")
|
||||
return err == nil
|
||||
},
|
||||
Command: []string{"npm", "install", "-g", "cline"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "codex",
|
||||
Runner: &Codex{},
|
||||
Description: "OpenAI's open-source coding agent",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := exec.LookPath("codex")
|
||||
return err == nil
|
||||
},
|
||||
URL: "https://developers.openai.com/codex/cli/",
|
||||
Command: []string{"npm", "install", "-g", "@openai/codex"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "droid",
|
||||
Runner: &Droid{},
|
||||
Description: "Factory's coding agent across terminal and IDEs",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := exec.LookPath("droid")
|
||||
return err == nil
|
||||
},
|
||||
URL: "https://docs.factory.ai/cli/getting-started/quickstart",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "opencode",
|
||||
Runner: &OpenCode{},
|
||||
Description: "Anomaly's open-source coding agent",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := exec.LookPath("opencode")
|
||||
return err == nil
|
||||
},
|
||||
URL: "https://opencode.ai",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "openclaw",
|
||||
Runner: &Openclaw{},
|
||||
Aliases: []string{"clawdbot", "moltbot"},
|
||||
Description: "Personal AI with 100+ skills",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
if _, err := exec.LookPath("openclaw"); err == nil {
|
||||
return true
|
||||
}
|
||||
if _, err := exec.LookPath("clawdbot"); err == nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
EnsureInstalled: func() error {
|
||||
_, err := ensureOpenclawInstalled()
|
||||
return err
|
||||
},
|
||||
URL: "https://docs.openclaw.ai",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "pi",
|
||||
Runner: &Pi{},
|
||||
Description: "Minimal AI agent toolkit with plugin support",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := exec.LookPath("pi")
|
||||
return err == nil
|
||||
},
|
||||
EnsureInstalled: func() error {
|
||||
_, err := ensurePiInstalled()
|
||||
return err
|
||||
},
|
||||
Command: []string{"npm", "install", "-g", "@mariozechner/pi-coding-agent@latest"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "vscode",
|
||||
Runner: &VSCode{},
|
||||
Aliases: []string{"code"},
|
||||
Description: "Microsoft's open-source AI code editor",
|
||||
Hidden: true,
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
return (&VSCode{}).findBinary() != ""
|
||||
},
|
||||
URL: "https://code.visualstudio.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var integrationSpecsByName map[string]*IntegrationSpec
|
||||
|
||||
func init() {
|
||||
rebuildIntegrationSpecIndexes()
|
||||
}
|
||||
|
||||
func hyperlink(url, text string) string {
|
||||
return fmt.Sprintf("\033]8;;%s\033\\%s\033]8;;\033\\", url, text)
|
||||
}
|
||||
|
||||
func rebuildIntegrationSpecIndexes() {
|
||||
integrationSpecsByName = make(map[string]*IntegrationSpec, len(integrationSpecs))
|
||||
|
||||
canonical := make(map[string]bool, len(integrationSpecs))
|
||||
for _, spec := range integrationSpecs {
|
||||
key := strings.ToLower(spec.Name)
|
||||
if key == "" {
|
||||
panic("launch: integration spec missing name")
|
||||
}
|
||||
if canonical[key] {
|
||||
panic(fmt.Sprintf("launch: duplicate integration name %q", key))
|
||||
}
|
||||
canonical[key] = true
|
||||
integrationSpecsByName[key] = spec
|
||||
}
|
||||
|
||||
seenAliases := make(map[string]string)
|
||||
for _, spec := range integrationSpecs {
|
||||
for _, alias := range spec.Aliases {
|
||||
key := strings.ToLower(alias)
|
||||
if key == "" {
|
||||
panic(fmt.Sprintf("launch: integration %q has empty alias", spec.Name))
|
||||
}
|
||||
if canonical[key] {
|
||||
panic(fmt.Sprintf("launch: alias %q collides with canonical integration name", key))
|
||||
}
|
||||
if owner, exists := seenAliases[key]; exists {
|
||||
panic(fmt.Sprintf("launch: alias %q collides between %q and %q", key, owner, spec.Name))
|
||||
}
|
||||
seenAliases[key] = spec.Name
|
||||
integrationSpecsByName[key] = spec
|
||||
}
|
||||
}
|
||||
|
||||
orderSeen := make(map[string]bool, len(launcherIntegrationOrder))
|
||||
for _, name := range launcherIntegrationOrder {
|
||||
key := strings.ToLower(name)
|
||||
if orderSeen[key] {
|
||||
panic(fmt.Sprintf("launch: duplicate launcher order entry %q", key))
|
||||
}
|
||||
orderSeen[key] = true
|
||||
|
||||
spec, ok := integrationSpecsByName[key]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("launch: unknown launcher order entry %q", key))
|
||||
}
|
||||
if spec.Name != key {
|
||||
panic(fmt.Sprintf("launch: launcher order entry %q must use canonical name, not alias", key))
|
||||
}
|
||||
if spec.Hidden {
|
||||
panic(fmt.Sprintf("launch: hidden integration %q cannot appear in launcher order", key))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LookupIntegrationSpec resolves either a canonical integration name or alias to its spec.
|
||||
func LookupIntegrationSpec(name string) (*IntegrationSpec, error) {
|
||||
spec, ok := integrationSpecsByName[strings.ToLower(name)]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
return spec, nil
|
||||
}
|
||||
|
||||
// LookupIntegration resolves a registry name to the canonical key and runner.
|
||||
func LookupIntegration(name string) (string, Runner, error) {
|
||||
spec, err := LookupIntegrationSpec(name)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return spec.Name, spec.Runner, nil
|
||||
}
|
||||
|
||||
// ListVisibleIntegrationSpecs returns the canonical integrations that should appear in interactive UIs.
|
||||
func ListVisibleIntegrationSpecs() []IntegrationSpec {
|
||||
visible := make([]IntegrationSpec, 0, len(integrationSpecs))
|
||||
for _, spec := range integrationSpecs {
|
||||
if spec.Hidden {
|
||||
continue
|
||||
}
|
||||
visible = append(visible, *spec)
|
||||
}
|
||||
|
||||
orderRank := make(map[string]int, len(launcherIntegrationOrder))
|
||||
for i, name := range launcherIntegrationOrder {
|
||||
orderRank[name] = i + 1
|
||||
}
|
||||
|
||||
slices.SortFunc(visible, func(a, b IntegrationSpec) int {
|
||||
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
|
||||
if aRank > 0 && bRank > 0 {
|
||||
return aRank - bRank
|
||||
}
|
||||
if aRank > 0 {
|
||||
return 1
|
||||
}
|
||||
if bRank > 0 {
|
||||
return -1
|
||||
}
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
return visible
|
||||
}
|
||||
|
||||
// ListIntegrationInfos returns the registered integrations in launcher display order.
|
||||
func ListIntegrationInfos() []IntegrationInfo {
|
||||
visible := ListVisibleIntegrationSpecs()
|
||||
infos := make([]IntegrationInfo, 0, len(visible))
|
||||
for _, spec := range visible {
|
||||
infos = append(infos, IntegrationInfo{
|
||||
Name: spec.Name,
|
||||
DisplayName: spec.Runner.String(),
|
||||
Description: spec.Description,
|
||||
})
|
||||
}
|
||||
return infos
|
||||
}
|
||||
|
||||
// IntegrationSelectionItems returns the sorted integration items shown by launcher selection UIs.
|
||||
func IntegrationSelectionItems() ([]ModelItem, error) {
|
||||
visible := ListVisibleIntegrationSpecs()
|
||||
if len(visible) == 0 {
|
||||
return nil, fmt.Errorf("no integrations available")
|
||||
}
|
||||
|
||||
items := make([]ModelItem, 0, len(visible))
|
||||
for _, spec := range visible {
|
||||
description := spec.Runner.String()
|
||||
if conn, err := loadStoredIntegrationConfig(spec.Name); err == nil && len(conn.Models) > 0 {
|
||||
description = fmt.Sprintf("%s (%s)", spec.Runner.String(), conn.Models[0])
|
||||
}
|
||||
items = append(items, ModelItem{Name: spec.Name, Description: description})
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// IsIntegrationInstalled checks if an integration binary is installed.
|
||||
func IsIntegrationInstalled(name string) bool {
|
||||
integration, err := integrationFor(name)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Ollama couldn't find integration %q, so it'll show up as not installed.\n", name)
|
||||
return false
|
||||
}
|
||||
return integration.installed
|
||||
}
|
||||
|
||||
// integration is resolved registry metadata used by launcher state and install checks.
|
||||
// It combines immutable registry spec data with computed runtime traits.
|
||||
type integration struct {
|
||||
spec *IntegrationSpec
|
||||
installed bool
|
||||
autoInstallable bool
|
||||
editor bool
|
||||
installHint string
|
||||
}
|
||||
|
||||
// integrationFor resolves an integration name into the canonical spec plus
|
||||
// derived launcher/install traits used across registry and launch flows.
|
||||
func integrationFor(name string) (integration, error) {
|
||||
spec, err := LookupIntegrationSpec(name)
|
||||
if err != nil {
|
||||
return integration{}, err
|
||||
}
|
||||
|
||||
installed := true
|
||||
if spec.Install.CheckInstalled != nil {
|
||||
installed = spec.Install.CheckInstalled()
|
||||
}
|
||||
|
||||
_, editor := spec.Runner.(Editor)
|
||||
hint := ""
|
||||
if spec.Install.URL != "" {
|
||||
hint = "Install from " + hyperlink(spec.Install.URL, spec.Install.URL)
|
||||
} else if len(spec.Install.Command) > 0 {
|
||||
hint = "Install with: " + strings.Join(spec.Install.Command, " ")
|
||||
}
|
||||
|
||||
return integration{
|
||||
spec: spec,
|
||||
installed: installed,
|
||||
autoInstallable: spec.Install.EnsureInstalled != nil,
|
||||
editor: editor,
|
||||
installHint: hint,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EnsureIntegrationInstalled installs auto-installable integrations when missing.
|
||||
func EnsureIntegrationInstalled(name string, runner Runner) error {
|
||||
integration, err := integrationFor(name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s is not installed", runner)
|
||||
}
|
||||
|
||||
if integration.installed {
|
||||
return nil
|
||||
}
|
||||
if integration.autoInstallable {
|
||||
return integration.spec.Install.EnsureInstalled()
|
||||
}
|
||||
|
||||
switch {
|
||||
case integration.spec.Install.URL != "":
|
||||
return fmt.Errorf("%s is not installed, install from %s", integration.spec.Name, integration.spec.Install.URL)
|
||||
case len(integration.spec.Install.Command) > 0:
|
||||
return fmt.Errorf("%s is not installed, install with: %s", integration.spec.Name, strings.Join(integration.spec.Install.Command, " "))
|
||||
default:
|
||||
return fmt.Errorf("%s is not installed", runner)
|
||||
}
|
||||
}
|
||||
21
cmd/launch/registry_test_helpers_test.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package launch
|
||||
|
||||
import "strings"
|
||||
|
||||
// OverrideIntegration replaces one registry entry's runner for tests and returns a restore function.
|
||||
func OverrideIntegration(name string, runner Runner) func() {
|
||||
spec, err := LookupIntegrationSpec(name)
|
||||
if err != nil {
|
||||
key := strings.ToLower(name)
|
||||
integrationSpecsByName[key] = &IntegrationSpec{Name: key, Runner: runner}
|
||||
return func() {
|
||||
delete(integrationSpecsByName, key)
|
||||
}
|
||||
}
|
||||
|
||||
original := spec.Runner
|
||||
spec.Runner = runner
|
||||
return func() {
|
||||
spec.Runner = original
|
||||
}
|
||||
}
|
||||
71
cmd/launch/runner_exec_only_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
binary string
|
||||
runner Runner
|
||||
checkPath func(home string) string
|
||||
}{
|
||||
{
|
||||
name: "droid",
|
||||
binary: "droid",
|
||||
runner: &Droid{},
|
||||
checkPath: func(home string) string {
|
||||
return filepath.Join(home, ".factory", "settings.json")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "opencode",
|
||||
binary: "opencode",
|
||||
runner: &OpenCode{},
|
||||
checkPath: func(home string) string {
|
||||
return filepath.Join(home, ".config", "opencode", "opencode.json")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "cline",
|
||||
binary: "cline",
|
||||
runner: &Cline{},
|
||||
checkPath: func(home string) string {
|
||||
return filepath.Join(home, ".cline", "data", "globalState.json")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "pi",
|
||||
binary: "pi",
|
||||
runner: &Pi{},
|
||||
checkPath: func(home string) string {
|
||||
return filepath.Join(home, ".pi", "agent", "models.json")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
home := t.TempDir()
|
||||
setTestHome(t, home)
|
||||
|
||||
binDir := t.TempDir()
|
||||
writeFakeBinary(t, binDir, tt.binary)
|
||||
if tt.name == "pi" {
|
||||
writeFakeBinary(t, binDir, "npm")
|
||||
}
|
||||
t.Setenv("PATH", binDir)
|
||||
|
||||
configPath := tt.checkPath(home)
|
||||
if err := tt.runner.Run("llama3.2", nil); err != nil {
|
||||
t.Fatalf("Run returned error: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(configPath); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected Run to leave %s untouched, got err=%v", configPath, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
103
cmd/launch/selector_hooks.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// ANSI escape sequences for terminal formatting.
|
||||
const (
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiGreen = "\033[32m"
|
||||
ansiYellow = "\033[33m"
|
||||
)
|
||||
|
||||
// ErrCancelled is returned when the user cancels a selection.
|
||||
var ErrCancelled = errors.New("cancelled")
|
||||
|
||||
// errCancelled is kept as an internal alias for existing call sites.
|
||||
var errCancelled = ErrCancelled
|
||||
|
||||
// DefaultConfirmPrompt provides a TUI-based confirmation prompt.
|
||||
// When set, ConfirmPrompt delegates to it instead of using raw terminal I/O.
|
||||
var DefaultConfirmPrompt func(prompt string) (bool, error)
|
||||
|
||||
// SingleSelector is a function type for single item selection.
|
||||
// current is the name of the previously selected item to highlight; empty means no pre-selection.
|
||||
type SingleSelector func(title string, items []ModelItem, current string) (string, error)
|
||||
|
||||
// MultiSelector is a function type for multi item selection.
|
||||
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
|
||||
|
||||
// DefaultSingleSelector is the default single-select implementation.
|
||||
var DefaultSingleSelector SingleSelector
|
||||
|
||||
// DefaultMultiSelector is the default multi-select implementation.
|
||||
var DefaultMultiSelector MultiSelector
|
||||
|
||||
// DefaultSignIn provides a TUI-based sign-in flow.
|
||||
// When set, ensureAuth uses it instead of plain text prompts.
|
||||
// Returns the signed-in username or an error.
|
||||
var DefaultSignIn func(modelName, signInURL string) (string, error)
|
||||
|
||||
type launchConfirmPolicy struct {
|
||||
yes bool
|
||||
requireYesMessage bool
|
||||
}
|
||||
|
||||
var currentLaunchConfirmPolicy launchConfirmPolicy
|
||||
|
||||
func withLaunchConfirmPolicy(policy launchConfirmPolicy) func() {
|
||||
old := currentLaunchConfirmPolicy
|
||||
currentLaunchConfirmPolicy = policy
|
||||
return func() {
|
||||
currentLaunchConfirmPolicy = old
|
||||
}
|
||||
}
|
||||
|
||||
// ConfirmPrompt is the shared confirmation gate for launch flows (integration
|
||||
// edits, missing-model pulls, sign-in prompts, OpenClaw install/security, etc).
|
||||
// Behavior is controlled by currentLaunchConfirmPolicy, typically scoped by
|
||||
// withLaunchConfirmPolicy in LaunchCmd (e.g. auto-approve with --yes).
|
||||
func ConfirmPrompt(prompt string) (bool, error) {
|
||||
if currentLaunchConfirmPolicy.yes {
|
||||
return true, nil
|
||||
}
|
||||
if currentLaunchConfirmPolicy.requireYesMessage {
|
||||
return false, fmt.Errorf("%s requires confirmation; re-run with --yes to continue", prompt)
|
||||
}
|
||||
|
||||
if DefaultConfirmPrompt != nil {
|
||||
return DefaultConfirmPrompt(prompt)
|
||||
}
|
||||
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer term.Restore(fd, oldState)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
|
||||
|
||||
buf := make([]byte, 1)
|
||||
for {
|
||||
if _, err := os.Stdin.Read(buf); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
switch buf[0] {
|
||||
case 'Y', 'y', 13:
|
||||
fmt.Fprintf(os.Stderr, "yes\r\n")
|
||||
return true, nil
|
||||
case 'N', 'n', 27, 3:
|
||||
fmt.Fprintf(os.Stderr, "no\r\n")
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
76
cmd/launch/selector_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestErrCancelled(t *testing.T) {
|
||||
t.Run("NotNil", func(t *testing.T) {
|
||||
if errCancelled == nil {
|
||||
t.Error("errCancelled should not be nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Message", func(t *testing.T) {
|
||||
if errCancelled.Error() != "cancelled" {
|
||||
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWithLaunchConfirmPolicy_ScopesAndRestores(t *testing.T) {
|
||||
oldPolicy := currentLaunchConfirmPolicy
|
||||
oldHook := DefaultConfirmPrompt
|
||||
t.Cleanup(func() {
|
||||
currentLaunchConfirmPolicy = oldPolicy
|
||||
DefaultConfirmPrompt = oldHook
|
||||
})
|
||||
|
||||
currentLaunchConfirmPolicy = launchConfirmPolicy{}
|
||||
var hookCalls int
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
hookCalls++
|
||||
return true, nil
|
||||
}
|
||||
|
||||
restoreOuter := withLaunchConfirmPolicy(launchConfirmPolicy{requireYesMessage: true})
|
||||
restoreInner := withLaunchConfirmPolicy(launchConfirmPolicy{yes: true})
|
||||
|
||||
ok, err := ConfirmPrompt("test prompt")
|
||||
if err != nil {
|
||||
t.Fatalf("expected --yes policy to allow prompt, got error: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Fatal("expected --yes policy to auto-accept prompt")
|
||||
}
|
||||
if hookCalls != 0 {
|
||||
t.Fatalf("expected --yes to skip hook, got %d hook calls", hookCalls)
|
||||
}
|
||||
|
||||
restoreInner()
|
||||
|
||||
_, err = ConfirmPrompt("test prompt")
|
||||
if err == nil {
|
||||
t.Fatal("expected requireYesMessage policy to block prompt")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "re-run with --yes") {
|
||||
t.Fatalf("expected actionable --yes error, got: %v", err)
|
||||
}
|
||||
if hookCalls != 0 {
|
||||
t.Fatalf("expected blocking policy to skip hook, got %d hook calls", hookCalls)
|
||||
}
|
||||
|
||||
restoreOuter()
|
||||
|
||||
ok, err = ConfirmPrompt("test prompt")
|
||||
if err != nil {
|
||||
t.Fatalf("expected restored default behavior to use hook, got error: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Fatal("expected hook to return true")
|
||||
}
|
||||
if hookCalls != 1 {
|
||||
t.Fatalf("expected one hook call after restore, got %d", hookCalls)
|
||||
}
|
||||
}
|
||||
82
cmd/launch/test_config_helpers_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
)
|
||||
|
||||
var (
|
||||
integrations map[string]Runner
|
||||
integrationAliases map[string]bool
|
||||
integrationOrder = launcherIntegrationOrder
|
||||
)
|
||||
|
||||
func init() {
|
||||
integrations = buildTestIntegrations()
|
||||
integrationAliases = buildTestIntegrationAliases()
|
||||
}
|
||||
|
||||
func buildTestIntegrations() map[string]Runner {
|
||||
result := make(map[string]Runner, len(integrationSpecsByName))
|
||||
for name, spec := range integrationSpecsByName {
|
||||
result[strings.ToLower(name)] = spec.Runner
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func buildTestIntegrationAliases() map[string]bool {
|
||||
result := make(map[string]bool)
|
||||
for _, spec := range integrationSpecs {
|
||||
for _, alias := range spec.Aliases {
|
||||
result[strings.ToLower(alias)] = true
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func setTestHome(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
setLaunchTestHome(t, dir)
|
||||
}
|
||||
|
||||
func SaveIntegration(appName string, models []string) error {
|
||||
return config.SaveIntegration(appName, models)
|
||||
}
|
||||
|
||||
func LoadIntegration(appName string) (*config.IntegrationConfig, error) {
|
||||
return config.LoadIntegration(appName)
|
||||
}
|
||||
|
||||
func SaveAliases(appName string, aliases map[string]string) error {
|
||||
return config.SaveAliases(appName, aliases)
|
||||
}
|
||||
|
||||
func LastModel() string {
|
||||
return config.LastModel()
|
||||
}
|
||||
|
||||
func SetLastModel(model string) error {
|
||||
return config.SetLastModel(model)
|
||||
}
|
||||
|
||||
func LastSelection() string {
|
||||
return config.LastSelection()
|
||||
}
|
||||
|
||||
func SetLastSelection(selection string) error {
|
||||
return config.SetLastSelection(selection)
|
||||
}
|
||||
|
||||
func IntegrationModel(appName string) string {
|
||||
return config.IntegrationModel(appName)
|
||||
}
|
||||
|
||||
func IntegrationModels(appName string) []string {
|
||||
return config.IntegrationModels(appName)
|
||||
}
|
||||
|
||||
func integrationOnboarded(appName string) error {
|
||||
return config.MarkIntegrationOnboarded(appName)
|
||||
}
|
||||
591
cmd/launch/vscode.go
Normal file
@@ -0,0 +1,591 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// VSCode implements Runner and Editor for Visual Studio Code integration.
|
||||
type VSCode struct{}
|
||||
|
||||
func (v *VSCode) String() string { return "Visual Studio Code" }
|
||||
|
||||
// findBinary returns the path/command to launch VS Code, or "" if not found.
|
||||
// It checks platform-specific locations only.
|
||||
func (v *VSCode) findBinary() string {
|
||||
var candidates []string
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
candidates = []string{
|
||||
"/Applications/Visual Studio Code.app",
|
||||
}
|
||||
case "windows":
|
||||
if localAppData := os.Getenv("LOCALAPPDATA"); localAppData != "" {
|
||||
candidates = append(candidates, filepath.Join(localAppData, "Programs", "Microsoft VS Code", "bin", "code.cmd"))
|
||||
}
|
||||
default: // linux
|
||||
candidates = []string{
|
||||
"/usr/bin/code",
|
||||
"/snap/bin/code",
|
||||
}
|
||||
}
|
||||
for _, c := range candidates {
|
||||
if _, err := os.Stat(c); err == nil {
|
||||
return c
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsRunning reports whether VS Code is currently running.
|
||||
// Each platform uses a pattern specific enough to avoid matching Cursor or
|
||||
// other VS Code forks.
|
||||
func (v *VSCode) IsRunning() bool {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
out, err := exec.Command("pgrep", "-f", "Visual Studio Code.app/Contents/MacOS/Code").Output()
|
||||
return err == nil && len(out) > 0
|
||||
case "windows":
|
||||
// Match VS Code by executable path to avoid matching Cursor or other forks.
|
||||
out, err := exec.Command("powershell", "-NoProfile", "-Command",
|
||||
`Get-Process Code -ErrorAction SilentlyContinue | Where-Object { $_.Path -like '*Microsoft VS Code*' } | Select-Object -First 1`).Output()
|
||||
return err == nil && len(strings.TrimSpace(string(out))) > 0
|
||||
default:
|
||||
// Match VS Code specifically by its install path to avoid matching
|
||||
// Cursor (/cursor/) or other forks.
|
||||
for _, pattern := range []string{"/usr/share/code/", "/snap/code/"} {
|
||||
out, err := exec.Command("pgrep", "-f", pattern).Output()
|
||||
if err == nil && len(out) > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Quit gracefully quits VS Code and waits for it to exit so that it flushes
|
||||
// its in-memory state back to the database.
|
||||
func (v *VSCode) Quit() {
|
||||
if !v.IsRunning() {
|
||||
return
|
||||
}
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("osascript", "-e", `quit app "Visual Studio Code"`).Run()
|
||||
case "windows":
|
||||
// Kill VS Code by executable path to avoid killing Cursor or other forks.
|
||||
_ = exec.Command("powershell", "-NoProfile", "-Command",
|
||||
`Get-Process Code -ErrorAction SilentlyContinue | Where-Object { $_.Path -like '*Microsoft VS Code*' } | Stop-Process -Force`).Run()
|
||||
default:
|
||||
for _, pattern := range []string{"/usr/share/code/", "/snap/code/"} {
|
||||
_ = exec.Command("pkill", "-f", pattern).Run()
|
||||
}
|
||||
}
|
||||
// Wait for the process to fully exit and flush its state to disk
|
||||
// TODO(hoyyeva): update spinner to use bubble tea
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
fmt.Fprintf(os.Stderr, "\033[90mRestarting VS Code... %s\033[0m", spinnerFrames[0])
|
||||
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range 150 { // 150 ticks × 200ms = 30s timeout
|
||||
<-ticker.C
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mRestarting VS Code... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
|
||||
if frame%5 == 0 { // check every ~1s
|
||||
if !v.IsRunning() {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
// Give VS Code a moment to finish writing its state DB
|
||||
time.Sleep(1 * time.Second)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
}
|
||||
|
||||
const (
|
||||
minCopilotChatVersion = "0.41.0"
|
||||
minVSCodeVersion = "1.113"
|
||||
)
|
||||
|
||||
func (v *VSCode) Run(model string, args []string) error {
|
||||
v.checkVSCodeVersion()
|
||||
v.checkCopilotChatVersion()
|
||||
|
||||
// Get all configured models (saved by the launcher framework before Run is called)
|
||||
models := []string{model}
|
||||
if cfg, err := loadStoredIntegrationConfig("vscode"); err == nil && len(cfg.Models) > 0 {
|
||||
models = cfg.Models
|
||||
}
|
||||
|
||||
// VS Code discovers models from ollama ls. Cloud models that pass Show
|
||||
// (the server knows about them) but aren't in ls need to be pulled to
|
||||
// register them so VS Code can find them.
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
v.ensureModelsRegistered(context.Background(), client, models)
|
||||
}
|
||||
|
||||
// Warn if the default model doesn't support tool calling
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
if resp, err := client.Show(context.Background(), &api.ShowRequest{Model: models[0]}); err == nil {
|
||||
hasTools := false
|
||||
for _, c := range resp.Capabilities {
|
||||
if c == "tools" {
|
||||
hasTools = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasTools {
|
||||
fmt.Fprintf(os.Stderr, "Note: %s does not support tool calling and may not appear in the Copilot Chat model picker.\n", models[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
v.printModelAccessTip()
|
||||
|
||||
if v.IsRunning() {
|
||||
restart, err := ConfirmPrompt("Restart VS Code?")
|
||||
if err != nil {
|
||||
restart = false
|
||||
}
|
||||
if restart {
|
||||
v.Quit()
|
||||
if err := v.ShowInModelPicker(models); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: could not update VS Code model picker: %v%s\n", ansiYellow, err, ansiReset)
|
||||
}
|
||||
v.FocusVSCode()
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\nTo get the latest model configuration, restart VS Code when you're ready.\n")
|
||||
}
|
||||
} else {
|
||||
if err := v.ShowInModelPicker(models); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: could not update VS Code model picker: %v%s\n", ansiYellow, err, ansiReset)
|
||||
}
|
||||
v.FocusVSCode()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureModelsRegistered pulls models that the server knows about (Show succeeds)
|
||||
// but aren't in ollama ls yet. This is needed for cloud models so that VS Code
|
||||
// can discover them from the Ollama API.
|
||||
func (v *VSCode) ensureModelsRegistered(ctx context.Context, client *api.Client, models []string) {
|
||||
listed, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
registered := make(map[string]bool, len(listed.Models))
|
||||
for _, m := range listed.Models {
|
||||
registered[m.Name] = true
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
if registered[model] {
|
||||
continue
|
||||
}
|
||||
// Also check without :latest suffix
|
||||
if !strings.Contains(model, ":") && registered[model+":latest"] {
|
||||
continue
|
||||
}
|
||||
if err := pullModel(ctx, client, model, false); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: could not register model %s: %v%s\n", ansiYellow, model, err, ansiReset)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FocusVSCode brings VS Code to the foreground.
|
||||
func (v *VSCode) FocusVSCode() {
|
||||
binary := v.findBinary()
|
||||
if binary == "" {
|
||||
return
|
||||
}
|
||||
if runtime.GOOS == "darwin" && strings.HasSuffix(binary, ".app") {
|
||||
_ = exec.Command("open", "-a", binary).Run()
|
||||
} else {
|
||||
_ = exec.Command(binary).Start()
|
||||
}
|
||||
}
|
||||
|
||||
// printModelAccessTip shows instructions for finding Ollama models in VS Code.
|
||||
func (v *VSCode) printModelAccessTip() {
|
||||
fmt.Fprintf(os.Stderr, "\nTip: To use Ollama models, open Copilot Chat and click the model picker.\n")
|
||||
fmt.Fprintf(os.Stderr, " If you don't see your models, click \"Other models\" to find them.\n\n")
|
||||
}
|
||||
|
||||
func (v *VSCode) Paths() []string {
|
||||
if p := v.chatLanguageModelsPath(); fileExists(p) {
|
||||
return []string{p}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *VSCode) Edit(models []string) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write chatLanguageModels.json with Ollama vendor entry
|
||||
clmPath := v.chatLanguageModelsPath()
|
||||
if err := os.MkdirAll(filepath.Dir(clmPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var entries []map[string]any
|
||||
if data, err := os.ReadFile(clmPath); err == nil {
|
||||
_ = json.Unmarshal(data, &entries)
|
||||
}
|
||||
|
||||
// Remove any existing Ollama entries, preserve others
|
||||
filtered := make([]map[string]any, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
if vendor, _ := entry["vendor"].(string); vendor != "ollama" {
|
||||
filtered = append(filtered, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Add new Ollama entry
|
||||
filtered = append(filtered, map[string]any{
|
||||
"vendor": "ollama",
|
||||
"name": "Ollama",
|
||||
"url": envconfig.Host().String(),
|
||||
})
|
||||
|
||||
data, err := json.MarshalIndent(filtered, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := fileutil.WriteWithBackup(clmPath, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Clean up legacy settings from older Ollama integrations
|
||||
v.updateSettings()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *VSCode) Models() []string {
|
||||
if !v.hasOllamaVendor() {
|
||||
return nil
|
||||
}
|
||||
if cfg, err := loadStoredIntegrationConfig("vscode"); err == nil {
|
||||
return cfg.Models
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasOllamaVendor checks if chatLanguageModels.json contains an Ollama vendor entry.
|
||||
func (v *VSCode) hasOllamaVendor() bool {
|
||||
data, err := os.ReadFile(v.chatLanguageModelsPath())
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var entries []map[string]any
|
||||
if err := json.Unmarshal(data, &entries); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if vendor, _ := entry["vendor"].(string); vendor == "ollama" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (v *VSCode) chatLanguageModelsPath() string {
|
||||
return v.vscodePath("chatLanguageModels.json")
|
||||
}
|
||||
|
||||
func (v *VSCode) settingsPath() string {
|
||||
return v.vscodePath("settings.json")
|
||||
}
|
||||
|
||||
// updateSettings cleans up legacy settings from older Ollama integrations.
|
||||
func (v *VSCode) updateSettings() {
|
||||
settingsPath := v.settingsPath()
|
||||
data, err := os.ReadFile(settingsPath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var settings map[string]any
|
||||
if err := json.Unmarshal(data, &settings); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
changed := false
|
||||
for _, key := range []string{"github.copilot.chat.byok.ollamaEndpoint", "ollama.launch.configured"} {
|
||||
if _, ok := settings[key]; ok {
|
||||
delete(settings, key)
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
|
||||
if !changed {
|
||||
return
|
||||
}
|
||||
|
||||
updated, err := json.MarshalIndent(settings, "", " ")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = fileutil.WriteWithBackup(settingsPath, updated)
|
||||
}
|
||||
|
||||
func (v *VSCode) statePath() string {
|
||||
return v.vscodePath("globalStorage", "state.vscdb")
|
||||
}
|
||||
|
||||
// ShowInModelPicker ensures the given models are visible in VS Code's Copilot
|
||||
// Chat model picker. It sets the configured models to true in the picker
|
||||
// preferences so they appear in the dropdown. Models use the VS Code identifier
|
||||
// format "ollama/Ollama/<name>".
|
||||
func (v *VSCode) ShowInModelPicker(models []string) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
dbPath := v.statePath()
|
||||
needsCreate := !fileExists(dbPath)
|
||||
if needsCreate {
|
||||
if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil {
|
||||
return fmt.Errorf("creating state directory: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite3", dbPath+"?_busy_timeout=5000")
|
||||
if err != nil {
|
||||
return fmt.Errorf("opening state database: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create the table if this is a fresh DB. Schema must match what VS Code creates.
|
||||
if needsCreate {
|
||||
if _, err := db.Exec("CREATE TABLE ItemTable (key TEXT UNIQUE ON CONFLICT REPLACE, value BLOB)"); err != nil {
|
||||
return fmt.Errorf("initializing state database: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Read existing preferences
|
||||
prefs := make(map[string]bool)
|
||||
var prefsJSON string
|
||||
if err := db.QueryRow("SELECT value FROM ItemTable WHERE key = 'chatModelPickerPreferences'").Scan(&prefsJSON); err == nil {
|
||||
_ = json.Unmarshal([]byte(prefsJSON), &prefs)
|
||||
}
|
||||
|
||||
// Build name→ID map from VS Code's cached model list.
|
||||
// VS Code uses numeric IDs like "ollama/Ollama/4", not "ollama/Ollama/kimi-k2.5:cloud".
|
||||
nameToID := make(map[string]string)
|
||||
var cacheJSON string
|
||||
if err := db.QueryRow("SELECT value FROM ItemTable WHERE key = 'chat.cachedLanguageModels.v2'").Scan(&cacheJSON); err == nil {
|
||||
var cached []map[string]any
|
||||
if json.Unmarshal([]byte(cacheJSON), &cached) == nil {
|
||||
for _, entry := range cached {
|
||||
meta, _ := entry["metadata"].(map[string]any)
|
||||
if meta == nil {
|
||||
continue
|
||||
}
|
||||
if vendor, _ := meta["vendor"].(string); vendor == "ollama" {
|
||||
name, _ := meta["name"].(string)
|
||||
id, _ := entry["identifier"].(string)
|
||||
if name != "" && id != "" {
|
||||
nameToID[name] = id
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ollama config is authoritative: always show configured models,
|
||||
// hide Ollama models that are no longer in the config.
|
||||
configuredIDs := make(map[string]bool)
|
||||
for _, m := range models {
|
||||
for _, id := range v.modelVSCodeIDs(m, nameToID) {
|
||||
prefs[id] = true
|
||||
configuredIDs[id] = true
|
||||
}
|
||||
}
|
||||
for id := range prefs {
|
||||
if strings.HasPrefix(id, "ollama/") && !configuredIDs[id] {
|
||||
prefs[id] = false
|
||||
}
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(prefs)
|
||||
if _, err = db.Exec("INSERT OR REPLACE INTO ItemTable (key, value) VALUES ('chatModelPickerPreferences', ?)", string(data)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// modelVSCodeIDs returns all possible VS Code picker IDs for a model name.
|
||||
func (v *VSCode) modelVSCodeIDs(model string, nameToID map[string]string) []string {
|
||||
var ids []string
|
||||
if id, ok := nameToID[model]; ok {
|
||||
ids = append(ids, id)
|
||||
} else if !strings.Contains(model, ":") {
|
||||
if id, ok := nameToID[model+":latest"]; ok {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
ids = append(ids, "ollama/Ollama/"+model)
|
||||
if !strings.Contains(model, ":") {
|
||||
ids = append(ids, "ollama/Ollama/"+model+":latest")
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func (v *VSCode) vscodePath(parts ...string) string {
|
||||
home, _ := os.UserHomeDir()
|
||||
var base string
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
base = filepath.Join(home, "Library", "Application Support", "Code", "User")
|
||||
case "windows":
|
||||
base = filepath.Join(os.Getenv("APPDATA"), "Code", "User")
|
||||
default:
|
||||
base = filepath.Join(home, ".config", "Code", "User")
|
||||
}
|
||||
return filepath.Join(append([]string{base}, parts...)...)
|
||||
}
|
||||
|
||||
// checkVSCodeVersion warns if VS Code is older than minVSCodeVersion.
|
||||
func (v *VSCode) checkVSCodeVersion() {
|
||||
codeCLI := v.findCodeCLI()
|
||||
if codeCLI == "" {
|
||||
return
|
||||
}
|
||||
|
||||
out, err := exec.Command(codeCLI, "--version").Output()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// "code --version" outputs: version\ncommit\narch
|
||||
lines := strings.Split(strings.TrimSpace(string(out)), "\n")
|
||||
if len(lines) == 0 || lines[0] == "" {
|
||||
return
|
||||
}
|
||||
version := strings.TrimSpace(lines[0])
|
||||
|
||||
if compareVersions(version, minVSCodeVersion) < 0 {
|
||||
fmt.Fprintf(os.Stderr, "\n%sWarning: VS Code version (%s) is older than the recommended version (%s)%s\n", ansiYellow, version, minVSCodeVersion, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "Please update VS Code to the latest version.\n\n")
|
||||
}
|
||||
}
|
||||
|
||||
// checkCopilotChatVersion warns if the GitHub Copilot Chat extension is
|
||||
// missing or older than minCopilotChatVersion.
|
||||
func (v *VSCode) checkCopilotChatVersion() {
|
||||
codeCLI := v.findCodeCLI()
|
||||
if codeCLI == "" {
|
||||
return
|
||||
}
|
||||
|
||||
out, err := exec.Command(codeCLI, "--list-extensions", "--show-versions").Output()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
installed, version := parseCopilotChatVersion(string(out))
|
||||
if !installed {
|
||||
fmt.Fprintf(os.Stderr, "\n%sWarning: GitHub Copilot Chat extension is not installed%s\n", ansiYellow, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "Install it in VS Code: Extensions → search \"GitHub Copilot Chat\" → Install\n\n")
|
||||
return
|
||||
}
|
||||
if compareVersions(version, minCopilotChatVersion) < 0 {
|
||||
fmt.Fprintf(os.Stderr, "\n%sWarning: GitHub Copilot Chat extension version (%s) is older than the recommended version (%s)%s\n", ansiYellow, version, minCopilotChatVersion, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "Please update it in VS Code: Extensions → search \"GitHub Copilot Chat\" → Update\n\n")
|
||||
}
|
||||
}
|
||||
|
||||
// findCodeCLI returns the path to the VS Code CLI for querying extensions.
|
||||
// On macOS, findBinary may return an .app bundle which can't run --list-extensions,
|
||||
// so this resolves to the actual CLI binary inside the bundle.
|
||||
func (v *VSCode) findCodeCLI() string {
|
||||
binary := v.findBinary()
|
||||
if binary == "" {
|
||||
return ""
|
||||
}
|
||||
if runtime.GOOS == "darwin" && strings.HasSuffix(binary, ".app") {
|
||||
bundleCLI := binary + "/Contents/Resources/app/bin/code"
|
||||
if _, err := os.Stat(bundleCLI); err == nil {
|
||||
return bundleCLI
|
||||
}
|
||||
return ""
|
||||
}
|
||||
return binary
|
||||
}
|
||||
|
||||
// parseCopilotChatVersion extracts the version of the GitHub Copilot Chat
|
||||
// extension from "code --list-extensions --show-versions" output.
|
||||
func parseCopilotChatVersion(output string) (installed bool, version string) {
|
||||
for _, line := range strings.Split(output, "\n") {
|
||||
// Format: github.copilot-chat@0.40.1
|
||||
if !strings.HasPrefix(strings.ToLower(line), "github.copilot-chat@") {
|
||||
continue
|
||||
}
|
||||
parts := strings.SplitN(line, "@", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
return true, strings.TrimSpace(parts[1])
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// compareVersions compares two dot-separated version strings.
|
||||
// Returns -1 if a < b, 0 if a == b, 1 if a > b.
|
||||
func compareVersions(a, b string) int {
|
||||
aParts := strings.Split(a, ".")
|
||||
bParts := strings.Split(b, ".")
|
||||
|
||||
maxLen := len(aParts)
|
||||
if len(bParts) > maxLen {
|
||||
maxLen = len(bParts)
|
||||
}
|
||||
|
||||
for i := range maxLen {
|
||||
var aNum, bNum int
|
||||
if i < len(aParts) {
|
||||
aNum, _ = strconv.Atoi(aParts[i])
|
||||
}
|
||||
if i < len(bParts) {
|
||||
bNum, _ = strconv.Atoi(bParts[i])
|
||||
}
|
||||
if aNum < bNum {
|
||||
return -1
|
||||
}
|
||||
if aNum > bNum {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func fileExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return err == nil
|
||||
}
|
||||
486
cmd/launch/vscode_test.go
Normal file
@@ -0,0 +1,486 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestVSCodeIntegration(t *testing.T) {
|
||||
v := &VSCode{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := v.String(); got != "Visual Studio Code" {
|
||||
t.Errorf("String() = %q, want %q", got, "Visual Studio Code")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = v
|
||||
})
|
||||
|
||||
t.Run("implements Editor", func(t *testing.T) {
|
||||
var _ Editor = v
|
||||
})
|
||||
}
|
||||
|
||||
func TestVSCodeEdit(t *testing.T) {
|
||||
v := &VSCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("XDG_CONFIG_HOME", "")
|
||||
clmPath := testVSCodePath(t, tmpDir, "chatLanguageModels.json")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setup string // initial chatLanguageModels.json content, empty means no file
|
||||
models []string
|
||||
validate func(t *testing.T, data []byte)
|
||||
}{
|
||||
{
|
||||
name: "fresh install",
|
||||
models: []string{"llama3.2"},
|
||||
validate: func(t *testing.T, data []byte) {
|
||||
assertOllamaVendorConfigured(t, data)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "preserve other vendor entries",
|
||||
setup: `[{"vendor": "azure", "name": "Azure", "url": "https://example.com"}]`,
|
||||
models: []string{"llama3.2"},
|
||||
validate: func(t *testing.T, data []byte) {
|
||||
var entries []map[string]any
|
||||
json.Unmarshal(data, &entries)
|
||||
if len(entries) != 2 {
|
||||
t.Errorf("expected 2 entries, got %d", len(entries))
|
||||
}
|
||||
// Check Azure entry preserved
|
||||
found := false
|
||||
for _, e := range entries {
|
||||
if v, _ := e["vendor"].(string); v == "azure" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("azure vendor entry was not preserved")
|
||||
}
|
||||
assertOllamaVendorConfigured(t, data)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "update existing ollama entry",
|
||||
setup: `[{"vendor": "ollama", "name": "Ollama", "url": "http://old:11434"}]`,
|
||||
models: []string{"llama3.2"},
|
||||
validate: func(t *testing.T, data []byte) {
|
||||
assertOllamaVendorConfigured(t, data)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty models is no-op",
|
||||
setup: `[{"vendor": "azure", "name": "Azure"}]`,
|
||||
models: []string{},
|
||||
validate: func(t *testing.T, data []byte) {
|
||||
if string(data) != `[{"vendor": "azure", "name": "Azure"}]` {
|
||||
t.Error("empty models should not modify file")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "corrupted JSON treated as empty",
|
||||
setup: `{corrupted json`,
|
||||
models: []string{"llama3.2"},
|
||||
validate: func(t *testing.T, data []byte) {
|
||||
var entries []map[string]any
|
||||
if err := json.Unmarshal(data, &entries); err != nil {
|
||||
t.Errorf("result is not valid JSON: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Dir(clmPath))
|
||||
|
||||
if tt.setup != "" {
|
||||
os.MkdirAll(filepath.Dir(clmPath), 0o755)
|
||||
os.WriteFile(clmPath, []byte(tt.setup), 0o644)
|
||||
}
|
||||
|
||||
if err := v.Edit(tt.models); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(clmPath)
|
||||
tt.validate(t, data)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVSCodeEditCleansUpOldSettings(t *testing.T) {
|
||||
v := &VSCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("XDG_CONFIG_HOME", "")
|
||||
settingsPath := testVSCodePath(t, tmpDir, "settings.json")
|
||||
|
||||
// Create settings.json with old byok setting
|
||||
os.MkdirAll(filepath.Dir(settingsPath), 0o755)
|
||||
os.WriteFile(settingsPath, []byte(`{"github.copilot.chat.byok.ollamaEndpoint": "http://old:11434", "ollama.launch.configured": true, "editor.fontSize": 14}`), 0o644)
|
||||
|
||||
if err := v.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Verify old settings were removed
|
||||
data, err := os.ReadFile(settingsPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var settings map[string]any
|
||||
json.Unmarshal(data, &settings)
|
||||
if _, ok := settings["github.copilot.chat.byok.ollamaEndpoint"]; ok {
|
||||
t.Error("github.copilot.chat.byok.ollamaEndpoint should have been removed")
|
||||
}
|
||||
if _, ok := settings["ollama.launch.configured"]; ok {
|
||||
t.Error("ollama.launch.configured should have been removed")
|
||||
}
|
||||
if settings["editor.fontSize"] != float64(14) {
|
||||
t.Error("editor.fontSize should have been preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVSCodePaths(t *testing.T) {
|
||||
v := &VSCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("XDG_CONFIG_HOME", "")
|
||||
clmPath := testVSCodePath(t, tmpDir, "chatLanguageModels.json")
|
||||
|
||||
t.Run("no file returns nil", func(t *testing.T) {
|
||||
os.Remove(clmPath)
|
||||
if paths := v.Paths(); paths != nil {
|
||||
t.Errorf("expected nil, got %v", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("existing file returns path", func(t *testing.T) {
|
||||
os.MkdirAll(filepath.Dir(clmPath), 0o755)
|
||||
os.WriteFile(clmPath, []byte(`[]`), 0o644)
|
||||
|
||||
if paths := v.Paths(); len(paths) != 1 {
|
||||
t.Errorf("expected 1 path, got %d", len(paths))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// testVSCodePath returns the expected VS Code config path for the given file in tests.
|
||||
func testVSCodePath(t *testing.T, tmpDir, filename string) string {
|
||||
t.Helper()
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return filepath.Join(tmpDir, "Library", "Application Support", "Code", "User", filename)
|
||||
case "windows":
|
||||
t.Setenv("APPDATA", tmpDir)
|
||||
return filepath.Join(tmpDir, "Code", "User", filename)
|
||||
default:
|
||||
return filepath.Join(tmpDir, ".config", "Code", "User", filename)
|
||||
}
|
||||
}
|
||||
|
||||
func assertOllamaVendorConfigured(t *testing.T, data []byte) {
|
||||
t.Helper()
|
||||
var entries []map[string]any
|
||||
if err := json.Unmarshal(data, &entries); err != nil {
|
||||
t.Fatalf("invalid JSON: %v", err)
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if vendor, _ := entry["vendor"].(string); vendor == "ollama" {
|
||||
if name, _ := entry["name"].(string); name != "Ollama" {
|
||||
t.Errorf("expected name \"Ollama\", got %q", name)
|
||||
}
|
||||
if url, _ := entry["url"].(string); url == "" {
|
||||
t.Error("url not set")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Error("no ollama vendor entry found")
|
||||
}
|
||||
|
||||
func TestShowInModelPicker(t *testing.T) {
|
||||
v := &VSCode{}
|
||||
|
||||
// helper to create a state DB with optional seed data
|
||||
setupDB := func(t *testing.T, tmpDir string, seedPrefs map[string]bool, seedCache []map[string]any) string {
|
||||
t.Helper()
|
||||
dbDir := filepath.Join(tmpDir, "globalStorage")
|
||||
os.MkdirAll(dbDir, 0o755)
|
||||
dbPath := filepath.Join(dbDir, "state.vscdb")
|
||||
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
if _, err := db.Exec("CREATE TABLE ItemTable (key TEXT UNIQUE ON CONFLICT REPLACE, value BLOB)"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if seedPrefs != nil {
|
||||
data, _ := json.Marshal(seedPrefs)
|
||||
db.Exec("INSERT INTO ItemTable (key, value) VALUES ('chatModelPickerPreferences', ?)", string(data))
|
||||
}
|
||||
if seedCache != nil {
|
||||
data, _ := json.Marshal(seedCache)
|
||||
db.Exec("INSERT INTO ItemTable (key, value) VALUES ('chat.cachedLanguageModels.v2', ?)", string(data))
|
||||
}
|
||||
return dbPath
|
||||
}
|
||||
|
||||
// helper to read prefs back from DB
|
||||
readPrefs := func(t *testing.T, dbPath string) map[string]bool {
|
||||
t.Helper()
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
var raw string
|
||||
if err := db.QueryRow("SELECT value FROM ItemTable WHERE key = 'chatModelPickerPreferences'").Scan(&raw); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
prefs := make(map[string]bool)
|
||||
json.Unmarshal([]byte(raw), &prefs)
|
||||
return prefs
|
||||
}
|
||||
|
||||
t.Run("fresh DB creates table and shows models", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("XDG_CONFIG_HOME", "")
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Setenv("APPDATA", tmpDir)
|
||||
}
|
||||
|
||||
err := v.ShowInModelPicker([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dbPath := testVSCodePath(t, tmpDir, filepath.Join("globalStorage", "state.vscdb"))
|
||||
prefs := readPrefs(t, dbPath)
|
||||
if !prefs["ollama/Ollama/llama3.2"] {
|
||||
t.Error("expected llama3.2 to be shown")
|
||||
}
|
||||
if !prefs["ollama/Ollama/llama3.2:latest"] {
|
||||
t.Error("expected llama3.2:latest to be shown")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("configured models are shown", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("XDG_CONFIG_HOME", "")
|
||||
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), nil, nil)
|
||||
|
||||
err := v.ShowInModelPicker([]string{"llama3.2", "qwen3:8b"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
prefs := readPrefs(t, dbPath)
|
||||
if !prefs["ollama/Ollama/llama3.2"] {
|
||||
t.Error("expected llama3.2 to be shown")
|
||||
}
|
||||
if !prefs["ollama/Ollama/qwen3:8b"] {
|
||||
t.Error("expected qwen3:8b to be shown")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("removed models are hidden", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("XDG_CONFIG_HOME", "")
|
||||
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), map[string]bool{
|
||||
"ollama/Ollama/llama3.2": true,
|
||||
"ollama/Ollama/llama3.2:latest": true,
|
||||
"ollama/Ollama/mistral": true,
|
||||
"ollama/Ollama/mistral:latest": true,
|
||||
}, nil)
|
||||
|
||||
// Only configure llama3.2 — mistral should get hidden
|
||||
err := v.ShowInModelPicker([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
prefs := readPrefs(t, dbPath)
|
||||
if !prefs["ollama/Ollama/llama3.2"] {
|
||||
t.Error("expected llama3.2 to stay shown")
|
||||
}
|
||||
if prefs["ollama/Ollama/mistral"] {
|
||||
t.Error("expected mistral to be hidden")
|
||||
}
|
||||
if prefs["ollama/Ollama/mistral:latest"] {
|
||||
t.Error("expected mistral:latest to be hidden")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-ollama prefs are preserved", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("XDG_CONFIG_HOME", "")
|
||||
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), map[string]bool{
|
||||
"copilot/gpt-4o": true,
|
||||
}, nil)
|
||||
|
||||
err := v.ShowInModelPicker([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
prefs := readPrefs(t, dbPath)
|
||||
if !prefs["copilot/gpt-4o"] {
|
||||
t.Error("expected copilot/gpt-4o to stay shown")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses cached numeric IDs when available", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("XDG_CONFIG_HOME", "")
|
||||
cache := []map[string]any{
|
||||
{
|
||||
"identifier": "ollama/Ollama/4",
|
||||
"metadata": map[string]any{"vendor": "ollama", "name": "llama3.2"},
|
||||
},
|
||||
}
|
||||
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), nil, cache)
|
||||
|
||||
err := v.ShowInModelPicker([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
prefs := readPrefs(t, dbPath)
|
||||
if !prefs["ollama/Ollama/4"] {
|
||||
t.Error("expected numeric ID ollama/Ollama/4 to be shown")
|
||||
}
|
||||
// Name-based fallback should also be set
|
||||
if !prefs["ollama/Ollama/llama3.2"] {
|
||||
t.Error("expected name-based ID to also be shown")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty models is no-op", func(t *testing.T) {
|
||||
err := v.ShowInModelPicker([]string{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("previously hidden model is re-shown when configured", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("XDG_CONFIG_HOME", "")
|
||||
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), map[string]bool{
|
||||
"ollama/Ollama/llama3.2": false,
|
||||
"ollama/Ollama/llama3.2:latest": false,
|
||||
}, nil)
|
||||
|
||||
// Ollama config is authoritative — should override the hidden state
|
||||
err := v.ShowInModelPicker([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
prefs := readPrefs(t, dbPath)
|
||||
if !prefs["ollama/Ollama/llama3.2"] {
|
||||
t.Error("expected llama3.2 to be re-shown")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseCopilotChatVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
output string
|
||||
wantInstalled bool
|
||||
wantVersion string
|
||||
}{
|
||||
{
|
||||
name: "found among other extensions",
|
||||
output: "ms-python.python@2024.1.1\ngithub.copilot-chat@0.40.1\ngithub.copilot@1.200.0\n",
|
||||
wantInstalled: true,
|
||||
wantVersion: "0.40.1",
|
||||
},
|
||||
{
|
||||
name: "only extension",
|
||||
output: "GitHub.copilot-chat@0.41.0\n",
|
||||
wantInstalled: true,
|
||||
wantVersion: "0.41.0",
|
||||
},
|
||||
{
|
||||
name: "not installed",
|
||||
output: "ms-python.python@2024.1.1\ngithub.copilot@1.200.0\n",
|
||||
wantInstalled: false,
|
||||
},
|
||||
{
|
||||
name: "empty output",
|
||||
output: "",
|
||||
wantInstalled: false,
|
||||
},
|
||||
{
|
||||
name: "case insensitive match",
|
||||
output: "GitHub.Copilot-Chat@0.39.0\n",
|
||||
wantInstalled: true,
|
||||
wantVersion: "0.39.0",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
installed, version := parseCopilotChatVersion(tt.output)
|
||||
if installed != tt.wantInstalled {
|
||||
t.Errorf("installed = %v, want %v", installed, tt.wantInstalled)
|
||||
}
|
||||
if installed && version != tt.wantVersion {
|
||||
t.Errorf("version = %q, want %q", version, tt.wantVersion)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareVersions(t *testing.T) {
|
||||
tests := []struct {
|
||||
a, b string
|
||||
want int
|
||||
}{
|
||||
{"0.40.1", "0.40.1", 0},
|
||||
{"0.40.2", "0.40.1", 1},
|
||||
{"0.40.0", "0.40.1", -1},
|
||||
{"0.41.0", "0.40.1", 1},
|
||||
{"0.39.9", "0.40.1", -1},
|
||||
{"1.0.0", "0.40.1", 1},
|
||||
{"0.40", "0.40.1", -1},
|
||||
{"0.40.1.1", "0.40.1", 1},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.a+"_vs_"+tt.b, func(t *testing.T) {
|
||||
got := compareVersions(tt.a, tt.b)
|
||||
if got != tt.want {
|
||||
t.Errorf("compareVersions(%q, %q) = %d, want %d", tt.a, tt.b, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -64,8 +64,8 @@ type SelectItem struct {
|
||||
Recommended bool
|
||||
}
|
||||
|
||||
// ConvertItems converts config.ModelItem slice to SelectItem slice.
|
||||
func ConvertItems(items []config.ModelItem) []SelectItem {
|
||||
// ConvertItems converts launch.ModelItem slice to SelectItem slice.
|
||||
func ConvertItems(items []launch.ModelItem) []SelectItem {
|
||||
out := make([]SelectItem, len(items))
|
||||
for i, item := range items {
|
||||
out[i] = SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
|
||||
@@ -101,6 +101,16 @@ type selectorModel struct {
|
||||
width int
|
||||
}
|
||||
|
||||
func selectorModelWithCurrent(title string, items []SelectItem, current string) selectorModel {
|
||||
m := selectorModel{
|
||||
title: title,
|
||||
items: items,
|
||||
cursor: cursorForCurrent(items, current),
|
||||
}
|
||||
m.updateScroll(m.otherStart())
|
||||
return m
|
||||
}
|
||||
|
||||
func (m selectorModel) filteredItems() []SelectItem {
|
||||
if m.filter == "" {
|
||||
return m.items
|
||||
@@ -232,6 +242,10 @@ func (m selectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
|
||||
case tea.KeyLeft:
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
|
||||
case tea.KeyEnter:
|
||||
filtered := m.filteredItems()
|
||||
if len(filtered) > 0 && m.cursor < len(filtered) {
|
||||
@@ -344,7 +358,7 @@ func (m selectorModel) renderContent() string {
|
||||
}
|
||||
|
||||
s.WriteString("\n")
|
||||
help := "↑/↓ navigate • enter select • esc cancel"
|
||||
help := "↑/↓ navigate • enter select • ← back"
|
||||
if m.helpText != "" {
|
||||
help = m.helpText
|
||||
}
|
||||
@@ -367,13 +381,24 @@ func (m selectorModel) View() string {
|
||||
|
||||
// cursorForCurrent returns the item index matching current, or 0 if not found.
|
||||
func cursorForCurrent(items []SelectItem, current string) int {
|
||||
if current != "" {
|
||||
for i, item := range items {
|
||||
if item.Name == current || strings.HasPrefix(item.Name, current+":") || strings.HasPrefix(current, item.Name+":") {
|
||||
return i
|
||||
}
|
||||
if current == "" {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Prefer exact name matches before tag-prefix fallback so "qwen3.5" does not
|
||||
// incorrectly select "qwen3.5:cloud" (and vice versa) based on list order.
|
||||
for i, item := range items {
|
||||
if item.Name == current {
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
for i, item := range items {
|
||||
if strings.HasPrefix(item.Name, current+":") || strings.HasPrefix(current, item.Name+":") {
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -382,11 +407,7 @@ func SelectSingle(title string, items []SelectItem, current string) (string, err
|
||||
return "", fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
m := selectorModel{
|
||||
title: title,
|
||||
items: items,
|
||||
cursor: cursorForCurrent(items, current),
|
||||
}
|
||||
m := selectorModelWithCurrent(title, items, current)
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
finalModel, err := p.Run()
|
||||
@@ -523,6 +544,7 @@ func (m *multiSelectorModel) toggleItem() {
|
||||
origIdx := m.itemIndex[item.Name]
|
||||
|
||||
if m.checked[origIdx] {
|
||||
wasDefault := len(m.checkOrder) > 0 && m.checkOrder[len(m.checkOrder)-1] == origIdx
|
||||
delete(m.checked, origIdx)
|
||||
for i, idx := range m.checkOrder {
|
||||
if idx == origIdx {
|
||||
@@ -530,6 +552,34 @@ func (m *multiSelectorModel) toggleItem() {
|
||||
break
|
||||
}
|
||||
}
|
||||
if wasDefault {
|
||||
// When removing the default, pick the nearest checked model above it
|
||||
// (or below if none above) so default fallback follows list order.
|
||||
newDefault := -1
|
||||
for i := origIdx - 1; i >= 0; i-- {
|
||||
if m.checked[i] {
|
||||
newDefault = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if newDefault == -1 {
|
||||
for i := origIdx + 1; i < len(m.items); i++ {
|
||||
if m.checked[i] {
|
||||
newDefault = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if newDefault != -1 {
|
||||
for i, idx := range m.checkOrder {
|
||||
if idx == newDefault {
|
||||
m.checkOrder = append(m.checkOrder[:i], m.checkOrder[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
m.checkOrder = append(m.checkOrder, newDefault)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
m.checked[origIdx] = true
|
||||
m.checkOrder = append(m.checkOrder, origIdx)
|
||||
@@ -562,6 +612,10 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
|
||||
case tea.KeyLeft:
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
|
||||
case tea.KeyTab:
|
||||
m.multi = !m.multi
|
||||
|
||||
@@ -764,7 +818,7 @@ func (m multiSelectorModel) View() string {
|
||||
s.WriteString("\n")
|
||||
|
||||
if !m.multi {
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • tab add multiple • esc cancel"))
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • tab add multiple • ← back"))
|
||||
} else {
|
||||
count := m.selectedCount()
|
||||
if count == 0 {
|
||||
@@ -773,7 +827,7 @@ func (m multiSelectorModel) View() string {
|
||||
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • tab select single • enter confirm • esc cancel"))
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • tab select single • enter confirm • ← back"))
|
||||
}
|
||||
|
||||
result := s.String()
|
||||
|
||||
@@ -216,6 +216,41 @@ func TestUpdateScroll(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectorModelWithCurrent_ScrollsToCurrentInMoreSection(t *testing.T) {
|
||||
m := selectorModelWithCurrent("Pick:", mixedItems(), "other-10")
|
||||
|
||||
if m.cursor != 11 {
|
||||
t.Fatalf("cursor = %d, want 11", m.cursor)
|
||||
}
|
||||
if m.scrollOffset == 0 {
|
||||
t.Fatal("scrollOffset should move to reveal current item in More section")
|
||||
}
|
||||
|
||||
content := m.renderContent()
|
||||
if !strings.Contains(content, "▸ other-10") {
|
||||
t.Fatalf("expected current item to be visible and highlighted\n%s", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectorModelWithCurrent_HighlightsExactLocalWhenCloudVariantExists(t *testing.T) {
|
||||
m := selectorModelWithCurrent("Pick:", []SelectItem{
|
||||
{Name: "qwen3.5:cloud", Recommended: true},
|
||||
{Name: "qwen3.5", Recommended: true},
|
||||
}, "qwen3.5")
|
||||
|
||||
if m.cursor != 1 {
|
||||
t.Fatalf("cursor = %d, want 1", m.cursor)
|
||||
}
|
||||
|
||||
content := m.renderContent()
|
||||
if !strings.Contains(content, "▸ qwen3.5") {
|
||||
t.Fatalf("expected local qwen3.5 to be highlighted\n%s", content)
|
||||
}
|
||||
if strings.Contains(content, "▸ qwen3.5:cloud") {
|
||||
t.Fatalf("did not expect cloud qwen3.5:cloud to be highlighted\n%s", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderContent_SectionHeaders(t *testing.T) {
|
||||
m := selectorModel{
|
||||
title: "Pick:",
|
||||
@@ -418,6 +453,28 @@ func TestCursorForCurrent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCursorForCurrent_PrefersExactLocalOverCloudPrefix(t *testing.T) {
|
||||
testItems := []SelectItem{
|
||||
{Name: "qwen3.5:cloud", Recommended: true},
|
||||
{Name: "qwen3.5", Recommended: true},
|
||||
}
|
||||
|
||||
if got := cursorForCurrent(testItems, "qwen3.5"); got != 1 {
|
||||
t.Errorf("cursorForCurrent(%q) = %d, want %d", "qwen3.5", got, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCursorForCurrent_PrefersExactCloudOverLocalPrefix(t *testing.T) {
|
||||
testItems := []SelectItem{
|
||||
{Name: "qwen3.5", Recommended: true},
|
||||
{Name: "qwen3.5:cloud", Recommended: true},
|
||||
}
|
||||
|
||||
if got := cursorForCurrent(testItems, "qwen3.5:cloud"); got != 1 {
|
||||
t.Errorf("cursorForCurrent(%q) = %d, want %d", "qwen3.5:cloud", got, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// --- ReorderItems ---
|
||||
|
||||
func TestReorderItems(t *testing.T) {
|
||||
@@ -725,6 +782,9 @@ func TestMulti_MultiModeHelpText(t *testing.T) {
|
||||
if !strings.Contains(content, "tab select single") {
|
||||
t.Error("multi mode should show 'tab select single' in help")
|
||||
}
|
||||
if !strings.Contains(content, "← back") {
|
||||
t.Error("multi mode should show '← back' in help")
|
||||
}
|
||||
}
|
||||
|
||||
// --- preChecked initialization order ---
|
||||
@@ -783,6 +843,74 @@ func TestMulti_LastCheckedIsDefault(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_UncheckingDefaultFallsBackToNearestCheckedAbove(t *testing.T) {
|
||||
// Default is "b", and checked models are "a", "b", "c".
|
||||
// Unticking default should make "a" (the nearest checked item above) default.
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"b", "c", "a"})
|
||||
m.multi = true
|
||||
m.cursor = 1 // "b"
|
||||
m.toggleItem()
|
||||
|
||||
lastIdx := m.checkOrder[len(m.checkOrder)-1]
|
||||
if m.items[lastIdx].Name != "a" {
|
||||
t.Fatalf("expected default to fall back to 'a', got %q", m.items[lastIdx].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_UncheckingTopDefaultFallsBackToNearestCheckedBelow(t *testing.T) {
|
||||
// Default is top item "a". With no checked item above, fallback should pick
|
||||
// the nearest checked item below ("b").
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "c", "b"})
|
||||
m.multi = true
|
||||
m.cursor = 0 // "a"
|
||||
m.toggleItem()
|
||||
|
||||
lastIdx := m.checkOrder[len(m.checkOrder)-1]
|
||||
if m.items[lastIdx].Name != "b" {
|
||||
t.Fatalf("expected default to fall back to 'b', got %q", m.items[lastIdx].Name)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Left arrow back navigation ---
|
||||
|
||||
func TestSelectorLeftArrowCancelsWhenNoFilter(t *testing.T) {
|
||||
m := selectorModelWithCurrent("Pick:", items("a", "b", "c"), "")
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyLeft})
|
||||
got := updated.(selectorModel)
|
||||
if !got.cancelled {
|
||||
t.Error("left arrow with empty filter should cancel (go back)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectorLeftArrowCancelsWhenFiltering(t *testing.T) {
|
||||
m := selectorModelWithCurrent("Pick:", items("a", "b", "c"), "")
|
||||
m.filter = "a"
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyLeft})
|
||||
got := updated.(selectorModel)
|
||||
if !got.cancelled {
|
||||
t.Error("left arrow with active filter should still cancel (go back)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiSelectorLeftArrowCancelsWhenNoFilter(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyLeft})
|
||||
got := updated.(multiSelectorModel)
|
||||
if !got.cancelled {
|
||||
t.Error("left arrow with empty filter should cancel (go back)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiSelectorLeftArrowCancelsWhenFiltering(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||
m.filter = "a"
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyLeft})
|
||||
got := updated.(multiSelectorModel)
|
||||
if !got.cancelled {
|
||||
t.Error("left arrow with active filter should still cancel (go back)")
|
||||
}
|
||||
}
|
||||
|
||||
// Key message helpers for testing
|
||||
|
||||
type keyType = int
|
||||
|
||||
@@ -1,15 +1,24 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
)
|
||||
|
||||
type signInTickMsg struct{}
|
||||
|
||||
type signInCheckMsg struct {
|
||||
signedIn bool
|
||||
userName string
|
||||
}
|
||||
|
||||
type signInModel struct {
|
||||
modelName string
|
||||
signInURL string
|
||||
@@ -88,11 +97,8 @@ func renderSignIn(modelName, signInURL string, spinner, width int) string {
|
||||
|
||||
fmt.Fprintf(&s, "To use %s, please sign in.\n\n", selectorSelectedItemStyle.Render(modelName))
|
||||
|
||||
// Wrap in OSC 8 hyperlink so the entire URL is clickable even when wrapped.
|
||||
// Padding is outside the hyperlink so spaces don't get underlined.
|
||||
link := fmt.Sprintf("\033]8;;%s\033\\%s\033]8;;\033\\", signInURL, urlColor.Render(signInURL))
|
||||
s.WriteString("Navigate to:\n")
|
||||
s.WriteString(urlWrap.Render(link))
|
||||
s.WriteString(urlWrap.Render(urlColor.Render(signInURL)))
|
||||
s.WriteString("\n\n")
|
||||
|
||||
s.WriteString(lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).Render(
|
||||
@@ -104,9 +110,21 @@ func renderSignIn(modelName, signInURL string, spinner, width int) string {
|
||||
return lipgloss.NewStyle().PaddingLeft(2).Render(s.String())
|
||||
}
|
||||
|
||||
func checkSignIn() tea.Msg {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return signInCheckMsg{signedIn: false}
|
||||
}
|
||||
user, err := client.Whoami(context.Background())
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return signInCheckMsg{signedIn: true, userName: user.Name}
|
||||
}
|
||||
return signInCheckMsg{signedIn: false}
|
||||
}
|
||||
|
||||
// RunSignIn shows a bubbletea sign-in dialog and polls until the user signs in or cancels.
|
||||
func RunSignIn(modelName, signInURL string) (string, error) {
|
||||
config.OpenBrowser(signInURL)
|
||||
launch.OpenBrowser(signInURL)
|
||||
|
||||
m := signInModel{
|
||||
modelName: modelName,
|
||||
|
||||
@@ -25,22 +25,6 @@ func TestRenderSignIn_ContainsURL(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderSignIn_OSC8Hyperlink(t *testing.T) {
|
||||
url := "https://ollama.com/connect?key=abc123"
|
||||
got := renderSignIn("test:cloud", url, 0, 120)
|
||||
|
||||
// Should contain OSC 8 open sequence with the URL
|
||||
osc8Open := "\033]8;;" + url + "\033\\"
|
||||
if !strings.Contains(got, osc8Open) {
|
||||
t.Error("should contain OSC 8 open sequence with URL")
|
||||
}
|
||||
|
||||
// Should contain OSC 8 close sequence
|
||||
osc8Close := "\033]8;;\033\\"
|
||||
if !strings.Contains(got, osc8Close) {
|
||||
t.Error("should contain OSC 8 close sequence")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderSignIn_ContainsSpinner(t *testing.T) {
|
||||
got := renderSignIn("test:cloud", "https://example.com", 0, 80)
|
||||
|
||||
841
cmd/tui/tui.go
@@ -1,16 +1,11 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
@@ -45,30 +40,24 @@ var (
|
||||
type menuItem struct {
|
||||
title string
|
||||
description string
|
||||
integration string // integration name for loading model config, empty if not an integration
|
||||
integration string
|
||||
isRunModel bool
|
||||
isOthers bool
|
||||
}
|
||||
|
||||
var mainMenuItems = []menuItem{
|
||||
{
|
||||
title: "Run a model",
|
||||
title: "Chat with a model",
|
||||
description: "Start an interactive chat with a model",
|
||||
isRunModel: true,
|
||||
},
|
||||
{
|
||||
title: "Launch Claude Code",
|
||||
description: "Agentic coding across large codebases",
|
||||
integration: "claude",
|
||||
},
|
||||
{
|
||||
title: "Launch Codex",
|
||||
description: "OpenAI's open-source coding agent",
|
||||
integration: "codex",
|
||||
},
|
||||
{
|
||||
title: "Launch OpenClaw",
|
||||
description: "Personal AI with 100+ skills",
|
||||
integration: "openclaw",
|
||||
},
|
||||
}
|
||||
@@ -79,277 +68,106 @@ var othersMenuItem = menuItem{
|
||||
isOthers: true,
|
||||
}
|
||||
|
||||
// getOtherIntegrations dynamically builds the "Others" list from the integration
|
||||
// registry, excluding any integrations already present in the pinned mainMenuItems.
|
||||
func getOtherIntegrations() []menuItem {
|
||||
pinned := map[string]bool{
|
||||
"run": true, // not an integration but in the pinned list
|
||||
type model struct {
|
||||
state *launch.LauncherState
|
||||
items []menuItem
|
||||
cursor int
|
||||
showOthers bool
|
||||
width int
|
||||
quitting bool
|
||||
selected bool
|
||||
action TUIAction
|
||||
}
|
||||
|
||||
func newModel(state *launch.LauncherState) model {
|
||||
m := model{
|
||||
state: state,
|
||||
}
|
||||
m.showOthers = shouldExpandOthers(state)
|
||||
m.items = buildMenuItems(state, m.showOthers)
|
||||
m.cursor = initialCursor(state, m.items)
|
||||
return m
|
||||
}
|
||||
|
||||
func shouldExpandOthers(state *launch.LauncherState) bool {
|
||||
if state == nil {
|
||||
return false
|
||||
}
|
||||
for _, item := range otherIntegrationItems(state) {
|
||||
if item.integration == state.LastSelection {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func buildMenuItems(state *launch.LauncherState, showOthers bool) []menuItem {
|
||||
items := make([]menuItem, 0, len(mainMenuItems)+1)
|
||||
for _, item := range mainMenuItems {
|
||||
if item.integration != "" {
|
||||
pinned[item.integration] = true
|
||||
if item.integration == "" {
|
||||
items = append(items, item)
|
||||
continue
|
||||
}
|
||||
if integrationState, ok := state.Integrations[item.integration]; ok {
|
||||
items = append(items, integrationMenuItem(integrationState))
|
||||
}
|
||||
}
|
||||
|
||||
var others []menuItem
|
||||
for _, info := range config.ListIntegrationInfos() {
|
||||
if showOthers {
|
||||
items = append(items, otherIntegrationItems(state)...)
|
||||
} else {
|
||||
items = append(items, othersMenuItem)
|
||||
}
|
||||
|
||||
return items
|
||||
}
|
||||
|
||||
func integrationMenuItem(state launch.LauncherIntegrationState) menuItem {
|
||||
description := state.Description
|
||||
if description == "" {
|
||||
description = "Open " + state.DisplayName + " integration"
|
||||
}
|
||||
return menuItem{
|
||||
title: "Launch " + state.DisplayName,
|
||||
description: description,
|
||||
integration: state.Name,
|
||||
}
|
||||
}
|
||||
|
||||
func otherIntegrationItems(state *launch.LauncherState) []menuItem {
|
||||
pinned := map[string]bool{
|
||||
"claude": true,
|
||||
"codex": true,
|
||||
"openclaw": true,
|
||||
}
|
||||
|
||||
var items []menuItem
|
||||
for _, info := range launch.ListIntegrationInfos() {
|
||||
if pinned[info.Name] {
|
||||
continue
|
||||
}
|
||||
desc := info.Description
|
||||
if desc == "" {
|
||||
desc = "Open " + info.DisplayName + " integration"
|
||||
}
|
||||
others = append(others, menuItem{
|
||||
title: "Launch " + info.DisplayName,
|
||||
description: desc,
|
||||
integration: info.Name,
|
||||
})
|
||||
}
|
||||
return others
|
||||
}
|
||||
|
||||
type model struct {
|
||||
items []menuItem
|
||||
cursor int
|
||||
quitting bool
|
||||
selected bool
|
||||
changeModel bool
|
||||
changeModels []string // multi-select result for Editor integrations
|
||||
showOthers bool
|
||||
availableModels map[string]bool
|
||||
err error
|
||||
|
||||
showingModal bool
|
||||
modalSelector selectorModel
|
||||
modalItems []SelectItem
|
||||
|
||||
showingMultiModal bool
|
||||
multiModalSelector multiSelectorModel
|
||||
|
||||
showingSignIn bool
|
||||
signInURL string
|
||||
signInModel string
|
||||
signInSpinner int
|
||||
signInFromModal bool // true if sign-in was triggered from modal (not main menu)
|
||||
|
||||
width int // terminal width from WindowSizeMsg
|
||||
statusMsg string // temporary status message shown near help text
|
||||
}
|
||||
|
||||
type signInTickMsg struct{}
|
||||
|
||||
type signInCheckMsg struct {
|
||||
signedIn bool
|
||||
userName string
|
||||
}
|
||||
|
||||
type clearStatusMsg struct{}
|
||||
|
||||
func (m *model) modelExists(name string) bool {
|
||||
if m.availableModels == nil || name == "" {
|
||||
return false
|
||||
}
|
||||
if m.availableModels[name] {
|
||||
return true
|
||||
}
|
||||
// Check for prefix match (e.g., "llama2" matches "llama2:latest")
|
||||
for modelName := range m.availableModels {
|
||||
if strings.HasPrefix(modelName, name+":") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *model) buildModalItems() []SelectItem {
|
||||
modelItems, _ := config.GetModelItems(context.Background())
|
||||
return ReorderItems(ConvertItems(modelItems))
|
||||
}
|
||||
|
||||
func (m *model) openModelModal(currentModel string) {
|
||||
m.modalItems = m.buildModalItems()
|
||||
cursor := 0
|
||||
if currentModel != "" {
|
||||
for i, item := range m.modalItems {
|
||||
if item.Name == currentModel || strings.HasPrefix(item.Name, currentModel+":") || strings.HasPrefix(currentModel, item.Name+":") {
|
||||
cursor = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
m.modalSelector = selectorModel{
|
||||
title: "Select model:",
|
||||
items: m.modalItems,
|
||||
cursor: cursor,
|
||||
helpText: "↑/↓ navigate • enter select • ← back",
|
||||
}
|
||||
m.modalSelector.updateScroll(m.modalSelector.otherStart())
|
||||
m.showingModal = true
|
||||
}
|
||||
|
||||
func (m *model) openMultiModelModal(integration string) {
|
||||
items := m.buildModalItems()
|
||||
var preChecked []string
|
||||
if models := config.IntegrationModels(integration); len(models) > 0 {
|
||||
preChecked = models
|
||||
}
|
||||
m.multiModalSelector = newMultiSelectorModel("Select models:", items, preChecked)
|
||||
// Set cursor to the first pre-checked (last used) model
|
||||
if len(preChecked) > 0 {
|
||||
for i, item := range items {
|
||||
if item.Name == preChecked[0] {
|
||||
m.multiModalSelector.cursor = i
|
||||
m.multiModalSelector.updateScroll(m.multiModalSelector.otherStart())
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
m.showingMultiModal = true
|
||||
}
|
||||
|
||||
func isCloudModel(name string) bool {
|
||||
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
|
||||
}
|
||||
|
||||
func cloudStatusDisabled(client *api.Client) bool {
|
||||
status, err := client.CloudStatusExperimental(context.Background())
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return status.Cloud.Disabled
|
||||
}
|
||||
|
||||
func cloudModelDisabled(name string) bool {
|
||||
if !isCloudModel(name) {
|
||||
return false
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return cloudStatusDisabled(client)
|
||||
}
|
||||
|
||||
// checkCloudSignIn checks if a cloud model needs sign-in.
|
||||
// Returns a command to start sign-in if needed, or nil if already signed in.
|
||||
func (m *model) checkCloudSignIn(modelName string, fromModal bool) tea.Cmd {
|
||||
if modelName == "" || !isCloudModel(modelName) {
|
||||
return nil
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if cloudStatusDisabled(client) {
|
||||
return nil
|
||||
}
|
||||
user, err := client.Whoami(context.Background())
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return nil
|
||||
}
|
||||
var aErr api.AuthorizationError
|
||||
if errors.As(err, &aErr) && aErr.SigninURL != "" {
|
||||
return m.startSignIn(modelName, aErr.SigninURL, fromModal)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// startSignIn initiates the sign-in flow for a cloud model.
|
||||
// fromModal indicates if this was triggered from the model picker modal.
|
||||
func (m *model) startSignIn(modelName, signInURL string, fromModal bool) tea.Cmd {
|
||||
m.showingModal = false
|
||||
m.showingSignIn = true
|
||||
m.signInURL = signInURL
|
||||
m.signInModel = modelName
|
||||
m.signInSpinner = 0
|
||||
m.signInFromModal = fromModal
|
||||
|
||||
config.OpenBrowser(signInURL)
|
||||
|
||||
return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
})
|
||||
}
|
||||
|
||||
func checkSignIn() tea.Msg {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return signInCheckMsg{signedIn: false}
|
||||
}
|
||||
user, err := client.Whoami(context.Background())
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return signInCheckMsg{signedIn: true, userName: user.Name}
|
||||
}
|
||||
return signInCheckMsg{signedIn: false}
|
||||
}
|
||||
|
||||
func (m *model) loadAvailableModels() {
|
||||
m.availableModels = make(map[string]bool)
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
models, err := client.List(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cloudDisabled := cloudStatusDisabled(client)
|
||||
for _, mdl := range models.Models {
|
||||
if cloudDisabled && mdl.RemoteModel != "" {
|
||||
integrationState, ok := state.Integrations[info.Name]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
m.availableModels[mdl.Name] = true
|
||||
items = append(items, integrationMenuItem(integrationState))
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
func (m *model) buildItems() {
|
||||
others := getOtherIntegrations()
|
||||
m.items = make([]menuItem, 0, len(mainMenuItems)+1+len(others))
|
||||
m.items = append(m.items, mainMenuItems...)
|
||||
|
||||
if m.showOthers {
|
||||
m.items = append(m.items, others...)
|
||||
} else {
|
||||
m.items = append(m.items, othersMenuItem)
|
||||
func initialCursor(state *launch.LauncherState, items []menuItem) int {
|
||||
if state == nil || state.LastSelection == "" {
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func isOthersIntegration(name string) bool {
|
||||
for _, item := range getOtherIntegrations() {
|
||||
if item.integration == name {
|
||||
return true
|
||||
for i, item := range items {
|
||||
if state.LastSelection == "run" && item.isRunModel {
|
||||
return i
|
||||
}
|
||||
if item.integration == state.LastSelection {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func initialModel() model {
|
||||
m := model{
|
||||
cursor: 0,
|
||||
}
|
||||
m.loadAvailableModels()
|
||||
|
||||
lastSelection := config.LastSelection()
|
||||
if isOthersIntegration(lastSelection) {
|
||||
m.showOthers = true
|
||||
}
|
||||
|
||||
m.buildItems()
|
||||
|
||||
if lastSelection != "" {
|
||||
for i, item := range m.items {
|
||||
if lastSelection == "run" && item.isRunModel {
|
||||
m.cursor = i
|
||||
break
|
||||
} else if item.integration == lastSelection {
|
||||
m.cursor = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m model) Init() tea.Cmd {
|
||||
@@ -357,143 +175,11 @@ func (m model) Init() tea.Cmd {
|
||||
}
|
||||
|
||||
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if wmsg, ok := msg.(tea.WindowSizeMsg); ok {
|
||||
wasSet := m.width > 0
|
||||
m.width = wmsg.Width
|
||||
if wasSet {
|
||||
return m, tea.EnterAltScreen
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if _, ok := msg.(clearStatusMsg); ok {
|
||||
m.statusMsg = ""
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if m.showingSignIn {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc:
|
||||
m.showingSignIn = false
|
||||
if m.signInFromModal {
|
||||
m.showingModal = true
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
case signInTickMsg:
|
||||
m.signInSpinner++
|
||||
// Check sign-in status every 5th tick (~1 second)
|
||||
if m.signInSpinner%5 == 0 {
|
||||
return m, tea.Batch(
|
||||
tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
}),
|
||||
checkSignIn,
|
||||
)
|
||||
}
|
||||
return m, tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
})
|
||||
|
||||
case signInCheckMsg:
|
||||
if msg.signedIn {
|
||||
if m.signInFromModal {
|
||||
m.modalSelector.selected = m.signInModel
|
||||
m.changeModel = true
|
||||
} else {
|
||||
m.selected = true
|
||||
}
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if m.showingMultiModal {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
if msg.Type == tea.KeyLeft {
|
||||
m.showingMultiModal = false
|
||||
return m, nil
|
||||
}
|
||||
updated, cmd := m.multiModalSelector.Update(msg)
|
||||
m.multiModalSelector = updated.(multiSelectorModel)
|
||||
|
||||
if m.multiModalSelector.cancelled {
|
||||
m.showingMultiModal = false
|
||||
return m, nil
|
||||
}
|
||||
if m.multiModalSelector.confirmed {
|
||||
var selected []string
|
||||
if m.multiModalSelector.singleAdd != "" {
|
||||
// Single-add mode: prepend picked model, keep existing deduped
|
||||
selected = []string{m.multiModalSelector.singleAdd}
|
||||
for _, name := range config.IntegrationModels(m.items[m.cursor].integration) {
|
||||
if name != m.multiModalSelector.singleAdd {
|
||||
selected = append(selected, name)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Last checked is default (first in result)
|
||||
co := m.multiModalSelector.checkOrder
|
||||
last := co[len(co)-1]
|
||||
selected = []string{m.multiModalSelector.items[last].Name}
|
||||
for _, idx := range co {
|
||||
if idx != last {
|
||||
selected = append(selected, m.multiModalSelector.items[idx].Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(selected) > 0 {
|
||||
m.changeModels = selected
|
||||
m.changeModel = true
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
m.multiModalSelector.confirmed = false
|
||||
return m, nil
|
||||
}
|
||||
return m, cmd
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if m.showingModal {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc, tea.KeyLeft:
|
||||
m.showingModal = false
|
||||
return m, nil
|
||||
|
||||
case tea.KeyEnter:
|
||||
filtered := m.modalSelector.filteredItems()
|
||||
if len(filtered) > 0 && m.modalSelector.cursor < len(filtered) {
|
||||
m.modalSelector.selected = filtered[m.modalSelector.cursor].Name
|
||||
}
|
||||
if m.modalSelector.selected != "" {
|
||||
if cmd := m.checkCloudSignIn(m.modalSelector.selected, true); cmd != nil {
|
||||
return m, cmd
|
||||
}
|
||||
m.changeModel = true
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
return m, nil
|
||||
|
||||
default:
|
||||
// Delegate navigation (up/down/pgup/pgdown/filter/backspace) to selectorModel
|
||||
m.modalSelector.updateNavigation(msg)
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
m.width = msg.Width
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
switch msg.String() {
|
||||
case "ctrl+c", "q", "esc":
|
||||
@@ -504,162 +190,78 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if m.cursor > 0 {
|
||||
m.cursor--
|
||||
}
|
||||
// Auto-collapse "Others" when cursor moves back into pinned items
|
||||
if m.showOthers && m.cursor < len(mainMenuItems) {
|
||||
m.showOthers = false
|
||||
m.buildItems()
|
||||
m.items = buildMenuItems(m.state, false)
|
||||
m.cursor = min(m.cursor, len(m.items)-1)
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case "down", "j":
|
||||
if m.cursor < len(m.items)-1 {
|
||||
m.cursor++
|
||||
}
|
||||
// Auto-expand "Others..." when cursor lands on it
|
||||
if m.cursor < len(m.items) && m.items[m.cursor].isOthers && !m.showOthers {
|
||||
m.showOthers = true
|
||||
m.buildItems()
|
||||
// cursor now points at the first "other" integration
|
||||
m.items = buildMenuItems(m.state, true)
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case "enter", " ":
|
||||
item := m.items[m.cursor]
|
||||
|
||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) && !config.AutoInstallable(item.integration) {
|
||||
return m, nil
|
||||
if m.selectableItem(m.items[m.cursor]) {
|
||||
m.selected = true
|
||||
m.action = actionForMenuItem(m.items[m.cursor], false)
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
var configuredModel string
|
||||
if item.isRunModel {
|
||||
configuredModel = config.LastModel()
|
||||
} else if item.integration != "" {
|
||||
configuredModel = config.IntegrationModel(item.integration)
|
||||
}
|
||||
if cmd := m.checkCloudSignIn(configuredModel, false); cmd != nil {
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
if configuredModel != "" && isCloudModel(configuredModel) && cloudModelDisabled(configuredModel) {
|
||||
if item.integration != "" && config.IsEditorIntegration(item.integration) {
|
||||
m.openMultiModelModal(item.integration)
|
||||
} else {
|
||||
m.openModelModal(configuredModel)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.selected = true
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
return m, nil
|
||||
|
||||
case "right", "l":
|
||||
item := m.items[m.cursor]
|
||||
if item.integration != "" || item.isRunModel {
|
||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
|
||||
if config.AutoInstallable(item.integration) {
|
||||
// Auto-installable: select to trigger install flow
|
||||
m.selected = true
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
if item.integration != "" && config.IsEditorIntegration(item.integration) {
|
||||
m.openMultiModelModal(item.integration)
|
||||
} else {
|
||||
var currentModel string
|
||||
if item.isRunModel {
|
||||
currentModel = config.LastModel()
|
||||
} else if item.integration != "" {
|
||||
currentModel = config.IntegrationModel(item.integration)
|
||||
}
|
||||
m.openModelModal(currentModel)
|
||||
}
|
||||
if item.isRunModel || m.changeableItem(item) {
|
||||
m.selected = true
|
||||
m.action = actionForMenuItem(item, true)
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m model) selectableItem(item menuItem) bool {
|
||||
if item.isRunModel {
|
||||
return true
|
||||
}
|
||||
if item.integration == "" || item.isOthers {
|
||||
return false
|
||||
}
|
||||
state, ok := m.state.Integrations[item.integration]
|
||||
return ok && state.Selectable
|
||||
}
|
||||
|
||||
func (m model) changeableItem(item menuItem) bool {
|
||||
if item.integration == "" || item.isOthers {
|
||||
return false
|
||||
}
|
||||
state, ok := m.state.Integrations[item.integration]
|
||||
return ok && state.Changeable
|
||||
}
|
||||
|
||||
func (m model) View() string {
|
||||
if m.quitting {
|
||||
return ""
|
||||
}
|
||||
|
||||
if m.showingSignIn {
|
||||
return m.renderSignInDialog()
|
||||
}
|
||||
|
||||
if m.showingMultiModal {
|
||||
return m.multiModalSelector.View()
|
||||
}
|
||||
|
||||
if m.showingModal {
|
||||
return m.renderModal()
|
||||
}
|
||||
|
||||
s := selectorTitleStyle.Render("Ollama "+versionStyle.Render(version.Version)) + "\n\n"
|
||||
|
||||
for i, item := range m.items {
|
||||
cursor := ""
|
||||
style := menuItemStyle
|
||||
isInstalled := true
|
||||
|
||||
if item.integration != "" {
|
||||
isInstalled = config.IsIntegrationInstalled(item.integration)
|
||||
}
|
||||
|
||||
if m.cursor == i {
|
||||
cursor = "▸ "
|
||||
if isInstalled {
|
||||
style = menuSelectedItemStyle
|
||||
} else {
|
||||
style = greyedSelectedStyle
|
||||
}
|
||||
} else if !isInstalled && item.integration != "" {
|
||||
style = greyedStyle
|
||||
}
|
||||
|
||||
title := item.title
|
||||
var modelSuffix string
|
||||
if item.integration != "" {
|
||||
if !isInstalled {
|
||||
if config.AutoInstallable(item.integration) {
|
||||
title += " " + notInstalledStyle.Render("(install)")
|
||||
} else {
|
||||
title += " " + notInstalledStyle.Render("(not installed)")
|
||||
}
|
||||
} else if m.cursor == i {
|
||||
if mdl := config.IntegrationModel(item.integration); mdl != "" && m.modelExists(mdl) {
|
||||
modelSuffix = " " + modelStyle.Render("("+mdl+")")
|
||||
}
|
||||
}
|
||||
} else if item.isRunModel && m.cursor == i {
|
||||
if mdl := config.LastModel(); mdl != "" && m.modelExists(mdl) {
|
||||
modelSuffix = " " + modelStyle.Render("("+mdl+")")
|
||||
}
|
||||
}
|
||||
|
||||
s += style.Render(cursor+title) + modelSuffix + "\n"
|
||||
|
||||
desc := item.description
|
||||
if !isInstalled && item.integration != "" && m.cursor == i {
|
||||
if config.AutoInstallable(item.integration) {
|
||||
desc = "Press enter to install"
|
||||
} else if hint := config.IntegrationInstallHint(item.integration); hint != "" {
|
||||
desc = hint
|
||||
} else {
|
||||
desc = "not installed"
|
||||
}
|
||||
}
|
||||
s += menuDescStyle.Render(desc) + "\n\n"
|
||||
s += m.renderMenuItem(i, item)
|
||||
}
|
||||
|
||||
if m.statusMsg != "" {
|
||||
s += "\n" + lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "124", Dark: "210"}).Render(m.statusMsg) + "\n"
|
||||
}
|
||||
|
||||
s += "\n" + selectorHelpStyle.Render("↑/↓ navigate • enter launch • → change model • esc quit")
|
||||
s += "\n" + selectorHelpStyle.Render("↑/↓ navigate • enter launch • → configure • esc quit")
|
||||
|
||||
if m.width > 0 {
|
||||
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
|
||||
@@ -667,80 +269,125 @@ func (m model) View() string {
|
||||
return s
|
||||
}
|
||||
|
||||
func (m model) renderModal() string {
|
||||
modalStyle := lipgloss.NewStyle().
|
||||
PaddingBottom(1).
|
||||
PaddingRight(2)
|
||||
func (m model) renderMenuItem(index int, item menuItem) string {
|
||||
cursor := ""
|
||||
style := menuItemStyle
|
||||
title := item.title
|
||||
description := item.description
|
||||
modelSuffix := ""
|
||||
|
||||
s := modalStyle.Render(m.modalSelector.renderContent())
|
||||
if m.width > 0 {
|
||||
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (m model) renderSignInDialog() string {
|
||||
return renderSignIn(m.signInModel, m.signInURL, m.signInSpinner, m.width)
|
||||
}
|
||||
|
||||
type Selection int
|
||||
|
||||
const (
|
||||
SelectionNone Selection = iota
|
||||
SelectionRunModel
|
||||
SelectionChangeRunModel
|
||||
SelectionIntegration // Generic integration selection
|
||||
SelectionChangeIntegration // Generic change model for integration
|
||||
)
|
||||
|
||||
type Result struct {
|
||||
Selection Selection
|
||||
Integration string // integration name if applicable
|
||||
Model string // model name if selected from single-select modal
|
||||
Models []string // models selected from multi-select modal (Editor integrations)
|
||||
}
|
||||
|
||||
func Run() (Result, error) {
|
||||
m := initialModel()
|
||||
p := tea.NewProgram(m)
|
||||
|
||||
finalModel, err := p.Run()
|
||||
if err != nil {
|
||||
return Result{Selection: SelectionNone}, fmt.Errorf("error running TUI: %w", err)
|
||||
}
|
||||
|
||||
fm := finalModel.(model)
|
||||
if fm.err != nil {
|
||||
return Result{Selection: SelectionNone}, fm.err
|
||||
}
|
||||
|
||||
if !fm.selected && !fm.changeModel {
|
||||
return Result{Selection: SelectionNone}, nil
|
||||
}
|
||||
|
||||
item := fm.items[fm.cursor]
|
||||
|
||||
if fm.changeModel {
|
||||
if item.isRunModel {
|
||||
return Result{
|
||||
Selection: SelectionChangeRunModel,
|
||||
Model: fm.modalSelector.selected,
|
||||
}, nil
|
||||
}
|
||||
return Result{
|
||||
Selection: SelectionChangeIntegration,
|
||||
Integration: item.integration,
|
||||
Model: fm.modalSelector.selected,
|
||||
Models: fm.changeModels,
|
||||
}, nil
|
||||
if m.cursor == index {
|
||||
cursor = "▸ "
|
||||
}
|
||||
|
||||
if item.isRunModel {
|
||||
return Result{Selection: SelectionRunModel}, nil
|
||||
if m.cursor == index && m.state.RunModel != "" {
|
||||
modelSuffix = " " + modelStyle.Render("("+m.state.RunModel+")")
|
||||
}
|
||||
if m.cursor == index {
|
||||
style = menuSelectedItemStyle
|
||||
}
|
||||
} else if item.isOthers {
|
||||
if m.cursor == index {
|
||||
style = menuSelectedItemStyle
|
||||
}
|
||||
} else {
|
||||
integrationState := m.state.Integrations[item.integration]
|
||||
if !integrationState.Selectable {
|
||||
if m.cursor == index {
|
||||
style = greyedSelectedStyle
|
||||
} else {
|
||||
style = greyedStyle
|
||||
}
|
||||
} else if m.cursor == index {
|
||||
style = menuSelectedItemStyle
|
||||
}
|
||||
|
||||
if m.cursor == index && integrationState.CurrentModel != "" {
|
||||
modelSuffix = " " + modelStyle.Render("("+integrationState.CurrentModel+")")
|
||||
}
|
||||
|
||||
if !integrationState.Installed {
|
||||
if integrationState.AutoInstallable {
|
||||
title += " " + notInstalledStyle.Render("(install)")
|
||||
} else {
|
||||
title += " " + notInstalledStyle.Render("(not installed)")
|
||||
}
|
||||
if m.cursor == index {
|
||||
if integrationState.AutoInstallable {
|
||||
description = "Press enter to install"
|
||||
} else if integrationState.InstallHint != "" {
|
||||
description = integrationState.InstallHint
|
||||
} else {
|
||||
description = "not installed"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Result{
|
||||
Selection: SelectionIntegration,
|
||||
Integration: item.integration,
|
||||
}, nil
|
||||
return style.Render(cursor+title) + modelSuffix + "\n" + menuDescStyle.Render(description) + "\n\n"
|
||||
}
|
||||
|
||||
type TUIActionKind int
|
||||
|
||||
const (
|
||||
TUIActionNone TUIActionKind = iota
|
||||
TUIActionRunModel
|
||||
TUIActionLaunchIntegration
|
||||
)
|
||||
|
||||
type TUIAction struct {
|
||||
Kind TUIActionKind
|
||||
Integration string
|
||||
ForceConfigure bool
|
||||
}
|
||||
|
||||
func (a TUIAction) LastSelection() string {
|
||||
switch a.Kind {
|
||||
case TUIActionRunModel:
|
||||
return "run"
|
||||
case TUIActionLaunchIntegration:
|
||||
return a.Integration
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (a TUIAction) RunModelRequest() launch.RunModelRequest {
|
||||
return launch.RunModelRequest{ForcePicker: a.ForceConfigure}
|
||||
}
|
||||
|
||||
func (a TUIAction) IntegrationLaunchRequest() launch.IntegrationLaunchRequest {
|
||||
return launch.IntegrationLaunchRequest{
|
||||
Name: a.Integration,
|
||||
ForceConfigure: a.ForceConfigure,
|
||||
}
|
||||
}
|
||||
|
||||
func actionForMenuItem(item menuItem, forceConfigure bool) TUIAction {
|
||||
switch {
|
||||
case item.isRunModel:
|
||||
return TUIAction{Kind: TUIActionRunModel, ForceConfigure: forceConfigure}
|
||||
case item.integration != "":
|
||||
return TUIAction{Kind: TUIActionLaunchIntegration, Integration: item.integration, ForceConfigure: forceConfigure}
|
||||
default:
|
||||
return TUIAction{Kind: TUIActionNone}
|
||||
}
|
||||
}
|
||||
|
||||
func RunMenu(state *launch.LauncherState) (TUIAction, error) {
|
||||
menu := newModel(state)
|
||||
program := tea.NewProgram(menu)
|
||||
|
||||
finalModel, err := program.Run()
|
||||
if err != nil {
|
||||
return TUIAction{Kind: TUIActionNone}, fmt.Errorf("error running TUI: %w", err)
|
||||
}
|
||||
|
||||
finalMenu := finalModel.(model)
|
||||
if !finalMenu.selected {
|
||||
return TUIAction{Kind: TUIActionNone}, nil
|
||||
}
|
||||
|
||||
return finalMenu.action, nil
|
||||
}
|
||||
|
||||
178
cmd/tui/tui_test.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
)
|
||||
|
||||
func launcherTestState() *launch.LauncherState {
|
||||
return &launch.LauncherState{
|
||||
LastSelection: "run",
|
||||
RunModel: "qwen3:8b",
|
||||
Integrations: map[string]launch.LauncherIntegrationState{
|
||||
"claude": {
|
||||
Name: "claude",
|
||||
DisplayName: "Claude Code",
|
||||
Description: "Anthropic's coding tool with subagents",
|
||||
Selectable: true,
|
||||
Changeable: true,
|
||||
CurrentModel: "glm-5:cloud",
|
||||
},
|
||||
"codex": {
|
||||
Name: "codex",
|
||||
DisplayName: "Codex",
|
||||
Description: "OpenAI's open-source coding agent",
|
||||
Selectable: true,
|
||||
Changeable: true,
|
||||
},
|
||||
"openclaw": {
|
||||
Name: "openclaw",
|
||||
DisplayName: "OpenClaw",
|
||||
Description: "Personal AI with 100+ skills",
|
||||
Selectable: true,
|
||||
Changeable: true,
|
||||
AutoInstallable: true,
|
||||
},
|
||||
"droid": {
|
||||
Name: "droid",
|
||||
DisplayName: "Droid",
|
||||
Description: "Factory's coding agent across terminal and IDEs",
|
||||
Selectable: true,
|
||||
Changeable: true,
|
||||
},
|
||||
"pi": {
|
||||
Name: "pi",
|
||||
DisplayName: "Pi",
|
||||
Description: "Minimal AI agent toolkit with plugin support",
|
||||
Selectable: true,
|
||||
Changeable: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuRendersPinnedItemsAndMore(t *testing.T) {
|
||||
view := newModel(launcherTestState()).View()
|
||||
for _, want := range []string{"Chat with a model", "Launch Claude Code", "Launch Codex", "Launch OpenClaw", "More..."} {
|
||||
if !strings.Contains(view, want) {
|
||||
t.Fatalf("expected menu view to contain %q\n%s", want, view)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuExpandsOthersFromLastSelection(t *testing.T) {
|
||||
state := launcherTestState()
|
||||
state.LastSelection = "pi"
|
||||
|
||||
menu := newModel(state)
|
||||
if !menu.showOthers {
|
||||
t.Fatal("expected others section to expand when last selection is in the overflow list")
|
||||
}
|
||||
view := menu.View()
|
||||
if !strings.Contains(view, "Launch Pi") {
|
||||
t.Fatalf("expected expanded view to contain overflow integration\n%s", view)
|
||||
}
|
||||
if strings.Contains(view, "More...") {
|
||||
t.Fatalf("expected expanded view to replace More... item\n%s", view)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuEnterOnRunSelectsRun(t *testing.T) {
|
||||
menu := newModel(launcherTestState())
|
||||
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
got := updated.(model)
|
||||
want := TUIAction{Kind: TUIActionRunModel}
|
||||
if !got.selected || got.action != want {
|
||||
t.Fatalf("expected enter on run to select run action, got selected=%v action=%v", got.selected, got.action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuRightOnRunSelectsChangeRun(t *testing.T) {
|
||||
menu := newModel(launcherTestState())
|
||||
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
|
||||
got := updated.(model)
|
||||
want := TUIAction{Kind: TUIActionRunModel, ForceConfigure: true}
|
||||
if !got.selected || got.action != want {
|
||||
t.Fatalf("expected right on run to select change-run action, got selected=%v action=%v", got.selected, got.action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuEnterOnIntegrationSelectsLaunch(t *testing.T) {
|
||||
menu := newModel(launcherTestState())
|
||||
menu.cursor = 1
|
||||
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
got := updated.(model)
|
||||
want := TUIAction{Kind: TUIActionLaunchIntegration, Integration: "claude"}
|
||||
if !got.selected || got.action != want {
|
||||
t.Fatalf("expected enter on integration to launch, got selected=%v action=%v", got.selected, got.action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuRightOnIntegrationSelectsConfigure(t *testing.T) {
|
||||
menu := newModel(launcherTestState())
|
||||
menu.cursor = 1
|
||||
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
|
||||
got := updated.(model)
|
||||
want := TUIAction{Kind: TUIActionLaunchIntegration, Integration: "claude", ForceConfigure: true}
|
||||
if !got.selected || got.action != want {
|
||||
t.Fatalf("expected right on integration to configure, got selected=%v action=%v", got.selected, got.action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuIgnoresDisabledActions(t *testing.T) {
|
||||
state := launcherTestState()
|
||||
claude := state.Integrations["claude"]
|
||||
claude.Selectable = false
|
||||
claude.Changeable = false
|
||||
state.Integrations["claude"] = claude
|
||||
|
||||
menu := newModel(state)
|
||||
menu.cursor = 1
|
||||
|
||||
updatedEnter, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
if updatedEnter.(model).selected {
|
||||
t.Fatal("expected non-selectable integration to ignore enter")
|
||||
}
|
||||
|
||||
updatedRight, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
|
||||
if updatedRight.(model).selected {
|
||||
t.Fatal("expected non-changeable integration to ignore right")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuShowsCurrentModelSuffixes(t *testing.T) {
|
||||
menu := newModel(launcherTestState())
|
||||
runView := menu.View()
|
||||
if !strings.Contains(runView, "(qwen3:8b)") {
|
||||
t.Fatalf("expected run row to show current model suffix\n%s", runView)
|
||||
}
|
||||
|
||||
menu.cursor = 1
|
||||
integrationView := menu.View()
|
||||
if !strings.Contains(integrationView, "(glm-5:cloud)") {
|
||||
t.Fatalf("expected integration row to show current model suffix\n%s", integrationView)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuShowsInstallStatusAndHint(t *testing.T) {
|
||||
state := launcherTestState()
|
||||
codex := state.Integrations["codex"]
|
||||
codex.Installed = false
|
||||
codex.Selectable = false
|
||||
codex.Changeable = false
|
||||
codex.InstallHint = "Install from https://example.com/codex"
|
||||
state.Integrations["codex"] = codex
|
||||
|
||||
menu := newModel(state)
|
||||
menu.cursor = 2
|
||||
view := menu.View()
|
||||
if !strings.Contains(view, "(not installed)") {
|
||||
t.Fatalf("expected not-installed marker\n%s", view)
|
||||
}
|
||||
if !strings.Contains(view, codex.InstallHint) {
|
||||
t.Fatalf("expected install hint in description\n%s", view)
|
||||
}
|
||||
}
|
||||
@@ -290,6 +290,8 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
conv = &gemma3Model{Architecture: p.Architectures[0]}
|
||||
case "Gemma3nForConditionalGeneration":
|
||||
conv = &gemma3nModel{}
|
||||
case "Gemma4ForCausalLM", "Gemma4ForConditionalGeneration":
|
||||
conv = &gemma4Model{Architecture: p.Architectures[0]}
|
||||
case "Phi3ForCausalLM":
|
||||
conv = &phi3Model{}
|
||||
case "Qwen2ForCausalLM":
|
||||
|
||||
556
convert/convert_gemma4.go
Normal file
@@ -0,0 +1,556 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type gemma4Model struct {
|
||||
gemmaModel
|
||||
Architecture string
|
||||
TextModel struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
GlobalHeadDim uint32 `json:"global_head_dim"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
SlidingWindowPattern *int32 `json:"_sliding_window_pattern"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
FinalLogitSoftcapping float32 `json:"final_logit_softcapping"`
|
||||
EnableMoeBlock bool `json:"enable_moe_block"`
|
||||
NumExperts *uint32 `json:"num_experts"`
|
||||
TopKExperts *uint32 `json:"top_k_experts"`
|
||||
ExpertIntermediateSize *uint32 `json:"moe_intermediate_size"`
|
||||
HiddenSizePerLayerInput *uint32 `json:"hidden_size_per_layer_input"`
|
||||
NumKVSharedLayers uint32 `json:"num_kv_shared_layers"`
|
||||
AttentionKEqV bool `json:"attention_k_eq_v"`
|
||||
NumGlobalKeyValueHeads *uint32 `json:"num_global_key_value_heads"`
|
||||
QueryPreAttnScalar *uint32 `json:"query_pre_attn_scalar"`
|
||||
UseDoubleWideMLP bool `json:"use_double_wide_mlp"`
|
||||
RopeParameters map[string]*struct {
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
PartialRotaryFactor *float32 `json:"partial_rotary_factor"`
|
||||
} `json:"rope_parameters"`
|
||||
} `json:"text_config"`
|
||||
|
||||
VisionModel struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
NumChannels uint32 `json:"num_channels"`
|
||||
PoolingKernelSize uint32 `json:"pooling_kernel_size"`
|
||||
LayerNormEps float32 `json:"layer_norm_eps"`
|
||||
} `json:"vision_config"`
|
||||
|
||||
AudioModel *struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
OutputProjDims uint32 `json:"output_proj_dims"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
ConvKernelSize uint32 `json:"conv_kernel_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
} `json:"audio_config"`
|
||||
}
|
||||
|
||||
func (p *gemma4Model) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma4"
|
||||
kv["tokenizer.ggml.model"] = "llama"
|
||||
kv["tokenizer.ggml.pre"] = "gemma4"
|
||||
|
||||
tc := p.TextModel
|
||||
|
||||
kv["gemma4.block_count"] = tc.NumHiddenLayers
|
||||
kv["gemma4.embedding_length"] = tc.HiddenSize
|
||||
|
||||
// Per-layer FFN width: when use_double_wide_mlp is set, KV-shared layers get 2x FFN width.
|
||||
if tc.UseDoubleWideMLP && tc.NumKVSharedLayers > 0 {
|
||||
firstShared := int(tc.NumHiddenLayers) - int(tc.NumKVSharedLayers)
|
||||
ffnWidths := make([]int32, tc.NumHiddenLayers)
|
||||
for i := range ffnWidths {
|
||||
if i >= firstShared {
|
||||
ffnWidths[i] = int32(tc.IntermediateSize * 2)
|
||||
} else {
|
||||
ffnWidths[i] = int32(tc.IntermediateSize)
|
||||
}
|
||||
}
|
||||
kv["gemma4.feed_forward_length"] = ffnWidths
|
||||
} else {
|
||||
kv["gemma4.feed_forward_length"] = tc.IntermediateSize
|
||||
}
|
||||
kv["gemma4.context_length"] = tc.MaxPositionEmbeddings
|
||||
kv["gemma4.attention.head_count"] = tc.NumAttentionHeads
|
||||
// Per-layer KV head count array: SWA layers use NumKeyValueHeads, global layers use NumGlobalKeyValueHeads
|
||||
if tc.NumGlobalKeyValueHeads != nil && *tc.NumGlobalKeyValueHeads != tc.NumKeyValueHeads && len(tc.LayerTypes) > 0 {
|
||||
kvHeads := make([]int32, len(tc.LayerTypes))
|
||||
for i, lt := range tc.LayerTypes {
|
||||
if lt == "sliding_attention" {
|
||||
kvHeads[i] = int32(tc.NumKeyValueHeads)
|
||||
} else {
|
||||
kvHeads[i] = int32(*tc.NumGlobalKeyValueHeads)
|
||||
}
|
||||
}
|
||||
kv["gemma4.attention.head_count_kv"] = kvHeads
|
||||
} else {
|
||||
kv["gemma4.attention.head_count_kv"] = tc.NumKeyValueHeads
|
||||
}
|
||||
// key_length = global head dim, key_length_swa = local (SWA) head dim
|
||||
kv["gemma4.attention.key_length"] = tc.GlobalHeadDim
|
||||
kv["gemma4.attention.value_length"] = tc.GlobalHeadDim
|
||||
kv["gemma4.attention.key_length_swa"] = tc.HeadDim
|
||||
kv["gemma4.attention.value_length_swa"] = tc.HeadDim
|
||||
kv["gemma4.attention.layer_norm_rms_epsilon"] = tc.RMSNormEps
|
||||
kv["gemma4.attention.sliding_window"] = tc.SlidingWindow
|
||||
|
||||
// Sliding window pattern from layer_types
|
||||
if len(tc.LayerTypes) > 0 {
|
||||
kv["gemma4.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
|
||||
for _, lt := range tc.LayerTypes {
|
||||
if !yield(lt == "sliding_attention") {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
kv["gemma4.attention.shared_kv_layers"] = tc.NumKVSharedLayers
|
||||
|
||||
// RoPE: dimension_count is the full global head dim (freq_factors handle partial rotation)
|
||||
if rp, ok := tc.RopeParameters["full_attention"]; ok && rp != nil {
|
||||
kv["gemma4.rope.freq_base"] = rp.RopeTheta
|
||||
kv["gemma4.rope.dimension_count"] = tc.GlobalHeadDim
|
||||
}
|
||||
if rp, ok := tc.RopeParameters["sliding_attention"]; ok && rp != nil {
|
||||
kv["gemma4.rope.freq_base_swa"] = rp.RopeTheta
|
||||
kv["gemma4.rope.dimension_count_swa"] = tc.HeadDim
|
||||
}
|
||||
|
||||
if tc.FinalLogitSoftcapping > 0 {
|
||||
kv["gemma4.final_logit_softcapping"] = tc.FinalLogitSoftcapping
|
||||
}
|
||||
|
||||
// MoE
|
||||
if tc.EnableMoeBlock && tc.NumExperts != nil {
|
||||
kv["gemma4.expert_count"] = *tc.NumExperts
|
||||
if tc.TopKExperts != nil {
|
||||
kv["gemma4.expert_used_count"] = *tc.TopKExperts
|
||||
}
|
||||
if tc.ExpertIntermediateSize != nil {
|
||||
kv["gemma4.expert_feed_forward_length"] = *tc.ExpertIntermediateSize
|
||||
}
|
||||
}
|
||||
|
||||
// PLE — always emit, even when 0
|
||||
pleSize := uint32(0)
|
||||
if tc.HiddenSizePerLayerInput != nil {
|
||||
pleSize = *tc.HiddenSizePerLayerInput
|
||||
}
|
||||
kv["gemma4.embedding_length_per_layer_input"] = pleSize
|
||||
|
||||
// Vision model KV metadata
|
||||
vc := p.VisionModel
|
||||
if vc.NumHiddenLayers > 0 {
|
||||
kv["gemma4.vision.block_count"] = vc.NumHiddenLayers
|
||||
kv["gemma4.vision.embedding_length"] = vc.HiddenSize
|
||||
kv["gemma4.vision.attention.head_count"] = vc.NumAttentionHeads
|
||||
kv["gemma4.vision.feed_forward_length"] = vc.IntermediateSize
|
||||
kv["gemma4.vision.patch_size"] = vc.PatchSize
|
||||
numCh := vc.NumChannels
|
||||
if numCh == 0 {
|
||||
numCh = 3
|
||||
}
|
||||
kv["gemma4.vision.num_channels"] = numCh
|
||||
nMerge := vc.PoolingKernelSize
|
||||
if nMerge == 0 {
|
||||
nMerge = 3
|
||||
}
|
||||
kv["gemma4.vision.projector.scale_factor"] = nMerge
|
||||
eps := vc.LayerNormEps
|
||||
if eps == 0 {
|
||||
eps = 1e-6
|
||||
}
|
||||
kv["gemma4.vision.attention.layer_norm_epsilon"] = eps
|
||||
}
|
||||
|
||||
// Audio model KV metadata
|
||||
if p.AudioModel != nil && p.AudioModel.NumHiddenLayers > 0 {
|
||||
ac := p.AudioModel
|
||||
kv["gemma4.audio.block_count"] = ac.NumHiddenLayers
|
||||
kv["gemma4.audio.embedding_length"] = ac.HiddenSize
|
||||
kv["gemma4.audio.feed_forward_length"] = ac.HiddenSize * 4
|
||||
kv["gemma4.audio.attention.head_count"] = ac.NumAttentionHeads
|
||||
eps := ac.RMSNormEps
|
||||
if eps == 0 {
|
||||
eps = 1e-6
|
||||
}
|
||||
kv["gemma4.audio.attention.layer_norm_epsilon"] = eps
|
||||
if ac.ConvKernelSize > 0 {
|
||||
kv["gemma4.audio.conv_kernel_size"] = ac.ConvKernelSize
|
||||
}
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *gemma4Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
// First pass: collect vision clamp scalar values into a packed tensor.
|
||||
// Layout: per vision layer (0..N-1), 7 linears (q,k,v,out,gate,up,down) × 4 values (inMin,inMax,outMin,outMax).
|
||||
// Then 4 values for the projector (mm.input_projection).
|
||||
clampSuffixes := []string{".input_min", ".input_max", ".output_min", ".output_max"}
|
||||
clampMap := make(map[string]float32)
|
||||
for _, t := range ts {
|
||||
name := t.Name()
|
||||
for _, sfx := range clampSuffixes {
|
||||
if strings.HasSuffix(name, sfx) && (strings.Contains(name, "vision_tower") || strings.Contains(name, "embed_vision")) {
|
||||
var buf bytes.Buffer
|
||||
t.WriteTo(&buf)
|
||||
data := buf.Bytes()
|
||||
if len(data) >= 4 {
|
||||
clampMap[name] = math.Float32frombits(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16 | uint32(data[3])<<24)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var out []*ggml.Tensor
|
||||
for _, t := range ts {
|
||||
name := t.Name()
|
||||
|
||||
// Skip embedding_post_projection_norm — used as weightless RMS norm in inference
|
||||
if strings.Contains(name, "embedding_post_projection_norm") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Vision tensor renaming: match published mmproj GGUF names
|
||||
if strings.HasPrefix(name, "v.blk.") {
|
||||
name = strings.Replace(name, ".attn_norm.", ".ln1.", 1)
|
||||
name = strings.Replace(name, ".ffn_norm.", ".ln2.", 1)
|
||||
name = strings.Replace(name, ".attn_output.", ".attn_out.", 1)
|
||||
name = strings.Replace(name, ".post_attention_norm.", ".attn_post_norm.", 1)
|
||||
name = strings.Replace(name, ".post_ffw_norm.", ".ffn_post_norm.", 1)
|
||||
name = strings.Replace(name, ".layer_output_scale.", ".out_scale.", 1)
|
||||
}
|
||||
|
||||
// per_dim_scale: apply softplus to weight data and add .weight suffix.
|
||||
if strings.HasPrefix(name, "a.blk.") && strings.HasSuffix(name, "per_dim_scale") {
|
||||
name = name + ".weight"
|
||||
t.SetRepacker(softplusRepacker)
|
||||
}
|
||||
|
||||
// Depthwise conv1d: squeeze middle dimension [C, 1, K] → [C, K].
|
||||
if strings.HasPrefix(name, "a.blk.") && strings.Contains(name, "conv_dw") && strings.HasSuffix(name, ".weight") {
|
||||
t.SetRepacker(squeezeMiddleDim)
|
||||
}
|
||||
|
||||
shape := t.Shape()
|
||||
|
||||
// Convert scalar tensors (input_min/max, output_min/max) to 1D
|
||||
if len(shape) == 0 {
|
||||
shape = []uint64{1}
|
||||
}
|
||||
|
||||
// Depthwise conv1d shape: safetensors [C, 1, K] → GGUF ne[K, C].
|
||||
// Shape array here maps to GGUF ne[] directly, but safetensors reader
|
||||
// stores shape in PyTorch order [C, 1, K] which the GGUF writer inverts.
|
||||
// Published GGUF has ne[0]=K, ne[1]=C → shape array must be [K, C].
|
||||
if strings.HasPrefix(name, "a.blk.") && strings.Contains(name, "conv_dw") && strings.HasSuffix(name, ".weight") && len(shape) == 3 {
|
||||
shape = []uint64{shape[0], shape[2]}
|
||||
}
|
||||
|
||||
// MoE expert weights: no transpose needed. Safetensors stores [experts, out, in]
|
||||
// which the framework reverses to GGUF ne=[in, out, experts], matching ggml_mul_mat_id.
|
||||
// (transposeExperts was incorrectly swapping dims — removed)
|
||||
|
||||
// Audio conv weights are forced to F32 via tensorBase.Kind() in reader.go
|
||||
// (im2col doesn't support BF16). No kindOverride needed — the Kind() method
|
||||
// controls both the GGUF header type AND the WriteTo data encoding path.
|
||||
var kindOverride *uint32
|
||||
|
||||
// Vision patch embedding: reshape from [n_embd, ksize_sq_c] to [n_embd, 3, patch_size, patch_size]
|
||||
// Must be stored as F16 (not BF16) because the Conv2D im2col kernel requires F16/F32.
|
||||
if strings.Contains(name, "v.patch_embd.weight") && len(shape) == 2 {
|
||||
nEmbd := shape[0]
|
||||
patchSize := uint64(p.VisionModel.PatchSize)
|
||||
if patchSize == 0 {
|
||||
patchSize = 16
|
||||
}
|
||||
numCh := uint64(p.VisionModel.NumChannels)
|
||||
if numCh == 0 {
|
||||
numCh = 3
|
||||
}
|
||||
t.SetRepacker(p.reshapePatchEmbed)
|
||||
shape = []uint64{nEmbd, numCh, patchSize, patchSize}
|
||||
f16Kind := uint32(1) // tensorKindFP16
|
||||
kindOverride = &f16Kind
|
||||
}
|
||||
|
||||
// Vision position embedding: keep 3D [2, maxPos, nEmbd] — matching published mmproj format.
|
||||
// The framework reverses shape to GGUF ne=[nEmbd, maxPos, 2]. No data repacking needed.
|
||||
|
||||
kind := t.Kind()
|
||||
if kindOverride != nil {
|
||||
kind = *kindOverride
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: kind,
|
||||
Shape: shape,
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
// Generate a single global rope_freqs.weight for proportional RoPE on global attention layers.
|
||||
// This matches the published GGUF format: one global tensor shared by all layers.
|
||||
// Global layers use partial_rotary_factor (0.25) — only rotate that fraction of dims.
|
||||
// Dimensions beyond the rotated portion get freq_factor=1e30 (effectively no rotation).
|
||||
tc := p.TextModel
|
||||
if tc.GlobalHeadDim > 0 {
|
||||
globalFreqsSize := tc.GlobalHeadDim / 2 // freq_factors are per dimension pair
|
||||
|
||||
// Compute number of rotated pairs for global layers
|
||||
partialRotaryFactor := float32(0.25) // default
|
||||
if rp, ok := tc.RopeParameters["full_attention"]; ok && rp != nil && rp.PartialRotaryFactor != nil {
|
||||
partialRotaryFactor = *rp.PartialRotaryFactor
|
||||
}
|
||||
nRotFull := int(float32(tc.GlobalHeadDim) * partialRotaryFactor / 2)
|
||||
|
||||
freqs := make(ropeFactor, globalFreqsSize)
|
||||
for j := range freqs {
|
||||
if j < nRotFull {
|
||||
freqs[j] = 1.0
|
||||
} else {
|
||||
freqs[j] = 1e30 // effectively disable rotation
|
||||
}
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: "rope_freqs.weight",
|
||||
Kind: 0, // F32
|
||||
Shape: []uint64{uint64(len(freqs))},
|
||||
WriterTo: freqs,
|
||||
})
|
||||
}
|
||||
|
||||
// Emit packed vision clamp data as a single F32 tensor.
|
||||
// Layout: numLayers × 7 linears (q,k,v,out,gate,up,down) × 4 floats (inMin,inMax,outMin,outMax)
|
||||
// then 4 floats for the projector. Total = (numLayers*7 + 1) * 4 floats.
|
||||
if len(clampMap) > 0 {
|
||||
numLayers := int(p.VisionModel.NumHiddenLayers)
|
||||
linearNames := []string{"attn_q", "attn_k", "attn_v", "attn_out", "ffn_gate", "ffn_up", "ffn_down"}
|
||||
suffixes := []string{".input_min", ".input_max", ".output_min", ".output_max"}
|
||||
|
||||
totalFloats := (numLayers*len(linearNames) + 1) * 4 // +1 for projector
|
||||
clampData := make([]float32, totalFloats)
|
||||
|
||||
for layer := range numLayers {
|
||||
for li, ln := range linearNames {
|
||||
for si, sfx := range suffixes {
|
||||
sfxMap := map[string]string{"attn_q": "q_proj", "attn_k": "k_proj", "attn_v": "v_proj", "attn_out": "o_proj", "ffn_gate": "gate_proj", "ffn_up": "up_proj", "ffn_down": "down_proj"}
|
||||
for origName, val := range clampMap {
|
||||
if strings.Contains(origName, fmt.Sprintf("layers.%d.", layer)) && strings.HasSuffix(origName, sfx) && strings.Contains(origName, sfxMap[ln]) {
|
||||
idx := (layer*len(linearNames)+li)*4 + si
|
||||
clampData[idx] = val
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Projector clamp values
|
||||
projIdx := numLayers * len(linearNames) * 4
|
||||
for si, sfx := range suffixes {
|
||||
for origName, val := range clampMap {
|
||||
if strings.Contains(origName, "input_projection") && strings.HasSuffix(origName, sfx) {
|
||||
clampData[projIdx+si] = val
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
binary.Write(&buf, binary.LittleEndian, clampData)
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: "v.clamp_data",
|
||||
Kind: 0, // F32
|
||||
Shape: []uint64{uint64(totalFloats)},
|
||||
WriterTo: &buf,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// reshapePatchEmbed reshapes the vision patch embedding from HF layout [n_embd, ksize*ksize*channels]
|
||||
// to GGUF layout [n_embd, channels, patch_size, patch_size].
|
||||
func (*gemma4Model) reshapePatchEmbed(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
if len(shape) != 2 {
|
||||
return data, nil
|
||||
}
|
||||
nEmbd := int(shape[0])
|
||||
ksqC := int(shape[1])
|
||||
nChannels := 3
|
||||
patchSize := int(math.Sqrt(float64(ksqC / nChannels)))
|
||||
|
||||
// HF layout: [n_embd, patch_size * patch_size * channels] (row-major)
|
||||
// Need: [n_embd, channels, patch_size, patch_size]
|
||||
result := make([]float32, len(data))
|
||||
for e := range nEmbd {
|
||||
for c := range nChannels {
|
||||
for h := range patchSize {
|
||||
for w := range patchSize {
|
||||
srcIdx := e*ksqC + h*patchSize*nChannels + w*nChannels + c
|
||||
dstIdx := e*nChannels*patchSize*patchSize + c*patchSize*patchSize + h*patchSize + w
|
||||
result[dstIdx] = data[srcIdx]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
shape[0] = uint64(nEmbd)
|
||||
shape[1] = uint64(nChannels * patchSize * patchSize)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// softplusRepacker applies softplus (ln(1 + exp(x))) to tensor data.
|
||||
// Used for per_dim_scale tensors which the published GGUF stores pre-activated.
|
||||
func softplusRepacker(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
result := make([]float32, len(data))
|
||||
for i, x := range data {
|
||||
result[i] = float32(math.Log(1 + math.Exp(float64(x))))
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// squeezeMiddleDim squeezes the middle dimension from [C, 1, K] → [C, K] for depthwise conv1d weights.
|
||||
// Data layout stays the same since the middle dim is 1 — just a shape change.
|
||||
func squeezeMiddleDim(_ string, data []float32, _ []uint64) ([]float32, error) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (p *gemma4Model) Replacements() []string {
|
||||
return []string{
|
||||
// ClippableLinear wraps nn.Linear — strip .linear. from weight path
|
||||
".linear.weight", ".weight",
|
||||
".linear.bias", ".bias",
|
||||
|
||||
// Audio SSCP (Sub-Sample Convolution Projection)
|
||||
"model.audio_tower.subsample_conv_projection.conv_0.conv", "a.conv1d.0",
|
||||
"model.audio_tower.subsample_conv_projection.conv_0.norm", "a.conv1d.0.norm",
|
||||
"model.audio_tower.subsample_conv_projection.conv_1.conv", "a.conv1d.1",
|
||||
"model.audio_tower.subsample_conv_projection.conv_1.norm", "a.conv1d.1.norm",
|
||||
"model.audio_tower.subsample_conv_projection.input_proj_linear", "a.pre_encode.out",
|
||||
|
||||
// Audio conformer blocks
|
||||
"model.audio_tower.conformer", "a.blk",
|
||||
|
||||
// Audio conformer attention
|
||||
"attention.attn.relative_position_embedding.pos_proj", "linear_pos",
|
||||
"attention.attn.per_dim_key_scale", "per_dim_k_scale",
|
||||
"attention.attn.per_dim_scale", "per_dim_scale",
|
||||
"attention.attn.q_proj", "attn_q",
|
||||
"attention.attn.k_proj", "attn_k",
|
||||
"attention.attn.v_proj", "attn_v",
|
||||
"attention.pre_attn_norm", "ln1",
|
||||
"attention.post_norm", "ln2",
|
||||
"attention.post", "attn_out",
|
||||
|
||||
// Audio conformer feedforward
|
||||
"ffw_layer_start.pre_layer_norm", "ffn_norm",
|
||||
"ffw_layer_start.post_layer_norm", "ffn_post_norm",
|
||||
"ffw_layer_start.ffw_layer_1", "ffn_up",
|
||||
"ffw_layer_start.ffw_layer_2", "ffn_down",
|
||||
"ffw_layer_end.pre_layer_norm", "ffn_norm_1",
|
||||
"ffw_layer_end.post_layer_norm", "ffn_post_norm_1",
|
||||
"ffw_layer_end.ffw_layer_1", "ffn_up_1",
|
||||
"ffw_layer_end.ffw_layer_2", "ffn_down_1",
|
||||
|
||||
// Audio conformer lightweight conv1d
|
||||
"lconv1d.depthwise_conv1d", "conv_dw",
|
||||
"lconv1d.pre_layer_norm", "conv_norm",
|
||||
"lconv1d.conv_norm", "norm_conv",
|
||||
"lconv1d.linear_start", "conv_pw1",
|
||||
"lconv1d.linear_end", "conv_pw2",
|
||||
|
||||
// Audio block final norm
|
||||
"norm_out", "layer_pre_norm",
|
||||
|
||||
// Audio embedder and output projection
|
||||
"model.embed_audio.embedding_projection", "mm.a.input_projection",
|
||||
"model.audio_tower.output_proj", "mm.a.fc",
|
||||
|
||||
// Vision encoder
|
||||
"model.vision_tower.encoder.layers", "v.blk",
|
||||
"model.vision_tower.patch_embedder.input_proj", "v.patch_embd",
|
||||
"model.vision_tower.patch_embedder.position_embedding_table", "v.position_embd.weight",
|
||||
"model.vision_tower.std_bias", "v.std_bias",
|
||||
"model.vision_tower.std_scale", "v.std_scale",
|
||||
|
||||
// Vision multimodal projector
|
||||
"model.embed_vision.embedding_projection", "mm.input_projection",
|
||||
|
||||
// Text model
|
||||
"model.language_model.embed_tokens_per_layer", "per_layer_token_embd",
|
||||
"model.language_model.embed_tokens", "token_embd",
|
||||
"model.language_model.per_layer_model_projection", "per_layer_model_proj",
|
||||
"model.language_model.per_layer_projection_norm", "per_layer_proj_norm",
|
||||
"model.language_model.norm", "output_norm",
|
||||
"model.language_model.layers", "blk",
|
||||
|
||||
// Shared attention replacements (work for both text and vision tensors)
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.q_norm", "attn_q_norm",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.k_norm", "attn_k_norm",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
|
||||
// Post norms
|
||||
"post_attention_layernorm", "post_attention_norm",
|
||||
"pre_feedforward_layernorm_2", "pre_ffw_norm_2",
|
||||
"pre_feedforward_layernorm", "ffn_norm",
|
||||
"post_feedforward_layernorm_1", "post_ffw_norm_1",
|
||||
"post_feedforward_layernorm_2", "post_ffw_norm_2",
|
||||
"post_feedforward_layernorm", "post_ffw_norm",
|
||||
|
||||
// PLE
|
||||
"per_layer_input_gate", "inp_gate",
|
||||
"per_layer_projection", "proj",
|
||||
"post_per_layer_input_norm", "post_norm",
|
||||
|
||||
// MoE
|
||||
"router.proj", "ffn_gate_inp",
|
||||
"router.scale", "ffn_gate_inp.scale",
|
||||
"router.per_expert_scale.weight", "ffn_down_exps.scale",
|
||||
"router.per_expert_scale", "ffn_down_exps.scale",
|
||||
"experts.gate_up_proj.weight", "ffn_gate_up_exps.weight",
|
||||
"experts.gate_up_proj", "ffn_gate_up_exps.weight",
|
||||
"experts.down_proj.weight", "ffn_down_exps.weight",
|
||||
"experts.down_proj", "ffn_down_exps.weight",
|
||||
"moe.gate_proj", "ffn_gate_exps.weight",
|
||||
"moe.up_proj", "ffn_up_exps.weight",
|
||||
"moe.gate_up_proj.weight", "ffn_gate_up_exps.weight",
|
||||
"moe.gate_up_proj", "ffn_gate_up_exps.weight",
|
||||
"moe.down_proj", "ffn_down_exps.weight",
|
||||
"moe.per_expert_scale.weight", "ffn_down_exps.scale",
|
||||
"moe.per_expert_scale", "ffn_down_exps.scale",
|
||||
|
||||
// Layer scalar
|
||||
"layer_scalar", "layer_output_scale.weight",
|
||||
}
|
||||
}
|
||||
263
convert/convert_gemma4_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGemma4AudioReplacements(t *testing.T) {
|
||||
p := gemma4Model{}
|
||||
r := strings.NewReplacer(p.Replacements()...)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
// SSCP convolution blocks
|
||||
{
|
||||
"sscp conv0 weight",
|
||||
"model.audio_tower.subsample_conv_projection.conv_0.conv.weight",
|
||||
"a.conv1d.0.weight",
|
||||
},
|
||||
{
|
||||
"sscp conv0 norm",
|
||||
"model.audio_tower.subsample_conv_projection.conv_0.norm.weight",
|
||||
"a.conv1d.0.norm.weight",
|
||||
},
|
||||
{
|
||||
"sscp conv1 weight",
|
||||
"model.audio_tower.subsample_conv_projection.conv_1.conv.weight",
|
||||
"a.conv1d.1.weight",
|
||||
},
|
||||
{
|
||||
"sscp input proj weight",
|
||||
"model.audio_tower.subsample_conv_projection.input_proj_linear.weight",
|
||||
"a.pre_encode.out.weight",
|
||||
},
|
||||
{
|
||||
"sscp input proj bias",
|
||||
"model.audio_tower.subsample_conv_projection.input_proj_linear.bias",
|
||||
"a.pre_encode.out.bias",
|
||||
},
|
||||
|
||||
// Conformer attention
|
||||
{
|
||||
"attn q weight",
|
||||
"model.audio_tower.conformer.0.attention.attn.q_proj.linear.weight",
|
||||
"a.blk.0.attn_q.weight",
|
||||
},
|
||||
{
|
||||
"attn k weight",
|
||||
"model.audio_tower.conformer.5.attention.attn.k_proj.linear.weight",
|
||||
"a.blk.5.attn_k.weight",
|
||||
},
|
||||
{
|
||||
"attn v clamp input_min",
|
||||
"model.audio_tower.conformer.0.attention.attn.v_proj.input_min",
|
||||
"a.blk.0.attn_v.input_min",
|
||||
},
|
||||
{
|
||||
"attn out weight (ClippableLinear)",
|
||||
"model.audio_tower.conformer.0.attention.post.linear.weight",
|
||||
"a.blk.0.attn_out.weight",
|
||||
},
|
||||
{
|
||||
"attn out clamp output_max",
|
||||
"model.audio_tower.conformer.0.attention.post.output_max",
|
||||
"a.blk.0.attn_out.output_max",
|
||||
},
|
||||
{
|
||||
"attn pre norm",
|
||||
"model.audio_tower.conformer.0.attention.pre_attn_norm.weight",
|
||||
"a.blk.0.ln1.weight",
|
||||
},
|
||||
{
|
||||
"attn post norm",
|
||||
"model.audio_tower.conformer.0.attention.post_norm.weight",
|
||||
"a.blk.0.ln2.weight",
|
||||
},
|
||||
{
|
||||
"linear pos",
|
||||
"model.audio_tower.conformer.0.attention.attn.relative_position_embedding.pos_proj.weight",
|
||||
"a.blk.0.linear_pos.weight",
|
||||
},
|
||||
{
|
||||
"per dim scale",
|
||||
"model.audio_tower.conformer.0.attention.attn.per_dim_scale",
|
||||
"a.blk.0.per_dim_scale",
|
||||
},
|
||||
{
|
||||
"per dim key scale",
|
||||
"model.audio_tower.conformer.0.attention.attn.per_dim_key_scale",
|
||||
"a.blk.0.per_dim_k_scale",
|
||||
},
|
||||
|
||||
// Conformer feedforward start
|
||||
{
|
||||
"ffn up weight",
|
||||
"model.audio_tower.conformer.0.ffw_layer_start.ffw_layer_1.linear.weight",
|
||||
"a.blk.0.ffn_up.weight",
|
||||
},
|
||||
{
|
||||
"ffn down weight",
|
||||
"model.audio_tower.conformer.0.ffw_layer_start.ffw_layer_2.linear.weight",
|
||||
"a.blk.0.ffn_down.weight",
|
||||
},
|
||||
{
|
||||
"ffn norm",
|
||||
"model.audio_tower.conformer.0.ffw_layer_start.pre_layer_norm.weight",
|
||||
"a.blk.0.ffn_norm.weight",
|
||||
},
|
||||
{
|
||||
"ffn post norm",
|
||||
"model.audio_tower.conformer.0.ffw_layer_start.post_layer_norm.weight",
|
||||
"a.blk.0.ffn_post_norm.weight",
|
||||
},
|
||||
|
||||
// Conformer feedforward end
|
||||
{
|
||||
"ffn up 1 weight",
|
||||
"model.audio_tower.conformer.0.ffw_layer_end.ffw_layer_1.linear.weight",
|
||||
"a.blk.0.ffn_up_1.weight",
|
||||
},
|
||||
{
|
||||
"ffn down 1 weight",
|
||||
"model.audio_tower.conformer.0.ffw_layer_end.ffw_layer_2.linear.weight",
|
||||
"a.blk.0.ffn_down_1.weight",
|
||||
},
|
||||
{
|
||||
"ffn norm 1",
|
||||
"model.audio_tower.conformer.0.ffw_layer_end.pre_layer_norm.weight",
|
||||
"a.blk.0.ffn_norm_1.weight",
|
||||
},
|
||||
{
|
||||
"ffn post norm 1",
|
||||
"model.audio_tower.conformer.0.ffw_layer_end.post_layer_norm.weight",
|
||||
"a.blk.0.ffn_post_norm_1.weight",
|
||||
},
|
||||
|
||||
// Conformer lightweight conv1d
|
||||
{
|
||||
"conv dw weight",
|
||||
"model.audio_tower.conformer.0.lconv1d.depthwise_conv1d.weight",
|
||||
"a.blk.0.conv_dw.weight",
|
||||
},
|
||||
{
|
||||
"conv norm (pre_layer_norm)",
|
||||
"model.audio_tower.conformer.0.lconv1d.pre_layer_norm.weight",
|
||||
"a.blk.0.conv_norm.weight",
|
||||
},
|
||||
{
|
||||
"norm conv (conv_norm)",
|
||||
"model.audio_tower.conformer.0.lconv1d.conv_norm.weight",
|
||||
"a.blk.0.norm_conv.weight",
|
||||
},
|
||||
{
|
||||
"conv pw1 weight",
|
||||
"model.audio_tower.conformer.0.lconv1d.linear_start.linear.weight",
|
||||
"a.blk.0.conv_pw1.weight",
|
||||
},
|
||||
{
|
||||
"conv pw2 weight",
|
||||
"model.audio_tower.conformer.0.lconv1d.linear_end.linear.weight",
|
||||
"a.blk.0.conv_pw2.weight",
|
||||
},
|
||||
|
||||
// Audio embedder
|
||||
{
|
||||
"audio embedder projection weight",
|
||||
"model.embed_audio.embedding_projection.linear.weight",
|
||||
"mm.a.input_projection.weight",
|
||||
},
|
||||
{
|
||||
"audio embedder projection bias",
|
||||
"model.embed_audio.embedding_projection.linear.bias",
|
||||
"mm.a.input_projection.bias",
|
||||
},
|
||||
|
||||
// Audio output projection
|
||||
{
|
||||
"audio output proj weight",
|
||||
"model.audio_tower.output_proj.weight",
|
||||
"mm.a.fc.weight",
|
||||
},
|
||||
{
|
||||
"audio output proj bias",
|
||||
"model.audio_tower.output_proj.bias",
|
||||
"mm.a.fc.bias",
|
||||
},
|
||||
|
||||
// Verify vision tensors still work
|
||||
{
|
||||
"vision q weight",
|
||||
"model.vision_tower.encoder.layers.0.self_attn.q_proj.linear.weight",
|
||||
"v.blk.0.attn_q.weight",
|
||||
},
|
||||
{
|
||||
"vision std bias",
|
||||
"model.vision_tower.std_bias",
|
||||
"v.std_bias",
|
||||
},
|
||||
{
|
||||
"vision std scale",
|
||||
"model.vision_tower.std_scale",
|
||||
"v.std_scale",
|
||||
},
|
||||
{
|
||||
"vision patch embd",
|
||||
"model.vision_tower.patch_embedder.input_proj.weight",
|
||||
"v.patch_embd.weight",
|
||||
},
|
||||
{
|
||||
"vision projector",
|
||||
"model.embed_vision.embedding_projection.linear.weight",
|
||||
"mm.input_projection.weight",
|
||||
},
|
||||
|
||||
// Verify text tensors still work
|
||||
{
|
||||
"text attn q",
|
||||
"model.language_model.layers.0.self_attn.q_proj.weight",
|
||||
"blk.0.attn_q.weight",
|
||||
},
|
||||
{
|
||||
"text token embd",
|
||||
"model.language_model.embed_tokens.weight",
|
||||
"token_embd.weight",
|
||||
},
|
||||
{
|
||||
"text moe gate up fused",
|
||||
"model.language_model.layers.0.experts.gate_up_proj",
|
||||
"blk.0.ffn_gate_up_exps.weight",
|
||||
},
|
||||
{
|
||||
"text moe down",
|
||||
"model.language_model.layers.0.experts.down_proj",
|
||||
"blk.0.ffn_down_exps.weight",
|
||||
},
|
||||
{
|
||||
"text moe down with weight suffix",
|
||||
"model.language_model.layers.0.experts.down_proj.weight",
|
||||
"blk.0.ffn_down_exps.weight",
|
||||
},
|
||||
{
|
||||
"text moe per expert scale",
|
||||
"model.language_model.layers.0.router.per_expert_scale",
|
||||
"blk.0.ffn_down_exps.scale",
|
||||
},
|
||||
{
|
||||
"text moe per expert scale with weight suffix",
|
||||
"model.language_model.layers.0.router.per_expert_scale.weight",
|
||||
"blk.0.ffn_down_exps.scale",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := r.Replace(tt.in); got != tt.want {
|
||||
t.Errorf("Replace(%q) = %q, want %q", tt.in, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -205,8 +205,8 @@ func TestConvertInvalidDatatype(t *testing.T) {
|
||||
generateSafetensorTestData(t, tempDir, td)
|
||||
|
||||
err = ConvertModel(os.DirFS(tempDir), f)
|
||||
if err == nil || err.Error() != "unsupported safetensors model" {
|
||||
t.Errorf("expected error but didn't get one")
|
||||
if err == nil || !strings.Contains(err.Error(), "unknown data type") {
|
||||
t.Errorf("expected 'unknown data type' error but got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -42,8 +42,11 @@ func (t tensorBase) Kind() uint32 {
|
||||
strings.HasSuffix(t.name, ".bias") ||
|
||||
strings.HasSuffix(t.name, ".shortconv.conv.weight") ||
|
||||
strings.HasSuffix(t.name, ".ssm_conv1d.weight") || // SSM conv kernel must be F32 for Metal
|
||||
strings.HasPrefix(t.name, "a.conv1d.") || // audio SSCP conv weights must be F32 for im2col
|
||||
strings.Contains(t.name, ".conv_dw.") || // audio depthwise conv weights must be F32
|
||||
t.name == "token_types.weight" ||
|
||||
t.name == "v.positional_embedding_vlm" ||
|
||||
t.name == "v.position_embd.weight" ||
|
||||
t.name == "v.tile_position_embd.weight" ||
|
||||
t.name == "v.pre_tile_position_embd.weight" ||
|
||||
t.name == "v.post_tile_position_embd.weight" ||
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
@@ -53,9 +52,10 @@ func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]T
|
||||
|
||||
for _, key := range keys {
|
||||
if value := headers[key]; value.Type != "" {
|
||||
// bitsandbytes quantized models are unsupported
|
||||
// Scalar tensors (e.g. clipped linear min/max) are 0-dim in safetensors.
|
||||
// Promote them to 1-dim so they can be stored in GGUF.
|
||||
if len(value.Shape) == 0 {
|
||||
return nil, errors.New("unsupported safetensors model")
|
||||
value.Shape = []uint64{1}
|
||||
}
|
||||
ggufName := replacer.Replace(key)
|
||||
if _, ok := names[ggufName]; ok {
|
||||
|
||||
@@ -12,7 +12,6 @@ To use Ollama with tools that expect the Anthropic API (like Claude Code), set t
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
|
||||
export ANTHROPIC_API_KEY="" # required but ignored
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
```
|
||||
|
||||
@@ -269,7 +268,7 @@ ollama launch claude --config
|
||||
Set the environment variables and run Claude Code:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Or set the environment variables in your shell profile:
|
||||
@@ -277,7 +276,6 @@ Or set the environment variables in your shell profile:
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=""
|
||||
```
|
||||
|
||||
Then run Claude Code with any Ollama model:
|
||||
|
||||
@@ -6,7 +6,7 @@ Ollama provides compatibility with parts of the [OpenAI API](https://platform.op
|
||||
|
||||
## Usage
|
||||
|
||||
### Simple `v1/chat/completions` example
|
||||
### Simple `/v1/chat/completions` example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
@@ -57,7 +57,7 @@ curl -X POST http://localhost:11434/v1/chat/completions \
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
### Simple `v1/responses` example
|
||||
### Simple `/v1/responses` example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
@@ -103,7 +103,7 @@ curl -X POST http://localhost:11434/v1/responses \
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
### v1/chat/completions with vision example
|
||||
### `/v1/chat/completions` with vision example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
@@ -184,6 +184,7 @@ curl -X POST http://localhost:11434/v1/chat/completions \
|
||||
- [x] Reproducible outputs
|
||||
- [x] Vision
|
||||
- [x] Tools
|
||||
- [x] Reasoning/thinking control (for thinking models)
|
||||
- [ ] Logprobs
|
||||
|
||||
#### Supported request fields
|
||||
@@ -207,6 +208,9 @@ curl -X POST http://localhost:11434/v1/chat/completions \
|
||||
- [x] `top_p`
|
||||
- [x] `max_tokens`
|
||||
- [x] `tools`
|
||||
- [x] `reasoning_effort` (`"high"`, `"medium"`, `"low"`, `"none"`)
|
||||
- [x] `reasoning`
|
||||
- [x] `effort` (`"high"`, `"medium"`, `"low"`, `"none"`)
|
||||
- [ ] `tool_choice`
|
||||
- [ ] `logit_bias`
|
||||
- [ ] `user`
|
||||
|
||||
@@ -21,6 +21,7 @@ Configure and launch external applications to use Ollama models. This provides a
|
||||
- **OpenCode** - Open-source coding assistant
|
||||
- **Claude Code** - Anthropic's agentic coding tool
|
||||
- **Codex** - OpenAI's coding assistant
|
||||
- **VS Code** - Microsoft's IDE with built-in AI chat
|
||||
- **Droid** - Factory's AI coding agent
|
||||
|
||||
#### Examples
|
||||
@@ -40,7 +41,7 @@ ollama launch claude
|
||||
Launch with a specific model:
|
||||
|
||||
```
|
||||
ollama launch claude --model qwen3-coder
|
||||
ollama launch claude --model qwen3.5
|
||||
```
|
||||
|
||||
Configure without launching:
|
||||
|
||||
@@ -51,6 +51,9 @@ Install prerequisites:
|
||||
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
|
||||
- (Optional) VULKAN GPU support
|
||||
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
|
||||
- (Optional) MLX engine support
|
||||
- [CUDA 13+ SDK](https://developer.nvidia.com/cuda-downloads)
|
||||
- [cuDNN 9+](https://developer.nvidia.com/cudnn)
|
||||
|
||||
Then, configure and build the project:
|
||||
|
||||
@@ -101,6 +104,10 @@ Install prerequisites:
|
||||
- (Optional) VULKAN GPU support
|
||||
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
|
||||
- Or install via package manager: `sudo apt install vulkan-sdk` (Ubuntu/Debian) or `sudo dnf install vulkan-sdk` (Fedora/CentOS)
|
||||
- (Optional) MLX engine support
|
||||
- [CUDA 13+ SDK](https://developer.nvidia.com/cuda-downloads)
|
||||
- [cuDNN 9+](https://developer.nvidia.com/cudnn)
|
||||
- OpenBLAS/LAPACK: `sudo apt install libopenblas-dev liblapack-dev liblapacke-dev` (Ubuntu/Debian)
|
||||
> [!IMPORTANT]
|
||||
> Ensure prerequisites are in `PATH` before running CMake.
|
||||
|
||||
@@ -118,6 +125,67 @@ Lastly, run Ollama:
|
||||
go run . serve
|
||||
```
|
||||
|
||||
## MLX Engine (Optional)
|
||||
|
||||
The MLX engine enables running safetensor based models. It requires building the [MLX](https://github.com/ml-explore/mlx) and [MLX-C](https://github.com/ml-explore/mlx-c) shared libraries separately via CMake. On MacOS, MLX leverages the Metal library to run on the GPU, and on Windows and Linux, runs on NVIDIA GPUs via CUDA v13.
|
||||
|
||||
### macOS (Apple Silicon)
|
||||
|
||||
Requires the Metal toolchain. Install [Xcode](https://developer.apple.com/xcode/) first, then:
|
||||
|
||||
```shell
|
||||
xcodebuild -downloadComponent MetalToolchain
|
||||
```
|
||||
|
||||
Verify it's installed correctly (should print "no input files"):
|
||||
|
||||
```shell
|
||||
xcrun metal
|
||||
```
|
||||
|
||||
Then build:
|
||||
|
||||
```shell
|
||||
cmake -B build --preset MLX
|
||||
cmake --build build --preset MLX --parallel
|
||||
cmake --install build --component MLX
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> Without the Metal toolchain, cmake will silently complete with Metal disabled. Check the cmake output for `Setting MLX_BUILD_METAL=OFF` which indicates the toolchain is missing.
|
||||
|
||||
### Windows / Linux (CUDA)
|
||||
|
||||
Requires CUDA 13+ and [cuDNN](https://developer.nvidia.com/cudnn) 9+.
|
||||
|
||||
```shell
|
||||
cmake -B build --preset "MLX CUDA 13"
|
||||
cmake --build build --target mlx --target mlxc --config Release --parallel
|
||||
cmake --install build --component MLX --strip
|
||||
```
|
||||
|
||||
### Local MLX source overrides
|
||||
|
||||
To build against a local checkout of MLX and/or MLX-C (useful for development), set environment variables before running CMake:
|
||||
|
||||
```shell
|
||||
export OLLAMA_MLX_SOURCE=/path/to/mlx
|
||||
export OLLAMA_MLX_C_SOURCE=/path/to/mlx-c
|
||||
```
|
||||
|
||||
For example, using the helper scripts with local mlx and mlx-c repos:
|
||||
```shell
|
||||
OLLAMA_MLX_SOURCE=../mlx OLLAMA_MLX_C_SOURCE=../mlx-c ./scripts/build_linux.sh
|
||||
|
||||
OLLAMA_MLX_SOURCE=../mlx OLLAMA_MLX_C_SOURCE=../mlx-c ./scripts/build_darwin.sh
|
||||
```
|
||||
|
||||
```powershell
|
||||
$env:OLLAMA_MLX_SOURCE="../mlx"
|
||||
$env:OLLAMA_MLX_C_SOURCE="../mlx-c"
|
||||
./scripts/build_darwin.ps1
|
||||
```
|
||||
|
||||
## Docker
|
||||
|
||||
```shell
|
||||
|
||||
@@ -127,6 +127,7 @@
|
||||
},
|
||||
{
|
||||
"group": "IDEs & Editors",
|
||||
"expanded": true,
|
||||
"pages": [
|
||||
"/integrations/cline",
|
||||
"/integrations/jetbrains",
|
||||
@@ -160,6 +161,12 @@
|
||||
"group": "More information",
|
||||
"pages": [
|
||||
"/cli",
|
||||
{
|
||||
"group": "Assistant Sandboxing",
|
||||
"pages": [
|
||||
"/integrations/nemoclaw"
|
||||
]
|
||||
},
|
||||
"/modelfile",
|
||||
"/context-length",
|
||||
"/linux",
|
||||
|
||||
33
docs/gpu.mdx
@@ -61,11 +61,17 @@ Ollama supports the following AMD GPUs via the ROCm library:
|
||||
|
||||
### Linux Support
|
||||
|
||||
| Family | Cards and accelerators |
|
||||
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` |
|
||||
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `SSG` |
|
||||
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` |
|
||||
Ollama requires the AMD ROCm v7 driver on Linux. You can install or upgrade
|
||||
using the `amdgpu-install` utility from
|
||||
[AMD's ROCm documentation](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/).
|
||||
|
||||
| Family | Cards and accelerators |
|
||||
| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| AMD Radeon RX | `9070 XT` `9070 GRE` `9070` `9060 XT` `9060 XT LP` `9060` `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7700` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `5700 XT` `5700` `5600 XT` `5500 XT` |
|
||||
| AMD Radeon AI PRO | `R9700` `R9600D` |
|
||||
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` |
|
||||
| AMD Ryzen AI | `Ryzen AI Max+ 395` `Ryzen AI Max 390` `Ryzen AI Max 385` `Ryzen AI 9 HX 475` `Ryzen AI 9 HX 470` `Ryzen AI 9 465` `Ryzen AI 9 HX 375` `Ryzen AI 9 HX 370` `Ryzen AI 9 365` |
|
||||
| AMD Instinct | `MI350X` `MI300X` `MI300A` `MI250X` `MI250` `MI210` `MI100` |
|
||||
|
||||
### Windows Support
|
||||
|
||||
@@ -97,17 +103,20 @@ This table shows some example GPUs that map to these LLVM targets:
|
||||
| **LLVM Target** | **An Example GPU** |
|
||||
|-----------------|---------------------|
|
||||
| gfx908 | Radeon Instinct MI100 |
|
||||
| gfx90a | Radeon Instinct MI210 |
|
||||
| gfx940 | Radeon Instinct MI300 |
|
||||
| gfx941 | |
|
||||
| gfx942 | |
|
||||
| gfx90a | Radeon Instinct MI210/MI250 |
|
||||
| gfx942 | Radeon Instinct MI300X/MI300A |
|
||||
| gfx950 | Radeon Instinct MI350X |
|
||||
| gfx1010 | Radeon RX 5700 XT |
|
||||
| gfx1012 | Radeon RX 5500 XT |
|
||||
| gfx1030 | Radeon PRO V620 |
|
||||
| gfx1100 | Radeon PRO W7900 |
|
||||
| gfx1101 | Radeon PRO W7700 |
|
||||
| gfx1102 | Radeon RX 7600 |
|
||||
|
||||
AMD is working on enhancing ROCm v6 to broaden support for families of GPUs in a
|
||||
future release which should increase support for more GPUs.
|
||||
| gfx1103 | Radeon 780M |
|
||||
| gfx1150 | Ryzen AI 9 HX 375 |
|
||||
| gfx1151 | Ryzen AI Max+ 395 |
|
||||
| gfx1200 | Radeon RX 9070 |
|
||||
| gfx1201 | Radeon RX 9070 XT |
|
||||
|
||||
Reach out on [Discord](https://discord.gg/ollama) or file an
|
||||
[issue](https://github.com/ollama/ollama/issues) for additional help.
|
||||
|
||||
BIN
docs/images/local.png
Normal file
|
After Width: | Height: | Size: 29 KiB |
BIN
docs/images/vscode-add-ollama.png
Normal file
|
After Width: | Height: | Size: 64 KiB |
|
Before Width: | Height: | Size: 77 KiB |
|
Before Width: | Height: | Size: 56 KiB |
BIN
docs/images/vscode-other-models.png
Normal file
|
After Width: | Height: | Size: 52 KiB |
BIN
docs/images/vscode-unhide.png
Normal file
|
After Width: | Height: | Size: 67 KiB |
BIN
docs/images/vscode.png
Normal file
|
After Width: | Height: | Size: 2.7 MiB |
@@ -4,7 +4,7 @@ title: Claude Code
|
||||
|
||||
Claude Code is Anthropic's agentic coding tool that can read, modify, and execute code in your working directory.
|
||||
|
||||
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `glm-4.7`, `qwen3-coder`, `gpt-oss`.
|
||||
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `qwen3.5`, `glm-5:cloud`, `kimi-k2.5:cloud`.
|
||||
|
||||

|
||||
|
||||
@@ -32,13 +32,83 @@ irm https://claude.ai/install.ps1 | iex
|
||||
ollama launch claude
|
||||
```
|
||||
|
||||
To configure without launching:
|
||||
|
||||
### Run directly with a model
|
||||
```shell
|
||||
ollama launch claude --config
|
||||
ollama launch claude --model kimi-k2.5:cloud
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
## Recommended Models
|
||||
|
||||
- `kimi-k2.5:cloud`
|
||||
- `glm-5:cloud`
|
||||
- `minimax-m2.7:cloud`
|
||||
- `qwen3.5:cloud`
|
||||
- `glm-4.7-flash`
|
||||
- `qwen3.5`
|
||||
|
||||
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
||||
|
||||
## Non-interactive (headless) mode
|
||||
|
||||
Run Claude Code without interaction for use in Docker, CI/CD, or scripts:
|
||||
|
||||
```shell
|
||||
ollama launch claude --model kimi-k2.5:cloud --yes -- -p "how does this repository work?"
|
||||
```
|
||||
|
||||
The `--yes` flag auto-pulls the model, skips selectors, and requires `--model` to be specified. Arguments after `--` are passed directly to Claude Code.
|
||||
|
||||
## Web search
|
||||
|
||||
Claude Code can search the web through Ollama's web search API. See the [web search documentation](/capabilities/web-search) for setup and usage.
|
||||
|
||||
## Scheduled Tasks with `/loop`
|
||||
|
||||
The `/loop` command runs a prompt or slash command on a recurring schedule inside Claude Code. This is useful for automating repetitive tasks like checking PRs, running research, or setting reminders.
|
||||
|
||||
```
|
||||
/loop <interval> <prompt or /command>
|
||||
```
|
||||
|
||||
### Examples
|
||||
|
||||
**Check in on your PRs**
|
||||
|
||||
```
|
||||
/loop 30m Check my open PRs and summarize their status
|
||||
```
|
||||
|
||||
**Automate research tasks**
|
||||
|
||||
```
|
||||
/loop 1h Research the latest AI news and summarize key developments
|
||||
```
|
||||
|
||||
**Automate bug reporting and triaging**
|
||||
|
||||
```
|
||||
/loop 15m Check for new GitHub issues and triage by priority
|
||||
```
|
||||
|
||||
**Set reminders**
|
||||
|
||||
```
|
||||
/loop 1h Remind me to review the deploy status
|
||||
```
|
||||
|
||||
## Telegram
|
||||
|
||||
Chat with Claude Code from Telegram by connecting a bot to your session. Install the [Telegram plugin](https://github.com/anthropics/claude-plugins-official), create a bot via [@BotFather](https://t.me/BotFather), then launch with the channel flag:
|
||||
|
||||
```shell
|
||||
ollama launch claude -- --channels plugin:telegram@claude-plugins-official
|
||||
```
|
||||
|
||||
Claude Code will prompt for permission on most actions. To allow the bot to work autonomously, configure [permission rules](https://code.claude.com/docs/en/permissions) or pass `--dangerously-skip-permissions` in isolated environments.
|
||||
|
||||
See the [plugin README](https://github.com/anthropics/claude-plugins-official/tree/main/external_plugins/telegram) for full setup instructions including pairing and access control.
|
||||
|
||||
## Manual setup
|
||||
|
||||
Claude Code connects to Ollama using the Anthropic-compatible API.
|
||||
|
||||
@@ -53,23 +123,14 @@ export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
2. Run Claude Code with an Ollama model:
|
||||
|
||||
```shell
|
||||
claude --model gpt-oss:20b
|
||||
claude --model qwen3.5
|
||||
```
|
||||
|
||||
Or run with environment variables inline:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model glm-5:cloud
|
||||
```
|
||||
|
||||
**Note:** Claude Code requires a large context window. We recommend at least 64k tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
|
||||
|
||||
## Recommended Models
|
||||
|
||||
- `qwen3-coder`
|
||||
- `glm-4.7`
|
||||
- `gpt-oss:20b`
|
||||
- `gpt-oss:120b`
|
||||
|
||||
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
||||
|
||||
|
||||
67
docs/integrations/nemoclaw.mdx
Normal file
@@ -0,0 +1,67 @@
|
||||
---
|
||||
title: NemoClaw
|
||||
---
|
||||
|
||||
NemoClaw is NVIDIA's open source security stack for [OpenClaw](/integrations/openclaw). It wraps OpenClaw with the NVIDIA OpenShell runtime to provide kernel-level sandboxing, network policy controls, and audit trails for AI agents.
|
||||
|
||||
## Quick start
|
||||
|
||||
Pull a model:
|
||||
|
||||
```bash
|
||||
ollama pull nemotron-3-nano:30b
|
||||
```
|
||||
|
||||
Run the installer:
|
||||
|
||||
```bash
|
||||
curl -fsSL https://www.nvidia.com/nemoclaw.sh | \
|
||||
NEMOCLAW_NON_INTERACTIVE=1 \
|
||||
NEMOCLAW_PROVIDER=ollama \
|
||||
NEMOCLAW_MODEL=nemotron-3-nano:30b \
|
||||
bash
|
||||
```
|
||||
|
||||
Connect to your sandbox:
|
||||
|
||||
```bash
|
||||
nemoclaw my-assistant connect
|
||||
```
|
||||
|
||||
Open the TUI:
|
||||
|
||||
```bash
|
||||
openclaw tui
|
||||
```
|
||||
|
||||
<Note>Ollama support in NemoClaw is still experimental.</Note>
|
||||
|
||||
## Platform support
|
||||
|
||||
| Platform | Runtime | Status |
|
||||
|----------|---------|--------|
|
||||
| Linux (Ubuntu 22.04+) | Docker | Primary |
|
||||
| macOS (Apple Silicon) | Colima or Docker Desktop | Supported |
|
||||
| Windows | WSL2 with Docker Desktop | Supported |
|
||||
|
||||
CMD and PowerShell are not supported on Windows — WSL2 is required.
|
||||
|
||||
<Note>Ollama must be installed and running before the installer runs. When running inside WSL2 or a container, ensure Ollama is reachable from the sandbox (e.g. `OLLAMA_HOST=0.0.0.0`).</Note>
|
||||
|
||||
## System requirements
|
||||
|
||||
- CPU: 4 vCPU minimum
|
||||
- RAM: 8 GB minimum (16 GB recommended)
|
||||
- Disk: 20 GB free (40 GB recommended for local models)
|
||||
- Node.js 20+ and npm 10+
|
||||
- Container runtime (Docker preferred)
|
||||
|
||||
## Recommended models
|
||||
|
||||
- `nemotron-3-super:cloud` — Strong reasoning and coding
|
||||
- `qwen3.5:cloud` — 397B; reasoning and code generation
|
||||
- `nemotron-3-nano:30b` — Recommended local model; fits in 24 GB VRAM
|
||||
- `qwen3.5:27b` — Fast local reasoning (~18 GB VRAM)
|
||||
- `glm-4.7-flash` — Reasoning and code generation (~25 GB VRAM)
|
||||
|
||||
More models at [ollama.com/search](https://ollama.com/search).
|
||||
@@ -15,13 +15,29 @@ Ollama handles everything automatically:
|
||||
1. **Install** — If OpenClaw isn't installed, Ollama prompts to install it via npm
|
||||
2. **Security** — On the first launch, a security notice explains the risks of tool access
|
||||
3. **Model** — Pick a model from the selector (local or cloud)
|
||||
4. **Onboarding** — Ollama configures the provider, installs the gateway daemon, and sets your model as the primary
|
||||
4. **Onboarding** — Ollama configures the provider, installs the gateway daemon, sets your model as the primary, and installs the web search and fetch plugin
|
||||
5. **Gateway** — Starts in the background and opens the OpenClaw TUI
|
||||
|
||||
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens if using local models. See [Context length](/context-length) for more information.</Note>
|
||||
|
||||
<Note>Previously known as Clawdbot. `ollama launch clawdbot` still works as an alias.</Note>
|
||||
|
||||
## Web search and fetch
|
||||
|
||||
OpenClaw ships with a web search and fetch plugin that gives local or cloud models the ability to search the web and extract readable page content.
|
||||
|
||||
```bash
|
||||
ollama launch openclaw
|
||||
```
|
||||
|
||||
Web search and fetch is enabled automatically when launching OpenClaw through Ollama. To install the plugin directly:
|
||||
|
||||
```bash
|
||||
openclaw plugins install @ollama/openclaw-web-search
|
||||
```
|
||||
|
||||
<Note>Web search for local models requires `ollama signin`.</Note>
|
||||
|
||||
## Configure without launching
|
||||
|
||||
To change the model without starting the gateway and TUI:
|
||||
@@ -43,7 +59,7 @@ If the gateway is already running, it restarts automatically to pick up the new
|
||||
**Cloud models**:
|
||||
|
||||
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
|
||||
- `minimax-m2.5:cloud` — Fast, efficient coding and real-world productivity
|
||||
- `minimax-m2.7:cloud` — Fast, efficient coding and real-world productivity
|
||||
- `glm-5:cloud` — Reasoning and code generation
|
||||
|
||||
**Local models:**
|
||||
@@ -52,6 +68,16 @@ If the gateway is already running, it restarts automatically to pick up the new
|
||||
|
||||
More models at [ollama.com/search](https://ollama.com/search?c=cloud).
|
||||
|
||||
## Non-interactive (headless) mode
|
||||
|
||||
Run OpenClaw without interaction for use in Docker, CI/CD, or scripts:
|
||||
|
||||
```bash
|
||||
ollama launch openclaw --model kimi-k2.5:cloud --yes
|
||||
```
|
||||
|
||||
The `--yes` flag auto-pulls the model, skips selectors, and requires `--model` to be specified.
|
||||
|
||||
## Connect messaging apps
|
||||
|
||||
```bash
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
title: Pi
|
||||
---
|
||||
|
||||
Pi is a minimal AI agent toolkit with plugin support.
|
||||
Pi is a minimal and extensible coding agent.
|
||||
|
||||
## Install
|
||||
|
||||
@@ -20,13 +20,65 @@ npm install -g @mariozechner/pi-coding-agent
|
||||
ollama launch pi
|
||||
```
|
||||
|
||||
This installs Pi, configures Ollama as a provider including web tools, and drops you into an interactive session.
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch pi --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
### Run directly with a model
|
||||
|
||||
```shell
|
||||
ollama launch pi --model qwen3.5:cloud
|
||||
```
|
||||
|
||||
Cloud models are also available at [ollama.com](https://ollama.com/search?c=cloud).
|
||||
|
||||
## Extensions
|
||||
|
||||
Pi ships with four core tools: `read`, `write`, `edit`, and `bash`. All other capabilities are added through its extension system.
|
||||
|
||||
On-demand capability packages invoked via `/skill:name` commands.
|
||||
|
||||
Install from npm or git:
|
||||
|
||||
```bash
|
||||
pi install npm:@foo/some-tools
|
||||
pi install git:github.com/user/repo@v1
|
||||
```
|
||||
|
||||
See all packages at [pi.dev](https://pi.dev/packages)
|
||||
|
||||
### Web search
|
||||
|
||||
Pi can use web search and fetch tools via the `@ollama/pi-web-search` package.
|
||||
|
||||
When launching Pi through Ollama, package install/update is managed automatically.
|
||||
To install manually:
|
||||
|
||||
```bash
|
||||
pi install npm:@ollama/pi-web-search
|
||||
```
|
||||
|
||||
### Autoresearch with `pi-autoresearch`
|
||||
|
||||
[pi-autoresearch](https://github.com/davebcn87/pi-autoresearch) brings autonomous experiment loops to Pi. Inspired by Karpathy's autoresearch, it turns any measurable metric into an optimization target: test speed, bundle size, build time, model training loss, Lighthouse scores.
|
||||
|
||||
```bash
|
||||
pi install https://github.com/davebcn87/pi-autoresearch
|
||||
```
|
||||
|
||||
Tell Pi what to optimize. It runs experiments, benchmarks each one, keeps improvements, reverts regressions, and repeats — all autonomously. A built-in dashboard tracks every run with confidence scoring to distinguish real gains from benchmark noise.
|
||||
|
||||
```bash
|
||||
/autoresearch optimize unit test runtime
|
||||
```
|
||||
|
||||
Each kept experiment is automatically committed. Each failed one is reverted. When you're done, Pi can group improvements into independent branches for clean review and merge.
|
||||
|
||||
## Manual setup
|
||||
|
||||
Add a configuration block to `~/.pi/agent/models.json`:
|
||||
|
||||
|
||||
@@ -2,33 +2,84 @@
|
||||
title: VS Code
|
||||
---
|
||||
|
||||
## Install
|
||||
VS Code includes built-in AI chat through GitHub Copilot Chat. Ollama models can be used directly in the Copilot Chat model picker.
|
||||
|
||||
Install [VS Code](https://code.visualstudio.com/download).
|
||||
|
||||
## Usage with Ollama
|
||||

|
||||
|
||||
1. Open Copilot side bar found in top right window
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Ollama v0.18.3+
|
||||
- [VS Code 1.113+](https://code.visualstudio.com/download)
|
||||
- [GitHub Copilot Chat extension 0.41.0+](https://marketplace.visualstudio.com/items?itemName=GitHub.copilot-chat)
|
||||
|
||||
<Note> VS Code requires you to be logged in to use its model selector, even for custom models. This doesn't require a paid GitHub Copilot account; GitHub Copilot Free will enable model selection for custom models.</Note>
|
||||
|
||||
## Quick setup
|
||||
|
||||
```shell
|
||||
ollama launch vscode
|
||||
```
|
||||
|
||||
Recommended models will be shown after running the command. See the latest models at [ollama.com](https://ollama.com/search?c=tools).
|
||||
|
||||
Make sure **Local** is selected at the bottom of the Copilot Chat panel to use your Ollama models.
|
||||
<div style={{ display: "flex", justifyContent: "center" }}>
|
||||
<img
|
||||
src="/images/local.png"
|
||||
alt="Ollama Local Models"
|
||||
width="60%"
|
||||
style={{ borderRadius: "4px", marginTop: "10px", marginBottom: "10px" }}
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
||||
## Run directly with a model
|
||||
|
||||
```shell
|
||||
ollama launch vscode --model qwen3.5:cloud
|
||||
```
|
||||
Cloud models are also available at [ollama.com](https://ollama.com/search?c=cloud).
|
||||
|
||||
## Manual setup
|
||||
|
||||
To configure Ollama manually without `ollama launch`:
|
||||
|
||||
1. Open the **Copilot Chat** side bar from the top right corner
|
||||
<div style={{ display: "flex", justifyContent: "center" }}>
|
||||
<img
|
||||
src="/images/vscode-sidebar.png"
|
||||
alt="VS Code chat Sidebar"
|
||||
width="75%"
|
||||
style={{ borderRadius: "4px" }}
|
||||
/>
|
||||
</div>
|
||||
2. Select the model dropdown > **Manage models**
|
||||
2. Click the **settings gear icon** (<Icon icon="gear" />) to bring up the Language Models window
|
||||
<div style={{ display: "flex", justifyContent: "center" }}>
|
||||
<img
|
||||
src="/images/vscode-models.png"
|
||||
src="/images/vscode-other-models.png"
|
||||
alt="VS Code model picker"
|
||||
width="75%"
|
||||
style={{ borderRadius: "4px" }}
|
||||
/>
|
||||
</div>
|
||||
3. Enter **Ollama** under **Provider Dropdown** and select desired models (e.g `qwen3, qwen3-coder:480b-cloud`)
|
||||
3. Click **Add Models** and select **Ollama** to load all your Ollama models into VS Code
|
||||
<div style={{ display: "flex", justifyContent: "center" }}>
|
||||
<img
|
||||
src="/images/vscode-model-options.png"
|
||||
alt="VS Code model options dropdown"
|
||||
src="/images/vscode-add-ollama.png"
|
||||
alt="VS Code model options dropdown to add ollama models"
|
||||
width="75%"
|
||||
style={{ borderRadius: "4px" }}
|
||||
/>
|
||||
</div>
|
||||
|
||||
4. Click the **Unhide** button in the model picker to show your Ollama models
|
||||
<div style={{ display: "flex", justifyContent: "center" }}>
|
||||
<img
|
||||
src="/images/vscode-unhide.png"
|
||||
alt="VS Code unhide models button"
|
||||
width="75%"
|
||||
style={{ borderRadius: "4px" }}
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -101,7 +101,7 @@ nvidia-smi
|
||||
|
||||
### Install AMD ROCm drivers (optional)
|
||||
|
||||
[Download and Install](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/tutorial/quick-start.html) ROCm v6.
|
||||
[Download and Install](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/tutorial/quick-start.html) ROCm v7.
|
||||
|
||||
### Start Ollama
|
||||
|
||||
|
||||
@@ -114,6 +114,25 @@ If you are experiencing problems getting Ollama to correctly discover or use you
|
||||
- `OLLAMA_DEBUG=1` During GPU discovery additional information will be reported
|
||||
- Check dmesg for any errors from amdgpu or kfd drivers `sudo dmesg | grep -i amdgpu` and `sudo dmesg | grep -i kfd`
|
||||
|
||||
### AMD Driver Version Mismatch
|
||||
|
||||
If your AMD GPU is not detected on Linux and the server logs contain messages like:
|
||||
|
||||
```
|
||||
msg="failure during GPU discovery" ... error="failed to finish discovery before timeout"
|
||||
msg="bootstrap discovery took" duration=30s ...
|
||||
```
|
||||
|
||||
This typically means the system's AMD GPU driver is too old. Ollama bundles
|
||||
ROCm 7 linux libraries which require a compatible ROCm 7 kernel driver. If the
|
||||
system is running an older driver (ROCm 6.x or earlier), GPU initialization
|
||||
will hang during device discovery and eventually time out, causing Ollama to
|
||||
fall back to CPU.
|
||||
|
||||
To resolve this, upgrade to the ROCm v7 driver using the `amdgpu-install`
|
||||
utility from [AMD's ROCm documentation](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/).
|
||||
After upgrading, reboot and restart Ollama.
|
||||
|
||||
## Multiple AMD GPUs
|
||||
|
||||
If you experience gibberish responses when models load across multiple AMD GPUs on Linux, see the following guide.
|
||||
|
||||
@@ -80,9 +80,13 @@ help you keep up to date.
|
||||
|
||||
If you'd like to install or integrate Ollama as a service, a standalone
|
||||
`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI
|
||||
and GPU library dependencies for Nvidia. If you have an AMD GPU, also download
|
||||
and extract the additional ROCm package `ollama-windows-amd64-rocm.zip` into the
|
||||
same directory. This allows for embedding Ollama in existing applications, or
|
||||
and GPU library dependencies for Nvidia. Depending on your hardware, you may also
|
||||
need to download and extract additional packages into the same directory:
|
||||
|
||||
- **AMD GPU**: `ollama-windows-amd64-rocm.zip`
|
||||
- **MLX (CUDA)**: `ollama-windows-amd64-mlx.zip`
|
||||
|
||||
This allows for embedding Ollama in existing applications, or
|
||||
running it as a system service via `ollama serve` with tools such as
|
||||
[NSSM](https://nssm.cc/).
|
||||
|
||||
|
||||
@@ -59,6 +59,29 @@ func Host() *url.URL {
|
||||
}
|
||||
}
|
||||
|
||||
// ConnectableHost returns Host() with unspecified bind addresses (0.0.0.0, ::)
|
||||
// replaced by the corresponding loopback address (127.0.0.1, ::1).
|
||||
// Unspecified addresses are valid for binding a server socket but not for
|
||||
// connecting as a client, which fails on Windows.
|
||||
func ConnectableHost() *url.URL {
|
||||
u := Host()
|
||||
host, port, err := net.SplitHostPort(u.Host)
|
||||
if err != nil {
|
||||
return u
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(host); ip != nil && ip.IsUnspecified() {
|
||||
if ip.To4() != nil {
|
||||
host = "127.0.0.1"
|
||||
} else {
|
||||
host = "::1"
|
||||
}
|
||||
u.Host = net.JoinHostPort(host, port)
|
||||
}
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
// AllowedOrigins returns a list of allowed origins. AllowedOrigins can be configured via the OLLAMA_ORIGINS environment variable.
|
||||
func AllowedOrigins() (origins []string) {
|
||||
if s := Var("OLLAMA_ORIGINS"); s != "" {
|
||||
@@ -191,6 +214,8 @@ func LogLevel() slog.Level {
|
||||
var (
|
||||
// FlashAttention enables the experimental flash attention feature.
|
||||
FlashAttention = BoolWithDefault("OLLAMA_FLASH_ATTENTION")
|
||||
// DebugLogRequests logs inference requests to disk for replay/debugging.
|
||||
DebugLogRequests = Bool("OLLAMA_DEBUG_LOG_REQUESTS")
|
||||
// KvCacheType is the quantization type for the K/V cache.
|
||||
KvCacheType = String("OLLAMA_KV_CACHE_TYPE")
|
||||
// NoHistory disables readline history.
|
||||
@@ -279,28 +304,29 @@ type EnvVar struct {
|
||||
|
||||
func AsMap() map[string]EnvVar {
|
||||
ret := map[string]EnvVar{
|
||||
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
||||
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(false), "Enabled flash attention"},
|
||||
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
|
||||
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
|
||||
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
|
||||
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
|
||||
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"},
|
||||
"OLLAMA_LOAD_TIMEOUT": {"OLLAMA_LOAD_TIMEOUT", LoadTimeout(), "How long to allow model loads to stall before giving up (default \"5m\")"},
|
||||
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
|
||||
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
|
||||
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
|
||||
"OLLAMA_NO_CLOUD": {"OLLAMA_NO_CLOUD", NoCloud(), "Disable Ollama cloud features (remote inference and web search)"},
|
||||
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
|
||||
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
|
||||
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
|
||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
|
||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
|
||||
"OLLAMA_EDITOR": {"OLLAMA_EDITOR", Editor(), "Path to editor for interactive prompt editing (Ctrl+G)"},
|
||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
|
||||
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
||||
"OLLAMA_DEBUG_LOG_REQUESTS": {"OLLAMA_DEBUG_LOG_REQUESTS", DebugLogRequests(), "Log inference request bodies and replay curl commands to a temp directory"},
|
||||
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(false), "Enabled flash attention"},
|
||||
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
|
||||
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
|
||||
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
|
||||
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
|
||||
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"},
|
||||
"OLLAMA_LOAD_TIMEOUT": {"OLLAMA_LOAD_TIMEOUT", LoadTimeout(), "How long to allow model loads to stall before giving up (default \"5m\")"},
|
||||
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
|
||||
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
|
||||
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
|
||||
"OLLAMA_NO_CLOUD": {"OLLAMA_NO_CLOUD", NoCloud(), "Disable Ollama cloud features (remote inference and web search)"},
|
||||
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
|
||||
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
|
||||
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
|
||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
|
||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
|
||||
"OLLAMA_EDITOR": {"OLLAMA_EDITOR", Editor(), "Path to editor for interactive prompt editing (Ctrl+G)"},
|
||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
|
||||
|
||||
// Informational
|
||||
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
||||
|
||||
@@ -52,6 +52,37 @@ func TestHost(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectableHost(t *testing.T) {
|
||||
cases := map[string]struct {
|
||||
value string
|
||||
expect string
|
||||
}{
|
||||
"empty": {"", "http://127.0.0.1:11434"},
|
||||
"localhost": {"127.0.0.1", "http://127.0.0.1:11434"},
|
||||
"localhost and port": {"127.0.0.1:1234", "http://127.0.0.1:1234"},
|
||||
"ipv4 unspecified": {"0.0.0.0", "http://127.0.0.1:11434"},
|
||||
"ipv4 unspecified + port": {"0.0.0.0:1234", "http://127.0.0.1:1234"},
|
||||
"ipv6 unspecified": {"[::]", "http://[::1]:11434"},
|
||||
"ipv6 unspecified + port": {"[::]:1234", "http://[::1]:1234"},
|
||||
"ipv6 localhost": {"[::1]", "http://[::1]:11434"},
|
||||
"ipv6 localhost + port": {"[::1]:1234", "http://[::1]:1234"},
|
||||
"specific address": {"192.168.1.5", "http://192.168.1.5:11434"},
|
||||
"specific address + port": {"192.168.1.5:8080", "http://192.168.1.5:8080"},
|
||||
"hostname": {"example.com", "http://example.com:11434"},
|
||||
"hostname and port": {"example.com:1234", "http://example.com:1234"},
|
||||
"https unspecified + port": {"https://0.0.0.0:4321", "https://127.0.0.1:4321"},
|
||||
}
|
||||
|
||||
for name, tt := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", tt.value)
|
||||
if host := ConnectableHost(); host.String() != tt.expect {
|
||||
t.Errorf("%s: expected %s, got %s", name, tt.expect, host.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrigins(t *testing.T) {
|
||||
cases := []struct {
|
||||
value string
|
||||
|
||||