Compare commits
116 Commits
pdevine/sa
...
v0.19.0-rc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aec2fef95d | ||
|
|
366625a831 | ||
|
|
516ebd8548 | ||
|
|
f567abc63f | ||
|
|
1adfc27f04 | ||
|
|
4a2b9f9dbc | ||
|
|
e46b67a6cc | ||
|
|
c000afe76c | ||
|
|
9d7b18f81e | ||
|
|
4f5999fd3f | ||
|
|
ac5f0dbb6a | ||
|
|
d1151e18a1 | ||
|
|
ebbce136c7 | ||
|
|
26b9f53f8e | ||
|
|
7575438366 | ||
|
|
7d7c90d702 | ||
|
|
4fda69809a | ||
|
|
c9b5da6b0c | ||
|
|
de5cb7311f | ||
|
|
95ee7fbd29 | ||
|
|
ec55536734 | ||
|
|
77491439c2 | ||
|
|
b166b36cd2 | ||
|
|
c2b0bb7a52 | ||
|
|
22c2bdbd8a | ||
|
|
6df6d097d9 | ||
|
|
d7c176ab91 | ||
|
|
0ff7d724ff | ||
|
|
46cb7795e1 | ||
|
|
126d8db7f3 | ||
|
|
3f3a24b418 | ||
|
|
96e36c0d90 | ||
|
|
6f8ddbb26b | ||
|
|
b5e7888414 | ||
|
|
eab4d22269 | ||
|
|
5759c2d2d2 | ||
|
|
42b1c2642b | ||
|
|
727d69ddf3 | ||
|
|
f622b0c5fc | ||
|
|
5d0000634c | ||
|
|
676d9845ba | ||
|
|
e37a9b4c01 | ||
|
|
d727aacd04 | ||
|
|
fa69b833cd | ||
|
|
bbbad97686 | ||
|
|
bcf6d55b54 | ||
|
|
810d4f9c22 | ||
|
|
856c047a6c | ||
|
|
79c1e93c00 | ||
|
|
f8b657c967 | ||
|
|
10fefe0d57 | ||
|
|
2f9a68f9e9 | ||
|
|
3980c0217d | ||
|
|
870599f5da | ||
|
|
abf8e8e9c8 | ||
|
|
f3f31a8192 | ||
|
|
9e7ba835da | ||
|
|
347f17b8d1 | ||
|
|
081b9eb423 | ||
|
|
bb867c6fdb | ||
|
|
81f4506a61 | ||
|
|
76925f1284 | ||
|
|
f676231de9 | ||
|
|
af5f7c0a9e | ||
|
|
a6b27d776b | ||
|
|
539741199e | ||
|
|
8f45236d09 | ||
|
|
97013a190c | ||
|
|
c222735c02 | ||
|
|
87d21c7fc0 | ||
|
|
54e05172a0 | ||
|
|
464186e995 | ||
|
|
8c4d5d6c2f | ||
|
|
bc72b14016 | ||
|
|
61086083eb | ||
|
|
62d1f01ab4 | ||
|
|
10e51c5177 | ||
|
|
3e06bde643 | ||
|
|
6be2de8214 | ||
|
|
ebb1b9ec14 | ||
|
|
d126467d5d | ||
|
|
afb4c62fbf | ||
|
|
e790dc435b | ||
|
|
288077c3a3 | ||
|
|
4425c54eda | ||
|
|
778899a5d2 | ||
|
|
4eab60c1e2 | ||
|
|
1af850e6e3 | ||
|
|
9b0c7cc7b9 | ||
|
|
6928630601 | ||
|
|
9896e3627f | ||
|
|
15732f0ea7 | ||
|
|
562c76d7cc | ||
|
|
122c68c151 | ||
|
|
82848a7806 | ||
|
|
39982a954e | ||
|
|
e9f6ea232f | ||
|
|
110eff01a9 | ||
|
|
799e51d419 | ||
|
|
e8fcb29586 | ||
|
|
97d2f05a6d | ||
|
|
8207e55ec7 | ||
|
|
ad16bffc7d | ||
|
|
c1e3ef4bcc | ||
|
|
a3093cd5e5 | ||
|
|
23d4cad1a2 | ||
|
|
86513cb697 | ||
|
|
3490e9590b | ||
|
|
8da09b1e7e | ||
|
|
a60b9adcce | ||
|
|
a16f96658b | ||
|
|
18ab09b431 | ||
|
|
638faeac54 | ||
|
|
dd5eb6337d | ||
|
|
79917cf80b | ||
|
|
cc90a035a0 |
67
.github/workflows/release.yaml
vendored
@@ -117,6 +117,25 @@ jobs:
|
|||||||
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
||||||
flags: ''
|
flags: ''
|
||||||
runner_dir: 'vulkan'
|
runner_dir: 'vulkan'
|
||||||
|
- os: windows
|
||||||
|
arch: amd64
|
||||||
|
preset: 'MLX CUDA 13'
|
||||||
|
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
|
||||||
|
cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip
|
||||||
|
cuda-components:
|
||||||
|
- '"cudart"'
|
||||||
|
- '"nvcc"'
|
||||||
|
- '"cublas"'
|
||||||
|
- '"cublas_dev"'
|
||||||
|
- '"cufft"'
|
||||||
|
- '"cufft_dev"'
|
||||||
|
- '"nvrtc"'
|
||||||
|
- '"nvrtc_dev"'
|
||||||
|
- '"crt"'
|
||||||
|
- '"nvvm"'
|
||||||
|
- '"nvptxcompiler"'
|
||||||
|
cuda-version: '13.0'
|
||||||
|
flags: ''
|
||||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||||
environment: release
|
environment: release
|
||||||
env:
|
env:
|
||||||
@@ -125,8 +144,10 @@ jobs:
|
|||||||
- name: Install system dependencies
|
- name: Install system dependencies
|
||||||
run: |
|
run: |
|
||||||
choco install -y --no-progress ccache ninja
|
choco install -y --no-progress ccache ninja
|
||||||
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
if (Get-Command ccache -ErrorAction SilentlyContinue) {
|
||||||
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan')
|
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
||||||
|
}
|
||||||
|
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan') || startsWith(matrix.preset, 'MLX ')
|
||||||
id: cache-install
|
id: cache-install
|
||||||
uses: actions/cache/restore@v4
|
uses: actions/cache/restore@v4
|
||||||
with:
|
with:
|
||||||
@@ -134,8 +155,9 @@ jobs:
|
|||||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||||
C:\Program Files\AMD\ROCm
|
C:\Program Files\AMD\ROCm
|
||||||
C:\VulkanSDK
|
C:\VulkanSDK
|
||||||
key: ${{ matrix.install }}
|
C:\Program Files\NVIDIA\CUDNN
|
||||||
- if: startsWith(matrix.preset, 'CUDA ')
|
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||||
|
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'MLX ')
|
||||||
name: Install CUDA ${{ matrix.cuda-version }}
|
name: Install CUDA ${{ matrix.cuda-version }}
|
||||||
run: |
|
run: |
|
||||||
$ErrorActionPreference = "Stop"
|
$ErrorActionPreference = "Stop"
|
||||||
@@ -179,6 +201,23 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
- if: startsWith(matrix.preset, 'MLX ')
|
||||||
|
name: Install cuDNN for MLX
|
||||||
|
run: |
|
||||||
|
$ErrorActionPreference = "Stop"
|
||||||
|
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
|
||||||
|
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||||
|
Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip"
|
||||||
|
Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted
|
||||||
|
$cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName
|
||||||
|
New-Item -ItemType Directory -Force -Path $cudnnRoot
|
||||||
|
Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse
|
||||||
|
}
|
||||||
|
|
||||||
|
echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||||
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
||||||
uses: actions/cache/save@v4
|
uses: actions/cache/save@v4
|
||||||
with:
|
with:
|
||||||
@@ -186,7 +225,8 @@ jobs:
|
|||||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||||
C:\Program Files\AMD\ROCm
|
C:\Program Files\AMD\ROCm
|
||||||
C:\VulkanSDK
|
C:\VulkanSDK
|
||||||
key: ${{ matrix.install }}
|
C:\Program Files\NVIDIA\CUDNN
|
||||||
|
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/cache@v4
|
- uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
@@ -198,7 +238,7 @@ jobs:
|
|||||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}"
|
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}"
|
||||||
cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}"
|
cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}"
|
||||||
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
|
cmake --install build --component "${{ startsWith(matrix.preset, 'MLX ') && 'MLX' || startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
|
||||||
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
|
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
|
||||||
env:
|
env:
|
||||||
CMAKE_GENERATOR: Ninja
|
CMAKE_GENERATOR: Ninja
|
||||||
@@ -384,6 +424,7 @@ jobs:
|
|||||||
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
|
lib/ollama/include*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||||
@@ -543,11 +584,19 @@ jobs:
|
|||||||
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg dist/*.ps1 dist/*.sh ; do
|
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg dist/*.ps1 dist/*.sh ; do
|
||||||
echo "Uploading $payload"
|
echo "Uploading $payload"
|
||||||
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
|
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
|
||||||
pids[$!]=$!
|
pids+=($!)
|
||||||
sleep 1
|
sleep 1
|
||||||
done
|
done
|
||||||
echo "Waiting for uploads to complete"
|
echo "Waiting for uploads to complete"
|
||||||
for pid in "${pids[*]}"; do
|
failed=0
|
||||||
wait $pid
|
for pid in "${pids[@]}"; do
|
||||||
|
if ! wait $pid; then
|
||||||
|
echo "::error::Upload failed (pid $pid)"
|
||||||
|
failed=1
|
||||||
|
fi
|
||||||
done
|
done
|
||||||
|
if [ $failed -ne 0 ]; then
|
||||||
|
echo "One or more uploads failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
echo "done"
|
echo "done"
|
||||||
|
|||||||
73
.github/workflows/test.yaml
vendored
@@ -37,7 +37,7 @@ jobs:
|
|||||||
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
|
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
|
||||||
}
|
}
|
||||||
|
|
||||||
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
|
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*' '.github/**/*') | tee -a $GITHUB_OUTPUT
|
||||||
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
|
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
|
||||||
|
|
||||||
linux:
|
linux:
|
||||||
@@ -51,7 +51,7 @@ jobs:
|
|||||||
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
|
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
|
||||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
||||||
- preset: ROCm
|
- preset: ROCm
|
||||||
container: rocm/dev-ubuntu-22.04:6.1.2
|
container: rocm/dev-ubuntu-22.04:7.2
|
||||||
extra-packages: rocm-libs
|
extra-packages: rocm-libs
|
||||||
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm'
|
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm'
|
||||||
- preset: Vulkan
|
- preset: Vulkan
|
||||||
@@ -60,6 +60,11 @@ jobs:
|
|||||||
mesa-vulkan-drivers vulkan-tools
|
mesa-vulkan-drivers vulkan-tools
|
||||||
libvulkan1 libvulkan-dev
|
libvulkan1 libvulkan-dev
|
||||||
vulkan-sdk cmake ccache g++ make
|
vulkan-sdk cmake ccache g++ make
|
||||||
|
- preset: 'MLX CUDA 13'
|
||||||
|
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
|
||||||
|
extra-packages: libcudnn9-dev-cuda-13 libopenblas-dev liblapack-dev liblapacke-dev git curl
|
||||||
|
flags: '-DCMAKE_CUDA_ARCHITECTURES=87 -DBLAS_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu -DLAPACK_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu'
|
||||||
|
install-go: true
|
||||||
runs-on: linux
|
runs-on: linux
|
||||||
container: ${{ matrix.container }}
|
container: ${{ matrix.container }}
|
||||||
steps:
|
steps:
|
||||||
@@ -76,19 +81,29 @@ jobs:
|
|||||||
$sudo apt-get update
|
$sudo apt-get update
|
||||||
fi
|
fi
|
||||||
$sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }}
|
$sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }}
|
||||||
|
# MLX requires CMake 3.25+, install from official releases
|
||||||
|
if [ "${{ matrix.preset }}" = "MLX CUDA 13" ]; then
|
||||||
|
curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.31.2/cmake-3.31.2-linux-$(uname -m).tar.gz | $sudo tar xz -C /usr/local --strip-components 1
|
||||||
|
fi
|
||||||
# Export VULKAN_SDK if provided by LunarG package (defensive)
|
# Export VULKAN_SDK if provided by LunarG package (defensive)
|
||||||
if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then
|
if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then
|
||||||
echo "VULKAN_SDK=/usr" >> $GITHUB_ENV
|
echo "VULKAN_SDK=/usr" >> $GITHUB_ENV
|
||||||
fi
|
fi
|
||||||
env:
|
env:
|
||||||
DEBIAN_FRONTEND: noninteractive
|
DEBIAN_FRONTEND: noninteractive
|
||||||
|
- if: matrix.install-go
|
||||||
|
name: Install Go
|
||||||
|
run: |
|
||||||
|
GO_VERSION=$(awk '/^go / { print $2 }' go.mod)
|
||||||
|
curl -fsSL "https://golang.org/dl/go${GO_VERSION}.linux-$(dpkg --print-architecture).tar.gz" | tar xz -C /usr/local
|
||||||
|
echo "/usr/local/go/bin" >> $GITHUB_PATH
|
||||||
- uses: actions/cache@v4
|
- uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: /github/home/.cache/ccache
|
path: /github/home/.cache/ccache
|
||||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
|
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
|
||||||
- run: |
|
- run: |
|
||||||
cmake --preset ${{ matrix.preset }} ${{ matrix.flags }}
|
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
|
||||||
cmake --build --preset ${{ matrix.preset }} --parallel
|
cmake --build --preset "${{ matrix.preset }}" --parallel
|
||||||
|
|
||||||
windows:
|
windows:
|
||||||
needs: [changes]
|
needs: [changes]
|
||||||
@@ -114,12 +129,31 @@ jobs:
|
|||||||
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||||
- preset: Vulkan
|
- preset: Vulkan
|
||||||
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
||||||
|
- preset: 'MLX CUDA 13'
|
||||||
|
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
|
||||||
|
cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip
|
||||||
|
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
|
||||||
|
cuda-components:
|
||||||
|
- '"cudart"'
|
||||||
|
- '"nvcc"'
|
||||||
|
- '"cublas"'
|
||||||
|
- '"cublas_dev"'
|
||||||
|
- '"cufft"'
|
||||||
|
- '"cufft_dev"'
|
||||||
|
- '"nvrtc"'
|
||||||
|
- '"nvrtc_dev"'
|
||||||
|
- '"crt"'
|
||||||
|
- '"nvvm"'
|
||||||
|
- '"nvptxcompiler"'
|
||||||
|
cuda-version: '13.0'
|
||||||
runs-on: windows
|
runs-on: windows
|
||||||
steps:
|
steps:
|
||||||
- run: |
|
- run: |
|
||||||
choco install -y --no-progress ccache ninja
|
choco install -y --no-progress ccache ninja
|
||||||
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
if (Get-Command ccache -ErrorAction SilentlyContinue) {
|
||||||
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan'
|
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
||||||
|
}
|
||||||
|
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan' || matrix.preset == 'MLX CUDA 13'
|
||||||
id: cache-install
|
id: cache-install
|
||||||
uses: actions/cache/restore@v4
|
uses: actions/cache/restore@v4
|
||||||
with:
|
with:
|
||||||
@@ -127,8 +161,9 @@ jobs:
|
|||||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||||
C:\Program Files\AMD\ROCm
|
C:\Program Files\AMD\ROCm
|
||||||
C:\VulkanSDK
|
C:\VulkanSDK
|
||||||
key: ${{ matrix.install }}
|
C:\Program Files\NVIDIA\CUDNN
|
||||||
- if: matrix.preset == 'CUDA'
|
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||||
|
- if: matrix.preset == 'CUDA' || matrix.preset == 'MLX CUDA 13'
|
||||||
name: Install CUDA ${{ matrix.cuda-version }}
|
name: Install CUDA ${{ matrix.cuda-version }}
|
||||||
run: |
|
run: |
|
||||||
$ErrorActionPreference = "Stop"
|
$ErrorActionPreference = "Stop"
|
||||||
@@ -164,10 +199,27 @@ jobs:
|
|||||||
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
||||||
Start-Process -FilePath .\install.exe -ArgumentList "-c","--am","--al","in" -NoNewWindow -Wait
|
Start-Process -FilePath .\install.exe -ArgumentList "-c","--am","--al","in" -NoNewWindow -Wait
|
||||||
}
|
}
|
||||||
|
|
||||||
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
|
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
|
||||||
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||||
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
|
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
|
||||||
|
- if: matrix.preset == 'MLX CUDA 13'
|
||||||
|
name: Install cuDNN for MLX
|
||||||
|
run: |
|
||||||
|
$ErrorActionPreference = "Stop"
|
||||||
|
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
|
||||||
|
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||||
|
Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip"
|
||||||
|
Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted
|
||||||
|
$cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName
|
||||||
|
New-Item -ItemType Directory -Force -Path $cudnnRoot
|
||||||
|
Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse
|
||||||
|
}
|
||||||
|
|
||||||
|
echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||||
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
||||||
uses: actions/cache/save@v4
|
uses: actions/cache/save@v4
|
||||||
with:
|
with:
|
||||||
@@ -175,7 +227,8 @@ jobs:
|
|||||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||||
C:\Program Files\AMD\ROCm
|
C:\Program Files\AMD\ROCm
|
||||||
C:\VulkanSDK
|
C:\VulkanSDK
|
||||||
key: ${{ matrix.install }}
|
C:\Program Files\NVIDIA\CUDNN
|
||||||
|
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/cache@v4
|
- uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
|
|||||||
171
CMakeLists.txt
@@ -64,10 +64,15 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR})
|
|||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
|
||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${OLLAMA_BUILD_DIR})
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${OLLAMA_BUILD_DIR})
|
||||||
|
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
# Store ggml include paths for use with target_include_directories later.
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include)
|
# We avoid global include_directories() to prevent polluting the include path
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
|
# for other projects like MLX (whose openblas dependency has its own common.h).
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
|
set(GGML_INCLUDE_DIRS
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx
|
||||||
|
)
|
||||||
|
|
||||||
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
|
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
|
||||||
|
|
||||||
@@ -87,6 +92,14 @@ if(NOT CPU_VARIANTS)
|
|||||||
set(CPU_VARIANTS "ggml-cpu")
|
set(CPU_VARIANTS "ggml-cpu")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# Apply ggml include directories to ggml targets only (not globally)
|
||||||
|
target_include_directories(ggml-base PRIVATE ${GGML_INCLUDE_DIRS})
|
||||||
|
foreach(variant ${CPU_VARIANTS})
|
||||||
|
if(TARGET ${variant})
|
||||||
|
target_include_directories(${variant} PRIVATE ${GGML_INCLUDE_DIRS})
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
|
||||||
install(TARGETS ggml-base ${CPU_VARIANTS}
|
install(TARGETS ggml-base ${CPU_VARIANTS}
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
PRE_EXCLUDE_REGEXES ".*"
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
@@ -103,6 +116,7 @@ if(CMAKE_CUDA_COMPILER)
|
|||||||
|
|
||||||
find_package(CUDAToolkit)
|
find_package(CUDAToolkit)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
||||||
|
target_include_directories(ggml-cuda PRIVATE ${GGML_INCLUDE_DIRS})
|
||||||
install(TARGETS ggml-cuda
|
install(TARGETS ggml-cuda
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||||
@@ -134,6 +148,7 @@ if(CMAKE_HIP_COMPILER)
|
|||||||
if(AMDGPU_TARGETS)
|
if(AMDGPU_TARGETS)
|
||||||
find_package(hip REQUIRED)
|
find_package(hip REQUIRED)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
||||||
|
target_include_directories(ggml-hip PRIVATE ${GGML_INCLUDE_DIRS})
|
||||||
|
|
||||||
if (WIN32)
|
if (WIN32)
|
||||||
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
|
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
|
||||||
@@ -148,7 +163,7 @@ if(CMAKE_HIP_COMPILER)
|
|||||||
)
|
)
|
||||||
install(RUNTIME_DEPENDENCY_SET rocm
|
install(RUNTIME_DEPENDENCY_SET rocm
|
||||||
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
|
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
|
||||||
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf
|
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register roctx64 rocroller drm drm_amdgpu numa elf
|
||||||
PRE_EXCLUDE_REGEXES ".*"
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
POST_EXCLUDE_REGEXES "system32"
|
POST_EXCLUDE_REGEXES "system32"
|
||||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
|
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
|
||||||
@@ -168,6 +183,7 @@ if(NOT APPLE)
|
|||||||
find_package(Vulkan)
|
find_package(Vulkan)
|
||||||
if(Vulkan_FOUND)
|
if(Vulkan_FOUND)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
||||||
|
target_include_directories(ggml-vulkan PRIVATE ${GGML_INCLUDE_DIRS})
|
||||||
install(TARGETS ggml-vulkan
|
install(TARGETS ggml-vulkan
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
PRE_INCLUDE_REGEXES vulkan
|
PRE_INCLUDE_REGEXES vulkan
|
||||||
@@ -179,7 +195,6 @@ if(NOT APPLE)
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
option(MLX_ENGINE "Enable MLX backend" OFF)
|
option(MLX_ENGINE "Enable MLX backend" OFF)
|
||||||
|
|
||||||
if(MLX_ENGINE)
|
if(MLX_ENGINE)
|
||||||
message(STATUS "Setting up MLX (this takes a while...)")
|
message(STATUS "Setting up MLX (this takes a while...)")
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx)
|
||||||
@@ -187,10 +202,36 @@ if(MLX_ENGINE)
|
|||||||
# Find CUDA toolkit if MLX is built with CUDA support
|
# Find CUDA toolkit if MLX is built with CUDA support
|
||||||
find_package(CUDAToolkit)
|
find_package(CUDAToolkit)
|
||||||
|
|
||||||
|
# Build list of directories for runtime dependency resolution
|
||||||
|
set(MLX_RUNTIME_DIRS ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR})
|
||||||
|
# Add cuDNN bin paths for DLLs (Windows MLX CUDA builds)
|
||||||
|
# CUDNN_ROOT_DIR is the standard CMake variable for cuDNN location
|
||||||
|
if(DEFINED ENV{CUDNN_ROOT_DIR})
|
||||||
|
# cuDNN 9.x has versioned subdirectories under bin/ (e.g., bin/13.0/)
|
||||||
|
file(GLOB CUDNN_BIN_SUBDIRS "$ENV{CUDNN_ROOT_DIR}/bin/*")
|
||||||
|
list(APPEND MLX_RUNTIME_DIRS ${CUDNN_BIN_SUBDIRS})
|
||||||
|
endif()
|
||||||
|
# Add build output directory and MLX dependency build directories
|
||||||
|
list(APPEND MLX_RUNTIME_DIRS ${OLLAMA_BUILD_DIR})
|
||||||
|
# OpenBLAS DLL location (pre-built zip extracts into openblas-src/bin/)
|
||||||
|
list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/openblas-src/bin)
|
||||||
|
# NCCL: on Linux, if real NCCL is found, cmake bundles libnccl.so via the
|
||||||
|
# regex below. If NCCL is not found, MLX links a static stub (OBJECT lib)
|
||||||
|
# so there is no runtime dependency. This path covers the stub build dir
|
||||||
|
# for windows so we include the DLL in our dependencies.
|
||||||
|
list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/distributed/nccl/nccl_stub-prefix/src/nccl_stub-build/Release)
|
||||||
|
|
||||||
|
# Base regexes for runtime dependencies (cross-platform)
|
||||||
|
set(MLX_INCLUDE_REGEXES cublas cublasLt cudart cufft nvrtc nvrtc-builtins cudnn nccl openblas gfortran)
|
||||||
|
# On Windows, also include dl.dll (dlfcn-win32 POSIX emulation layer)
|
||||||
|
if(WIN32)
|
||||||
|
list(APPEND MLX_INCLUDE_REGEXES "^dl\\.dll$")
|
||||||
|
endif()
|
||||||
|
|
||||||
install(TARGETS mlx mlxc
|
install(TARGETS mlx mlxc
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
DIRECTORIES ${MLX_RUNTIME_DIRS}
|
||||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
|
PRE_INCLUDE_REGEXES ${MLX_INCLUDE_REGEXES}
|
||||||
PRE_EXCLUDE_REGEXES ".*"
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||||
@@ -205,13 +246,117 @@ if(MLX_ENGINE)
|
|||||||
COMPONENT MLX)
|
COMPONENT MLX)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Manually install cudart and cublas since they might not be picked up as direct dependencies
|
# Install headers for NVRTC JIT compilation at runtime.
|
||||||
|
# MLX's own install rules use the default component so they get skipped by
|
||||||
|
# --component MLX. Headers are installed alongside libmlx in OLLAMA_INSTALL_DIR.
|
||||||
|
#
|
||||||
|
# Layout:
|
||||||
|
# ${OLLAMA_INSTALL_DIR}/include/cccl/{cuda,nv}/ — CCCL headers
|
||||||
|
# ${OLLAMA_INSTALL_DIR}/include/*.h — CUDA toolkit headers
|
||||||
|
#
|
||||||
|
# MLX's jit_module.cpp resolves CCCL via
|
||||||
|
# current_binary_dir()[.parent_path()] / "include" / "cccl"
|
||||||
|
# On Linux, MLX's jit_module.cpp resolves CCCL via
|
||||||
|
# current_binary_dir().parent_path() / "include" / "cccl", so we create a
|
||||||
|
# symlink from lib/ollama/include -> ${OLLAMA_RUNNER_DIR}/include
|
||||||
|
# This will need refinement if we add multiple CUDA versions for MLX in the future.
|
||||||
|
# CUDA runtime headers are found via CUDA_PATH env var (set by mlxrunner).
|
||||||
|
if(EXISTS ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda)
|
||||||
|
install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda
|
||||||
|
DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl
|
||||||
|
COMPONENT MLX)
|
||||||
|
install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/nv
|
||||||
|
DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl
|
||||||
|
COMPONENT MLX)
|
||||||
|
if(NOT WIN32 AND NOT APPLE)
|
||||||
|
install(CODE "
|
||||||
|
set(_link \"${CMAKE_INSTALL_PREFIX}/lib/ollama/include\")
|
||||||
|
set(_target \"${OLLAMA_RUNNER_DIR}/include\")
|
||||||
|
if(NOT EXISTS \${_link})
|
||||||
|
execute_process(COMMAND \${CMAKE_COMMAND} -E create_symlink \${_target} \${_link})
|
||||||
|
endif()
|
||||||
|
" COMPONENT MLX)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Install minimal CUDA toolkit headers needed by MLX JIT kernels.
|
||||||
|
# These are the transitive closure of includes from mlx/backend/cuda/device/*.cuh.
|
||||||
|
# The Go mlxrunner sets CUDA_PATH to OLLAMA_INSTALL_DIR so MLX finds them at
|
||||||
|
# $CUDA_PATH/include/*.h via NVRTC --include-path.
|
||||||
if(CUDAToolkit_FOUND)
|
if(CUDAToolkit_FOUND)
|
||||||
file(GLOB CUDART_LIBS
|
# CUDAToolkit_INCLUDE_DIRS may be a semicolon-separated list
|
||||||
|
# (e.g. ".../include;.../include/cccl"). Find the entry that
|
||||||
|
# contains the CUDA runtime headers we need.
|
||||||
|
set(_cuda_inc "")
|
||||||
|
foreach(_dir ${CUDAToolkit_INCLUDE_DIRS})
|
||||||
|
if(EXISTS "${_dir}/cuda_runtime_api.h")
|
||||||
|
set(_cuda_inc "${_dir}")
|
||||||
|
break()
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
if(NOT _cuda_inc)
|
||||||
|
message(WARNING "Could not find cuda_runtime_api.h in CUDAToolkit_INCLUDE_DIRS: ${CUDAToolkit_INCLUDE_DIRS}")
|
||||||
|
else()
|
||||||
|
set(_dst "${OLLAMA_INSTALL_DIR}/include")
|
||||||
|
set(_MLX_JIT_CUDA_HEADERS
|
||||||
|
builtin_types.h
|
||||||
|
cooperative_groups.h
|
||||||
|
cuda_bf16.h
|
||||||
|
cuda_bf16.hpp
|
||||||
|
cuda_device_runtime_api.h
|
||||||
|
cuda_fp16.h
|
||||||
|
cuda_fp16.hpp
|
||||||
|
cuda_fp8.h
|
||||||
|
cuda_fp8.hpp
|
||||||
|
cuda_runtime_api.h
|
||||||
|
device_types.h
|
||||||
|
driver_types.h
|
||||||
|
math_constants.h
|
||||||
|
surface_types.h
|
||||||
|
texture_types.h
|
||||||
|
vector_functions.h
|
||||||
|
vector_functions.hpp
|
||||||
|
vector_types.h
|
||||||
|
)
|
||||||
|
foreach(_hdr ${_MLX_JIT_CUDA_HEADERS})
|
||||||
|
install(FILES "${_cuda_inc}/${_hdr}"
|
||||||
|
DESTINATION ${_dst}
|
||||||
|
COMPONENT MLX)
|
||||||
|
endforeach()
|
||||||
|
# Subdirectory headers
|
||||||
|
install(DIRECTORY "${_cuda_inc}/cooperative_groups"
|
||||||
|
DESTINATION ${_dst}
|
||||||
|
COMPONENT MLX
|
||||||
|
FILES_MATCHING PATTERN "*.h")
|
||||||
|
install(FILES "${_cuda_inc}/crt/host_defines.h"
|
||||||
|
DESTINATION "${_dst}/crt"
|
||||||
|
COMPONENT MLX)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# On Windows, explicitly install dl.dll (dlfcn-win32 POSIX dlopen emulation)
|
||||||
|
# RUNTIME_DEPENDENCIES auto-excludes it via POST_EXCLUDE_FILES_STRICT because
|
||||||
|
# dlfcn-win32 is a known CMake target with its own install rules (which install
|
||||||
|
# to the wrong destination). We must install it explicitly here.
|
||||||
|
if(WIN32)
|
||||||
|
install(FILES ${OLLAMA_BUILD_DIR}/dl.dll
|
||||||
|
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||||
|
COMPONENT MLX)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Manually install CUDA runtime libraries that MLX loads via dlopen
|
||||||
|
# (not detected by RUNTIME_DEPENDENCIES since they aren't link-time deps)
|
||||||
|
if(CUDAToolkit_FOUND)
|
||||||
|
file(GLOB MLX_CUDA_LIBS
|
||||||
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
|
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
|
||||||
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
|
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*"
|
||||||
if(CUDART_LIBS)
|
"${CUDAToolkit_LIBRARY_DIR}/libcublasLt.so*"
|
||||||
install(FILES ${CUDART_LIBS}
|
"${CUDAToolkit_LIBRARY_DIR}/libnvrtc.so*"
|
||||||
|
"${CUDAToolkit_LIBRARY_DIR}/libnvrtc-builtins.so*"
|
||||||
|
"${CUDAToolkit_LIBRARY_DIR}/libcufft.so*"
|
||||||
|
"${CUDAToolkit_LIBRARY_DIR}/libcudnn.so*")
|
||||||
|
if(MLX_CUDA_LIBS)
|
||||||
|
install(FILES ${MLX_CUDA_LIBS}
|
||||||
DESTINATION ${OLLAMA_INSTALL_DIR}
|
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||||
COMPONENT MLX)
|
COMPONENT MLX)
|
||||||
endif()
|
endif()
|
||||||
|
|||||||
@@ -77,6 +77,15 @@
|
|||||||
"OLLAMA_RUNNER_DIR": "rocm"
|
"OLLAMA_RUNNER_DIR": "rocm"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "ROCm 7",
|
||||||
|
"inherits": [ "ROCm" ],
|
||||||
|
"cacheVariables": {
|
||||||
|
"CMAKE_HIP_FLAGS": "-parallel-jobs=4",
|
||||||
|
"AMDGPU_TARGETS": "gfx942;gfx950;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1151;gfx1200;gfx1201;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-",
|
||||||
|
"OLLAMA_RUNNER_DIR": "rocm"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "Vulkan",
|
"name": "Vulkan",
|
||||||
"inherits": [ "Default" ],
|
"inherits": [ "Default" ],
|
||||||
@@ -103,6 +112,7 @@
|
|||||||
"name": "MLX CUDA 13",
|
"name": "MLX CUDA 13",
|
||||||
"inherits": [ "MLX", "CUDA 13" ],
|
"inherits": [ "MLX", "CUDA 13" ],
|
||||||
"cacheVariables": {
|
"cacheVariables": {
|
||||||
|
"MLX_CUDA_ARCHITECTURES": "86;89;90;90a;100;103;75-virtual;80-virtual;110-virtual;120-virtual;121-virtual",
|
||||||
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
|
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -158,6 +168,11 @@
|
|||||||
"inherits": [ "ROCm" ],
|
"inherits": [ "ROCm" ],
|
||||||
"configurePreset": "ROCm 6"
|
"configurePreset": "ROCm 6"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "ROCm 7",
|
||||||
|
"inherits": [ "ROCm" ],
|
||||||
|
"configurePreset": "ROCm 7"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "Vulkan",
|
"name": "Vulkan",
|
||||||
"targets": [ "ggml-vulkan" ],
|
"targets": [ "ggml-vulkan" ],
|
||||||
|
|||||||
122
Dockerfile
@@ -1,28 +1,23 @@
|
|||||||
# vim: filetype=dockerfile
|
# vim: filetype=dockerfile
|
||||||
|
|
||||||
ARG FLAVOR=${TARGETARCH}
|
ARG FLAVOR=${TARGETARCH}
|
||||||
ARG PARALLEL=8
|
|
||||||
|
|
||||||
ARG ROCMVERSION=6.3.3
|
ARG ROCMVERSION=7.2
|
||||||
ARG JETPACK5VERSION=r35.4.1
|
ARG JETPACK5VERSION=r35.4.1
|
||||||
ARG JETPACK6VERSION=r36.4.0
|
ARG JETPACK6VERSION=r36.4.0
|
||||||
ARG CMAKEVERSION=3.31.2
|
ARG CMAKEVERSION=3.31.2
|
||||||
|
ARG NINJAVERSION=1.12.1
|
||||||
ARG VULKANVERSION=1.4.321.1
|
ARG VULKANVERSION=1.4.321.1
|
||||||
|
|
||||||
|
# Default empty stages for local MLX source overrides.
|
||||||
|
# Override with: docker build --build-context local-mlx=../mlx --build-context local-mlx-c=../mlx-c
|
||||||
|
FROM scratch AS local-mlx
|
||||||
|
FROM scratch AS local-mlx-c
|
||||||
|
|
||||||
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
|
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
|
||||||
RUN dnf install -y yum-utils ccache gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ gcc-toolset-11-binutils \
|
RUN dnf install -y yum-utils ccache gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ gcc-toolset-11-binutils \
|
||||||
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
||||||
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
||||||
ARG VULKANVERSION
|
|
||||||
RUN wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
|
|
||||||
&& tar xvf /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
|
|
||||||
&& dnf -y install ninja-build \
|
|
||||||
&& ln -s /usr/bin/python3 /usr/bin/python \
|
|
||||||
&& /${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \
|
|
||||||
&& /${VULKANVERSION}/vulkansdk -j 8 shaderc
|
|
||||||
RUN cp -r /${VULKANVERSION}/x86_64/include/* /usr/local/include/ \
|
|
||||||
&& cp -r /${VULKANVERSION}/x86_64/lib/* /usr/local/lib
|
|
||||||
ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH
|
|
||||||
|
|
||||||
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
|
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
|
||||||
# install epel-release for ccache
|
# install epel-release for ccache
|
||||||
@@ -33,100 +28,119 @@ ENV CC=clang CXX=clang++
|
|||||||
|
|
||||||
FROM base-${TARGETARCH} AS base
|
FROM base-${TARGETARCH} AS base
|
||||||
ARG CMAKEVERSION
|
ARG CMAKEVERSION
|
||||||
|
ARG NINJAVERSION
|
||||||
RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||||
|
RUN dnf install -y unzip \
|
||||||
|
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux$([ "$(uname -m)" = "aarch64" ] && echo "-aarch64").zip \
|
||||||
|
&& unzip /tmp/ninja.zip -d /usr/local/bin \
|
||||||
|
&& rm /tmp/ninja.zip
|
||||||
|
ENV CMAKE_GENERATOR=Ninja
|
||||||
ENV LDFLAGS=-s
|
ENV LDFLAGS=-s
|
||||||
|
|
||||||
FROM base AS cpu
|
FROM base AS cpu
|
||||||
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
|
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
|
||||||
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
||||||
ARG PARALLEL
|
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CPU' \
|
cmake --preset 'CPU' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'CPU' \
|
&& cmake --build --preset 'CPU' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CPU --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CPU --strip
|
||||||
|
|
||||||
FROM base AS cuda-11
|
FROM base AS cuda-11
|
||||||
ARG CUDA11VERSION=11.8
|
ARG CUDA11VERSION=11.8
|
||||||
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
||||||
ARG PARALLEL
|
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 11' \
|
cmake --preset 'CUDA 11' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
|
&& cmake --build --preset 'CUDA 11' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CUDA --strip
|
||||||
|
|
||||||
FROM base AS cuda-12
|
FROM base AS cuda-12
|
||||||
ARG CUDA12VERSION=12.8
|
ARG CUDA12VERSION=12.8
|
||||||
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
||||||
ARG PARALLEL
|
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 12' \
|
cmake --preset 'CUDA 12' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
|
&& cmake --build --preset 'CUDA 12' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CUDA --strip
|
||||||
|
|
||||||
|
|
||||||
FROM base AS cuda-13
|
FROM base AS cuda-13
|
||||||
ARG CUDA13VERSION=13.0
|
ARG CUDA13VERSION=13.0
|
||||||
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
||||||
ARG PARALLEL
|
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 13' \
|
cmake --preset 'CUDA 13' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
|
&& cmake --build --preset 'CUDA 13' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CUDA --strip
|
||||||
|
|
||||||
|
|
||||||
FROM base AS rocm-6
|
FROM base AS rocm-7
|
||||||
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
||||||
ARG PARALLEL
|
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'ROCm 6' \
|
cmake --preset 'ROCm 7' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
|
&& cmake --build --preset 'ROCm 7' -- -l $(nproc) \
|
||||||
&& cmake --install build --component HIP --strip --parallel ${PARALLEL}
|
&& cmake --install build --component HIP --strip
|
||||||
RUN rm -f dist/lib/ollama/rocm/rocblas/library/*gfx90[06]*
|
RUN rm -f dist/lib/ollama/rocm/rocblas/library/*gfx90[06]*
|
||||||
|
|
||||||
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
|
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
|
||||||
ARG CMAKEVERSION
|
ARG CMAKEVERSION
|
||||||
RUN apt-get update && apt-get install -y curl ccache \
|
ARG NINJAVERSION
|
||||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
RUN apt-get update && apt-get install -y curl ccache unzip \
|
||||||
|
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 \
|
||||||
|
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux-aarch64.zip \
|
||||||
|
&& unzip /tmp/ninja.zip -d /usr/local/bin \
|
||||||
|
&& rm /tmp/ninja.zip
|
||||||
|
ENV CMAKE_GENERATOR=Ninja
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
ARG PARALLEL
|
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'JetPack 5' \
|
cmake --preset 'JetPack 5' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 5' \
|
&& cmake --build --preset 'JetPack 5' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CUDA --strip
|
||||||
|
|
||||||
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
|
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
|
||||||
ARG CMAKEVERSION
|
ARG CMAKEVERSION
|
||||||
RUN apt-get update && apt-get install -y curl ccache \
|
ARG NINJAVERSION
|
||||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
RUN apt-get update && apt-get install -y curl ccache unzip \
|
||||||
|
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 \
|
||||||
|
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux-aarch64.zip \
|
||||||
|
&& unzip /tmp/ninja.zip -d /usr/local/bin \
|
||||||
|
&& rm /tmp/ninja.zip
|
||||||
|
ENV CMAKE_GENERATOR=Ninja
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
ARG PARALLEL
|
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'JetPack 6' \
|
cmake --preset 'JetPack 6' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \
|
&& cmake --build --preset 'JetPack 6' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CUDA --strip
|
||||||
|
|
||||||
FROM base AS vulkan
|
FROM base AS vulkan
|
||||||
|
ARG VULKANVERSION
|
||||||
|
RUN ln -s /usr/bin/python3 /usr/bin/python \
|
||||||
|
&& wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk.tar.xz \
|
||||||
|
&& tar xvf /tmp/vulkansdk.tar.xz -C /tmp \
|
||||||
|
&& /tmp/${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \
|
||||||
|
&& /tmp/${VULKANVERSION}/vulkansdk -j 8 shaderc \
|
||||||
|
&& cp -r /tmp/${VULKANVERSION}/x86_64/include/* /usr/local/include/ \
|
||||||
|
&& cp -r /tmp/${VULKANVERSION}/x86_64/lib/* /usr/local/lib \
|
||||||
|
&& cp -r /tmp/${VULKANVERSION}/x86_64/bin/* /usr/local/bin/ \
|
||||||
|
&& rm -rf /tmp/${VULKANVERSION} /tmp/vulkansdk.tar.xz
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'Vulkan' \
|
cmake --preset 'Vulkan' \
|
||||||
&& cmake --build --parallel --preset 'Vulkan' \
|
&& cmake --build --preset 'Vulkan' -- -l $(nproc) \
|
||||||
&& cmake --install build --component Vulkan --strip --parallel 8
|
&& cmake --install build --component Vulkan --strip
|
||||||
|
|
||||||
FROM base AS mlx
|
FROM base AS mlx
|
||||||
ARG CUDA13VERSION=13.0
|
ARG CUDA13VERSION=13.0
|
||||||
@@ -138,20 +152,27 @@ ENV PATH=/usr/local/cuda-13/bin:$PATH
|
|||||||
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
|
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
|
||||||
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
|
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
|
||||||
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
|
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
|
||||||
ARG PARALLEL
|
|
||||||
WORKDIR /go/src/github.com/ollama/ollama
|
WORKDIR /go/src/github.com/ollama/ollama
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
COPY x/imagegen/mlx x/imagegen/mlx
|
COPY x/imagegen/mlx x/imagegen/mlx
|
||||||
COPY go.mod go.sum .
|
COPY go.mod go.sum .
|
||||||
COPY MLX_VERSION .
|
COPY MLX_VERSION MLX_C_VERSION .
|
||||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||||
ENV PATH=/usr/local/go/bin:$PATH
|
ENV PATH=/usr/local/go/bin:$PATH
|
||||||
RUN go mod download
|
RUN go mod download
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
--mount=type=bind,from=local-mlx,target=/tmp/local-mlx \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
|
--mount=type=bind,from=local-mlx-c,target=/tmp/local-mlx-c \
|
||||||
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
|
if [ -f /tmp/local-mlx/CMakeLists.txt ]; then \
|
||||||
|
export OLLAMA_MLX_SOURCE=/tmp/local-mlx; \
|
||||||
|
fi \
|
||||||
|
&& if [ -f /tmp/local-mlx-c/CMakeLists.txt ]; then \
|
||||||
|
export OLLAMA_MLX_C_SOURCE=/tmp/local-mlx-c; \
|
||||||
|
fi \
|
||||||
|
&& cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
||||||
|
&& cmake --build --preset 'MLX CUDA 13' -- -l $(nproc) \
|
||||||
|
&& cmake --install build --component MLX --strip
|
||||||
|
|
||||||
FROM base AS build
|
FROM base AS build
|
||||||
WORKDIR /go/src/github.com/ollama/ollama
|
WORKDIR /go/src/github.com/ollama/ollama
|
||||||
@@ -160,16 +181,14 @@ RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-
|
|||||||
ENV PATH=/usr/local/go/bin:$PATH
|
ENV PATH=/usr/local/go/bin:$PATH
|
||||||
RUN go mod download
|
RUN go mod download
|
||||||
COPY . .
|
COPY . .
|
||||||
# Clone mlx-c headers for CGO (version from MLX_VERSION file)
|
|
||||||
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
|
|
||||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||||
ENV CGO_ENABLED=1
|
ENV CGO_ENABLED=1
|
||||||
ARG CGO_CFLAGS
|
ARG CGO_CFLAGS
|
||||||
ARG CGO_CXXFLAGS
|
ARG CGO_CXXFLAGS
|
||||||
ENV CGO_CFLAGS="${CGO_CFLAGS} -I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
|
ENV CGO_CFLAGS="${CGO_CFLAGS}"
|
||||||
ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}"
|
ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}"
|
||||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||||
go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .
|
go build -trimpath -buildmode=pie -o /bin/ollama .
|
||||||
|
|
||||||
FROM --platform=linux/amd64 scratch AS amd64
|
FROM --platform=linux/amd64 scratch AS amd64
|
||||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||||
@@ -186,10 +205,9 @@ COPY --from=jetpack-5 dist/lib/ollama/ /lib/ollama/
|
|||||||
COPY --from=jetpack-6 dist/lib/ollama/ /lib/ollama/
|
COPY --from=jetpack-6 dist/lib/ollama/ /lib/ollama/
|
||||||
|
|
||||||
FROM scratch AS rocm
|
FROM scratch AS rocm
|
||||||
COPY --from=rocm-6 dist/lib/ollama /lib/ollama
|
COPY --from=rocm-7 dist/lib/ollama /lib/ollama
|
||||||
|
|
||||||
FROM ${FLAVOR} AS archive
|
FROM ${FLAVOR} AS archive
|
||||||
ARG VULKANVERSION
|
|
||||||
COPY --from=cpu dist/lib/ollama /lib/ollama
|
COPY --from=cpu dist/lib/ollama /lib/ollama
|
||||||
COPY --from=build /bin/ollama /bin/ollama
|
COPY --from=build /bin/ollama /bin/ollama
|
||||||
|
|
||||||
|
|||||||
1
MLX_C_VERSION
Normal file
@@ -0,0 +1 @@
|
|||||||
|
0726ca922fc902c4c61ef9c27d94132be418e945
|
||||||
@@ -1 +1 @@
|
|||||||
v0.5.0
|
38ad257088fb2193ad47e527cf6534a689f30943
|
||||||
|
|||||||
@@ -852,6 +852,19 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close thinking block if still open (thinking → tool_use without text in between)
|
||||||
|
if c.thinkingStarted && !c.thinkingDone {
|
||||||
|
c.thinkingDone = true
|
||||||
|
events = append(events, StreamEvent{
|
||||||
|
Event: "content_block_stop",
|
||||||
|
Data: ContentBlockStopEvent{
|
||||||
|
Type: "content_block_stop",
|
||||||
|
Index: c.contentIndex,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
c.contentIndex++
|
||||||
|
}
|
||||||
|
|
||||||
if c.textStarted {
|
if c.textStarted {
|
||||||
events = append(events, StreamEvent{
|
events = append(events, StreamEvent{
|
||||||
Event: "content_block_stop",
|
Event: "content_block_stop",
|
||||||
|
|||||||
@@ -799,6 +799,107 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestStreamConverter_ThinkingDirectlyFollowedByToolCall verifies that when a
|
||||||
|
// model emits a thinking block followed directly by a tool_use block (with no
|
||||||
|
// text block in between), the streaming converter correctly closes the thinking
|
||||||
|
// block and increments the content index before opening the tool_use block.
|
||||||
|
// Previously, the converter reused contentIndex=0 for the tool_use block,
|
||||||
|
// which caused "Content block not found" errors in clients. See #14816.
|
||||||
|
func TestStreamConverter_ThinkingDirectlyFollowedByToolCall(t *testing.T) {
|
||||||
|
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||||
|
|
||||||
|
// First chunk: thinking content (no text)
|
||||||
|
resp1 := api.ChatResponse{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Thinking: "I should call the tool.",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
events1 := conv.Process(resp1)
|
||||||
|
|
||||||
|
// Should have: message_start, content_block_start(thinking), content_block_delta(thinking)
|
||||||
|
if len(events1) < 3 {
|
||||||
|
t.Fatalf("expected at least 3 events for thinking chunk, got %d", len(events1))
|
||||||
|
}
|
||||||
|
if events1[0].Event != "message_start" {
|
||||||
|
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
|
||||||
|
}
|
||||||
|
thinkingStart, ok := events1[1].Data.(ContentBlockStartEvent)
|
||||||
|
if !ok || thinkingStart.ContentBlock.Type != "thinking" {
|
||||||
|
t.Errorf("expected content_block_start(thinking) as second event, got %+v", events1[1])
|
||||||
|
}
|
||||||
|
if thinkingStart.Index != 0 {
|
||||||
|
t.Errorf("expected thinking block at index 0, got %d", thinkingStart.Index)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second chunk: tool call (no text between thinking and tool)
|
||||||
|
resp2 := api.ChatResponse{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_abc",
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "ask_user",
|
||||||
|
Arguments: testArgs(map[string]any{"question": "cats or dogs?"}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "stop",
|
||||||
|
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
|
||||||
|
}
|
||||||
|
events2 := conv.Process(resp2)
|
||||||
|
|
||||||
|
// Expect: content_block_stop(index=0), content_block_start(tool_use, index=1),
|
||||||
|
// content_block_delta(input_json_delta, index=1), content_block_stop(index=1),
|
||||||
|
// message_delta, message_stop
|
||||||
|
var thinkingStop, toolStart, toolDelta, toolStop *StreamEvent
|
||||||
|
for i := range events2 {
|
||||||
|
e := &events2[i]
|
||||||
|
switch e.Event {
|
||||||
|
case "content_block_stop":
|
||||||
|
if stop, ok := e.Data.(ContentBlockStopEvent); ok {
|
||||||
|
if stop.Index == 0 && thinkingStop == nil {
|
||||||
|
thinkingStop = e
|
||||||
|
} else if stop.Index == 1 {
|
||||||
|
toolStop = e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "content_block_start":
|
||||||
|
if start, ok := e.Data.(ContentBlockStartEvent); ok && start.ContentBlock.Type == "tool_use" {
|
||||||
|
toolStart = e
|
||||||
|
}
|
||||||
|
case "content_block_delta":
|
||||||
|
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok && delta.Delta.Type == "input_json_delta" {
|
||||||
|
toolDelta = e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if thinkingStop == nil {
|
||||||
|
t.Error("expected content_block_stop for thinking block (index 0)")
|
||||||
|
}
|
||||||
|
if toolStart == nil {
|
||||||
|
t.Fatal("expected content_block_start for tool_use block")
|
||||||
|
}
|
||||||
|
if start, ok := toolStart.Data.(ContentBlockStartEvent); !ok || start.Index != 1 {
|
||||||
|
t.Errorf("expected tool_use block at index 1, got %+v", toolStart.Data)
|
||||||
|
}
|
||||||
|
if toolDelta == nil {
|
||||||
|
t.Fatal("expected input_json_delta event for tool call")
|
||||||
|
}
|
||||||
|
if delta, ok := toolDelta.Data.(ContentBlockDeltaEvent); !ok || delta.Index != 1 {
|
||||||
|
t.Errorf("expected tool delta at index 1, got %+v", toolDelta.Data)
|
||||||
|
}
|
||||||
|
if toolStop == nil {
|
||||||
|
t.Error("expected content_block_stop for tool_use block (index 1)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||||
// Test that unmarshalable arguments (like channels) are handled gracefully
|
// Test that unmarshalable arguments (like channels) are handled gracefully
|
||||||
// and don't cause a panic or corrupt stream
|
// and don't cause a panic or corrupt stream
|
||||||
|
|||||||
@@ -476,25 +476,3 @@ func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
|
|||||||
}
|
}
|
||||||
return &resp, nil
|
return &resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AliasRequest is the request body for creating or updating a model alias.
|
|
||||||
type AliasRequest struct {
|
|
||||||
Alias string `json:"alias"`
|
|
||||||
Target string `json:"target"`
|
|
||||||
PrefixMatching bool `json:"prefix_matching,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetAliasExperimental creates or updates a model alias via the experimental aliases API.
|
|
||||||
func (c *Client) SetAliasExperimental(ctx context.Context, req *AliasRequest) error {
|
|
||||||
return c.do(ctx, http.MethodPost, "/api/experimental/aliases", req, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AliasDeleteRequest is the request body for deleting a model alias.
|
|
||||||
type AliasDeleteRequest struct {
|
|
||||||
Alias string `json:"alias"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteAliasExperimental deletes a model alias via the experimental aliases API.
|
|
||||||
func (c *Client) DeleteAliasExperimental(ctx context.Context, req *AliasDeleteRequest) error {
|
|
||||||
return c.do(ctx, http.MethodDelete, "/api/experimental/aliases", req, nil)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -841,7 +841,8 @@ type CloudStatus struct {
|
|||||||
|
|
||||||
// StatusResponse is the response from [Client.CloudStatusExperimental].
|
// StatusResponse is the response from [Client.CloudStatusExperimental].
|
||||||
type StatusResponse struct {
|
type StatusResponse struct {
|
||||||
Cloud CloudStatus `json:"cloud"`
|
Cloud CloudStatus `json:"cloud"`
|
||||||
|
ContextLength int `json:"context_length,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateResponse is the response passed into [GenerateResponseFunc].
|
// GenerateResponse is the response passed into [GenerateResponseFunc].
|
||||||
|
|||||||
@@ -214,6 +214,7 @@ export default function Settings() {
|
|||||||
Agent: false,
|
Agent: false,
|
||||||
Tools: false,
|
Tools: false,
|
||||||
ContextLength: 0,
|
ContextLength: 0,
|
||||||
|
AutoUpdateEnabled: true,
|
||||||
});
|
});
|
||||||
updateSettingsMutation.mutate(defaultSettings);
|
updateSettingsMutation.mutate(defaultSettings);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -155,7 +155,7 @@ func (s *Server) ollamaProxy() http.Handler {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
target := envconfig.Host()
|
target := envconfig.ConnectableHost()
|
||||||
s.log().Info("configuring ollama proxy", "target", target.String())
|
s.log().Info("configuring ollama proxy", "target", target.String())
|
||||||
|
|
||||||
newProxy := httputil.NewSingleHostReverseProxy(target)
|
newProxy := httputil.NewSingleHostReverseProxy(target)
|
||||||
|
|||||||
@@ -1,27 +1,31 @@
|
|||||||
Ollama Benchmark Tool
|
Ollama Benchmark Tool
|
||||||
---------------------
|
---------------------
|
||||||
|
|
||||||
A Go-based command-line tool for benchmarking Ollama models with configurable parameters and multiple output formats.
|
A Go-based command-line tool for benchmarking Ollama models with configurable parameters, warmup phases, TTFT tracking, VRAM monitoring, and benchstat/CSV output.
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
* Benchmark multiple models in a single run
|
* Benchmark multiple models in a single run
|
||||||
* Support for both text and image prompts
|
* Support for both text and image prompts
|
||||||
* Configurable generation parameters (temperature, max tokens, seed, etc.)
|
* Configurable generation parameters (temperature, max tokens, seed, etc.)
|
||||||
* Supports benchstat and CSV output formats
|
* Warmup phase before timed epochs to stabilize measurements
|
||||||
* Detailed performance metrics (prefill, generate, load, total durations)
|
* Time-to-first-token (TTFT) tracking per epoch
|
||||||
|
* Model metadata display (parameter size, quantization level, family)
|
||||||
|
* VRAM and CPU memory usage tracking via running process info
|
||||||
|
* Controlled prompt token length for reproducible benchmarks
|
||||||
|
* Benchstat and CSV output formats
|
||||||
|
|
||||||
## Building from Source
|
## Building from Source
|
||||||
|
|
||||||
```
|
```
|
||||||
go build -o ollama-bench bench.go
|
go build -o ollama-bench ./cmd/bench
|
||||||
./ollama-bench -model gpt-oss:20b -epochs 6 -format csv
|
./ollama-bench -model gemma3 -epochs 6 -format csv
|
||||||
```
|
```
|
||||||
|
|
||||||
Using Go Run (without building)
|
Using Go Run (without building)
|
||||||
|
|
||||||
```
|
```
|
||||||
go run bench.go -model gpt-oss:20b -epochs 3
|
go run ./cmd/bench -model gemma3 -epochs 3
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
@@ -45,10 +49,16 @@ benchstat -col /name gemma.bench
|
|||||||
./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
|
./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Controlled Prompt Length
|
||||||
|
|
||||||
|
```
|
||||||
|
./ollama-bench -model gemma3 -epochs 6 -prompt-tokens 512
|
||||||
|
```
|
||||||
|
|
||||||
### Advanced Example
|
### Advanced Example
|
||||||
|
|
||||||
```
|
```
|
||||||
./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv
|
./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -warmup 2 -format csv -output results.csv
|
||||||
```
|
```
|
||||||
|
|
||||||
## Command Line Options
|
## Command Line Options
|
||||||
@@ -56,41 +66,48 @@ benchstat -col /name gemma.bench
|
|||||||
| Option | Description | Default |
|
| Option | Description | Default |
|
||||||
|----------|-------------|---------|
|
|----------|-------------|---------|
|
||||||
| -model | Comma-separated list of models to benchmark | (required) |
|
| -model | Comma-separated list of models to benchmark | (required) |
|
||||||
| -epochs | Number of iterations per model | 1 |
|
| -epochs | Number of iterations per model | 6 |
|
||||||
| -max-tokens | Maximum tokens for model response | 0 (unlimited) |
|
| -max-tokens | Maximum tokens for model response | 200 |
|
||||||
| -temperature | Temperature parameter | 0.0 |
|
| -temperature | Temperature parameter | 0.0 |
|
||||||
| -seed | Random seed | 0 (random) |
|
| -seed | Random seed | 0 (random) |
|
||||||
| -timeout | Timeout in seconds | 300 |
|
| -timeout | Timeout in seconds | 300 |
|
||||||
| -p | Prompt text | "Write a long story." |
|
| -p | Prompt text | (default story prompt) |
|
||||||
| -image | Image file to include in prompt | |
|
| -image | Image file to include in prompt | |
|
||||||
| -k | Keep-alive duration in seconds | 0 |
|
| -k | Keep-alive duration in seconds | 0 |
|
||||||
| -format | Output format (benchstat, csv) | benchstat |
|
| -format | Output format (benchstat, csv) | benchstat |
|
||||||
| -output | Output file for results | "" (stdout) |
|
| -output | Output file for results | "" (stdout) |
|
||||||
|
| -warmup | Number of warmup requests before timing | 1 |
|
||||||
|
| -prompt-tokens | Generate prompt targeting ~N tokens (0 = use -p) | 0 |
|
||||||
| -v | Verbose mode | false |
|
| -v | Verbose mode | false |
|
||||||
| -debug | Show debug information | false |
|
| -debug | Show debug information | false |
|
||||||
|
|
||||||
## Output Formats
|
## Output Formats
|
||||||
|
|
||||||
### Markdown Format
|
### Benchstat Format (default)
|
||||||
|
|
||||||
The default markdown format is suitable for copying and pasting into a GitHub issue and will look like:
|
Compatible with Go's benchstat tool for statistical analysis. Uses one value/unit pair per line, standard `ns/op` for timing metrics, and `ns/token` for throughput. Each epoch produces one set of lines -- benchstat aggregates across repeated runs to compute statistics.
|
||||||
```
|
|
||||||
Model | Step | Count | Duration | nsPerToken | tokensPerSec |
|
|
||||||
|-------|------|-------|----------|------------|--------------|
|
|
||||||
| gpt-oss:20b | prefill | 124 | 30.006458ms | 241987.56 | 4132.44 |
|
|
||||||
| gpt-oss:20b | generate | 200 | 2.646843954s | 13234219.77 | 75.56 |
|
|
||||||
| gpt-oss:20b | load | 1 | 121.674208ms | - | - |
|
|
||||||
| gpt-oss:20b | total | 1 | 2.861047625s | - | - |
|
|
||||||
```
|
|
||||||
|
|
||||||
### Benchstat Format
|
|
||||||
|
|
||||||
Compatible with Go's benchstat tool for statistical analysis:
|
|
||||||
|
|
||||||
```
|
```
|
||||||
BenchmarkModel/name=gpt-oss:20b/step=prefill 128 78125.00 ns/token 12800.00 token/sec
|
# Model: gemma3 | Params: 4.3B | Quant: Q4_K_M | Family: gemma3 | Size: 4080218931 | VRAM: 4080218931
|
||||||
BenchmarkModel/name=gpt-oss:20b/step=generate 512 19531.25 ns/token 51200.00 token/sec
|
BenchmarkModel/name=gemma3/step=prefill 1 78125.00 ns/token 12800.00 token/sec
|
||||||
BenchmarkModel/name=gpt-oss:20b/step=load 1 1500000000 ns/request
|
BenchmarkModel/name=gemma3/step=generate 1 19531.25 ns/token 51200.00 token/sec
|
||||||
|
BenchmarkModel/name=gemma3/step=ttft 1 45123000 ns/op
|
||||||
|
BenchmarkModel/name=gemma3/step=load 1 1500000000 ns/op
|
||||||
|
BenchmarkModel/name=gemma3/step=total 1 2861047625 ns/op
|
||||||
|
```
|
||||||
|
|
||||||
|
Use with benchstat:
|
||||||
|
```
|
||||||
|
./ollama-bench -model gemma3 -epochs 6 > gemma3.bench
|
||||||
|
benchstat -col /step gemma3.bench
|
||||||
|
```
|
||||||
|
|
||||||
|
Compare two runs:
|
||||||
|
```
|
||||||
|
./ollama-bench -model gemma3 -epochs 6 > before.bench
|
||||||
|
# ... make changes ...
|
||||||
|
./ollama-bench -model gemma3 -epochs 6 > after.bench
|
||||||
|
benchstat before.bench after.bench
|
||||||
```
|
```
|
||||||
|
|
||||||
### CSV Format
|
### CSV Format
|
||||||
@@ -99,17 +116,28 @@ Machine-readable comma-separated values:
|
|||||||
|
|
||||||
```
|
```
|
||||||
NAME,STEP,COUNT,NS_PER_COUNT,TOKEN_PER_SEC
|
NAME,STEP,COUNT,NS_PER_COUNT,TOKEN_PER_SEC
|
||||||
gpt-oss:20b,prefill,128,78125.00,12800.00
|
# Model: gemma3 | Params: 4.3B | Quant: Q4_K_M | Family: gemma3 | Size: 4080218931 | VRAM: 4080218931
|
||||||
gpt-oss:20b,generate,512,19531.25,51200.00
|
gemma3,prefill,128,78125.00,12800.00
|
||||||
gpt-oss:20b,load,1,1500000000,0
|
gemma3,generate,512,19531.25,51200.00
|
||||||
|
gemma3,ttft,1,45123000,0
|
||||||
|
gemma3,load,1,1500000000,0
|
||||||
|
gemma3,total,1,2861047625,0
|
||||||
```
|
```
|
||||||
|
|
||||||
## Metrics Explained
|
## Metrics Explained
|
||||||
|
|
||||||
The tool reports four types of metrics for each model:
|
The tool reports the following metrics for each epoch:
|
||||||
|
|
||||||
* prefill: Time spent processing the prompt
|
* **prefill**: Time spent processing the prompt (ns/token)
|
||||||
* generate: Time spent generating the response
|
* **generate**: Time spent generating the response (ns/token)
|
||||||
* load: Model loading time (one-time cost)
|
* **ttft**: Time to first token -- latency from request start to first response content
|
||||||
* total: Total request duration
|
* **load**: Model loading time (one-time cost)
|
||||||
|
* **total**: Total request duration
|
||||||
|
|
||||||
|
Additionally, the model info comment line (displayed once per model before epochs) includes:
|
||||||
|
|
||||||
|
* **Params**: Model parameter count (e.g., 4.3B)
|
||||||
|
* **Quant**: Quantization level (e.g., Q4_K_M)
|
||||||
|
* **Family**: Model family (e.g., gemma3)
|
||||||
|
* **Size**: Total model memory in bytes
|
||||||
|
* **VRAM**: GPU memory used by the loaded model (when Size > VRAM, the difference is CPU spill)
|
||||||
|
|||||||
@@ -17,19 +17,21 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type flagOptions struct {
|
type flagOptions struct {
|
||||||
models *string
|
models *string
|
||||||
epochs *int
|
epochs *int
|
||||||
maxTokens *int
|
maxTokens *int
|
||||||
temperature *float64
|
temperature *float64
|
||||||
seed *int
|
seed *int
|
||||||
timeout *int
|
timeout *int
|
||||||
prompt *string
|
prompt *string
|
||||||
imageFile *string
|
imageFile *string
|
||||||
keepAlive *float64
|
keepAlive *float64
|
||||||
format *string
|
format *string
|
||||||
outputFile *string
|
outputFile *string
|
||||||
debug *bool
|
debug *bool
|
||||||
verbose *bool
|
verbose *bool
|
||||||
|
warmup *int
|
||||||
|
promptTokens *int
|
||||||
}
|
}
|
||||||
|
|
||||||
type Metrics struct {
|
type Metrics struct {
|
||||||
@@ -39,48 +41,169 @@ type Metrics struct {
|
|||||||
Duration time.Duration
|
Duration time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
var once sync.Once
|
type ModelInfo struct {
|
||||||
|
Name string
|
||||||
|
ParameterSize string
|
||||||
|
QuantizationLevel string
|
||||||
|
Family string
|
||||||
|
SizeBytes int64
|
||||||
|
VRAMBytes int64
|
||||||
|
}
|
||||||
|
|
||||||
const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.`
|
const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.`
|
||||||
|
|
||||||
|
// Word list for generating prompts targeting a specific token count.
|
||||||
|
var promptWordList = []string{
|
||||||
|
"the", "quick", "brown", "fox", "jumps", "over", "lazy", "dog",
|
||||||
|
"a", "bright", "sunny", "day", "in", "the", "meadow", "where",
|
||||||
|
"flowers", "bloom", "and", "birds", "sing", "their", "morning",
|
||||||
|
"songs", "while", "gentle", "breeze", "carries", "sweet", "scent",
|
||||||
|
"of", "pine", "trees", "across", "rolling", "hills", "toward",
|
||||||
|
"distant", "mountains", "covered", "with", "fresh", "snow",
|
||||||
|
"beneath", "clear", "blue", "sky", "children", "play", "near",
|
||||||
|
"old", "stone", "bridge", "that", "crosses", "winding", "river",
|
||||||
|
}
|
||||||
|
|
||||||
|
func generatePromptForTokenCount(targetTokens int, epoch int) string {
|
||||||
|
// ~1.3 tokens per word heuristic
|
||||||
|
targetWords := int(float64(targetTokens) / 1.3)
|
||||||
|
if targetWords < 1 {
|
||||||
|
targetWords = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vary the starting offset by epoch to defeat KV cache prefix matching
|
||||||
|
offset := epoch * 7 // stride by a prime to get good distribution
|
||||||
|
n := len(promptWordList)
|
||||||
|
words := make([]string, targetWords)
|
||||||
|
for i := range words {
|
||||||
|
words[i] = promptWordList[((i+offset)%n+n)%n]
|
||||||
|
}
|
||||||
|
return strings.Join(words, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildGenerateRequest(model string, fOpt flagOptions, imgData api.ImageData, epoch int) *api.GenerateRequest {
|
||||||
|
options := make(map[string]interface{})
|
||||||
|
if *fOpt.maxTokens > 0 {
|
||||||
|
options["num_predict"] = *fOpt.maxTokens
|
||||||
|
}
|
||||||
|
options["temperature"] = *fOpt.temperature
|
||||||
|
if fOpt.seed != nil && *fOpt.seed > 0 {
|
||||||
|
options["seed"] = *fOpt.seed
|
||||||
|
}
|
||||||
|
|
||||||
|
var keepAliveDuration *api.Duration
|
||||||
|
if *fOpt.keepAlive > 0 {
|
||||||
|
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
|
||||||
|
keepAliveDuration = &duration
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt := *fOpt.prompt
|
||||||
|
if *fOpt.promptTokens > 0 {
|
||||||
|
prompt = generatePromptForTokenCount(*fOpt.promptTokens, epoch)
|
||||||
|
} else {
|
||||||
|
// Vary the prompt per epoch to defeat KV cache prefix matching
|
||||||
|
prompt = fmt.Sprintf("[%d] %s", epoch, prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: model,
|
||||||
|
Prompt: prompt,
|
||||||
|
Raw: true,
|
||||||
|
Options: options,
|
||||||
|
KeepAlive: keepAliveDuration,
|
||||||
|
}
|
||||||
|
|
||||||
|
if imgData != nil {
|
||||||
|
req.Images = []api.ImageData{imgData}
|
||||||
|
}
|
||||||
|
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchModelInfo(ctx context.Context, client *api.Client, model string) ModelInfo {
|
||||||
|
info := ModelInfo{Name: model}
|
||||||
|
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "WARNING: Could not fetch model info for '%s': %v\n", model, err)
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
info.ParameterSize = resp.Details.ParameterSize
|
||||||
|
info.QuantizationLevel = resp.Details.QuantizationLevel
|
||||||
|
info.Family = resp.Details.Family
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchMemoryUsage(ctx context.Context, client *api.Client, model string) (size, vram int64) {
|
||||||
|
resp, err := client.ListRunning(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
||||||
|
fmt.Fprintf(os.Stderr, "WARNING: Could not fetch memory usage: %v\n", err)
|
||||||
|
}
|
||||||
|
return 0, 0
|
||||||
|
}
|
||||||
|
for _, m := range resp.Models {
|
||||||
|
if m.Name == model || m.Model == model {
|
||||||
|
return m.Size, m.SizeVRAM
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Try prefix match (model names may include :latest or tags)
|
||||||
|
for _, m := range resp.Models {
|
||||||
|
if strings.HasPrefix(m.Name, model) || strings.HasPrefix(m.Model, model) {
|
||||||
|
return m.Size, m.SizeVRAM
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func outputFormatHeader(w io.Writer, format string, verbose bool) {
|
||||||
|
switch format {
|
||||||
|
case "benchstat":
|
||||||
|
if verbose {
|
||||||
|
fmt.Fprintf(w, "goos: %s\n", runtime.GOOS)
|
||||||
|
fmt.Fprintf(w, "goarch: %s\n", runtime.GOARCH)
|
||||||
|
}
|
||||||
|
case "csv":
|
||||||
|
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
|
||||||
|
fmt.Fprintln(w, strings.Join(headings, ","))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func outputModelInfo(w io.Writer, format string, info ModelInfo) {
|
||||||
|
params := cmp.Or(info.ParameterSize, "unknown")
|
||||||
|
quant := cmp.Or(info.QuantizationLevel, "unknown")
|
||||||
|
family := cmp.Or(info.Family, "unknown")
|
||||||
|
|
||||||
|
memStr := ""
|
||||||
|
if info.SizeBytes > 0 {
|
||||||
|
memStr = fmt.Sprintf(" | Size: %d | VRAM: %d", info.SizeBytes, info.VRAMBytes)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(w, "# Model: %s | Params: %s | Quant: %s | Family: %s%s\n",
|
||||||
|
info.Name, params, quant, family, memStr)
|
||||||
|
}
|
||||||
|
|
||||||
func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) {
|
func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) {
|
||||||
switch format {
|
switch format {
|
||||||
case "benchstat":
|
case "benchstat":
|
||||||
if verbose {
|
|
||||||
printHeader := func() {
|
|
||||||
fmt.Fprintf(w, "sysname: %s\n", runtime.GOOS)
|
|
||||||
fmt.Fprintf(w, "machine: %s\n", runtime.GOARCH)
|
|
||||||
}
|
|
||||||
once.Do(printHeader)
|
|
||||||
}
|
|
||||||
for _, m := range metrics {
|
for _, m := range metrics {
|
||||||
if m.Step == "generate" || m.Step == "prefill" {
|
if m.Step == "generate" || m.Step == "prefill" {
|
||||||
if m.Count > 0 {
|
if m.Count > 0 {
|
||||||
nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
||||||
tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
||||||
|
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 %.2f ns/token %.2f token/sec\n",
|
||||||
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d %.2f ns/token %.2f token/sec\n",
|
m.Model, m.Step, nsPerToken, tokensPerSec)
|
||||||
m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
|
|
||||||
} else {
|
} else {
|
||||||
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d 0 ns/token 0 token/sec\n",
|
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 0 ns/token 0 token/sec\n",
|
||||||
m.Model, m.Step, m.Count)
|
m.Model, m.Step)
|
||||||
}
|
}
|
||||||
|
} else if m.Step == "ttft" {
|
||||||
|
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=ttft 1 %d ns/op\n",
|
||||||
|
m.Model, m.Duration.Nanoseconds())
|
||||||
} else {
|
} else {
|
||||||
var suffix string
|
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 %d ns/op\n",
|
||||||
if m.Step == "load" {
|
m.Model, m.Step, m.Duration.Nanoseconds())
|
||||||
suffix = "/step=load"
|
|
||||||
}
|
|
||||||
fmt.Fprintf(w, "BenchmarkModel/name=%s%s 1 %d ns/request\n",
|
|
||||||
m.Model, suffix, m.Duration.Nanoseconds())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "csv":
|
case "csv":
|
||||||
printHeader := func() {
|
|
||||||
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
|
|
||||||
fmt.Fprintln(w, strings.Join(headings, ","))
|
|
||||||
}
|
|
||||||
once.Do(printHeader)
|
|
||||||
|
|
||||||
for _, m := range metrics {
|
for _, m := range metrics {
|
||||||
if m.Step == "generate" || m.Step == "prefill" {
|
if m.Step == "generate" || m.Step == "prefill" {
|
||||||
var nsPerToken float64
|
var nsPerToken float64
|
||||||
@@ -94,39 +217,14 @@ func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool)
|
|||||||
fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds())
|
fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "markdown":
|
|
||||||
printHeader := func() {
|
|
||||||
fmt.Fprintln(w, "| Model | Step | Count | Duration | nsPerToken | tokensPerSec |")
|
|
||||||
fmt.Fprintln(w, "|-------|------|-------|----------|------------|--------------|")
|
|
||||||
}
|
|
||||||
once.Do(printHeader)
|
|
||||||
|
|
||||||
for _, m := range metrics {
|
|
||||||
var nsPerToken, tokensPerSec float64
|
|
||||||
var nsPerTokenStr, tokensPerSecStr string
|
|
||||||
|
|
||||||
if m.Step == "generate" || m.Step == "prefill" {
|
|
||||||
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
|
||||||
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
|
||||||
nsPerTokenStr = fmt.Sprintf("%.2f", nsPerToken)
|
|
||||||
tokensPerSecStr = fmt.Sprintf("%.2f", tokensPerSec)
|
|
||||||
} else {
|
|
||||||
nsPerTokenStr = "-"
|
|
||||||
tokensPerSecStr = "-"
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Fprintf(w, "| %s | %s | %d | %v | %s | %s |\n",
|
|
||||||
m.Model, m.Step, m.Count, m.Duration, nsPerTokenStr, tokensPerSecStr)
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format)
|
fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkChat(fOpt flagOptions) error {
|
func BenchmarkModel(fOpt flagOptions) error {
|
||||||
models := strings.Split(*fOpt.models, ",")
|
models := strings.Split(*fOpt.models, ",")
|
||||||
|
|
||||||
// todo - add multi-image support
|
|
||||||
var imgData api.ImageData
|
var imgData api.ImageData
|
||||||
var err error
|
var err error
|
||||||
if *fOpt.imageFile != "" {
|
if *fOpt.imageFile != "" {
|
||||||
@@ -158,71 +256,124 @@ func BenchmarkChat(fOpt flagOptions) error {
|
|||||||
out = f
|
out = f
|
||||||
}
|
}
|
||||||
|
|
||||||
|
outputFormatHeader(out, *fOpt.format, *fOpt.verbose)
|
||||||
|
|
||||||
|
// Log prompt-tokens info in debug mode
|
||||||
|
if *fOpt.debug && *fOpt.promptTokens > 0 {
|
||||||
|
prompt := generatePromptForTokenCount(*fOpt.promptTokens, 0)
|
||||||
|
wordCount := len(strings.Fields(prompt))
|
||||||
|
fmt.Fprintf(os.Stderr, "Generated prompt targeting ~%d tokens (%d words, varied per epoch)\n", *fOpt.promptTokens, wordCount)
|
||||||
|
}
|
||||||
|
|
||||||
for _, model := range models {
|
for _, model := range models {
|
||||||
for range *fOpt.epochs {
|
// Fetch model info
|
||||||
options := make(map[string]interface{})
|
infoCtx, infoCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
if *fOpt.maxTokens > 0 {
|
info := fetchModelInfo(infoCtx, client, model)
|
||||||
options["num_predict"] = *fOpt.maxTokens
|
infoCancel()
|
||||||
}
|
|
||||||
options["temperature"] = *fOpt.temperature
|
|
||||||
if fOpt.seed != nil && *fOpt.seed > 0 {
|
|
||||||
options["seed"] = *fOpt.seed
|
|
||||||
}
|
|
||||||
|
|
||||||
var keepAliveDuration *api.Duration
|
|
||||||
if *fOpt.keepAlive > 0 {
|
|
||||||
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
|
|
||||||
keepAliveDuration = &duration
|
|
||||||
}
|
|
||||||
|
|
||||||
req := &api.ChatRequest{
|
|
||||||
Model: model,
|
|
||||||
Messages: []api.Message{
|
|
||||||
{
|
|
||||||
Role: "user",
|
|
||||||
Content: *fOpt.prompt,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Options: options,
|
|
||||||
KeepAlive: keepAliveDuration,
|
|
||||||
}
|
|
||||||
|
|
||||||
if imgData != nil {
|
|
||||||
req.Messages[0].Images = []api.ImageData{imgData}
|
|
||||||
}
|
|
||||||
|
|
||||||
var responseMetrics *api.Metrics
|
|
||||||
|
|
||||||
|
// Warmup phase (uses negative epoch numbers to avoid colliding with timed epochs)
|
||||||
|
for i := range *fOpt.warmup {
|
||||||
|
req := buildGenerateRequest(model, fOpt, imgData, -(i + 1))
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
err = client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
err = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||||
if *fOpt.debug {
|
|
||||||
fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Message.Thinking, resp.Message.Content))
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.Done {
|
|
||||||
responseMetrics = &resp.Metrics
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
cancel()
|
||||||
if *fOpt.debug {
|
|
||||||
fmt.Fprintln(os.Stderr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if ctx.Err() == context.DeadlineExceeded {
|
fmt.Fprintf(os.Stderr, "WARNING: Warmup %d/%d for %s failed: %v\n", i+1, *fOpt.warmup, model, err)
|
||||||
fmt.Fprintf(os.Stderr, "ERROR: Chat request timed out with model '%s' after %vs\n", model, 1)
|
} else if *fOpt.debug {
|
||||||
continue
|
fmt.Fprintf(os.Stderr, "Warmup %d/%d for %s complete\n", i+1, *fOpt.warmup, model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch memory usage once after warmup (model is loaded and stable)
|
||||||
|
memCtx, memCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
info.SizeBytes, info.VRAMBytes = fetchMemoryUsage(memCtx, client, model)
|
||||||
|
memCancel()
|
||||||
|
|
||||||
|
outputModelInfo(out, *fOpt.format, info)
|
||||||
|
|
||||||
|
// Timed epoch loop
|
||||||
|
shortCount := 0
|
||||||
|
for epoch := range *fOpt.epochs {
|
||||||
|
var responseMetrics *api.Metrics
|
||||||
|
var ttft time.Duration
|
||||||
|
short := false
|
||||||
|
|
||||||
|
// Retry loop: if the model hits a stop token before max-tokens,
|
||||||
|
// retry with a different prompt (up to maxRetries times).
|
||||||
|
const maxRetries = 3
|
||||||
|
for attempt := range maxRetries + 1 {
|
||||||
|
responseMetrics = nil
|
||||||
|
ttft = 0
|
||||||
|
var ttftOnce sync.Once
|
||||||
|
|
||||||
|
req := buildGenerateRequest(model, fOpt, imgData, epoch+attempt*1000)
|
||||||
|
requestStart := time.Now()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
|
||||||
|
|
||||||
|
err = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||||
|
if *fOpt.debug {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Thinking, resp.Response))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capture TTFT on first content
|
||||||
|
ttftOnce.Do(func() {
|
||||||
|
if resp.Response != "" || resp.Thinking != "" {
|
||||||
|
ttft = time.Since(requestStart)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if resp.Done {
|
||||||
|
responseMetrics = &resp.Metrics
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if *fOpt.debug {
|
||||||
|
fmt.Fprintln(os.Stderr)
|
||||||
}
|
}
|
||||||
fmt.Fprintf(os.Stderr, "ERROR: Couldn't chat with model '%s': %v\n", model, err)
|
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() == context.DeadlineExceeded {
|
||||||
|
fmt.Fprintf(os.Stderr, "ERROR: Request timed out with model '%s' after %vs\n", model, *fOpt.timeout)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "ERROR: Couldn't generate with model '%s': %v\n", model, err)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if responseMetrics == nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the response was shorter than requested
|
||||||
|
short = *fOpt.maxTokens > 0 && responseMetrics.EvalCount < *fOpt.maxTokens
|
||||||
|
if !short || attempt == maxRetries {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if *fOpt.debug {
|
||||||
|
fmt.Fprintf(os.Stderr, "Short response (%d/%d tokens), retrying with different prompt (attempt %d/%d)\n",
|
||||||
|
responseMetrics.EvalCount, *fOpt.maxTokens, attempt+1, maxRetries)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil || responseMetrics == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if responseMetrics == nil {
|
if short {
|
||||||
fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model)
|
shortCount++
|
||||||
continue
|
if *fOpt.debug {
|
||||||
|
fmt.Fprintf(os.Stderr, "WARNING: Short response (%d/%d tokens) after %d retries for epoch %d\n",
|
||||||
|
responseMetrics.EvalCount, *fOpt.maxTokens, maxRetries, epoch+1)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
metrics := []Metrics{
|
metrics := []Metrics{
|
||||||
@@ -238,6 +389,12 @@ func BenchmarkChat(fOpt flagOptions) error {
|
|||||||
Count: responseMetrics.EvalCount,
|
Count: responseMetrics.EvalCount,
|
||||||
Duration: responseMetrics.EvalDuration,
|
Duration: responseMetrics.EvalDuration,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Model: model,
|
||||||
|
Step: "ttft",
|
||||||
|
Count: 1,
|
||||||
|
Duration: ttft,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Model: model,
|
Model: model,
|
||||||
Step: "load",
|
Step: "load",
|
||||||
@@ -254,15 +411,42 @@ func BenchmarkChat(fOpt flagOptions) error {
|
|||||||
|
|
||||||
OutputMetrics(out, *fOpt.format, metrics, *fOpt.verbose)
|
OutputMetrics(out, *fOpt.format, metrics, *fOpt.verbose)
|
||||||
|
|
||||||
|
if *fOpt.debug && *fOpt.promptTokens > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "Generated prompt targeting ~%d tokens (actual: %d)\n",
|
||||||
|
*fOpt.promptTokens, responseMetrics.PromptEvalCount)
|
||||||
|
}
|
||||||
|
|
||||||
if *fOpt.keepAlive > 0 {
|
if *fOpt.keepAlive > 0 {
|
||||||
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
|
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if shortCount > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "WARNING: %d/%d epochs for '%s' had short responses (<%d tokens). Generation metrics may be unreliable.\n",
|
||||||
|
shortCount, *fOpt.epochs, model, *fOpt.maxTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unload model before moving to the next one
|
||||||
|
unloadModel(client, model, *fOpt.timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func unloadModel(client *api.Client, model string, timeout int) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
zero := api.Duration{Duration: 0}
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: model,
|
||||||
|
KeepAlive: &zero,
|
||||||
|
}
|
||||||
|
_ = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func readImage(filePath string) (api.ImageData, error) {
|
func readImage(filePath string) (api.ImageData, error) {
|
||||||
file, err := os.Open(filePath)
|
file, err := os.Open(filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -280,19 +464,21 @@ func readImage(filePath string) (api.ImageData, error) {
|
|||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
fOpt := flagOptions{
|
fOpt := flagOptions{
|
||||||
models: flag.String("model", "", "Model to benchmark"),
|
models: flag.String("model", "", "Model to benchmark"),
|
||||||
epochs: flag.Int("epochs", 6, "Number of epochs (iterations) per model"),
|
epochs: flag.Int("epochs", 6, "Number of epochs (iterations) per model"),
|
||||||
maxTokens: flag.Int("max-tokens", 200, "Maximum tokens for model response"),
|
maxTokens: flag.Int("max-tokens", 200, "Maximum tokens for model response"),
|
||||||
temperature: flag.Float64("temperature", 0, "Temperature parameter"),
|
temperature: flag.Float64("temperature", 0, "Temperature parameter"),
|
||||||
seed: flag.Int("seed", 0, "Random seed"),
|
seed: flag.Int("seed", 0, "Random seed"),
|
||||||
timeout: flag.Int("timeout", 60*5, "Timeout in seconds (default 300s)"),
|
timeout: flag.Int("timeout", 60*5, "Timeout in seconds (default 300s)"),
|
||||||
prompt: flag.String("p", DefaultPrompt, "Prompt to use"),
|
prompt: flag.String("p", DefaultPrompt, "Prompt to use"),
|
||||||
imageFile: flag.String("image", "", "Filename for an image to include"),
|
imageFile: flag.String("image", "", "Filename for an image to include"),
|
||||||
keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"),
|
keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"),
|
||||||
format: flag.String("format", "markdown", "Output format [benchstat|csv] (default benchstat)"),
|
format: flag.String("format", "benchstat", "Output format [benchstat|csv]"),
|
||||||
outputFile: flag.String("output", "", "Output file for results (stdout if empty)"),
|
outputFile: flag.String("output", "", "Output file for results (stdout if empty)"),
|
||||||
verbose: flag.Bool("v", false, "Show system information"),
|
verbose: flag.Bool("v", false, "Show system information"),
|
||||||
debug: flag.Bool("debug", false, "Show debug information"),
|
debug: flag.Bool("debug", false, "Show debug information"),
|
||||||
|
warmup: flag.Int("warmup", 1, "Number of warmup requests before timing"),
|
||||||
|
promptTokens: flag.Int("prompt-tokens", 0, "Generate prompt targeting ~N tokens (0 = use -p prompt)"),
|
||||||
}
|
}
|
||||||
|
|
||||||
flag.Usage = func() {
|
flag.Usage = func() {
|
||||||
@@ -302,11 +488,12 @@ func main() {
|
|||||||
fmt.Fprintf(os.Stderr, "Options:\n")
|
fmt.Fprintf(os.Stderr, "Options:\n")
|
||||||
flag.PrintDefaults()
|
flag.PrintDefaults()
|
||||||
fmt.Fprintf(os.Stderr, "\nExamples:\n")
|
fmt.Fprintf(os.Stderr, "\nExamples:\n")
|
||||||
fmt.Fprintf(os.Stderr, " bench -model gpt-oss:20b -epochs 3 -temperature 0.7\n")
|
fmt.Fprintf(os.Stderr, " bench -model gemma3,llama3 -epochs 6\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " bench -model gemma3 -epochs 6 -prompt-tokens 512 -format csv\n")
|
||||||
}
|
}
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
if !slices.Contains([]string{"markdown", "benchstat", "csv"}, *fOpt.format) {
|
if !slices.Contains([]string{"benchstat", "csv"}, *fOpt.format) {
|
||||||
fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format)
|
fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
@@ -317,5 +504,5 @@ func main() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
BenchmarkChat(fOpt)
|
BenchmarkModel(fOpt)
|
||||||
}
|
}
|
||||||
|
|||||||
415
cmd/cmd.go
@@ -11,6 +11,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -38,9 +39,12 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/cmd/config"
|
"github.com/ollama/ollama/cmd/config"
|
||||||
|
"github.com/ollama/ollama/cmd/launch"
|
||||||
"github.com/ollama/ollama/cmd/tui"
|
"github.com/ollama/ollama/cmd/tui"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
|
"github.com/ollama/ollama/internal/modelref"
|
||||||
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/progress"
|
"github.com/ollama/ollama/progress"
|
||||||
"github.com/ollama/ollama/readline"
|
"github.com/ollama/ollama/readline"
|
||||||
@@ -57,36 +61,42 @@ import (
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
|
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
|
||||||
config.DefaultSingleSelector = func(title string, items []config.ModelItem, current string) (string, error) {
|
launch.DefaultSingleSelector = func(title string, items []launch.ModelItem, current string) (string, error) {
|
||||||
|
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
|
||||||
|
return "", fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode")
|
||||||
|
}
|
||||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||||
result, err := tui.SelectSingle(title, tuiItems, current)
|
result, err := tui.SelectSingle(title, tuiItems, current)
|
||||||
if errors.Is(err, tui.ErrCancelled) {
|
if errors.Is(err, tui.ErrCancelled) {
|
||||||
return "", config.ErrCancelled
|
return "", launch.ErrCancelled
|
||||||
}
|
}
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
config.DefaultMultiSelector = func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
|
launch.DefaultMultiSelector = func(title string, items []launch.ModelItem, preChecked []string) ([]string, error) {
|
||||||
|
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
|
||||||
|
return nil, fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode")
|
||||||
|
}
|
||||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||||
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
|
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
|
||||||
if errors.Is(err, tui.ErrCancelled) {
|
if errors.Is(err, tui.ErrCancelled) {
|
||||||
return nil, config.ErrCancelled
|
return nil, launch.ErrCancelled
|
||||||
}
|
}
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
config.DefaultSignIn = func(modelName, signInURL string) (string, error) {
|
launch.DefaultSignIn = func(modelName, signInURL string) (string, error) {
|
||||||
userName, err := tui.RunSignIn(modelName, signInURL)
|
userName, err := tui.RunSignIn(modelName, signInURL)
|
||||||
if errors.Is(err, tui.ErrCancelled) {
|
if errors.Is(err, tui.ErrCancelled) {
|
||||||
return "", config.ErrCancelled
|
return "", launch.ErrCancelled
|
||||||
}
|
}
|
||||||
return userName, err
|
return userName, err
|
||||||
}
|
}
|
||||||
|
|
||||||
config.DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
launch.DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||||
ok, err := tui.RunConfirm(prompt)
|
ok, err := tui.RunConfirm(prompt)
|
||||||
if errors.Is(err, tui.ErrCancelled) {
|
if errors.Is(err, tui.ErrCancelled) {
|
||||||
return false, config.ErrCancelled
|
return false, launch.ErrCancelled
|
||||||
}
|
}
|
||||||
return ok, err
|
return ok, err
|
||||||
}
|
}
|
||||||
@@ -131,6 +141,17 @@ func getModelfileName(cmd *cobra.Command) (string, error) {
|
|||||||
return absName, nil
|
return absName, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isLocalhost returns true if the configured Ollama host is a loopback or unspecified address.
|
||||||
|
func isLocalhost() bool {
|
||||||
|
host := envconfig.Host()
|
||||||
|
h, _, _ := net.SplitHostPort(host.Host)
|
||||||
|
if h == "localhost" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
ip := net.ParseIP(h)
|
||||||
|
return ip != nil && (ip.IsLoopback() || ip.IsUnspecified())
|
||||||
|
}
|
||||||
|
|
||||||
func CreateHandler(cmd *cobra.Command, args []string) error {
|
func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||||
p := progress.NewProgress(os.Stderr)
|
p := progress.NewProgress(os.Stderr)
|
||||||
defer p.Stop()
|
defer p.Stop()
|
||||||
@@ -145,6 +166,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
// Check for --experimental flag for safetensors model creation
|
// Check for --experimental flag for safetensors model creation
|
||||||
experimental, _ := cmd.Flags().GetBool("experimental")
|
experimental, _ := cmd.Flags().GetBool("experimental")
|
||||||
if experimental {
|
if experimental {
|
||||||
|
if !isLocalhost() {
|
||||||
|
return errors.New("remote safetensor model creation not yet supported")
|
||||||
|
}
|
||||||
// Get Modelfile content - either from -f flag or default to "FROM ."
|
// Get Modelfile content - either from -f flag or default to "FROM ."
|
||||||
var reader io.Reader
|
var reader io.Reader
|
||||||
filename, err := getModelfileName(cmd)
|
filename, err := getModelfileName(cmd)
|
||||||
@@ -168,29 +192,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return fmt.Errorf("failed to parse Modelfile: %w", err)
|
return fmt.Errorf("failed to parse Modelfile: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract FROM path and configuration
|
modelDir, mfConfig, err := xcreateclient.ConfigFromModelfile(modelfile)
|
||||||
var modelDir string
|
if err != nil {
|
||||||
mfConfig := &xcreateclient.ModelfileConfig{}
|
return err
|
||||||
|
|
||||||
for _, cmd := range modelfile.Commands {
|
|
||||||
switch cmd.Name {
|
|
||||||
case "model":
|
|
||||||
modelDir = cmd.Args
|
|
||||||
case "template":
|
|
||||||
mfConfig.Template = cmd.Args
|
|
||||||
case "system":
|
|
||||||
mfConfig.System = cmd.Args
|
|
||||||
case "license":
|
|
||||||
mfConfig.License = cmd.Args
|
|
||||||
case "parser":
|
|
||||||
mfConfig.Parser = cmd.Args
|
|
||||||
case "renderer":
|
|
||||||
mfConfig.Renderer = cmd.Args
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if modelDir == "" {
|
|
||||||
modelDir = "."
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve relative paths based on Modelfile location
|
// Resolve relative paths based on Modelfile location
|
||||||
@@ -214,6 +218,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
if filename == "" {
|
if filename == "" {
|
||||||
// No Modelfile found - check if current directory is an image gen model
|
// No Modelfile found - check if current directory is an image gen model
|
||||||
if create.IsTensorModelDir(".") {
|
if create.IsTensorModelDir(".") {
|
||||||
|
if !isLocalhost() {
|
||||||
|
return errors.New("remote safetensor model creation not yet supported")
|
||||||
|
}
|
||||||
quantize, _ := cmd.Flags().GetString("quantize")
|
quantize, _ := cmd.Flags().GetString("quantize")
|
||||||
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
||||||
ModelName: modelName,
|
ModelName: modelName,
|
||||||
@@ -406,12 +413,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
requestedCloud := modelref.HasExplicitCloudSource(opts.Model)
|
||||||
|
|
||||||
if info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model}); err != nil {
|
if info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model}); err != nil {
|
||||||
return err
|
return err
|
||||||
} else if info.RemoteHost != "" {
|
} else if info.RemoteHost != "" || requestedCloud {
|
||||||
// Cloud model, no need to load/unload
|
// Cloud model, no need to load/unload
|
||||||
|
|
||||||
isCloud := strings.HasPrefix(info.RemoteHost, "https://ollama.com")
|
isCloud := requestedCloud || strings.HasPrefix(info.RemoteHost, "https://ollama.com")
|
||||||
|
|
||||||
// Check if user is signed in for ollama.com cloud models
|
// Check if user is signed in for ollama.com cloud models
|
||||||
if isCloud {
|
if isCloud {
|
||||||
@@ -422,10 +431,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
|||||||
|
|
||||||
if opts.ShowConnect {
|
if opts.ShowConnect {
|
||||||
p.StopAndClear()
|
p.StopAndClear()
|
||||||
|
remoteModel := info.RemoteModel
|
||||||
|
if remoteModel == "" {
|
||||||
|
remoteModel = opts.Model
|
||||||
|
}
|
||||||
if isCloud {
|
if isCloud {
|
||||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
|
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", remoteModel)
|
||||||
} else {
|
} else {
|
||||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
|
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", remoteModel, info.RemoteHost)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -497,6 +510,64 @@ func generateEmbedding(cmd *cobra.Command, modelName, input string, keepAlive *a
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(parthsareen): consolidate with TUI signin flow
|
||||||
|
func handleCloudAuthorizationError(err error) bool {
|
||||||
|
var authErr api.AuthorizationError
|
||||||
|
if errors.As(err, &authErr) && authErr.StatusCode == http.StatusUnauthorized {
|
||||||
|
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
|
||||||
|
if authErr.SigninURL != "" {
|
||||||
|
fmt.Printf(ConnectInstructions, authErr.SigninURL)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TEMP(drifkin): To match legacy `ollama run some-model:cloud` behavior, we
|
||||||
|
// best-effort pull cloud stub files for any explicit cloud source models.
|
||||||
|
// Remove this once `/api/tags` is cloud-aware.
|
||||||
|
func ensureCloudStub(ctx context.Context, client *api.Client, modelName string) {
|
||||||
|
if !modelref.HasExplicitCloudSource(modelName) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizedName, _, err := modelref.NormalizePullName(modelName)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to normalize pull name", "model", modelName, "error", err, "normalizedName", normalizedName)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
listResp, err := client.List(ctx)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to list models", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasListedModelName(listResp.Models, modelName) || hasListedModelName(listResp.Models, normalizedName) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logutil.Trace("pulling cloud stub", "model", modelName, "normalizedName", normalizedName)
|
||||||
|
err = client.Pull(ctx, &api.PullRequest{
|
||||||
|
Model: normalizedName,
|
||||||
|
}, func(api.ProgressResponse) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to pull cloud stub", "model", modelName, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasListedModelName(models []api.ListModelResponse, name string) bool {
|
||||||
|
for _, m := range models {
|
||||||
|
if strings.EqualFold(m.Name, name) || strings.EqualFold(m.Model, name) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||||
interactive := true
|
interactive := true
|
||||||
|
|
||||||
@@ -585,17 +656,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
opts.WordWrap = !nowrap
|
opts.WordWrap = !nowrap
|
||||||
|
|
||||||
useImagegen := false
|
|
||||||
if cmd.Flags().Lookup("imagegen") != nil {
|
|
||||||
useImagegen, err = cmd.Flags().GetBool("imagegen")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if useImagegen {
|
|
||||||
opts.Options["use_imagegen_runner"] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fill out the rest of the options based on information about the
|
// Fill out the rest of the options based on information about the
|
||||||
// model.
|
// model.
|
||||||
client, err := api.ClientFromEnvironment()
|
client, err := api.ClientFromEnvironment()
|
||||||
@@ -604,12 +664,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
name := args[0]
|
name := args[0]
|
||||||
|
requestedCloud := modelref.HasExplicitCloudSource(name)
|
||||||
|
|
||||||
info, err := func() (*api.ShowResponse, error) {
|
info, err := func() (*api.ShowResponse, error) {
|
||||||
showReq := &api.ShowRequest{Name: name}
|
showReq := &api.ShowRequest{Name: name}
|
||||||
info, err := client.Show(cmd.Context(), showReq)
|
info, err := client.Show(cmd.Context(), showReq)
|
||||||
var se api.StatusError
|
var se api.StatusError
|
||||||
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
|
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
|
||||||
|
if requestedCloud {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if err := PullHandler(cmd, []string{name}); err != nil {
|
if err := PullHandler(cmd, []string{name}); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -618,9 +682,14 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return info, err
|
return info, err
|
||||||
}()
|
}()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if handleCloudAuthorizationError(err) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ensureCloudStub(cmd.Context(), client, name)
|
||||||
|
|
||||||
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkFlag.Changed)
|
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkFlag.Changed)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -712,7 +781,13 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
return generateInteractive(cmd, opts)
|
return generateInteractive(cmd, opts)
|
||||||
}
|
}
|
||||||
return generate(cmd, opts)
|
if err := generate(cmd, opts); err != nil {
|
||||||
|
if handleCloudAuthorizationError(err) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SigninHandler(cmd *cobra.Command, args []string) error {
|
func SigninHandler(cmd *cobra.Command, args []string) error {
|
||||||
@@ -1892,6 +1967,24 @@ func ensureServerRunning(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func launchInteractiveModel(cmd *cobra.Command, modelName string) error {
|
||||||
|
opts := runOptions{
|
||||||
|
Model: modelName,
|
||||||
|
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||||
|
Options: map[string]any{},
|
||||||
|
ShowConnect: true,
|
||||||
|
}
|
||||||
|
// loadOrUnloadModel is cloud-safe here: remote/cloud models skip local preload
|
||||||
|
// and only validate auth/connectivity before interactive chat starts.
|
||||||
|
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||||
|
return fmt.Errorf("error loading model: %w", err)
|
||||||
|
}
|
||||||
|
if err := generateInteractive(cmd, opts); err != nil {
|
||||||
|
return fmt.Errorf("error running model: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// runInteractiveTUI runs the main interactive TUI menu.
|
// runInteractiveTUI runs the main interactive TUI menu.
|
||||||
func runInteractiveTUI(cmd *cobra.Command) {
|
func runInteractiveTUI(cmd *cobra.Command) {
|
||||||
// Ensure the server is running before showing the TUI
|
// Ensure the server is running before showing the TUI
|
||||||
@@ -1900,175 +1993,85 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Selector adapters for tui
|
deps := launcherDeps{
|
||||||
singleSelector := func(title string, items []config.ModelItem, current string) (string, error) {
|
buildState: launch.BuildLauncherState,
|
||||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
runMenu: tui.RunMenu,
|
||||||
result, err := tui.SelectSingle(title, tuiItems, current)
|
resolveRunModel: launch.ResolveRunModel,
|
||||||
if errors.Is(err, tui.ErrCancelled) {
|
launchIntegration: launch.LaunchIntegration,
|
||||||
return "", config.ErrCancelled
|
runModel: launchInteractiveModel,
|
||||||
}
|
|
||||||
return result, err
|
|
||||||
}
|
|
||||||
|
|
||||||
multiSelector := func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
|
|
||||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
|
||||||
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
|
|
||||||
if errors.Is(err, tui.ErrCancelled) {
|
|
||||||
return nil, config.ErrCancelled
|
|
||||||
}
|
|
||||||
return result, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
result, err := tui.Run()
|
continueLoop, err := runInteractiveTUIStep(cmd, deps)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||||
|
}
|
||||||
|
if !continueLoop {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
runModel := func(modelName string) {
|
type launcherDeps struct {
|
||||||
client, err := api.ClientFromEnvironment()
|
buildState func(context.Context) (*launch.LauncherState, error)
|
||||||
if err != nil {
|
runMenu func(*launch.LauncherState) (tui.TUIAction, error)
|
||||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
resolveRunModel func(context.Context, launch.RunModelRequest) (string, error)
|
||||||
return
|
launchIntegration func(context.Context, launch.IntegrationLaunchRequest) error
|
||||||
}
|
runModel func(*cobra.Command, string) error
|
||||||
if err := config.ShowOrPull(cmd.Context(), client, modelName); err != nil {
|
}
|
||||||
if errors.Is(err, config.ErrCancelled) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_ = config.SetLastModel(modelName)
|
|
||||||
opts := runOptions{
|
|
||||||
Model: modelName,
|
|
||||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
|
||||||
Options: map[string]any{},
|
|
||||||
ShowConnect: true,
|
|
||||||
}
|
|
||||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error loading model: %v\n", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := generateInteractive(cmd, opts); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error running model: %v\n", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
launchIntegration := func(name string) bool {
|
func runInteractiveTUIStep(cmd *cobra.Command, deps launcherDeps) (bool, error) {
|
||||||
if err := config.EnsureInstalled(name); err != nil {
|
state, err := deps.buildState(cmd.Context())
|
||||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
if err != nil {
|
||||||
return true
|
return false, fmt.Errorf("build launcher state: %w", err)
|
||||||
}
|
}
|
||||||
// If not configured or model no longer exists, prompt for model selection
|
|
||||||
configuredModel := config.IntegrationModel(name)
|
|
||||||
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) || config.IsCloudModelDisabled(cmd.Context(), configuredModel) {
|
|
||||||
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), name, singleSelector, multiSelector)
|
|
||||||
if errors.Is(err, config.ErrCancelled) {
|
|
||||||
return false // Return to main menu
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", name, err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := config.LaunchIntegration(name); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", name, err)
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
switch result.Selection {
|
action, err := deps.runMenu(state)
|
||||||
case tui.SelectionNone:
|
if err != nil {
|
||||||
// User quit
|
return false, fmt.Errorf("run launcher menu: %w", err)
|
||||||
return
|
}
|
||||||
case tui.SelectionRunModel:
|
|
||||||
_ = config.SetLastSelection("run")
|
return runLauncherAction(cmd, action, deps)
|
||||||
if modelName := config.LastModel(); modelName != "" && !config.IsCloudModelDisabled(cmd.Context(), modelName) {
|
}
|
||||||
runModel(modelName)
|
|
||||||
} else {
|
func saveLauncherSelection(action tui.TUIAction) {
|
||||||
modelName, err := config.SelectModelWithSelector(cmd.Context(), singleSelector)
|
// Best effort only: this affects menu recall, not launch correctness.
|
||||||
if errors.Is(err, config.ErrCancelled) {
|
_ = config.SetLastSelection(action.LastSelection())
|
||||||
continue // Return to main menu
|
}
|
||||||
}
|
|
||||||
if err != nil {
|
func runLauncherAction(cmd *cobra.Command, action tui.TUIAction, deps launcherDeps) (bool, error) {
|
||||||
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
|
switch action.Kind {
|
||||||
continue
|
case tui.TUIActionNone:
|
||||||
}
|
return false, nil
|
||||||
runModel(modelName)
|
case tui.TUIActionRunModel:
|
||||||
}
|
saveLauncherSelection(action)
|
||||||
case tui.SelectionChangeRunModel:
|
modelName, err := deps.resolveRunModel(cmd.Context(), action.RunModelRequest())
|
||||||
_ = config.SetLastSelection("run")
|
if errors.Is(err, launch.ErrCancelled) {
|
||||||
// Use model from modal if selected, otherwise show picker
|
return true, nil
|
||||||
modelName := result.Model
|
|
||||||
if modelName == "" {
|
|
||||||
var err error
|
|
||||||
modelName, err = config.SelectModelWithSelector(cmd.Context(), singleSelector)
|
|
||||||
if errors.Is(err, config.ErrCancelled) {
|
|
||||||
continue // Return to main menu
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if config.IsCloudModelDisabled(cmd.Context(), modelName) {
|
|
||||||
continue // Return to main menu
|
|
||||||
}
|
|
||||||
runModel(modelName)
|
|
||||||
case tui.SelectionIntegration:
|
|
||||||
_ = config.SetLastSelection(result.Integration)
|
|
||||||
if !launchIntegration(result.Integration) {
|
|
||||||
continue // Return to main menu
|
|
||||||
}
|
|
||||||
case tui.SelectionChangeIntegration:
|
|
||||||
_ = config.SetLastSelection(result.Integration)
|
|
||||||
if len(result.Models) > 0 {
|
|
||||||
// Filter out cloud-disabled models
|
|
||||||
var filtered []string
|
|
||||||
for _, m := range result.Models {
|
|
||||||
if !config.IsCloudModelDisabled(cmd.Context(), m) {
|
|
||||||
filtered = append(filtered, m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(filtered) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
result.Models = filtered
|
|
||||||
// Multi-select from modal (Editor integrations)
|
|
||||||
if err := config.SaveAndEditIntegration(result.Integration, result.Models); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := config.LaunchIntegrationWithModel(result.Integration, result.Models[0]); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
|
||||||
}
|
|
||||||
} else if result.Model != "" {
|
|
||||||
if config.IsCloudModelDisabled(cmd.Context(), result.Model) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Single-select from modal - save and launch
|
|
||||||
if err := config.SaveIntegration(result.Integration, []string{result.Model}); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error saving config: %v\n", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := config.LaunchIntegrationWithModel(result.Integration, result.Model); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), result.Integration, singleSelector, multiSelector)
|
|
||||||
if errors.Is(err, config.ErrCancelled) {
|
|
||||||
continue // Return to main menu
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := config.LaunchIntegration(result.Integration); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
|
return true, fmt.Errorf("selecting model: %w", err)
|
||||||
|
}
|
||||||
|
if err := deps.runModel(cmd, modelName); err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
case tui.TUIActionLaunchIntegration:
|
||||||
|
saveLauncherSelection(action)
|
||||||
|
err := deps.launchIntegration(cmd.Context(), action.IntegrationLaunchRequest())
|
||||||
|
if errors.Is(err, launch.ErrCancelled) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return true, fmt.Errorf("launching %s: %w", action.Integration, err)
|
||||||
|
}
|
||||||
|
// VS Code is a GUI app — exit the TUI loop after launching
|
||||||
|
if action.Integration == "vscode" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
default:
|
||||||
|
return false, fmt.Errorf("unknown launcher action: %d", action.Kind)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2338,7 +2341,7 @@ func NewCLI() *cobra.Command {
|
|||||||
copyCmd,
|
copyCmd,
|
||||||
deleteCmd,
|
deleteCmd,
|
||||||
runnerCmd,
|
runnerCmd,
|
||||||
config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
|
launch.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
|
||||||
)
|
)
|
||||||
|
|
||||||
return rootCmd
|
return rootCmd
|
||||||
|
|||||||
270
cmd/cmd_launcher_test.go
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/cmd/config"
|
||||||
|
"github.com/ollama/ollama/cmd/launch"
|
||||||
|
"github.com/ollama/ollama/cmd/tui"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setCmdTestHome(t *testing.T, dir string) {
|
||||||
|
t.Helper()
|
||||||
|
t.Setenv("HOME", dir)
|
||||||
|
t.Setenv("USERPROFILE", dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
func unexpectedRunModelResolution(t *testing.T) func(context.Context, launch.RunModelRequest) (string, error) {
|
||||||
|
t.Helper()
|
||||||
|
return func(ctx context.Context, req launch.RunModelRequest) (string, error) {
|
||||||
|
t.Fatalf("did not expect run-model resolution: %+v", req)
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func unexpectedIntegrationLaunch(t *testing.T) func(context.Context, launch.IntegrationLaunchRequest) error {
|
||||||
|
t.Helper()
|
||||||
|
return func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||||
|
t.Fatalf("did not expect integration launch: %+v", req)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func unexpectedModelLaunch(t *testing.T) func(*cobra.Command, string) error {
|
||||||
|
t.Helper()
|
||||||
|
return func(cmd *cobra.Command, model string) error {
|
||||||
|
t.Fatalf("did not expect chat launch: %s", model)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunInteractiveTUI_RunModelActionsUseResolveRunModel(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
action tui.TUIAction
|
||||||
|
wantForce bool
|
||||||
|
wantModel string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "enter uses saved model flow",
|
||||||
|
action: tui.TUIAction{Kind: tui.TUIActionRunModel},
|
||||||
|
wantModel: "qwen3:8b",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "right forces picker",
|
||||||
|
action: tui.TUIAction{Kind: tui.TUIActionRunModel, ForceConfigure: true},
|
||||||
|
wantForce: true,
|
||||||
|
wantModel: "glm-5:cloud",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
setCmdTestHome(t, t.TempDir())
|
||||||
|
|
||||||
|
var menuCalls int
|
||||||
|
runMenu := func(state *launch.LauncherState) (tui.TUIAction, error) {
|
||||||
|
menuCalls++
|
||||||
|
if menuCalls == 1 {
|
||||||
|
return tt.action, nil
|
||||||
|
}
|
||||||
|
return tui.TUIAction{Kind: tui.TUIActionNone}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var gotReq launch.RunModelRequest
|
||||||
|
var launched string
|
||||||
|
deps := launcherDeps{
|
||||||
|
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
|
||||||
|
return &launch.LauncherState{}, nil
|
||||||
|
},
|
||||||
|
runMenu: runMenu,
|
||||||
|
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
|
||||||
|
gotReq = req
|
||||||
|
return tt.wantModel, nil
|
||||||
|
},
|
||||||
|
launchIntegration: unexpectedIntegrationLaunch(t),
|
||||||
|
runModel: func(cmd *cobra.Command, model string) error {
|
||||||
|
launched = model
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(context.Background())
|
||||||
|
for {
|
||||||
|
continueLoop, err := runInteractiveTUIStep(cmd, deps)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected step error: %v", err)
|
||||||
|
}
|
||||||
|
if !continueLoop {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotReq.ForcePicker != tt.wantForce {
|
||||||
|
t.Fatalf("expected ForcePicker=%v, got %v", tt.wantForce, gotReq.ForcePicker)
|
||||||
|
}
|
||||||
|
if launched != tt.wantModel {
|
||||||
|
t.Fatalf("expected interactive launcher to run %q, got %q", tt.wantModel, launched)
|
||||||
|
}
|
||||||
|
if got := config.LastSelection(); got != "run" {
|
||||||
|
t.Fatalf("expected last selection to be run, got %q", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunInteractiveTUI_IntegrationActionsUseLaunchIntegration(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
action tui.TUIAction
|
||||||
|
wantForce bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "enter launches integration",
|
||||||
|
action: tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "right forces configure",
|
||||||
|
action: tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude", ForceConfigure: true},
|
||||||
|
wantForce: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
setCmdTestHome(t, t.TempDir())
|
||||||
|
|
||||||
|
var menuCalls int
|
||||||
|
runMenu := func(state *launch.LauncherState) (tui.TUIAction, error) {
|
||||||
|
menuCalls++
|
||||||
|
if menuCalls == 1 {
|
||||||
|
return tt.action, nil
|
||||||
|
}
|
||||||
|
return tui.TUIAction{Kind: tui.TUIActionNone}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var gotReq launch.IntegrationLaunchRequest
|
||||||
|
deps := launcherDeps{
|
||||||
|
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
|
||||||
|
return &launch.LauncherState{}, nil
|
||||||
|
},
|
||||||
|
runMenu: runMenu,
|
||||||
|
resolveRunModel: unexpectedRunModelResolution(t),
|
||||||
|
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||||
|
gotReq = req
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
runModel: unexpectedModelLaunch(t),
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(context.Background())
|
||||||
|
for {
|
||||||
|
continueLoop, err := runInteractiveTUIStep(cmd, deps)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected step error: %v", err)
|
||||||
|
}
|
||||||
|
if !continueLoop {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotReq.Name != "claude" {
|
||||||
|
t.Fatalf("expected integration name to be passed through, got %q", gotReq.Name)
|
||||||
|
}
|
||||||
|
if gotReq.ForceConfigure != tt.wantForce {
|
||||||
|
t.Fatalf("expected ForceConfigure=%v, got %v", tt.wantForce, gotReq.ForceConfigure)
|
||||||
|
}
|
||||||
|
if got := config.LastSelection(); got != "claude" {
|
||||||
|
t.Fatalf("expected last selection to be claude, got %q", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunLauncherAction_RunModelContinuesAfterCancellation(t *testing.T) {
|
||||||
|
setCmdTestHome(t, t.TempDir())
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(context.Background())
|
||||||
|
|
||||||
|
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionRunModel}, launcherDeps{
|
||||||
|
buildState: nil,
|
||||||
|
runMenu: nil,
|
||||||
|
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
|
||||||
|
return "", launch.ErrCancelled
|
||||||
|
},
|
||||||
|
launchIntegration: unexpectedIntegrationLaunch(t),
|
||||||
|
runModel: unexpectedModelLaunch(t),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected nil error on cancellation, got %v", err)
|
||||||
|
}
|
||||||
|
if !continueLoop {
|
||||||
|
t.Fatal("expected cancellation to continue the menu loop")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunLauncherAction_VSCodeExitsTUILoop(t *testing.T) {
|
||||||
|
setCmdTestHome(t, t.TempDir())
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(context.Background())
|
||||||
|
|
||||||
|
// VS Code should exit the TUI loop (return false) after a successful launch.
|
||||||
|
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "vscode"}, launcherDeps{
|
||||||
|
resolveRunModel: unexpectedRunModelResolution(t),
|
||||||
|
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
runModel: unexpectedModelLaunch(t),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected nil error, got %v", err)
|
||||||
|
}
|
||||||
|
if continueLoop {
|
||||||
|
t.Fatal("expected vscode launch to exit the TUI loop (return false)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Other integrations should continue the TUI loop (return true).
|
||||||
|
continueLoop, err = runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"}, launcherDeps{
|
||||||
|
resolveRunModel: unexpectedRunModelResolution(t),
|
||||||
|
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
runModel: unexpectedModelLaunch(t),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected nil error, got %v", err)
|
||||||
|
}
|
||||||
|
if !continueLoop {
|
||||||
|
t.Fatal("expected non-vscode integration to continue the TUI loop (return true)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunLauncherAction_IntegrationContinuesAfterCancellation(t *testing.T) {
|
||||||
|
setCmdTestHome(t, t.TempDir())
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(context.Background())
|
||||||
|
|
||||||
|
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"}, launcherDeps{
|
||||||
|
buildState: nil,
|
||||||
|
runMenu: nil,
|
||||||
|
resolveRunModel: unexpectedRunModelResolution(t),
|
||||||
|
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||||
|
return launch.ErrCancelled
|
||||||
|
},
|
||||||
|
runModel: unexpectedModelLaunch(t),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected nil error on cancellation, got %v", err)
|
||||||
|
}
|
||||||
|
if !continueLoop {
|
||||||
|
t.Fatal("expected cancellation to continue the menu loop")
|
||||||
|
}
|
||||||
|
}
|
||||||
484
cmd/cmd_test.go
@@ -705,6 +705,347 @@ func TestRunEmbeddingModelNoInput(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRunHandler_CloudAuthErrorOnShow_PrintsSigninMessage(t *testing.T) {
|
||||||
|
var generateCalled bool
|
||||||
|
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
if err := json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"error": "unauthorized",
|
||||||
|
"signin_url": "https://ollama.com/signin",
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||||
|
generateCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
t.Cleanup(mockServer.Close)
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(t.Context())
|
||||||
|
cmd.Flags().String("keepalive", "", "")
|
||||||
|
cmd.Flags().Bool("truncate", false, "")
|
||||||
|
cmd.Flags().Int("dimensions", 0, "")
|
||||||
|
cmd.Flags().Bool("verbose", false, "")
|
||||||
|
cmd.Flags().Bool("insecure", false, "")
|
||||||
|
cmd.Flags().Bool("nowordwrap", false, "")
|
||||||
|
cmd.Flags().String("format", "", "")
|
||||||
|
cmd.Flags().String("think", "", "")
|
||||||
|
cmd.Flags().Bool("hidethinking", false, "")
|
||||||
|
|
||||||
|
oldStdout := os.Stdout
|
||||||
|
readOut, writeOut, _ := os.Pipe()
|
||||||
|
os.Stdout = writeOut
|
||||||
|
t.Cleanup(func() { os.Stdout = oldStdout })
|
||||||
|
|
||||||
|
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||||
|
|
||||||
|
_ = writeOut.Close()
|
||||||
|
var out bytes.Buffer
|
||||||
|
_, _ = io.Copy(&out, readOut)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RunHandler returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if generateCalled {
|
||||||
|
t.Fatal("expected run to stop before /api/generate after unauthorized /api/show")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
|
||||||
|
t.Fatalf("expected sign-in guidance message, got %q", out.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(out.String(), "https://ollama.com/signin") {
|
||||||
|
t.Fatalf("expected signin_url in output, got %q", out.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunHandler_CloudAuthErrorOnGenerate_PrintsSigninMessage(t *testing.T) {
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||||
|
Capabilities: []model.Capability{model.CapabilityCompletion},
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
if err := json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"error": "unauthorized",
|
||||||
|
"signin_url": "https://ollama.com/signin",
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
t.Cleanup(mockServer.Close)
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(t.Context())
|
||||||
|
cmd.Flags().String("keepalive", "", "")
|
||||||
|
cmd.Flags().Bool("truncate", false, "")
|
||||||
|
cmd.Flags().Int("dimensions", 0, "")
|
||||||
|
cmd.Flags().Bool("verbose", false, "")
|
||||||
|
cmd.Flags().Bool("insecure", false, "")
|
||||||
|
cmd.Flags().Bool("nowordwrap", false, "")
|
||||||
|
cmd.Flags().String("format", "", "")
|
||||||
|
cmd.Flags().String("think", "", "")
|
||||||
|
cmd.Flags().Bool("hidethinking", false, "")
|
||||||
|
|
||||||
|
oldStdout := os.Stdout
|
||||||
|
readOut, writeOut, _ := os.Pipe()
|
||||||
|
os.Stdout = writeOut
|
||||||
|
t.Cleanup(func() { os.Stdout = oldStdout })
|
||||||
|
|
||||||
|
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||||
|
|
||||||
|
_ = writeOut.Close()
|
||||||
|
var out bytes.Buffer
|
||||||
|
_, _ = io.Copy(&out, readOut)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RunHandler returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
|
||||||
|
t.Fatalf("expected sign-in guidance message, got %q", out.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(out.String(), "https://ollama.com/signin") {
|
||||||
|
t.Fatalf("expected signin_url in output, got %q", out.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunHandler_ExplicitCloudStubMissing_PullsNormalizedNameTEMP(t *testing.T) {
|
||||||
|
var pulledModel string
|
||||||
|
var generateCalled bool
|
||||||
|
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||||
|
Capabilities: []model.Capability{model.CapabilityCompletion},
|
||||||
|
RemoteModel: "gpt-oss:20b",
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/tags" && r.Method == http.MethodGet:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ListResponse{Models: nil}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/pull" && r.Method == http.MethodPost:
|
||||||
|
var req api.PullRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pulledModel = req.Model
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ProgressResponse{Status: "success"}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||||
|
generateCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
t.Cleanup(mockServer.Close)
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(t.Context())
|
||||||
|
cmd.Flags().String("keepalive", "", "")
|
||||||
|
cmd.Flags().Bool("truncate", false, "")
|
||||||
|
cmd.Flags().Int("dimensions", 0, "")
|
||||||
|
cmd.Flags().Bool("verbose", false, "")
|
||||||
|
cmd.Flags().Bool("insecure", false, "")
|
||||||
|
cmd.Flags().Bool("nowordwrap", false, "")
|
||||||
|
cmd.Flags().String("format", "", "")
|
||||||
|
cmd.Flags().String("think", "", "")
|
||||||
|
cmd.Flags().Bool("hidethinking", false, "")
|
||||||
|
|
||||||
|
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RunHandler returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pulledModel != "gpt-oss:20b-cloud" {
|
||||||
|
t.Fatalf("expected normalized pull model %q, got %q", "gpt-oss:20b-cloud", pulledModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !generateCalled {
|
||||||
|
t.Fatal("expected /api/generate to be called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunHandler_ExplicitCloudStubPresent_SkipsPullTEMP(t *testing.T) {
|
||||||
|
var pullCalled bool
|
||||||
|
var generateCalled bool
|
||||||
|
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||||
|
Capabilities: []model.Capability{model.CapabilityCompletion},
|
||||||
|
RemoteModel: "gpt-oss:20b",
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/tags" && r.Method == http.MethodGet:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ListResponse{
|
||||||
|
Models: []api.ListModelResponse{{Name: "gpt-oss:20b-cloud"}},
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/pull" && r.Method == http.MethodPost:
|
||||||
|
pullCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ProgressResponse{Status: "success"}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||||
|
generateCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
t.Cleanup(mockServer.Close)
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(t.Context())
|
||||||
|
cmd.Flags().String("keepalive", "", "")
|
||||||
|
cmd.Flags().Bool("truncate", false, "")
|
||||||
|
cmd.Flags().Int("dimensions", 0, "")
|
||||||
|
cmd.Flags().Bool("verbose", false, "")
|
||||||
|
cmd.Flags().Bool("insecure", false, "")
|
||||||
|
cmd.Flags().Bool("nowordwrap", false, "")
|
||||||
|
cmd.Flags().String("format", "", "")
|
||||||
|
cmd.Flags().String("think", "", "")
|
||||||
|
cmd.Flags().Bool("hidethinking", false, "")
|
||||||
|
|
||||||
|
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RunHandler returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pullCalled {
|
||||||
|
t.Fatal("expected /api/pull not to be called when cloud stub already exists")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !generateCalled {
|
||||||
|
t.Fatal("expected /api/generate to be called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunHandler_ExplicitCloudStubPullFailure_IsBestEffortTEMP(t *testing.T) {
|
||||||
|
var generateCalled bool
|
||||||
|
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||||
|
Capabilities: []model.Capability{model.CapabilityCompletion},
|
||||||
|
RemoteModel: "gpt-oss:20b",
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/tags" && r.Method == http.MethodGet:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ListResponse{Models: nil}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/pull" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
if err := json.NewEncoder(w).Encode(map[string]string{"error": "pull failed"}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||||
|
generateCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
t.Cleanup(mockServer.Close)
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(t.Context())
|
||||||
|
cmd.Flags().String("keepalive", "", "")
|
||||||
|
cmd.Flags().Bool("truncate", false, "")
|
||||||
|
cmd.Flags().Int("dimensions", 0, "")
|
||||||
|
cmd.Flags().Bool("verbose", false, "")
|
||||||
|
cmd.Flags().Bool("insecure", false, "")
|
||||||
|
cmd.Flags().Bool("nowordwrap", false, "")
|
||||||
|
cmd.Flags().String("format", "", "")
|
||||||
|
cmd.Flags().String("think", "", "")
|
||||||
|
cmd.Flags().Bool("hidethinking", false, "")
|
||||||
|
|
||||||
|
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RunHandler returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !generateCalled {
|
||||||
|
t.Fatal("expected /api/generate to be called despite pull failure")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetModelfileName(t *testing.T) {
|
func TestGetModelfileName(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -1212,6 +1553,20 @@ func TestNewCreateRequest(t *testing.T) {
|
|||||||
Model: "newmodel",
|
Model: "newmodel",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"explicit cloud model preserves source when parent lacks it",
|
||||||
|
"newmodel",
|
||||||
|
runOptions{
|
||||||
|
Model: "qwen3.5:cloud",
|
||||||
|
ParentModel: "qwen3.5",
|
||||||
|
Messages: []api.Message{},
|
||||||
|
WordWrap: true,
|
||||||
|
},
|
||||||
|
&api.CreateRequest{
|
||||||
|
From: "qwen3.5:cloud",
|
||||||
|
Model: "newmodel",
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"parent model as filepath test",
|
"parent model as filepath test",
|
||||||
"newmodel",
|
"newmodel",
|
||||||
@@ -1663,31 +2018,81 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
|
|||||||
|
|
||||||
func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
remoteHost string
|
model string
|
||||||
whoamiStatus int
|
showStatus int
|
||||||
whoamiResp any
|
remoteHost string
|
||||||
expectedError string
|
remoteModel string
|
||||||
|
whoamiStatus int
|
||||||
|
whoamiResp any
|
||||||
|
expectWhoami bool
|
||||||
|
expectedError string
|
||||||
|
expectAuthError bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "ollama.com cloud model - user signed in",
|
name: "ollama.com cloud model - user signed in",
|
||||||
|
model: "test-cloud-model",
|
||||||
remoteHost: "https://ollama.com",
|
remoteHost: "https://ollama.com",
|
||||||
|
remoteModel: "test-model",
|
||||||
whoamiStatus: http.StatusOK,
|
whoamiStatus: http.StatusOK,
|
||||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||||
|
expectWhoami: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "ollama.com cloud model - user not signed in",
|
name: "ollama.com cloud model - user not signed in",
|
||||||
|
model: "test-cloud-model",
|
||||||
remoteHost: "https://ollama.com",
|
remoteHost: "https://ollama.com",
|
||||||
|
remoteModel: "test-model",
|
||||||
whoamiStatus: http.StatusUnauthorized,
|
whoamiStatus: http.StatusUnauthorized,
|
||||||
whoamiResp: map[string]string{
|
whoamiResp: map[string]string{
|
||||||
"error": "unauthorized",
|
"error": "unauthorized",
|
||||||
"signin_url": "https://ollama.com/signin",
|
"signin_url": "https://ollama.com/signin",
|
||||||
},
|
},
|
||||||
expectedError: "unauthorized",
|
expectWhoami: true,
|
||||||
|
expectedError: "unauthorized",
|
||||||
|
expectAuthError: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "non-ollama.com remote - no auth check",
|
name: "non-ollama.com remote - no auth check",
|
||||||
|
model: "test-cloud-model",
|
||||||
remoteHost: "https://other-remote.com",
|
remoteHost: "https://other-remote.com",
|
||||||
|
remoteModel: "test-model",
|
||||||
|
whoamiStatus: http.StatusUnauthorized, // should not be called
|
||||||
|
whoamiResp: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit :cloud model - auth check without remote metadata",
|
||||||
|
model: "kimi-k2.5:cloud",
|
||||||
|
remoteHost: "",
|
||||||
|
remoteModel: "",
|
||||||
|
whoamiStatus: http.StatusOK,
|
||||||
|
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||||
|
expectWhoami: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit :cloud model without local stub returns not found by default",
|
||||||
|
model: "minimax-m2.7:cloud",
|
||||||
|
showStatus: http.StatusNotFound,
|
||||||
|
whoamiStatus: http.StatusOK,
|
||||||
|
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||||
|
expectedError: "not found",
|
||||||
|
expectWhoami: false,
|
||||||
|
expectAuthError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit -cloud model - auth check without remote metadata",
|
||||||
|
model: "kimi-k2.5:latest-cloud",
|
||||||
|
remoteHost: "",
|
||||||
|
remoteModel: "",
|
||||||
|
whoamiStatus: http.StatusOK,
|
||||||
|
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||||
|
expectWhoami: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dash cloud-like name without explicit source does not require auth",
|
||||||
|
model: "test-cloud-model",
|
||||||
|
remoteHost: "",
|
||||||
|
remoteModel: "",
|
||||||
whoamiStatus: http.StatusUnauthorized, // should not be called
|
whoamiStatus: http.StatusUnauthorized, // should not be called
|
||||||
whoamiResp: nil,
|
whoamiResp: nil,
|
||||||
},
|
},
|
||||||
@@ -1699,10 +2104,15 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
|||||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
switch r.URL.Path {
|
switch r.URL.Path {
|
||||||
case "/api/show":
|
case "/api/show":
|
||||||
|
if tt.showStatus != 0 && tt.showStatus != http.StatusOK {
|
||||||
|
w.WriteHeader(tt.showStatus)
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]string{"error": "not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||||
RemoteHost: tt.remoteHost,
|
RemoteHost: tt.remoteHost,
|
||||||
RemoteModel: "test-model",
|
RemoteModel: tt.remoteModel,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
@@ -1715,6 +2125,8 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
|||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case "/api/generate":
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
default:
|
default:
|
||||||
http.NotFound(w, r)
|
http.NotFound(w, r)
|
||||||
}
|
}
|
||||||
@@ -1727,29 +2139,28 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
|||||||
cmd.SetContext(t.Context())
|
cmd.SetContext(t.Context())
|
||||||
|
|
||||||
opts := &runOptions{
|
opts := &runOptions{
|
||||||
Model: "test-cloud-model",
|
Model: tt.model,
|
||||||
ShowConnect: false,
|
ShowConnect: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := loadOrUnloadModel(cmd, opts)
|
err := loadOrUnloadModel(cmd, opts)
|
||||||
|
|
||||||
if strings.HasPrefix(tt.remoteHost, "https://ollama.com") {
|
if whoamiCalled != tt.expectWhoami {
|
||||||
if !whoamiCalled {
|
t.Errorf("whoami called = %v, want %v", whoamiCalled, tt.expectWhoami)
|
||||||
t.Error("expected whoami to be called for ollama.com cloud model")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if whoamiCalled {
|
|
||||||
t.Error("whoami should not be called for non-ollama.com remote")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if tt.expectedError != "" {
|
if tt.expectedError != "" {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("expected error containing %q, got nil", tt.expectedError)
|
t.Errorf("expected error containing %q, got nil", tt.expectedError)
|
||||||
} else {
|
} else {
|
||||||
var authErr api.AuthorizationError
|
if !tt.expectAuthError && !strings.Contains(strings.ToLower(err.Error()), strings.ToLower(tt.expectedError)) {
|
||||||
if !errors.As(err, &authErr) {
|
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
|
||||||
t.Errorf("expected AuthorizationError, got %T: %v", err, err)
|
}
|
||||||
|
if tt.expectAuthError {
|
||||||
|
var authErr api.AuthorizationError
|
||||||
|
if !errors.As(err, &authErr) {
|
||||||
|
t.Errorf("expected AuthorizationError, got %T: %v", err, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -1760,3 +2171,38 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsLocalhost(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
host string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"default empty", "", true},
|
||||||
|
{"localhost no port", "localhost", true},
|
||||||
|
{"localhost with port", "localhost:11435", true},
|
||||||
|
{"127.0.0.1 no port", "127.0.0.1", true},
|
||||||
|
{"127.0.0.1 with port", "127.0.0.1:11434", true},
|
||||||
|
{"0.0.0.0 no port", "0.0.0.0", true},
|
||||||
|
{"0.0.0.0 with port", "0.0.0.0:11434", true},
|
||||||
|
{"::1 no port", "::1", true},
|
||||||
|
{"[::1] with port", "[::1]:11434", true},
|
||||||
|
{"loopback with scheme", "http://localhost:11434", true},
|
||||||
|
{"remote hostname", "example.com", false},
|
||||||
|
{"remote hostname with port", "example.com:11434", false},
|
||||||
|
{"remote IP", "192.168.1.1", false},
|
||||||
|
{"remote IP with port", "192.168.1.1:11434", false},
|
||||||
|
{"remote with scheme", "http://example.com:11434", false},
|
||||||
|
{"https remote", "https://example.com:443", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_HOST", tt.host)
|
||||||
|
got := isLocalhost()
|
||||||
|
if got != tt.expected {
|
||||||
|
t.Errorf("isLocalhost() with OLLAMA_HOST=%q = %v, want %v", tt.host, got, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,192 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Claude implements Runner and AliasConfigurer for Claude Code integration
|
|
||||||
type Claude struct{}
|
|
||||||
|
|
||||||
// Compile-time check that Claude implements AliasConfigurer
|
|
||||||
var _ AliasConfigurer = (*Claude)(nil)
|
|
||||||
|
|
||||||
func (c *Claude) String() string { return "Claude Code" }
|
|
||||||
|
|
||||||
func (c *Claude) args(model string, extra []string) []string {
|
|
||||||
var args []string
|
|
||||||
if model != "" {
|
|
||||||
args = append(args, "--model", model)
|
|
||||||
}
|
|
||||||
args = append(args, extra...)
|
|
||||||
return args
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Claude) findPath() (string, error) {
|
|
||||||
if p, err := exec.LookPath("claude"); err == nil {
|
|
||||||
return p, nil
|
|
||||||
}
|
|
||||||
home, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
name := "claude"
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
name = "claude.exe"
|
|
||||||
}
|
|
||||||
fallback := filepath.Join(home, ".claude", "local", name)
|
|
||||||
if _, err := os.Stat(fallback); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return fallback, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Claude) Run(model string, args []string) error {
|
|
||||||
claudePath, err := c.findPath()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command(claudePath, c.args(model, args)...)
|
|
||||||
cmd.Stdin = os.Stdin
|
|
||||||
cmd.Stdout = os.Stdout
|
|
||||||
cmd.Stderr = os.Stderr
|
|
||||||
|
|
||||||
env := append(os.Environ(),
|
|
||||||
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
|
|
||||||
"ANTHROPIC_API_KEY=",
|
|
||||||
"ANTHROPIC_AUTH_TOKEN=ollama",
|
|
||||||
)
|
|
||||||
|
|
||||||
env = append(env, c.modelEnvVars(model)...)
|
|
||||||
|
|
||||||
cmd.Env = env
|
|
||||||
return cmd.Run()
|
|
||||||
}
|
|
||||||
|
|
||||||
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
|
|
||||||
func (c *Claude) modelEnvVars(model string) []string {
|
|
||||||
primary := model
|
|
||||||
fast := model
|
|
||||||
if cfg, err := loadIntegration("claude"); err == nil && cfg.Aliases != nil {
|
|
||||||
if p := cfg.Aliases["primary"]; p != "" {
|
|
||||||
primary = p
|
|
||||||
}
|
|
||||||
if f := cfg.Aliases["fast"]; f != "" {
|
|
||||||
fast = f
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return []string{
|
|
||||||
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + primary,
|
|
||||||
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + primary,
|
|
||||||
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + fast,
|
|
||||||
"CLAUDE_CODE_SUBAGENT_MODEL=" + primary,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConfigureAliases sets up model aliases for Claude Code.
|
|
||||||
// model: the model to use (if empty, user will be prompted to select)
|
|
||||||
// aliases: existing alias configuration to preserve/update
|
|
||||||
// Cloud-only: subagent routing (fast model) is gated to cloud models only until
|
|
||||||
// there is a better strategy for prompt caching on local models.
|
|
||||||
func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAliases map[string]string, force bool) (map[string]string, bool, error) {
|
|
||||||
aliases := make(map[string]string)
|
|
||||||
for k, v := range existingAliases {
|
|
||||||
aliases[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
if model != "" {
|
|
||||||
aliases["primary"] = model
|
|
||||||
}
|
|
||||||
|
|
||||||
if !force && aliases["primary"] != "" {
|
|
||||||
client, _ := api.ClientFromEnvironment()
|
|
||||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
|
||||||
if isCloudModel(ctx, client, aliases["fast"]) {
|
|
||||||
return aliases, false, nil
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
delete(aliases, "fast")
|
|
||||||
return aliases, false, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
items, existingModels, cloudModels, client, err := listModels(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
|
|
||||||
|
|
||||||
if aliases["primary"] == "" || force {
|
|
||||||
primary, err := DefaultSingleSelector("Select model:", items, aliases["primary"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, false, err
|
|
||||||
}
|
|
||||||
if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil {
|
|
||||||
return nil, false, err
|
|
||||||
}
|
|
||||||
if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil {
|
|
||||||
return nil, false, err
|
|
||||||
}
|
|
||||||
aliases["primary"] = primary
|
|
||||||
}
|
|
||||||
|
|
||||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
|
||||||
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
|
|
||||||
aliases["fast"] = aliases["primary"]
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
delete(aliases, "fast")
|
|
||||||
}
|
|
||||||
|
|
||||||
return aliases, true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetAliases syncs the configured aliases to the Ollama server using prefix matching.
|
|
||||||
// Cloud-only: for local models (fast is empty), we delete any existing aliases to
|
|
||||||
// prevent stale routing to a previous cloud model.
|
|
||||||
func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error {
|
|
||||||
client, err := api.ClientFromEnvironment()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
prefixes := []string{"claude-sonnet-", "claude-haiku-"}
|
|
||||||
|
|
||||||
if aliases["fast"] == "" {
|
|
||||||
for _, prefix := range prefixes {
|
|
||||||
_ = client.DeleteAliasExperimental(ctx, &api.AliasDeleteRequest{Alias: prefix})
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
prefixAliases := map[string]string{
|
|
||||||
"claude-sonnet-": aliases["primary"],
|
|
||||||
"claude-haiku-": aliases["fast"],
|
|
||||||
}
|
|
||||||
|
|
||||||
var errs []string
|
|
||||||
for prefix, target := range prefixAliases {
|
|
||||||
req := &api.AliasRequest{
|
|
||||||
Alias: prefix,
|
|
||||||
Target: target,
|
|
||||||
PrefixMatching: true,
|
|
||||||
}
|
|
||||||
if err := client.SetAliasExperimental(ctx, req); err != nil {
|
|
||||||
errs = append(errs, prefix)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(errs) > 0 {
|
|
||||||
return fmt.Errorf("failed to set aliases: %v", errs)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -3,7 +3,6 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -11,7 +10,7 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
type integration struct {
|
type integration struct {
|
||||||
@@ -20,6 +19,9 @@ type integration struct {
|
|||||||
Onboarded bool `json:"onboarded,omitempty"`
|
Onboarded bool `json:"onboarded,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IntegrationConfig is the persisted config for one integration.
|
||||||
|
type IntegrationConfig = integration
|
||||||
|
|
||||||
type config struct {
|
type config struct {
|
||||||
Integrations map[string]*integration `json:"integrations"`
|
Integrations map[string]*integration `json:"integrations"`
|
||||||
LastModel string `json:"last_model,omitempty"`
|
LastModel string `json:"last_model,omitempty"`
|
||||||
@@ -124,7 +126,7 @@ func save(cfg *config) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return writeWithBackup(path, data)
|
return fileutil.WriteWithBackup(path, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func SaveIntegration(appName string, models []string) error {
|
func SaveIntegration(appName string, models []string) error {
|
||||||
@@ -155,8 +157,8 @@ func SaveIntegration(appName string, models []string) error {
|
|||||||
return save(cfg)
|
return save(cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// integrationOnboarded marks an integration as onboarded in ollama's config.
|
// MarkIntegrationOnboarded marks an integration as onboarded in Ollama's config.
|
||||||
func integrationOnboarded(appName string) error {
|
func MarkIntegrationOnboarded(appName string) error {
|
||||||
cfg, err := load()
|
cfg, err := load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -174,7 +176,7 @@ func integrationOnboarded(appName string) error {
|
|||||||
|
|
||||||
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
|
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
|
||||||
func IntegrationModel(appName string) string {
|
func IntegrationModel(appName string) string {
|
||||||
integrationConfig, err := loadIntegration(appName)
|
integrationConfig, err := LoadIntegration(appName)
|
||||||
if err != nil || len(integrationConfig.Models) == 0 {
|
if err != nil || len(integrationConfig.Models) == 0 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@@ -183,7 +185,7 @@ func IntegrationModel(appName string) string {
|
|||||||
|
|
||||||
// IntegrationModels returns all configured models for an integration, or nil.
|
// IntegrationModels returns all configured models for an integration, or nil.
|
||||||
func IntegrationModels(appName string) []string {
|
func IntegrationModels(appName string) []string {
|
||||||
integrationConfig, err := loadIntegration(appName)
|
integrationConfig, err := LoadIntegration(appName)
|
||||||
if err != nil || len(integrationConfig.Models) == 0 {
|
if err != nil || len(integrationConfig.Models) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -228,28 +230,8 @@ func SetLastSelection(selection string) error {
|
|||||||
return save(cfg)
|
return save(cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModelExists checks if a model exists on the Ollama server.
|
// LoadIntegration returns the saved config for one integration.
|
||||||
func ModelExists(ctx context.Context, name string) bool {
|
func LoadIntegration(appName string) (*integration, error) {
|
||||||
if name == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
client, err := api.ClientFromEnvironment()
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
models, err := client.List(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for _, m := range models.Models {
|
|
||||||
if m.Name == name || strings.HasPrefix(m.Name, name+":") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadIntegration(appName string) (*integration, error) {
|
|
||||||
cfg, err := load()
|
cfg, err := load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -263,7 +245,8 @@ func loadIntegration(appName string) (*integration, error) {
|
|||||||
return integrationConfig, nil
|
return integrationConfig, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func saveAliases(appName string, aliases map[string]string) error {
|
// SaveAliases replaces the saved aliases for one integration.
|
||||||
|
func SaveAliases(appName string, aliases map[string]string) error {
|
||||||
if appName == "" {
|
if appName == "" {
|
||||||
return errors.New("app name cannot be empty")
|
return errors.New("app name cannot be empty")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -45,12 +44,12 @@ func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
|
|||||||
"primary": "cloud-model",
|
"primary": "cloud-model",
|
||||||
"fast": "cloud-model",
|
"fast": "cloud-model",
|
||||||
}
|
}
|
||||||
if err := saveAliases("claude", initial); err != nil {
|
if err := SaveAliases("claude", initial); err != nil {
|
||||||
t.Fatalf("failed to save initial aliases: %v", err)
|
t.Fatalf("failed to save initial aliases: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify both are saved
|
// Verify both are saved
|
||||||
loaded, err := loadIntegration("claude")
|
loaded, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load: %v", err)
|
t.Fatalf("failed to load: %v", err)
|
||||||
}
|
}
|
||||||
@@ -63,12 +62,12 @@ func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
|
|||||||
"primary": "local-model",
|
"primary": "local-model",
|
||||||
// fast intentionally missing
|
// fast intentionally missing
|
||||||
}
|
}
|
||||||
if err := saveAliases("claude", updated); err != nil {
|
if err := SaveAliases("claude", updated); err != nil {
|
||||||
t.Fatalf("failed to save updated aliases: %v", err)
|
t.Fatalf("failed to save updated aliases: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify fast is GONE (not merged/preserved)
|
// Verify fast is GONE (not merged/preserved)
|
||||||
loaded, err = loadIntegration("claude")
|
loaded, err = LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load after update: %v", err)
|
t.Fatalf("failed to load after update: %v", err)
|
||||||
}
|
}
|
||||||
@@ -91,12 +90,12 @@ func TestSaveAliases_PreservesModels(t *testing.T) {
|
|||||||
|
|
||||||
// Then update aliases
|
// Then update aliases
|
||||||
aliases := map[string]string{"primary": "new-model"}
|
aliases := map[string]string{"primary": "new-model"}
|
||||||
if err := saveAliases("claude", aliases); err != nil {
|
if err := SaveAliases("claude", aliases); err != nil {
|
||||||
t.Fatalf("failed to save aliases: %v", err)
|
t.Fatalf("failed to save aliases: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify models are preserved
|
// Verify models are preserved
|
||||||
loaded, err := loadIntegration("claude")
|
loaded, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load: %v", err)
|
t.Fatalf("failed to load: %v", err)
|
||||||
}
|
}
|
||||||
@@ -111,16 +110,16 @@ func TestSaveAliases_EmptyMap(t *testing.T) {
|
|||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
// Save with aliases
|
// Save with aliases
|
||||||
if err := saveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
|
if err := SaveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
|
||||||
t.Fatalf("failed to save: %v", err)
|
t.Fatalf("failed to save: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save empty map
|
// Save empty map
|
||||||
if err := saveAliases("claude", map[string]string{}); err != nil {
|
if err := SaveAliases("claude", map[string]string{}); err != nil {
|
||||||
t.Fatalf("failed to save empty: %v", err)
|
t.Fatalf("failed to save empty: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, err := loadIntegration("claude")
|
loaded, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load: %v", err)
|
t.Fatalf("failed to load: %v", err)
|
||||||
}
|
}
|
||||||
@@ -135,16 +134,16 @@ func TestSaveAliases_NilMap(t *testing.T) {
|
|||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
// Save with aliases first
|
// Save with aliases first
|
||||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||||
t.Fatalf("failed to save: %v", err)
|
t.Fatalf("failed to save: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save nil map - should clear aliases
|
// Save nil map - should clear aliases
|
||||||
if err := saveAliases("claude", nil); err != nil {
|
if err := SaveAliases("claude", nil); err != nil {
|
||||||
t.Fatalf("failed to save nil: %v", err)
|
t.Fatalf("failed to save nil: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, err := loadIntegration("claude")
|
loaded, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load: %v", err)
|
t.Fatalf("failed to load: %v", err)
|
||||||
}
|
}
|
||||||
@@ -155,7 +154,7 @@ func TestSaveAliases_NilMap(t *testing.T) {
|
|||||||
|
|
||||||
// TestSaveAliases_EmptyAppName returns error
|
// TestSaveAliases_EmptyAppName returns error
|
||||||
func TestSaveAliases_EmptyAppName(t *testing.T) {
|
func TestSaveAliases_EmptyAppName(t *testing.T) {
|
||||||
err := saveAliases("", map[string]string{"primary": "model"})
|
err := SaveAliases("", map[string]string{"primary": "model"})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error for empty app name")
|
t.Error("expected error for empty app name")
|
||||||
}
|
}
|
||||||
@@ -165,12 +164,12 @@ func TestSaveAliases_CaseInsensitive(t *testing.T) {
|
|||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
if err := saveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
|
if err := SaveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
|
||||||
t.Fatalf("failed to save: %v", err)
|
t.Fatalf("failed to save: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load with different case
|
// Load with different case
|
||||||
loaded, err := loadIntegration("claude")
|
loaded, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load: %v", err)
|
t.Fatalf("failed to load: %v", err)
|
||||||
}
|
}
|
||||||
@@ -179,11 +178,11 @@ func TestSaveAliases_CaseInsensitive(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update with different case
|
// Update with different case
|
||||||
if err := saveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
|
if err := SaveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
|
||||||
t.Fatalf("failed to update: %v", err)
|
t.Fatalf("failed to update: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, err = loadIntegration("claude")
|
loaded, err = LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load after update: %v", err)
|
t.Fatalf("failed to load after update: %v", err)
|
||||||
}
|
}
|
||||||
@@ -198,11 +197,11 @@ func TestSaveAliases_CreatesIntegration(t *testing.T) {
|
|||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
// Save aliases for non-existent integration
|
// Save aliases for non-existent integration
|
||||||
if err := saveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
|
if err := SaveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
|
||||||
t.Fatalf("failed to save: %v", err)
|
t.Fatalf("failed to save: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, err := loadIntegration("newintegration")
|
loaded, err := LoadIntegration("newintegration")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load: %v", err)
|
t.Fatalf("failed to load: %v", err)
|
||||||
}
|
}
|
||||||
@@ -371,12 +370,12 @@ func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) {
|
|||||||
t.Fatal("server should succeed")
|
t.Fatal("server should succeed")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||||
t.Fatalf("saveAliases failed: %v", err)
|
t.Fatalf("saveAliases failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify it was actually saved
|
// Verify it was actually saved
|
||||||
loaded, err := loadIntegration("claude")
|
loaded, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load: %v", err)
|
t.Fatalf("failed to load: %v", err)
|
||||||
}
|
}
|
||||||
@@ -408,7 +407,7 @@ func TestConfigFile_PreservesUnknownFields(t *testing.T) {
|
|||||||
os.WriteFile(configPath, []byte(initialConfig), 0o644)
|
os.WriteFile(configPath, []byte(initialConfig), 0o644)
|
||||||
|
|
||||||
// Update aliases
|
// Update aliases
|
||||||
if err := saveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
|
if err := SaveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
|
||||||
t.Fatalf("failed to save: %v", err)
|
t.Fatalf("failed to save: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -440,11 +439,6 @@ func containsHelper(s, substr string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClaudeImplementsAliasConfigurer(t *testing.T) {
|
|
||||||
c := &Claude{}
|
|
||||||
var _ AliasConfigurer = c // Compile-time check
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestModelNameEdgeCases(t *testing.T) {
|
func TestModelNameEdgeCases(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -464,11 +458,11 @@ func TestModelNameEdgeCases(t *testing.T) {
|
|||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
aliases := map[string]string{"primary": tc.model}
|
aliases := map[string]string{"primary": tc.model}
|
||||||
if err := saveAliases("claude", aliases); err != nil {
|
if err := SaveAliases("claude", aliases); err != nil {
|
||||||
t.Fatalf("failed to save model %q: %v", tc.model, err)
|
t.Fatalf("failed to save model %q: %v", tc.model, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, err := loadIntegration("claude")
|
loaded, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load: %v", err)
|
t.Fatalf("failed to load: %v", err)
|
||||||
}
|
}
|
||||||
@@ -485,7 +479,7 @@ func TestSwitchingScenarios(t *testing.T) {
|
|||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
// Initial cloud config
|
// Initial cloud config
|
||||||
if err := saveAliases("claude", map[string]string{
|
if err := SaveAliases("claude", map[string]string{
|
||||||
"primary": "cloud-model",
|
"primary": "cloud-model",
|
||||||
"fast": "cloud-model",
|
"fast": "cloud-model",
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
@@ -493,13 +487,13 @@ func TestSwitchingScenarios(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Switch to local (no fast)
|
// Switch to local (no fast)
|
||||||
if err := saveAliases("claude", map[string]string{
|
if err := SaveAliases("claude", map[string]string{
|
||||||
"primary": "local-model",
|
"primary": "local-model",
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, _ := loadIntegration("claude")
|
loaded, _ := LoadIntegration("claude")
|
||||||
if loaded.Aliases["fast"] != "" {
|
if loaded.Aliases["fast"] != "" {
|
||||||
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
|
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
|
||||||
}
|
}
|
||||||
@@ -513,21 +507,21 @@ func TestSwitchingScenarios(t *testing.T) {
|
|||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
// Initial local config
|
// Initial local config
|
||||||
if err := saveAliases("claude", map[string]string{
|
if err := SaveAliases("claude", map[string]string{
|
||||||
"primary": "local-model",
|
"primary": "local-model",
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Switch to cloud (with fast)
|
// Switch to cloud (with fast)
|
||||||
if err := saveAliases("claude", map[string]string{
|
if err := SaveAliases("claude", map[string]string{
|
||||||
"primary": "cloud-model",
|
"primary": "cloud-model",
|
||||||
"fast": "cloud-model",
|
"fast": "cloud-model",
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, _ := loadIntegration("claude")
|
loaded, _ := LoadIntegration("claude")
|
||||||
if loaded.Aliases["fast"] != "cloud-model" {
|
if loaded.Aliases["fast"] != "cloud-model" {
|
||||||
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
|
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
|
||||||
}
|
}
|
||||||
@@ -538,7 +532,7 @@ func TestSwitchingScenarios(t *testing.T) {
|
|||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
// Initial cloud config
|
// Initial cloud config
|
||||||
if err := saveAliases("claude", map[string]string{
|
if err := SaveAliases("claude", map[string]string{
|
||||||
"primary": "cloud-model-1",
|
"primary": "cloud-model-1",
|
||||||
"fast": "cloud-model-1",
|
"fast": "cloud-model-1",
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
@@ -546,14 +540,14 @@ func TestSwitchingScenarios(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Switch to different cloud
|
// Switch to different cloud
|
||||||
if err := saveAliases("claude", map[string]string{
|
if err := SaveAliases("claude", map[string]string{
|
||||||
"primary": "cloud-model-2",
|
"primary": "cloud-model-2",
|
||||||
"fast": "cloud-model-2",
|
"fast": "cloud-model-2",
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, _ := loadIntegration("claude")
|
loaded, _ := LoadIntegration("claude")
|
||||||
if loaded.Aliases["primary"] != "cloud-model-2" {
|
if loaded.Aliases["primary"] != "cloud-model-2" {
|
||||||
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
|
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
|
||||||
}
|
}
|
||||||
@@ -563,43 +557,13 @@ func TestSwitchingScenarios(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestToolCapabilityFiltering(t *testing.T) {
|
|
||||||
t.Run("all models checked for tool capability", func(t *testing.T) {
|
|
||||||
// Both cloud and local models are checked for tool capability via Show API
|
|
||||||
// Only models with "tools" in capabilities are included
|
|
||||||
m := modelInfo{Name: "tool-model", Remote: false, ToolCapable: true}
|
|
||||||
if !m.ToolCapable {
|
|
||||||
t.Error("tool capable model should be marked as such")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("modelInfo includes ToolCapable field", func(t *testing.T) {
|
|
||||||
m := modelInfo{Name: "test", Remote: true, ToolCapable: true}
|
|
||||||
if !m.ToolCapable {
|
|
||||||
t.Error("ToolCapable field should be accessible")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsCloudModel_RequiresClient(t *testing.T) {
|
|
||||||
t.Run("nil client always returns false", func(t *testing.T) {
|
|
||||||
// isCloudModel now only uses Show API, no suffix detection
|
|
||||||
if isCloudModel(context.Background(), nil, "model:cloud") {
|
|
||||||
t.Error("nil client should return false regardless of suffix")
|
|
||||||
}
|
|
||||||
if isCloudModel(context.Background(), nil, "local-model") {
|
|
||||||
t.Error("nil client should return false")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||||
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
|
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
// Save aliases with one model
|
// Save aliases with one model
|
||||||
if err := saveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
|
if err := SaveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -608,7 +572,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, _ := loadIntegration("claude")
|
loaded, _ := LoadIntegration("claude")
|
||||||
if loaded.Aliases["primary"] != loaded.Models[0] {
|
if loaded.Aliases["primary"] != loaded.Models[0] {
|
||||||
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
|
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
|
||||||
}
|
}
|
||||||
@@ -622,11 +586,11 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
|||||||
if err := SaveIntegration("claude", []string{"old-model"}); err != nil {
|
if err := SaveIntegration("claude", []string{"old-model"}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
|
if err := SaveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, _ := loadIntegration("claude")
|
loaded, _ := LoadIntegration("claude")
|
||||||
|
|
||||||
// They should be different (this is the bug state)
|
// They should be different (this is the bug state)
|
||||||
if loaded.Models[0] == loaded.Aliases["primary"] {
|
if loaded.Models[0] == loaded.Aliases["primary"] {
|
||||||
@@ -638,7 +602,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, _ = loadIntegration("claude")
|
loaded, _ = LoadIntegration("claude")
|
||||||
if loaded.Models[0] != loaded.Aliases["primary"] {
|
if loaded.Models[0] != loaded.Aliases["primary"] {
|
||||||
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
|
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
|
||||||
loaded.Models[0], loaded.Aliases["primary"])
|
loaded.Models[0], loaded.Aliases["primary"])
|
||||||
@@ -653,20 +617,20 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
|||||||
if err := SaveIntegration("claude", []string{"initial-model"}); err != nil {
|
if err := SaveIntegration("claude", []string{"initial-model"}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
|
if err := SaveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update aliases AND models together
|
// Update aliases AND models together
|
||||||
newAliases := map[string]string{"primary": "updated-model"}
|
newAliases := map[string]string{"primary": "updated-model"}
|
||||||
if err := saveAliases("claude", newAliases); err != nil {
|
if err := SaveAliases("claude", newAliases); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := SaveIntegration("claude", []string{newAliases["primary"]}); err != nil {
|
if err := SaveIntegration("claude", []string{newAliases["primary"]}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, _ := loadIntegration("claude")
|
loaded, _ := LoadIntegration("claude")
|
||||||
if loaded.Models[0] != "updated-model" {
|
if loaded.Models[0] != "updated-model" {
|
||||||
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
|
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,17 +10,10 @@ import (
|
|||||||
// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests
|
// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests
|
||||||
func setTestHome(t *testing.T, dir string) {
|
func setTestHome(t *testing.T, dir string) {
|
||||||
t.Setenv("HOME", dir)
|
t.Setenv("HOME", dir)
|
||||||
|
t.Setenv("TMPDIR", dir)
|
||||||
t.Setenv("USERPROFILE", dir)
|
t.Setenv("USERPROFILE", dir)
|
||||||
}
|
}
|
||||||
|
|
||||||
// editorPaths is a test helper that safely calls Paths if the runner implements Editor
|
|
||||||
func editorPaths(r Runner) []string {
|
|
||||||
if editor, ok := r.(Editor); ok {
|
|
||||||
return editor.Paths()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntegrationConfig(t *testing.T) {
|
func TestIntegrationConfig(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
@@ -31,7 +24,7 @@ func TestIntegrationConfig(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := loadIntegration("claude")
|
config, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -55,11 +48,11 @@ func TestIntegrationConfig(t *testing.T) {
|
|||||||
"primary": "llama3.2:70b",
|
"primary": "llama3.2:70b",
|
||||||
"fast": "llama3.2:8b",
|
"fast": "llama3.2:8b",
|
||||||
}
|
}
|
||||||
if err := saveAliases("claude", aliases); err != nil {
|
if err := SaveAliases("claude", aliases); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := loadIntegration("claude")
|
config, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -77,14 +70,14 @@ func TestIntegrationConfig(t *testing.T) {
|
|||||||
if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
|
if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
|
if err := SaveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := SaveIntegration("claude", []string{"model-b"}); err != nil {
|
if err := SaveIntegration("claude", []string{"model-b"}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
config, err := loadIntegration("claude")
|
config, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -96,7 +89,7 @@ func TestIntegrationConfig(t *testing.T) {
|
|||||||
t.Run("defaultModel returns first model", func(t *testing.T) {
|
t.Run("defaultModel returns first model", func(t *testing.T) {
|
||||||
SaveIntegration("codex", []string{"model-a", "model-b"})
|
SaveIntegration("codex", []string{"model-a", "model-b"})
|
||||||
|
|
||||||
config, _ := loadIntegration("codex")
|
config, _ := LoadIntegration("codex")
|
||||||
defaultModel := ""
|
defaultModel := ""
|
||||||
if len(config.Models) > 0 {
|
if len(config.Models) > 0 {
|
||||||
defaultModel = config.Models[0]
|
defaultModel = config.Models[0]
|
||||||
@@ -120,7 +113,7 @@ func TestIntegrationConfig(t *testing.T) {
|
|||||||
t.Run("app name is case-insensitive", func(t *testing.T) {
|
t.Run("app name is case-insensitive", func(t *testing.T) {
|
||||||
SaveIntegration("Claude", []string{"model-x"})
|
SaveIntegration("Claude", []string{"model-x"})
|
||||||
|
|
||||||
config, err := loadIntegration("claude")
|
config, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -137,8 +130,8 @@ func TestIntegrationConfig(t *testing.T) {
|
|||||||
SaveIntegration("app1", []string{"model-1"})
|
SaveIntegration("app1", []string{"model-1"})
|
||||||
SaveIntegration("app2", []string{"model-2"})
|
SaveIntegration("app2", []string{"model-2"})
|
||||||
|
|
||||||
config1, _ := loadIntegration("app1")
|
config1, _ := LoadIntegration("app1")
|
||||||
config2, _ := loadIntegration("app2")
|
config2, _ := LoadIntegration("app2")
|
||||||
|
|
||||||
defaultModel1 := ""
|
defaultModel1 := ""
|
||||||
if len(config1.Models) > 0 {
|
if len(config1.Models) > 0 {
|
||||||
@@ -185,64 +178,6 @@ func TestListIntegrations(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEditorPaths(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
t.Run("returns empty for claude (no Editor)", func(t *testing.T) {
|
|
||||||
r := integrations["claude"]
|
|
||||||
paths := editorPaths(r)
|
|
||||||
if len(paths) != 0 {
|
|
||||||
t.Errorf("expected no paths for claude, got %v", paths)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns empty for codex (no Editor)", func(t *testing.T) {
|
|
||||||
r := integrations["codex"]
|
|
||||||
paths := editorPaths(r)
|
|
||||||
if len(paths) != 0 {
|
|
||||||
t.Errorf("expected no paths for codex, got %v", paths)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns empty for droid when no config exists", func(t *testing.T) {
|
|
||||||
r := integrations["droid"]
|
|
||||||
paths := editorPaths(r)
|
|
||||||
if len(paths) != 0 {
|
|
||||||
t.Errorf("expected no paths, got %v", paths)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns path for droid when config exists", func(t *testing.T) {
|
|
||||||
settingsDir, _ := os.UserHomeDir()
|
|
||||||
settingsDir = filepath.Join(settingsDir, ".factory")
|
|
||||||
os.MkdirAll(settingsDir, 0o755)
|
|
||||||
os.WriteFile(filepath.Join(settingsDir, "settings.json"), []byte(`{}`), 0o644)
|
|
||||||
|
|
||||||
r := integrations["droid"]
|
|
||||||
paths := editorPaths(r)
|
|
||||||
if len(paths) != 1 {
|
|
||||||
t.Errorf("expected 1 path, got %d", len(paths))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns paths for opencode when configs exist", func(t *testing.T) {
|
|
||||||
home, _ := os.UserHomeDir()
|
|
||||||
configDir := filepath.Join(home, ".config", "opencode")
|
|
||||||
stateDir := filepath.Join(home, ".local", "state", "opencode")
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.MkdirAll(stateDir, 0o755)
|
|
||||||
os.WriteFile(filepath.Join(configDir, "opencode.json"), []byte(`{}`), 0o644)
|
|
||||||
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644)
|
|
||||||
|
|
||||||
r := integrations["opencode"]
|
|
||||||
paths := editorPaths(r)
|
|
||||||
if len(paths) != 2 {
|
|
||||||
t.Errorf("expected 2 paths, got %d: %v", len(paths), paths)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
@@ -251,7 +186,7 @@ func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
|||||||
os.MkdirAll(dir, 0o755)
|
os.MkdirAll(dir, 0o755)
|
||||||
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
|
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
|
||||||
|
|
||||||
_, err := loadIntegration("test")
|
_, err := LoadIntegration("test")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error for nonexistent integration in corrupted file")
|
t.Error("expected error for nonexistent integration in corrupted file")
|
||||||
}
|
}
|
||||||
@@ -265,7 +200,7 @@ func TestSaveIntegration_NilModels(t *testing.T) {
|
|||||||
t.Fatalf("saveIntegration with nil models failed: %v", err)
|
t.Fatalf("saveIntegration with nil models failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := loadIntegration("test")
|
config, err := LoadIntegration("test")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("loadIntegration failed: %v", err)
|
t.Fatalf("loadIntegration failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -294,7 +229,7 @@ func TestLoadIntegration_NonexistentIntegration(t *testing.T) {
|
|||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
_, err := loadIntegration("nonexistent")
|
_, err := LoadIntegration("nonexistent")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error for nonexistent integration, got nil")
|
t.Error("expected error for nonexistent integration, got nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,59 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"golang.org/x/term"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ANSI escape sequences for terminal formatting.
|
|
||||||
const (
|
|
||||||
ansiBold = "\033[1m"
|
|
||||||
ansiReset = "\033[0m"
|
|
||||||
ansiGray = "\033[37m"
|
|
||||||
ansiGreen = "\033[32m"
|
|
||||||
ansiYellow = "\033[33m"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ErrCancelled is returned when the user cancels a selection.
|
|
||||||
var ErrCancelled = errors.New("cancelled")
|
|
||||||
|
|
||||||
// errCancelled is kept as an alias for backward compatibility within the package.
|
|
||||||
var errCancelled = ErrCancelled
|
|
||||||
|
|
||||||
// DefaultConfirmPrompt provides a TUI-based confirmation prompt.
|
|
||||||
// When set, confirmPrompt delegates to it instead of using raw terminal I/O.
|
|
||||||
var DefaultConfirmPrompt func(prompt string) (bool, error)
|
|
||||||
|
|
||||||
func confirmPrompt(prompt string) (bool, error) {
|
|
||||||
if DefaultConfirmPrompt != nil {
|
|
||||||
return DefaultConfirmPrompt(prompt)
|
|
||||||
}
|
|
||||||
|
|
||||||
fd := int(os.Stdin.Fd())
|
|
||||||
oldState, err := term.MakeRaw(fd)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
defer term.Restore(fd, oldState)
|
|
||||||
|
|
||||||
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
|
|
||||||
|
|
||||||
buf := make([]byte, 1)
|
|
||||||
for {
|
|
||||||
if _, err := os.Stdin.Read(buf); err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch buf[0] {
|
|
||||||
case 'Y', 'y', 13:
|
|
||||||
fmt.Fprintf(os.Stderr, "yes\r\n")
|
|
||||||
return true, nil
|
|
||||||
case 'N', 'n', 27, 3:
|
|
||||||
fmt.Fprintf(os.Stderr, "no\r\n")
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestErrCancelled(t *testing.T) {
|
|
||||||
t.Run("NotNil", func(t *testing.T) {
|
|
||||||
if errCancelled == nil {
|
|
||||||
t.Error("errCancelled should not be nil")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Message", func(t *testing.T) {
|
|
||||||
if errCancelled.Error() != "cancelled" {
|
|
||||||
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/internal/modelref"
|
||||||
"github.com/ollama/ollama/readline"
|
"github.com/ollama/ollama/readline"
|
||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
@@ -540,6 +541,13 @@ func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
|
|||||||
parentModel = ""
|
parentModel = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Preserve explicit cloud intent for sessions started with `:cloud`.
|
||||||
|
// Cloud model metadata can return a source-less parent_model (for example
|
||||||
|
// "qwen3.5"), which would otherwise make `/save` create a local derivative.
|
||||||
|
if modelref.HasExplicitCloudSource(opts.Model) && !modelref.HasExplicitCloudSource(parentModel) {
|
||||||
|
parentModel = ""
|
||||||
|
}
|
||||||
|
|
||||||
req := &api.CreateRequest{
|
req := &api.CreateRequest{
|
||||||
Model: name,
|
Model: name,
|
||||||
From: cmp.Or(parentModel, opts.Model),
|
From: cmp.Or(parentModel, opts.Model),
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
package config
|
// Package fileutil provides small shared helpers for reading JSON files
|
||||||
|
// and writing config files with backup-on-overwrite semantics.
|
||||||
|
package fileutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@@ -9,7 +11,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func readJSONFile(path string) (map[string]any, error) {
|
// ReadJSON reads a JSON object file into a generic map.
|
||||||
|
func ReadJSON(path string) (map[string]any, error) {
|
||||||
data, err := os.ReadFile(path)
|
data, err := os.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -33,12 +36,13 @@ func copyFile(src, dst string) error {
|
|||||||
return os.WriteFile(dst, data, info.Mode().Perm())
|
return os.WriteFile(dst, data, info.Mode().Perm())
|
||||||
}
|
}
|
||||||
|
|
||||||
func backupDir() string {
|
// BackupDir returns the shared backup directory used before overwriting files.
|
||||||
|
func BackupDir() string {
|
||||||
return filepath.Join(os.TempDir(), "ollama-backups")
|
return filepath.Join(os.TempDir(), "ollama-backups")
|
||||||
}
|
}
|
||||||
|
|
||||||
func backupToTmp(srcPath string) (string, error) {
|
func backupToTmp(srcPath string) (string, error) {
|
||||||
dir := backupDir()
|
dir := BackupDir()
|
||||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -50,8 +54,8 @@ func backupToTmp(srcPath string) (string, error) {
|
|||||||
return backupPath, nil
|
return backupPath, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeWithBackup writes data to path via temp file + rename, backing up any existing file first
|
// WriteWithBackup writes data to path via temp file + rename, backing up any existing file first.
|
||||||
func writeWithBackup(path string, data []byte) error {
|
func WriteWithBackup(path string, data []byte) error {
|
||||||
var backupPath string
|
var backupPath string
|
||||||
// backup must be created before any writes to the target file
|
// backup must be created before any writes to the target file
|
||||||
if existingContent, err := os.ReadFile(path); err == nil {
|
if existingContent, err := os.ReadFile(path); err == nil {
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package fileutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -9,6 +9,21 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
tmpRoot, err := os.MkdirTemp("", "fileutil-test-*")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.Setenv("TMPDIR", tmpRoot); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
code := m.Run()
|
||||||
|
_ = os.RemoveAll(tmpRoot)
|
||||||
|
os.Exit(code)
|
||||||
|
}
|
||||||
|
|
||||||
func mustMarshal(t *testing.T, v any) []byte {
|
func mustMarshal(t *testing.T, v any) []byte {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
data, err := json.MarshalIndent(v, "", " ")
|
data, err := json.MarshalIndent(v, "", " ")
|
||||||
@@ -18,14 +33,19 @@ func mustMarshal(t *testing.T, v any) []byte {
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isolatedTempDir(t *testing.T) string {
|
||||||
|
t.Helper()
|
||||||
|
return t.TempDir()
|
||||||
|
}
|
||||||
|
|
||||||
func TestWriteWithBackup(t *testing.T) {
|
func TestWriteWithBackup(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
|
|
||||||
t.Run("creates file", func(t *testing.T) {
|
t.Run("creates file", func(t *testing.T) {
|
||||||
path := filepath.Join(tmpDir, "new.json")
|
path := filepath.Join(tmpDir, "new.json")
|
||||||
data := mustMarshal(t, map[string]string{"key": "value"})
|
data := mustMarshal(t, map[string]string{"key": "value"})
|
||||||
|
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -43,17 +63,17 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("creates backup in /tmp/ollama-backups", func(t *testing.T) {
|
t.Run("creates backup in the temp backup directory", func(t *testing.T) {
|
||||||
path := filepath.Join(tmpDir, "backup.json")
|
path := filepath.Join(tmpDir, "backup.json")
|
||||||
|
|
||||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||||
|
|
||||||
data := mustMarshal(t, map[string]bool{"updated": true})
|
data := mustMarshal(t, map[string]bool{"updated": true})
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
entries, err := os.ReadDir(backupDir())
|
entries, err := os.ReadDir(BackupDir())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("backup directory not created")
|
t.Fatal("backup directory not created")
|
||||||
}
|
}
|
||||||
@@ -63,7 +83,7 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
if filepath.Ext(entry.Name()) != ".json" {
|
if filepath.Ext(entry.Name()) != ".json" {
|
||||||
name := entry.Name()
|
name := entry.Name()
|
||||||
if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." {
|
if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." {
|
||||||
backupPath := filepath.Join(backupDir(), name)
|
backupPath := filepath.Join(BackupDir(), name)
|
||||||
backup, err := os.ReadFile(backupPath)
|
backup, err := os.ReadFile(backupPath)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
var backupData map[string]bool
|
var backupData map[string]bool
|
||||||
@@ -79,7 +99,7 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !foundBackup {
|
if !foundBackup {
|
||||||
t.Error("backup file not created in /tmp/ollama-backups")
|
t.Error("backup file not created in backup directory")
|
||||||
}
|
}
|
||||||
|
|
||||||
current, _ := os.ReadFile(path)
|
current, _ := os.ReadFile(path)
|
||||||
@@ -94,11 +114,11 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
path := filepath.Join(tmpDir, "nobak.json")
|
path := filepath.Join(tmpDir, "nobak.json")
|
||||||
|
|
||||||
data := mustMarshal(t, map[string]string{"new": "file"})
|
data := mustMarshal(t, map[string]string{"new": "file"})
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
entries, _ := os.ReadDir(backupDir())
|
entries, _ := os.ReadDir(BackupDir())
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." {
|
if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." {
|
||||||
t.Error("backup should not exist for new file")
|
t.Error("backup should not exist for new file")
|
||||||
@@ -111,11 +131,11 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
|
|
||||||
data := mustMarshal(t, map[string]string{"key": "value"})
|
data := mustMarshal(t, map[string]string{"key": "value"})
|
||||||
|
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
entries1, _ := os.ReadDir(backupDir())
|
entries1, _ := os.ReadDir(BackupDir())
|
||||||
countBefore := 0
|
countBefore := 0
|
||||||
for _, e := range entries1 {
|
for _, e := range entries1 {
|
||||||
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
||||||
@@ -123,11 +143,11 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
entries2, _ := os.ReadDir(backupDir())
|
entries2, _ := os.ReadDir(BackupDir())
|
||||||
countAfter := 0
|
countAfter := 0
|
||||||
for _, e := range entries2 {
|
for _, e := range entries2 {
|
||||||
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
||||||
@@ -145,11 +165,11 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
|
|
||||||
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
|
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
|
||||||
data := mustMarshal(t, map[string]int{"v": 2})
|
data := mustMarshal(t, map[string]int{"v": 2})
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
entries, _ := os.ReadDir(backupDir())
|
entries, _ := os.ReadDir(BackupDir())
|
||||||
var found bool
|
var found bool
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
name := entry.Name()
|
name := entry.Name()
|
||||||
@@ -161,7 +181,7 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
found = true
|
found = true
|
||||||
os.Remove(filepath.Join(backupDir(), name))
|
os.Remove(filepath.Join(BackupDir(), name))
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -180,7 +200,7 @@ func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
|
|||||||
t.Skip("permission tests unreliable on Windows")
|
t.Skip("permission tests unreliable on Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
path := filepath.Join(tmpDir, "config.json")
|
path := filepath.Join(tmpDir, "config.json")
|
||||||
|
|
||||||
// Create original file
|
// Create original file
|
||||||
@@ -188,13 +208,13 @@ func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
|
|||||||
os.WriteFile(path, originalContent, 0o644)
|
os.WriteFile(path, originalContent, 0o644)
|
||||||
|
|
||||||
// Make backup directory read-only to force backup failure
|
// Make backup directory read-only to force backup failure
|
||||||
backupDir := backupDir()
|
backupDir := BackupDir()
|
||||||
os.MkdirAll(backupDir, 0o755)
|
os.MkdirAll(backupDir, 0o755)
|
||||||
os.Chmod(backupDir, 0o444) // Read-only
|
os.Chmod(backupDir, 0o444) // Read-only
|
||||||
defer os.Chmod(backupDir, 0o755)
|
defer os.Chmod(backupDir, 0o755)
|
||||||
|
|
||||||
newContent := []byte(`{"updated": true}`)
|
newContent := []byte(`{"updated": true}`)
|
||||||
err := writeWithBackup(path, newContent)
|
err := WriteWithBackup(path, newContent)
|
||||||
|
|
||||||
// Should fail because backup couldn't be created
|
// Should fail because backup couldn't be created
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -215,7 +235,7 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
|
|||||||
t.Skip("permission tests unreliable on Windows")
|
t.Skip("permission tests unreliable on Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
|
|
||||||
// Create a read-only directory
|
// Create a read-only directory
|
||||||
readOnlyDir := filepath.Join(tmpDir, "readonly")
|
readOnlyDir := filepath.Join(tmpDir, "readonly")
|
||||||
@@ -224,7 +244,7 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
|
|||||||
defer os.Chmod(readOnlyDir, 0o755)
|
defer os.Chmod(readOnlyDir, 0o755)
|
||||||
|
|
||||||
path := filepath.Join(readOnlyDir, "config.json")
|
path := filepath.Join(readOnlyDir, "config.json")
|
||||||
err := writeWithBackup(path, []byte(`{"test": true}`))
|
err := WriteWithBackup(path, []byte(`{"test": true}`))
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected permission error, got nil")
|
t.Error("expected permission error, got nil")
|
||||||
@@ -234,10 +254,10 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
|
|||||||
// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist.
|
// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist.
|
||||||
// writeWithBackup doesn't create directories - caller is responsible.
|
// writeWithBackup doesn't create directories - caller is responsible.
|
||||||
func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) {
|
func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json")
|
path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json")
|
||||||
|
|
||||||
err := writeWithBackup(path, []byte(`{"test": true}`))
|
err := WriteWithBackup(path, []byte(`{"test": true}`))
|
||||||
|
|
||||||
// Should fail because directory doesn't exist
|
// Should fail because directory doesn't exist
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -252,7 +272,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
|
|||||||
t.Skip("symlink tests may require admin on Windows")
|
t.Skip("symlink tests may require admin on Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
realFile := filepath.Join(tmpDir, "real.json")
|
realFile := filepath.Join(tmpDir, "real.json")
|
||||||
symlink := filepath.Join(tmpDir, "link.json")
|
symlink := filepath.Join(tmpDir, "link.json")
|
||||||
|
|
||||||
@@ -261,7 +281,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
|
|||||||
os.Symlink(realFile, symlink)
|
os.Symlink(realFile, symlink)
|
||||||
|
|
||||||
// Write through symlink
|
// Write through symlink
|
||||||
err := writeWithBackup(symlink, []byte(`{"v": 2}`))
|
err := WriteWithBackup(symlink, []byte(`{"v": 2}`))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("writeWithBackup through symlink failed: %v", err)
|
t.Fatalf("writeWithBackup through symlink failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -276,7 +296,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
|
|||||||
// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters.
|
// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters.
|
||||||
// User may have config files with unusual names.
|
// User may have config files with unusual names.
|
||||||
func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) {
|
func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
|
|
||||||
// File with spaces and special chars
|
// File with spaces and special chars
|
||||||
path := filepath.Join(tmpDir, "my config (backup).json")
|
path := filepath.Join(tmpDir, "my config (backup).json")
|
||||||
@@ -305,7 +325,7 @@ func TestCopyFile_PreservesPermissions(t *testing.T) {
|
|||||||
t.Skip("permission preservation tests unreliable on Windows")
|
t.Skip("permission preservation tests unreliable on Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
src := filepath.Join(tmpDir, "src.json")
|
src := filepath.Join(tmpDir, "src.json")
|
||||||
dst := filepath.Join(tmpDir, "dst.json")
|
dst := filepath.Join(tmpDir, "dst.json")
|
||||||
|
|
||||||
@@ -327,7 +347,7 @@ func TestCopyFile_PreservesPermissions(t *testing.T) {
|
|||||||
|
|
||||||
// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist.
|
// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist.
|
||||||
func TestCopyFile_SourceNotFound(t *testing.T) {
|
func TestCopyFile_SourceNotFound(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
src := filepath.Join(tmpDir, "nonexistent.json")
|
src := filepath.Join(tmpDir, "nonexistent.json")
|
||||||
dst := filepath.Join(tmpDir, "dst.json")
|
dst := filepath.Join(tmpDir, "dst.json")
|
||||||
|
|
||||||
@@ -339,11 +359,11 @@ func TestCopyFile_SourceNotFound(t *testing.T) {
|
|||||||
|
|
||||||
// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory.
|
// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory.
|
||||||
func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
|
func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
dirPath := filepath.Join(tmpDir, "actualdir")
|
dirPath := filepath.Join(tmpDir, "actualdir")
|
||||||
os.MkdirAll(dirPath, 0o755)
|
os.MkdirAll(dirPath, 0o755)
|
||||||
|
|
||||||
err := writeWithBackup(dirPath, []byte(`{"test": true}`))
|
err := WriteWithBackup(dirPath, []byte(`{"test": true}`))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error when target is a directory, got nil")
|
t.Error("expected error when target is a directory, got nil")
|
||||||
}
|
}
|
||||||
@@ -351,10 +371,10 @@ func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
|
|||||||
|
|
||||||
// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly.
|
// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly.
|
||||||
func TestWriteWithBackup_EmptyData(t *testing.T) {
|
func TestWriteWithBackup_EmptyData(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
path := filepath.Join(tmpDir, "empty.json")
|
path := filepath.Join(tmpDir, "empty.json")
|
||||||
|
|
||||||
err := writeWithBackup(path, []byte{})
|
err := WriteWithBackup(path, []byte{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("writeWithBackup with empty data failed: %v", err)
|
t.Fatalf("writeWithBackup with empty data failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -375,7 +395,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
|
|||||||
t.Skip("permission tests unreliable on Windows")
|
t.Skip("permission tests unreliable on Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
path := filepath.Join(tmpDir, "unreadable.json")
|
path := filepath.Join(tmpDir, "unreadable.json")
|
||||||
|
|
||||||
// Create file and make it unreadable
|
// Create file and make it unreadable
|
||||||
@@ -384,7 +404,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
|
|||||||
defer os.Chmod(path, 0o644)
|
defer os.Chmod(path, 0o644)
|
||||||
|
|
||||||
// Should fail because we can't read the file to compare/backup
|
// Should fail because we can't read the file to compare/backup
|
||||||
err := writeWithBackup(path, []byte(`{"updated": true}`))
|
err := WriteWithBackup(path, []byte(`{"updated": true}`))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error when file is unreadable, got nil")
|
t.Error("expected error when file is unreadable, got nil")
|
||||||
}
|
}
|
||||||
@@ -393,7 +413,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
|
|||||||
// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes
|
// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes
|
||||||
// within the same second (timestamp collision scenario).
|
// within the same second (timestamp collision scenario).
|
||||||
func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
path := filepath.Join(tmpDir, "rapid.json")
|
path := filepath.Join(tmpDir, "rapid.json")
|
||||||
|
|
||||||
// Create initial file
|
// Create initial file
|
||||||
@@ -402,7 +422,7 @@ func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
|||||||
// Rapid successive writes
|
// Rapid successive writes
|
||||||
for i := 1; i <= 3; i++ {
|
for i := 1; i <= 3; i++ {
|
||||||
data := []byte(fmt.Sprintf(`{"v": %d}`, i))
|
data := []byte(fmt.Sprintf(`{"v": %d}`, i))
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatalf("write %d failed: %v", i, err)
|
t.Fatalf("write %d failed: %v", i, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -414,7 +434,7 @@ func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Verify at least one backup exists
|
// Verify at least one backup exists
|
||||||
entries, _ := os.ReadDir(backupDir())
|
entries, _ := os.ReadDir(BackupDir())
|
||||||
var backupCount int
|
var backupCount int
|
||||||
for _, e := range entries {
|
for _, e := range entries {
|
||||||
if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." {
|
if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." {
|
||||||
@@ -432,8 +452,9 @@ func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
|
|||||||
t.Skip("test modifies system temp directory")
|
t.Skip("test modifies system temp directory")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tmpDir := isolatedTempDir(t)
|
||||||
// Create a file at the backup directory path
|
// Create a file at the backup directory path
|
||||||
backupPath := backupDir()
|
backupPath := BackupDir()
|
||||||
// Clean up any existing directory first
|
// Clean up any existing directory first
|
||||||
os.RemoveAll(backupPath)
|
os.RemoveAll(backupPath)
|
||||||
// Create a file instead of directory
|
// Create a file instead of directory
|
||||||
@@ -443,11 +464,10 @@ func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
|
|||||||
os.MkdirAll(backupPath, 0o755)
|
os.MkdirAll(backupPath, 0o755)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
path := filepath.Join(tmpDir, "test.json")
|
path := filepath.Join(tmpDir, "test.json")
|
||||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||||
|
|
||||||
err := writeWithBackup(path, []byte(`{"updated": true}`))
|
err := WriteWithBackup(path, []byte(`{"updated": true}`))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error when backup dir is a file, got nil")
|
t.Error("expected error when backup dir is a file, got nil")
|
||||||
}
|
}
|
||||||
@@ -459,7 +479,7 @@ func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
|
|||||||
t.Skip("permission tests unreliable on Windows")
|
t.Skip("permission tests unreliable on Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
|
|
||||||
// Count existing temp files
|
// Count existing temp files
|
||||||
countTempFiles := func() int {
|
countTempFiles := func() int {
|
||||||
@@ -493,7 +513,7 @@ func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
|
|||||||
badPath := filepath.Join(tmpDir, "isdir")
|
badPath := filepath.Join(tmpDir, "isdir")
|
||||||
os.MkdirAll(badPath, 0o755)
|
os.MkdirAll(badPath, 0o755)
|
||||||
|
|
||||||
_ = writeWithBackup(badPath, []byte(`{"test": true}`))
|
_ = WriteWithBackup(badPath, []byte(`{"test": true}`))
|
||||||
|
|
||||||
after := countTempFiles()
|
after := countTempFiles()
|
||||||
if after > before {
|
if after > before {
|
||||||
87
cmd/launch/claude.go
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Claude implements Runner for Claude Code integration.
|
||||||
|
type Claude struct{}
|
||||||
|
|
||||||
|
func (c *Claude) String() string { return "Claude Code" }
|
||||||
|
|
||||||
|
func (c *Claude) args(model string, extra []string) []string {
|
||||||
|
var args []string
|
||||||
|
if model != "" {
|
||||||
|
args = append(args, "--model", model)
|
||||||
|
}
|
||||||
|
args = append(args, extra...)
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Claude) findPath() (string, error) {
|
||||||
|
if p, err := exec.LookPath("claude"); err == nil {
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
name := "claude"
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
name = "claude.exe"
|
||||||
|
}
|
||||||
|
fallback := filepath.Join(home, ".claude", "local", name)
|
||||||
|
if _, err := os.Stat(fallback); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return fallback, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Claude) Run(model string, args []string) error {
|
||||||
|
claudePath, err := c.findPath()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command(claudePath, c.args(model, args)...)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
|
||||||
|
env := append(os.Environ(),
|
||||||
|
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
|
||||||
|
"ANTHROPIC_API_KEY=",
|
||||||
|
"ANTHROPIC_AUTH_TOKEN=ollama",
|
||||||
|
"CLAUDE_CODE_ATTRIBUTION_HEADER=0",
|
||||||
|
)
|
||||||
|
|
||||||
|
env = append(env, c.modelEnvVars(model)...)
|
||||||
|
|
||||||
|
cmd.Env = env
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
|
||||||
|
func (c *Claude) modelEnvVars(model string) []string {
|
||||||
|
env := []string{
|
||||||
|
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + model,
|
||||||
|
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + model,
|
||||||
|
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + model,
|
||||||
|
"CLAUDE_CODE_SUBAGENT_MODEL=" + model,
|
||||||
|
}
|
||||||
|
|
||||||
|
if isCloudModelName(model) {
|
||||||
|
if l, ok := lookupCloudModelLimit(model); ok {
|
||||||
|
env = append(env, "CLAUDE_CODE_AUTO_COMPACT_WINDOW="+strconv.Itoa(l.Context))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return env
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
@@ -117,10 +117,7 @@ func TestClaudeModelEnvVars(t *testing.T) {
|
|||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("falls back to model param when no aliases saved", func(t *testing.T) {
|
t.Run("maps all Claude model env vars to the provided model", func(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
got := envMap(c.modelEnvVars("llama3.2"))
|
got := envMap(c.modelEnvVars("llama3.2"))
|
||||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2" {
|
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2" {
|
||||||
t.Errorf("OPUS = %q, want llama3.2", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
t.Errorf("OPUS = %q, want llama3.2", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||||
@@ -134,65 +131,41 @@ func TestClaudeModelEnvVars(t *testing.T) {
|
|||||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2" {
|
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2" {
|
||||||
t.Errorf("SUBAGENT = %q, want llama3.2", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
t.Errorf("SUBAGENT = %q, want llama3.2", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||||
}
|
}
|
||||||
})
|
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
|
||||||
|
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty for local models", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
|
||||||
t.Run("uses primary alias for opus sonnet and subagent", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
SaveIntegration("claude", []string{"qwen3:8b"})
|
|
||||||
saveAliases("claude", map[string]string{"primary": "qwen3:8b"})
|
|
||||||
|
|
||||||
got := envMap(c.modelEnvVars("qwen3:8b"))
|
|
||||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "qwen3:8b" {
|
|
||||||
t.Errorf("OPUS = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
|
||||||
}
|
|
||||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "qwen3:8b" {
|
|
||||||
t.Errorf("SONNET = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
|
||||||
}
|
|
||||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "qwen3:8b" {
|
|
||||||
t.Errorf("HAIKU = %q, want qwen3:8b (no fast alias)", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
|
||||||
}
|
|
||||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "qwen3:8b" {
|
|
||||||
t.Errorf("SUBAGENT = %q, want qwen3:8b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("uses fast alias for haiku", func(t *testing.T) {
|
t.Run("supports empty model", func(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
got := envMap(c.modelEnvVars(""))
|
||||||
setTestHome(t, tmpDir)
|
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "" {
|
||||||
|
t.Errorf("OPUS = %q, want empty", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||||
SaveIntegration("claude", []string{"llama3.2:70b"})
|
|
||||||
saveAliases("claude", map[string]string{
|
|
||||||
"primary": "llama3.2:70b",
|
|
||||||
"fast": "llama3.2:8b",
|
|
||||||
})
|
|
||||||
|
|
||||||
got := envMap(c.modelEnvVars("llama3.2:70b"))
|
|
||||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2:70b" {
|
|
||||||
t.Errorf("OPUS = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
|
||||||
}
|
}
|
||||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2:70b" {
|
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "" {
|
||||||
t.Errorf("SONNET = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
t.Errorf("SONNET = %q, want empty", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||||
}
|
}
|
||||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2:8b" {
|
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "" {
|
||||||
t.Errorf("HAIKU = %q, want llama3.2:8b", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
t.Errorf("HAIKU = %q, want empty", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||||
}
|
}
|
||||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2:70b" {
|
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "" {
|
||||||
t.Errorf("SUBAGENT = %q, want llama3.2:70b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
t.Errorf("SUBAGENT = %q, want empty", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||||
|
}
|
||||||
|
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
|
||||||
|
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("alias primary overrides model param", func(t *testing.T) {
|
t.Run("sets auto compact window for known cloud models", func(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
got := envMap(c.modelEnvVars("glm-5:cloud"))
|
||||||
setTestHome(t, tmpDir)
|
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "202752" {
|
||||||
|
t.Errorf("AUTO_COMPACT_WINDOW = %q, want 202752", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
SaveIntegration("claude", []string{"saved-model"})
|
t.Run("does not set auto compact window for unknown cloud models", func(t *testing.T) {
|
||||||
saveAliases("claude", map[string]string{"primary": "saved-model"})
|
got := envMap(c.modelEnvVars("unknown-model:cloud"))
|
||||||
|
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
|
||||||
got := envMap(c.modelEnvVars("different-model"))
|
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
|
||||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "saved-model" {
|
|
||||||
t.Errorf("OPUS = %q, want saved-model", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -1,14 +1,13 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -22,24 +21,6 @@ func (c *Cline) Run(model string, args []string) error {
|
|||||||
return fmt.Errorf("cline is not installed, install with: npm install -g cline")
|
return fmt.Errorf("cline is not installed, install with: npm install -g cline")
|
||||||
}
|
}
|
||||||
|
|
||||||
models := []string{model}
|
|
||||||
if config, err := loadIntegration("cline"); err == nil && len(config.Models) > 0 {
|
|
||||||
models = config.Models
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
models, err = resolveEditorModels("cline", models, func() ([]string, error) {
|
|
||||||
return selectModels(context.Background(), "cline", "")
|
|
||||||
})
|
|
||||||
if errors.Is(err, errCancelled) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := c.Edit(models); err != nil {
|
|
||||||
return fmt.Errorf("setup failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command("cline", args...)
|
cmd := exec.Command("cline", args...)
|
||||||
cmd.Stdin = os.Stdin
|
cmd.Stdin = os.Stdin
|
||||||
cmd.Stdout = os.Stdout
|
cmd.Stdout = os.Stdout
|
||||||
@@ -97,7 +78,7 @@ func (c *Cline) Edit(models []string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return writeWithBackup(configPath, data)
|
return fileutil.WriteWithBackup(configPath, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Cline) Models() []string {
|
func (c *Cline) Models() []string {
|
||||||
@@ -106,7 +87,7 @@ func (c *Cline) Models() []string {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := readJSONFile(filepath.Join(home, ".cline", "data", "globalState.json"))
|
config, err := fileutil.ReadJSON(filepath.Join(home, ".cline", "data", "globalState.json"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"slices"
|
"slices"
|
||||||
@@ -16,7 +16,7 @@ func TestCodexArgs(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{"with model", "llama3.2", nil, []string{"--oss", "-m", "llama3.2"}},
|
{"with model", "llama3.2", nil, []string{"--oss", "-m", "llama3.2"}},
|
||||||
{"empty model", "", nil, []string{"--oss"}},
|
{"empty model", "", nil, []string{"--oss"}},
|
||||||
{"with model and profile", "qwen3-coder", []string{"-p", "myprofile"}, []string{"--oss", "-m", "qwen3-coder", "-p", "myprofile"}},
|
{"with model and profile", "qwen3.5", []string{"-p", "myprofile"}, []string{"--oss", "-m", "qwen3.5", "-p", "myprofile"}},
|
||||||
{"with sandbox flag", "llama3.2", []string{"--sandbox", "workspace-write"}, []string{"--oss", "-m", "llama3.2", "--sandbox", "workspace-write"}},
|
{"with sandbox flag", "llama3.2", []string{"--sandbox", "workspace-write"}, []string{"--oss", "-m", "llama3.2", "--sandbox", "workspace-write"}},
|
||||||
}
|
}
|
||||||
|
|
||||||
598
cmd/launch/command_test.go
Normal file
@@ -0,0 +1,598 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/ollama/ollama/cmd/config"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
func captureStderr(t *testing.T, fn func()) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
oldStderr := os.Stderr
|
||||||
|
r, w, err := os.Pipe()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create stderr pipe: %v", err)
|
||||||
|
}
|
||||||
|
os.Stderr = w
|
||||||
|
defer func() {
|
||||||
|
os.Stderr = oldStderr
|
||||||
|
}()
|
||||||
|
|
||||||
|
done := make(chan string, 1)
|
||||||
|
go func() {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
_, _ = io.Copy(&buf, r)
|
||||||
|
done <- buf.String()
|
||||||
|
}()
|
||||||
|
|
||||||
|
fn()
|
||||||
|
|
||||||
|
_ = w.Close()
|
||||||
|
return <-done
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmd(t *testing.T) {
|
||||||
|
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
mockTUI := func(cmd *cobra.Command) {}
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
|
||||||
|
t.Run("command structure", func(t *testing.T) {
|
||||||
|
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
|
||||||
|
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]")
|
||||||
|
}
|
||||||
|
if cmd.Short == "" {
|
||||||
|
t.Error("Short description should not be empty")
|
||||||
|
}
|
||||||
|
if cmd.Long == "" {
|
||||||
|
t.Error("Long description should not be empty")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("flags exist", func(t *testing.T) {
|
||||||
|
if cmd.Flags().Lookup("model") == nil {
|
||||||
|
t.Error("--model flag should exist")
|
||||||
|
}
|
||||||
|
if cmd.Flags().Lookup("config") == nil {
|
||||||
|
t.Error("--config flag should exist")
|
||||||
|
}
|
||||||
|
if cmd.Flags().Lookup("yes") == nil {
|
||||||
|
t.Error("--yes flag should exist")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("PreRunE is set", func(t *testing.T) {
|
||||||
|
if cmd.PreRunE == nil {
|
||||||
|
t.Error("PreRunE should be set to checkServerHeartbeat")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdTUICallback(t *testing.T) {
|
||||||
|
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("no args calls TUI", func(t *testing.T) {
|
||||||
|
tuiCalled := false
|
||||||
|
mockTUI := func(cmd *cobra.Command) {
|
||||||
|
tuiCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
cmd.SetArgs([]string{})
|
||||||
|
_ = cmd.Execute()
|
||||||
|
|
||||||
|
if !tuiCalled {
|
||||||
|
t.Error("TUI callback should be called when no args provided")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("integration arg bypasses TUI", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.NotFoundHandler())
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
tuiCalled := false
|
||||||
|
mockTUI := func(cmd *cobra.Command) {
|
||||||
|
tuiCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
cmd.SetArgs([]string{"claude"})
|
||||||
|
_ = cmd.Execute()
|
||||||
|
|
||||||
|
if tuiCalled {
|
||||||
|
t.Error("TUI callback should NOT be called when integration arg provided")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("--model flag without integration returns error", func(t *testing.T) {
|
||||||
|
tuiCalled := false
|
||||||
|
mockTUI := func(cmd *cobra.Command) {
|
||||||
|
tuiCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
cmd.SetArgs([]string{"--model", "test-model"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected --model without an integration to fail")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "require an integration name") {
|
||||||
|
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if tuiCalled {
|
||||||
|
t.Error("TUI callback should NOT be called when --model is provided without an integration")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("--config flag without integration returns error", func(t *testing.T) {
|
||||||
|
tuiCalled := false
|
||||||
|
mockTUI := func(cmd *cobra.Command) {
|
||||||
|
tuiCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
cmd.SetArgs([]string{"--config"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected --config without an integration to fail")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "require an integration name") {
|
||||||
|
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if tuiCalled {
|
||||||
|
t.Error("TUI callback should NOT be called when --config is provided without an integration")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("--yes flag without integration returns error", func(t *testing.T) {
|
||||||
|
tuiCalled := false
|
||||||
|
mockTUI := func(cmd *cobra.Command) {
|
||||||
|
tuiCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
cmd.SetArgs([]string{"--yes"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected --yes without an integration to fail")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "require an integration name") {
|
||||||
|
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if tuiCalled {
|
||||||
|
t.Error("TUI callback should NOT be called when --yes is provided without an integration")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("extra args without integration return error", func(t *testing.T) {
|
||||||
|
tuiCalled := false
|
||||||
|
mockTUI := func(cmd *cobra.Command) {
|
||||||
|
tuiCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
cmd.SetArgs([]string{"--model", "test-model", "--", "--sandbox", "workspace-write"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected flags and extra args without an integration to fail")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "require an integration name") {
|
||||||
|
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if tuiCalled {
|
||||||
|
t.Error("TUI callback should NOT be called when flags or extra args are provided without an integration")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdNilHeartbeat(t *testing.T) {
|
||||||
|
cmd := LaunchCmd(nil, nil)
|
||||||
|
if cmd == nil {
|
||||||
|
t.Fatal("LaunchCmd returned nil")
|
||||||
|
}
|
||||||
|
if cmd.PreRunE != nil {
|
||||||
|
t.Log("Note: PreRunE is set even when nil is passed (acceptable)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdModelFlagFiltersDisabledCloudFromSavedConfig(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
if err := config.SaveIntegration("stubeditor", []string{"glm-5:cloud"}); err != nil {
|
||||||
|
t.Fatalf("failed to seed saved config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/status":
|
||||||
|
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprintf(w, `{"model":"llama3.2"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherEditorRunner{}
|
||||||
|
restore := OverrideIntegration("stubeditor", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
|
||||||
|
if err := cmd.Execute(); err != nil {
|
||||||
|
t.Fatalf("launch command failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
saved, err := config.LoadIntegration("stubeditor")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to reload integration config: %v", err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
|
||||||
|
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
|
||||||
|
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "llama3.2" {
|
||||||
|
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdModelFlagClearsDisabledCloudOverride(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/status":
|
||||||
|
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||||
|
case "/api/tags":
|
||||||
|
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherSingleRunner{}
|
||||||
|
restore := OverrideIntegration("stubapp", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
oldSelector := DefaultSingleSelector
|
||||||
|
defer func() { DefaultSingleSelector = oldSelector }()
|
||||||
|
|
||||||
|
var selectorCalls int
|
||||||
|
var gotCurrent string
|
||||||
|
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||||
|
selectorCalls++
|
||||||
|
gotCurrent = current
|
||||||
|
return "llama3.2", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubapp", "--model", "glm-5:cloud"})
|
||||||
|
stderr := captureStderr(t, func() {
|
||||||
|
if err := cmd.Execute(); err != nil {
|
||||||
|
t.Fatalf("launch command failed: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if selectorCalls != 1 {
|
||||||
|
t.Fatalf("expected disabled cloud override to fall back to selector, got %d calls", selectorCalls)
|
||||||
|
}
|
||||||
|
if gotCurrent != "" {
|
||||||
|
t.Fatalf("expected disabled override to be cleared before selection, got current %q", gotCurrent)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "llama3.2" {
|
||||||
|
t.Fatalf("expected launch to run with replacement local model, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
if !strings.Contains(stderr, "Warning: ignoring --model glm-5:cloud because cloud is disabled") {
|
||||||
|
t.Fatalf("expected disabled-cloud warning, got stderr: %q", stderr)
|
||||||
|
}
|
||||||
|
|
||||||
|
saved, err := config.LoadIntegration("stubapp")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to reload integration config: %v", err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
|
||||||
|
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdYes_AutoConfirmsLaunchPromptPath(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
withInteractiveSession(t, false)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||||
|
case "/api/status":
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprint(w, `{"error":"not found"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherEditorRunner{paths: []string{"/tmp/stubeditor.json"}}
|
||||||
|
restore := OverrideIntegration("stubeditor", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||||
|
t.Fatalf("unexpected prompt with --yes: %q", prompt)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2", "--yes"})
|
||||||
|
if err := cmd.Execute(); err != nil {
|
||||||
|
t.Fatalf("launch command with --yes failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
|
||||||
|
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "llama3.2" {
|
||||||
|
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdHeadlessWithYes_AutoPullsMissingLocalModel(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
withInteractiveSession(t, false)
|
||||||
|
|
||||||
|
var pullCalled bool
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprint(w, `{"error":"model not found"}`)
|
||||||
|
case "/api/pull":
|
||||||
|
pullCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprint(w, `{"status":"success"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherSingleRunner{}
|
||||||
|
restore := OverrideIntegration("stubapp", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||||
|
t.Fatalf("unexpected prompt with --yes in headless autopull path: %q", prompt)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubapp", "--model", "missing-model", "--yes"})
|
||||||
|
if err := cmd.Execute(); err != nil {
|
||||||
|
t.Fatalf("launch command with --yes failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !pullCalled {
|
||||||
|
t.Fatal("expected missing local model to be auto-pulled with --yes in headless mode")
|
||||||
|
}
|
||||||
|
if stub.ranModel != "missing-model" {
|
||||||
|
t.Fatalf("expected launch to run with pulled model, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdHeadlessWithoutYes_ReturnsActionableConfirmError(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
withInteractiveSession(t, false)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||||
|
case "/api/status":
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprint(w, `{"error":"not found"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherEditorRunner{paths: []string{"/tmp/stubeditor.json"}}
|
||||||
|
restore := OverrideIntegration("stubeditor", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||||
|
t.Fatalf("unexpected prompt in headless non-yes mode: %q", prompt)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected launch command to fail without --yes in headless mode")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "re-run with --yes") {
|
||||||
|
t.Fatalf("expected actionable --yes guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if len(stub.edited) != 0 {
|
||||||
|
t.Fatalf("expected no editor writes when confirmation is blocked, got %v", stub.edited)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "" {
|
||||||
|
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdIntegrationArgPromptsForModelWithSavedSelection(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
if err := config.SaveIntegration("stubapp", []string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatalf("failed to seed saved config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/tags":
|
||||||
|
fmt.Fprint(w, `{"models":[{"name":"llama3.2"},{"name":"qwen3:8b"}]}`)
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model":"qwen3:8b"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherSingleRunner{}
|
||||||
|
restore := OverrideIntegration("stubapp", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
oldSelector := DefaultSingleSelector
|
||||||
|
defer func() { DefaultSingleSelector = oldSelector }()
|
||||||
|
|
||||||
|
var gotCurrent string
|
||||||
|
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||||
|
gotCurrent = current
|
||||||
|
return "qwen3:8b", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubapp"})
|
||||||
|
if err := cmd.Execute(); err != nil {
|
||||||
|
t.Fatalf("launch command failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotCurrent != "llama3.2" {
|
||||||
|
t.Fatalf("expected selector current model to be saved model llama3.2, got %q", gotCurrent)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "qwen3:8b" {
|
||||||
|
t.Fatalf("expected launch to run selected model qwen3:8b, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
saved, err := config.LoadIntegration("stubapp")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to reload integration config: %v", err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff([]string{"qwen3:8b"}, saved.Models); diff != "" {
|
||||||
|
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdHeadlessYes_IntegrationRequiresModelEvenWhenSaved(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
withInteractiveSession(t, false)
|
||||||
|
|
||||||
|
if err := config.SaveIntegration("stubapp", []string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatalf("failed to seed saved config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherSingleRunner{}
|
||||||
|
restore := OverrideIntegration("stubapp", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
oldSelector := DefaultSingleSelector
|
||||||
|
defer func() { DefaultSingleSelector = oldSelector }()
|
||||||
|
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||||
|
t.Fatal("selector should not be called for headless --yes saved-model launch")
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubapp", "--yes"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected launch command to fail when --yes is used headlessly without --model")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "requires --model <model>") {
|
||||||
|
t.Fatalf("expected actionable --model guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "" {
|
||||||
|
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdHeadlessYes_IntegrationWithoutSavedModelReturnsError(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
withInteractiveSession(t, false)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherSingleRunner{}
|
||||||
|
restore := OverrideIntegration("stubapp", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
oldSelector := DefaultSingleSelector
|
||||||
|
defer func() { DefaultSingleSelector = oldSelector }()
|
||||||
|
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||||
|
t.Fatal("selector should not be called for headless --yes without saved model")
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubapp", "--yes"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected launch command to fail when --yes is used headlessly without --model")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "requires --model <model>") {
|
||||||
|
t.Fatalf("expected actionable --model guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "" {
|
||||||
|
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,16 +1,14 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -47,25 +45,6 @@ func (d *Droid) Run(model string, args []string) error {
|
|||||||
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
|
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call Edit() to ensure config is up-to-date before launch
|
|
||||||
models := []string{model}
|
|
||||||
if config, err := loadIntegration("droid"); err == nil && len(config.Models) > 0 {
|
|
||||||
models = config.Models
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
models, err = resolveEditorModels("droid", models, func() ([]string, error) {
|
|
||||||
return selectModels(context.Background(), "droid", "")
|
|
||||||
})
|
|
||||||
if errors.Is(err, errCancelled) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := d.Edit(models); err != nil {
|
|
||||||
return fmt.Errorf("setup failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command("droid", args...)
|
cmd := exec.Command("droid", args...)
|
||||||
cmd.Stdin = os.Stdin
|
cmd.Stdin = os.Stdin
|
||||||
cmd.Stdout = os.Stdout
|
cmd.Stdout = os.Stdout
|
||||||
@@ -111,6 +90,16 @@ func (d *Droid) Edit(models []string) error {
|
|||||||
json.Unmarshal(data, &settings) // ignore error, zero values are fine
|
json.Unmarshal(data, &settings) // ignore error, zero values are fine
|
||||||
}
|
}
|
||||||
|
|
||||||
|
settingsMap = updateDroidSettings(settingsMap, settings, models)
|
||||||
|
|
||||||
|
data, err := json.MarshalIndent(settingsMap, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return fileutil.WriteWithBackup(settingsPath, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateDroidSettings(settingsMap map[string]any, settings droidSettings, models []string) map[string]any {
|
||||||
// Keep only non-Ollama models from the raw map (preserves extra fields)
|
// Keep only non-Ollama models from the raw map (preserves extra fields)
|
||||||
// Rebuild Ollama models
|
// Rebuild Ollama models
|
||||||
var nonOllamaModels []any
|
var nonOllamaModels []any
|
||||||
@@ -125,13 +114,12 @@ func (d *Droid) Edit(models []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
|
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
|
||||||
client, _ := api.ClientFromEnvironment()
|
|
||||||
|
|
||||||
var newModels []any
|
var newModels []any
|
||||||
var defaultModelID string
|
var defaultModelID string
|
||||||
for i, model := range models {
|
for i, model := range models {
|
||||||
maxOutput := 64000
|
maxOutput := 64000
|
||||||
if isCloudModel(context.Background(), client, model) {
|
if isCloudModelName(model) {
|
||||||
if l, ok := lookupCloudModelLimit(model); ok {
|
if l, ok := lookupCloudModelLimit(model); ok {
|
||||||
maxOutput = l.Output
|
maxOutput = l.Output
|
||||||
}
|
}
|
||||||
@@ -167,12 +155,7 @@ func (d *Droid) Edit(models []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
settingsMap["sessionDefaultSettings"] = sessionSettings
|
settingsMap["sessionDefaultSettings"] = sessionSettings
|
||||||
|
return settingsMap
|
||||||
data, err := json.MarshalIndent(settingsMap, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return writeWithBackup(settingsPath, data)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Droid) Models() []string {
|
func (d *Droid) Models() []string {
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -6,6 +6,8 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDroidIntegration(t *testing.T) {
|
func TestDroidIntegration(t *testing.T) {
|
||||||
@@ -362,7 +364,7 @@ func TestDroidEdit_DuplicateModels(t *testing.T) {
|
|||||||
t.Fatalf("Edit with duplicates failed: %v", err)
|
t.Fatalf("Edit with duplicates failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := readJSONFile(settingsPath)
|
settings, err := fileutil.ReadJSON(settingsPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("readJSONFile failed: %v", err)
|
t.Fatalf("readJSONFile failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -392,7 +394,7 @@ func TestDroidEdit_MalformedModelEntry(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Malformed entries (non-object) are dropped - only valid model objects are preserved
|
// Malformed entries (non-object) are dropped - only valid model objects are preserved
|
||||||
settings, _ := readJSONFile(settingsPath)
|
settings, _ := fileutil.ReadJSON(settingsPath)
|
||||||
customModels, _ := settings["customModels"].([]any)
|
customModels, _ := settings["customModels"].([]any)
|
||||||
|
|
||||||
// Should have: 1 new Ollama model only (malformed entries dropped)
|
// Should have: 1 new Ollama model only (malformed entries dropped)
|
||||||
@@ -419,7 +421,7 @@ func TestDroidEdit_WrongTypeSessionSettings(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Should create proper sessionDefaultSettings
|
// Should create proper sessionDefaultSettings
|
||||||
settings, _ := readJSONFile(settingsPath)
|
settings, _ := fileutil.ReadJSON(settingsPath)
|
||||||
session, ok := settings["sessionDefaultSettings"].(map[string]any)
|
session, ok := settings["sessionDefaultSettings"].(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("sessionDefaultSettings should be map after setup, got %T", settings["sessionDefaultSettings"])
|
t.Fatalf("sessionDefaultSettings should be map after setup, got %T", settings["sessionDefaultSettings"])
|
||||||
@@ -1008,34 +1010,34 @@ func TestDroidEdit_ModelNamesWithSpecialCharacters(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDroidEdit_MissingCustomModelsKey(t *testing.T) {
|
func TestDroidEdit_MissingCustomModelsKey(t *testing.T) {
|
||||||
d := &Droid{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
settingsDir := filepath.Join(tmpDir, ".factory")
|
|
||||||
settingsPath := filepath.Join(settingsDir, "settings.json")
|
|
||||||
|
|
||||||
os.MkdirAll(settingsDir, 0o755)
|
|
||||||
|
|
||||||
// No customModels key at all
|
// No customModels key at all
|
||||||
original := `{
|
original := `{
|
||||||
"diffMode": "github",
|
"diffMode": "github",
|
||||||
"sessionDefaultSettings": {"autonomyMode": "auto-high"}
|
"sessionDefaultSettings": {"autonomyMode": "auto-high"}
|
||||||
}`
|
}`
|
||||||
os.WriteFile(settingsPath, []byte(original), 0o644)
|
|
||||||
|
|
||||||
if err := d.Edit([]string{"model-a"}); err != nil {
|
var settingsStruct droidSettings
|
||||||
|
var settings map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(original), &settings); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(original), &settingsStruct); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
data, _ := os.ReadFile(settingsPath)
|
settings = updateDroidSettings(settings, settingsStruct, []string{"model-a"})
|
||||||
var settings map[string]any
|
|
||||||
json.Unmarshal(data, &settings)
|
|
||||||
|
|
||||||
// Original fields preserved
|
// Original fields preserved
|
||||||
if settings["diffMode"] != "github" {
|
if settings["diffMode"] != "github" {
|
||||||
t.Error("diffMode not preserved")
|
t.Error("diffMode not preserved")
|
||||||
}
|
}
|
||||||
|
session, ok := settings["sessionDefaultSettings"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("sessionDefaultSettings not preserved")
|
||||||
|
}
|
||||||
|
if session["autonomyMode"] != "auto-high" {
|
||||||
|
t.Error("sessionDefaultSettings.autonomyMode not preserved")
|
||||||
|
}
|
||||||
|
|
||||||
// customModels created
|
// customModels created
|
||||||
models, ok := settings["customModels"].([]any)
|
models, ok := settings["customModels"].([]any)
|
||||||
@@ -1276,25 +1278,17 @@ func TestDroidEdit_LocalModelDefaultMaxOutput(t *testing.T) {
|
|||||||
|
|
||||||
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
|
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
|
||||||
// Verify that every cloud model in cloudModelLimits has a valid output
|
// Verify that every cloud model in cloudModelLimits has a valid output
|
||||||
// value that would be used for maxOutputTokens when isCloudModel returns true.
|
// value that would be used for maxOutputTokens when the selected model uses
|
||||||
// :cloud suffix stripping must also work since that's how users specify them.
|
// the explicit :cloud source tag.
|
||||||
for name, expected := range cloudModelLimits {
|
for name, expected := range cloudModelLimits {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
l, ok := lookupCloudModelLimit(name)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("lookupCloudModelLimit(%q) returned false", name)
|
|
||||||
}
|
|
||||||
if l.Output != expected.Output {
|
|
||||||
t.Errorf("output = %d, want %d", l.Output, expected.Output)
|
|
||||||
}
|
|
||||||
// Also verify :cloud suffix lookup
|
|
||||||
cloudName := name + ":cloud"
|
cloudName := name + ":cloud"
|
||||||
l2, ok := lookupCloudModelLimit(cloudName)
|
l, ok := lookupCloudModelLimit(cloudName)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName)
|
t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName)
|
||||||
}
|
}
|
||||||
if l2.Output != expected.Output {
|
if l.Output != expected.Output {
|
||||||
t.Errorf(":cloud output = %d, want %d", l2.Output, expected.Output)
|
t.Errorf("output = %d, want %d", l.Output, expected.Output)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
842
cmd/launch/launch.go
Normal file
@@ -0,0 +1,842 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/cmd/config"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"golang.org/x/term"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LauncherState is the launch-owned snapshot used to render the root launcher menu.
|
||||||
|
type LauncherState struct {
|
||||||
|
LastSelection string
|
||||||
|
RunModel string
|
||||||
|
RunModelUsable bool
|
||||||
|
Integrations map[string]LauncherIntegrationState
|
||||||
|
}
|
||||||
|
|
||||||
|
// LauncherIntegrationState is the launch-owned status for one launcher integration.
|
||||||
|
type LauncherIntegrationState struct {
|
||||||
|
Name string
|
||||||
|
DisplayName string
|
||||||
|
Description string
|
||||||
|
Installed bool
|
||||||
|
AutoInstallable bool
|
||||||
|
Selectable bool
|
||||||
|
Changeable bool
|
||||||
|
CurrentModel string
|
||||||
|
ModelUsable bool
|
||||||
|
InstallHint string
|
||||||
|
Editor bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunModelRequest controls how the root launcher resolves the chat model.
|
||||||
|
type RunModelRequest struct {
|
||||||
|
ForcePicker bool
|
||||||
|
Policy *LaunchPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
// LaunchConfirmMode controls confirmation behavior across launch flows.
|
||||||
|
type LaunchConfirmMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// LaunchConfirmPrompt prompts the user for confirmation.
|
||||||
|
LaunchConfirmPrompt LaunchConfirmMode = iota
|
||||||
|
// LaunchConfirmAutoApprove skips prompts and treats confirmation as accepted.
|
||||||
|
LaunchConfirmAutoApprove
|
||||||
|
// LaunchConfirmRequireYes rejects confirmation requests with a --yes hint.
|
||||||
|
LaunchConfirmRequireYes
|
||||||
|
)
|
||||||
|
|
||||||
|
// LaunchMissingModelMode controls local missing-model handling in launch flows.
|
||||||
|
type LaunchMissingModelMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// LaunchMissingModelPromptToPull prompts to pull a missing local model.
|
||||||
|
LaunchMissingModelPromptToPull LaunchMissingModelMode = iota
|
||||||
|
// LaunchMissingModelAutoPull pulls a missing local model without prompting.
|
||||||
|
LaunchMissingModelAutoPull
|
||||||
|
// LaunchMissingModelFail fails immediately when a local model is missing.
|
||||||
|
LaunchMissingModelFail
|
||||||
|
)
|
||||||
|
|
||||||
|
// LaunchPolicy controls launch behavior that may vary by caller context.
|
||||||
|
type LaunchPolicy struct {
|
||||||
|
Confirm LaunchConfirmMode
|
||||||
|
MissingModel LaunchMissingModelMode
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultLaunchPolicy(interactive bool, yes bool) LaunchPolicy {
|
||||||
|
policy := LaunchPolicy{
|
||||||
|
Confirm: LaunchConfirmPrompt,
|
||||||
|
MissingModel: LaunchMissingModelPromptToPull,
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case yes:
|
||||||
|
// if yes flag is set, auto approve and auto pull
|
||||||
|
policy.Confirm = LaunchConfirmAutoApprove
|
||||||
|
policy.MissingModel = LaunchMissingModelAutoPull
|
||||||
|
case !interactive:
|
||||||
|
// otherwise make sure to stop when needed
|
||||||
|
policy.Confirm = LaunchConfirmRequireYes
|
||||||
|
policy.MissingModel = LaunchMissingModelFail
|
||||||
|
}
|
||||||
|
return policy
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p LaunchPolicy) confirmPolicy() launchConfirmPolicy {
|
||||||
|
switch p.Confirm {
|
||||||
|
case LaunchConfirmAutoApprove:
|
||||||
|
return launchConfirmPolicy{yes: true}
|
||||||
|
case LaunchConfirmRequireYes:
|
||||||
|
return launchConfirmPolicy{requireYesMessage: true}
|
||||||
|
default:
|
||||||
|
return launchConfirmPolicy{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p LaunchPolicy) missingModelPolicy() missingModelPolicy {
|
||||||
|
switch p.MissingModel {
|
||||||
|
case LaunchMissingModelAutoPull:
|
||||||
|
return missingModelAutoPull
|
||||||
|
case LaunchMissingModelFail:
|
||||||
|
return missingModelFail
|
||||||
|
default:
|
||||||
|
return missingModelPromptPull
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntegrationLaunchRequest controls the canonical integration launcher flow.
|
||||||
|
type IntegrationLaunchRequest struct {
|
||||||
|
Name string
|
||||||
|
ModelOverride string
|
||||||
|
ForceConfigure bool
|
||||||
|
ConfigureOnly bool
|
||||||
|
ExtraArgs []string
|
||||||
|
Policy *LaunchPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
var isInteractiveSession = func() bool {
|
||||||
|
return term.IsTerminal(int(os.Stdin.Fd())) && term.IsTerminal(int(os.Stdout.Fd()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Runner executes a model with an integration.
|
||||||
|
type Runner interface {
|
||||||
|
Run(model string, args []string) error
|
||||||
|
String() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Editor can edit config files for integrations that support model configuration.
|
||||||
|
type Editor interface {
|
||||||
|
Paths() []string
|
||||||
|
Edit(models []string) error
|
||||||
|
Models() []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type modelInfo struct {
|
||||||
|
Name string
|
||||||
|
Remote bool
|
||||||
|
ToolCapable bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelInfo re-exports launcher model inventory details for callers.
|
||||||
|
type ModelInfo = modelInfo
|
||||||
|
|
||||||
|
// ModelItem represents a model for selection UIs.
|
||||||
|
type ModelItem struct {
|
||||||
|
Name string
|
||||||
|
Description string
|
||||||
|
Recommended bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// LaunchCmd returns the cobra command for launching integrations.
|
||||||
|
// The runTUI callback is called when the root launcher UI should be shown.
|
||||||
|
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error, runTUI func(cmd *cobra.Command)) *cobra.Command {
|
||||||
|
var modelFlag string
|
||||||
|
var configFlag bool
|
||||||
|
var yesFlag bool
|
||||||
|
|
||||||
|
cmd := &cobra.Command{
|
||||||
|
Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
|
||||||
|
Short: "Launch the Ollama menu or an integration",
|
||||||
|
Long: `Launch the Ollama interactive menu, or directly launch a specific integration.
|
||||||
|
|
||||||
|
Without arguments, this is equivalent to running 'ollama' directly.
|
||||||
|
Flags and extra arguments require an integration name.
|
||||||
|
|
||||||
|
Supported integrations:
|
||||||
|
claude Claude Code
|
||||||
|
cline Cline
|
||||||
|
codex Codex
|
||||||
|
droid Droid
|
||||||
|
opencode OpenCode
|
||||||
|
openclaw OpenClaw (aliases: clawdbot, moltbot)
|
||||||
|
pi Pi
|
||||||
|
vscode VS Code (aliases: code)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
ollama launch
|
||||||
|
ollama launch claude
|
||||||
|
ollama launch claude --model <model>
|
||||||
|
ollama launch droid --config (does not auto-launch)
|
||||||
|
ollama launch codex -- -p myprofile (pass extra args to integration)
|
||||||
|
ollama launch codex -- --sandbox workspace-write`,
|
||||||
|
Args: cobra.ArbitraryArgs,
|
||||||
|
PreRunE: checkServerHeartbeat,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
policy := defaultLaunchPolicy(isInteractiveSession(), yesFlag)
|
||||||
|
// reset when done to make sure state doens't leak between launches
|
||||||
|
restoreConfirmPolicy := withLaunchConfirmPolicy(policy.confirmPolicy())
|
||||||
|
defer restoreConfirmPolicy()
|
||||||
|
|
||||||
|
var name string
|
||||||
|
var passArgs []string
|
||||||
|
dashIdx := cmd.ArgsLenAtDash()
|
||||||
|
|
||||||
|
if dashIdx == -1 {
|
||||||
|
if len(args) > 1 {
|
||||||
|
return fmt.Errorf("unexpected arguments: %v\nUse '--' to pass extra arguments to the integration", args[1:])
|
||||||
|
}
|
||||||
|
if len(args) == 1 {
|
||||||
|
name = args[0]
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if dashIdx > 1 {
|
||||||
|
return fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx)
|
||||||
|
}
|
||||||
|
if dashIdx == 1 {
|
||||||
|
name = args[0]
|
||||||
|
}
|
||||||
|
passArgs = args[dashIdx:]
|
||||||
|
}
|
||||||
|
|
||||||
|
if name == "" {
|
||||||
|
if cmd.Flags().Changed("model") || cmd.Flags().Changed("config") || cmd.Flags().Changed("yes") || len(passArgs) > 0 {
|
||||||
|
return fmt.Errorf("flags and extra args require an integration name, for example: 'ollama launch claude --model qwen3.5'")
|
||||||
|
}
|
||||||
|
runTUI(cmd)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelFlag != "" && isCloudModelName(modelFlag) {
|
||||||
|
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||||
|
if disabled, _ := cloudStatusDisabled(cmd.Context(), client); disabled {
|
||||||
|
fmt.Fprintf(os.Stderr, "Warning: ignoring --model %s because cloud is disabled\n", modelFlag)
|
||||||
|
modelFlag = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
headlessYes := yesFlag && !isInteractiveSession()
|
||||||
|
err := LaunchIntegration(cmd.Context(), IntegrationLaunchRequest{
|
||||||
|
Name: name,
|
||||||
|
ModelOverride: modelFlag,
|
||||||
|
ForceConfigure: configFlag || (modelFlag == "" && !headlessYes),
|
||||||
|
ConfigureOnly: configFlag,
|
||||||
|
ExtraArgs: passArgs,
|
||||||
|
Policy: &policy,
|
||||||
|
})
|
||||||
|
if errors.Is(err, ErrCancelled) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Flags().StringVar(&modelFlag, "model", "", "Model to use")
|
||||||
|
cmd.Flags().BoolVar(&configFlag, "config", false, "Configure without launching")
|
||||||
|
cmd.Flags().BoolVarP(&yesFlag, "yes", "y", false, "Automatically answer yes to confirmation prompts")
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
type launcherClient struct {
|
||||||
|
apiClient *api.Client
|
||||||
|
modelInventory []ModelInfo
|
||||||
|
inventoryLoaded bool
|
||||||
|
policy LaunchPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLauncherClient(policy LaunchPolicy) (*launcherClient, error) {
|
||||||
|
apiClient, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &launcherClient{
|
||||||
|
apiClient: apiClient,
|
||||||
|
policy: policy,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildLauncherState returns the launch-owned root launcher menu snapshot.
|
||||||
|
func BuildLauncherState(ctx context.Context) (*LauncherState, error) {
|
||||||
|
launchClient, err := newLauncherClient(defaultLaunchPolicy(isInteractiveSession(), false))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return launchClient.buildLauncherState(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveRunModel returns the model that should be used for interactive chat.
|
||||||
|
func ResolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
|
||||||
|
// Called by the launcher TUI "Run a model" action (cmd/runLauncherAction),
|
||||||
|
// which resolves models separately from LaunchIntegration. Callers can pass
|
||||||
|
// Policy directly; otherwise we fall back to ambient --yes/session defaults.
|
||||||
|
policy := defaultLaunchPolicy(isInteractiveSession(), currentLaunchConfirmPolicy.yes)
|
||||||
|
if req.Policy != nil {
|
||||||
|
policy = *req.Policy
|
||||||
|
}
|
||||||
|
|
||||||
|
launchClient, err := newLauncherClient(policy)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return launchClient.resolveRunModel(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LaunchIntegration runs the canonical launcher flow for one integration.
|
||||||
|
func LaunchIntegration(ctx context.Context, req IntegrationLaunchRequest) error {
|
||||||
|
name, runner, err := LookupIntegration(req.Name)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !req.ConfigureOnly {
|
||||||
|
if err := EnsureIntegrationInstalled(name, runner); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var policy LaunchPolicy
|
||||||
|
// TUI does not set a policy, whereas ollama launch <app> does as it can have flags which change the behavior
|
||||||
|
if req.Policy == nil {
|
||||||
|
policy = defaultLaunchPolicy(isInteractiveSession(), false)
|
||||||
|
} else {
|
||||||
|
policy = *req.Policy
|
||||||
|
}
|
||||||
|
|
||||||
|
launchClient, err := newLauncherClient(policy)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
saved, _ := loadStoredIntegrationConfig(name)
|
||||||
|
// In headless --yes mode we cannot prompt, so require an explicit --model.
|
||||||
|
if policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() && req.ModelOverride == "" {
|
||||||
|
return fmt.Errorf("headless --yes launch for %s requires --model <model>", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if editor, ok := runner.(Editor); ok {
|
||||||
|
return launchClient.launchEditorIntegration(ctx, name, runner, editor, saved, req)
|
||||||
|
}
|
||||||
|
return launchClient.launchSingleIntegration(ctx, name, runner, saved, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) buildLauncherState(ctx context.Context) (*LauncherState, error) {
|
||||||
|
_ = c.loadModelInventoryOnce(ctx)
|
||||||
|
|
||||||
|
state := &LauncherState{
|
||||||
|
LastSelection: config.LastSelection(),
|
||||||
|
RunModel: config.LastModel(),
|
||||||
|
Integrations: make(map[string]LauncherIntegrationState),
|
||||||
|
}
|
||||||
|
runModelUsable, err := c.savedModelUsable(ctx, state.RunModel)
|
||||||
|
if err != nil {
|
||||||
|
runModelUsable = false
|
||||||
|
}
|
||||||
|
state.RunModelUsable = runModelUsable
|
||||||
|
|
||||||
|
for _, info := range ListIntegrationInfos() {
|
||||||
|
integrationState, err := c.buildLauncherIntegrationState(ctx, info)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
state.Integrations[info.Name] = integrationState
|
||||||
|
}
|
||||||
|
|
||||||
|
return state, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) buildLauncherIntegrationState(ctx context.Context, info IntegrationInfo) (LauncherIntegrationState, error) {
|
||||||
|
integration, err := integrationFor(info.Name)
|
||||||
|
if err != nil {
|
||||||
|
return LauncherIntegrationState{}, err
|
||||||
|
}
|
||||||
|
currentModel, usable, err := c.launcherModelState(ctx, info.Name, integration.editor)
|
||||||
|
if err != nil {
|
||||||
|
return LauncherIntegrationState{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return LauncherIntegrationState{
|
||||||
|
Name: info.Name,
|
||||||
|
DisplayName: info.DisplayName,
|
||||||
|
Description: info.Description,
|
||||||
|
Installed: integration.installed,
|
||||||
|
AutoInstallable: integration.autoInstallable,
|
||||||
|
Selectable: integration.installed || integration.autoInstallable,
|
||||||
|
Changeable: integration.installed || integration.autoInstallable,
|
||||||
|
CurrentModel: currentModel,
|
||||||
|
ModelUsable: usable,
|
||||||
|
InstallHint: integration.installHint,
|
||||||
|
Editor: integration.editor,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) launcherModelState(ctx context.Context, name string, isEditor bool) (string, bool, error) {
|
||||||
|
cfg, loadErr := loadStoredIntegrationConfig(name)
|
||||||
|
hasModels := loadErr == nil && len(cfg.Models) > 0
|
||||||
|
if !hasModels {
|
||||||
|
return "", false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if isEditor {
|
||||||
|
filtered := c.filterDisabledCloudModels(ctx, cfg.Models)
|
||||||
|
if len(filtered) > 0 {
|
||||||
|
return filtered[0], true, nil
|
||||||
|
}
|
||||||
|
return cfg.Models[0], false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
model := cfg.Models[0]
|
||||||
|
usable, usableErr := c.savedModelUsable(ctx, model)
|
||||||
|
return model, usableErr == nil && usable, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
|
||||||
|
current := config.LastModel()
|
||||||
|
if !req.ForcePicker && current != "" && c.policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() {
|
||||||
|
if err := c.ensureModelsReady(ctx, []string{current}); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "Headless mode: auto-selected last used model %q\n", current)
|
||||||
|
return current, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !req.ForcePicker {
|
||||||
|
usable, err := c.savedModelUsable(ctx, current)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if usable {
|
||||||
|
if err := c.ensureModelsReady(ctx, []string{current}); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return current, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
model, err := c.selectSingleModelWithSelector(ctx, "Select model to run:", current, DefaultSingleSelector)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if model != current {
|
||||||
|
if err := config.SetLastModel(model); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return model, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) launchSingleIntegration(ctx context.Context, name string, runner Runner, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
|
||||||
|
current := primaryModelFromConfig(saved)
|
||||||
|
target := req.ModelOverride
|
||||||
|
needsConfigure := req.ForceConfigure
|
||||||
|
|
||||||
|
if target == "" {
|
||||||
|
target = current
|
||||||
|
usable, err := c.savedModelUsable(ctx, target)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !usable {
|
||||||
|
needsConfigure = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if needsConfigure {
|
||||||
|
selected, err := c.selectSingleModelWithSelector(ctx, fmt.Sprintf("Select model for %s:", runner), target, DefaultSingleSelector)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
target = selected
|
||||||
|
} else if err := c.ensureModelsReady(ctx, []string{target}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if target == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := lowContextLength(ctx, c.apiClient, []string{target}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if target != current {
|
||||||
|
if err := config.SaveIntegration(name, []string{target}); err != nil {
|
||||||
|
return fmt.Errorf("failed to save: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return launchAfterConfiguration(name, runner, target, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) launchEditorIntegration(ctx context.Context, name string, runner Runner, editor Editor, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
|
||||||
|
models, needsConfigure := c.resolveEditorLaunchModels(ctx, saved, req)
|
||||||
|
|
||||||
|
if needsConfigure {
|
||||||
|
selected, err := c.selectMultiModelsForIntegration(ctx, runner, models)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
models = selected
|
||||||
|
} else if err := c.ensureModelsReady(ctx, models); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := lowContextLength(ctx, c.apiClient, models); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if needsConfigure || req.ModelOverride != "" {
|
||||||
|
if err := prepareEditorIntegration(name, runner, editor, models); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return launchAfterConfiguration(name, runner, models[0], req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) selectSingleModelWithSelector(ctx context.Context, title, current string, selector SingleSelector) (string, error) {
|
||||||
|
if selector == nil {
|
||||||
|
return "", fmt.Errorf("no selector configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
items, _, err := c.loadSelectableModels(ctx, nil, current, "no models available, run 'ollama pull <model>' first")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
selected, err := selector(title, items, current)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := c.ensureModelsReady(ctx, []string{selected}); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return selected, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) selectMultiModelsForIntegration(ctx context.Context, runner Runner, preChecked []string) ([]string, error) {
|
||||||
|
if DefaultMultiSelector == nil {
|
||||||
|
return nil, fmt.Errorf("no selector configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
current := firstModel(preChecked)
|
||||||
|
|
||||||
|
items, orderedChecked, err := c.loadSelectableModels(ctx, preChecked, current, "no models available")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(preChecked) > 0 {
|
||||||
|
// Keep list order stable in multi-select even when there are existing checks.
|
||||||
|
// checked/default state still comes from orderedChecked.
|
||||||
|
stableItems, _, stableErr := c.loadSelectableModels(ctx, nil, current, "no models available")
|
||||||
|
if stableErr != nil {
|
||||||
|
return nil, stableErr
|
||||||
|
}
|
||||||
|
items = stableItems
|
||||||
|
}
|
||||||
|
|
||||||
|
selected, err := DefaultMultiSelector(fmt.Sprintf("Select models for %s:", runner), items, orderedChecked)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := c.ensureModelsReady(ctx, selected); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return selected, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) loadSelectableModels(ctx context.Context, preChecked []string, current, emptyMessage string) ([]ModelItem, []string, error) {
|
||||||
|
if err := c.loadModelInventoryOnce(ctx); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||||
|
items, orderedChecked, _, _ := buildModelList(c.modelInventory, preChecked, current)
|
||||||
|
if cloudDisabled {
|
||||||
|
items = filterCloudItems(items)
|
||||||
|
orderedChecked = c.filterDisabledCloudModels(ctx, orderedChecked)
|
||||||
|
}
|
||||||
|
if len(items) == 0 {
|
||||||
|
return nil, nil, errors.New(emptyMessage)
|
||||||
|
}
|
||||||
|
return items, orderedChecked, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) ensureModelsReady(ctx context.Context, models []string) error {
|
||||||
|
var deduped []string
|
||||||
|
seen := make(map[string]bool, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
if model == "" || seen[model] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[model] = true
|
||||||
|
deduped = append(deduped, model)
|
||||||
|
}
|
||||||
|
models = deduped
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cloudModels := make(map[string]bool, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
isCloudModel := isCloudModelName(model)
|
||||||
|
if isCloudModel {
|
||||||
|
cloudModels[model] = true
|
||||||
|
}
|
||||||
|
if err := showOrPullWithPolicy(ctx, c.apiClient, model, c.policy.missingModelPolicy(), isCloudModel); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ensureAuth(ctx, c.apiClient, cloudModels, models)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) resolveEditorLaunchModels(ctx context.Context, saved *config.IntegrationConfig, req IntegrationLaunchRequest) ([]string, bool) {
|
||||||
|
if req.ForceConfigure {
|
||||||
|
return editorPreCheckedModels(saved, req.ModelOverride), true
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.ModelOverride != "" {
|
||||||
|
models := append([]string{req.ModelOverride}, additionalSavedModels(saved, req.ModelOverride)...)
|
||||||
|
models = c.filterDisabledCloudModels(ctx, models)
|
||||||
|
return models, len(models) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if saved == nil || len(saved.Models) == 0 {
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
|
||||||
|
models := c.filterDisabledCloudModels(ctx, saved.Models)
|
||||||
|
return models, len(models) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) filterDisabledCloudModels(ctx context.Context, models []string) []string {
|
||||||
|
// if connection cannot be established or there is a 404, cloud models will continue to be displayed
|
||||||
|
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||||
|
if !cloudDisabled {
|
||||||
|
return append([]string(nil), models...)
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := make([]string, 0, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
if !isCloudModelName(model) {
|
||||||
|
filtered = append(filtered, model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) savedModelUsable(ctx context.Context, name string) (bool, error) {
|
||||||
|
if err := c.loadModelInventoryOnce(ctx); err != nil {
|
||||||
|
return c.showBasedModelUsable(ctx, name)
|
||||||
|
}
|
||||||
|
return c.singleModelUsable(ctx, name), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) showBasedModelUsable(ctx context.Context, name string) (bool, error) {
|
||||||
|
if name == "" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := c.apiClient.Show(ctx, &api.ShowRequest{Model: name})
|
||||||
|
if err != nil {
|
||||||
|
var statusErr api.StatusError
|
||||||
|
if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if isCloudModelName(name) || info.RemoteModel != "" {
|
||||||
|
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||||
|
|
||||||
|
return !cloudDisabled, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) singleModelUsable(ctx context.Context, name string) bool {
|
||||||
|
if name == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if isCloudModelName(name) {
|
||||||
|
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||||
|
return !cloudDisabled
|
||||||
|
}
|
||||||
|
return c.hasLocalModel(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) hasLocalModel(name string) bool {
|
||||||
|
for _, model := range c.modelInventory {
|
||||||
|
if model.Remote {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if model.Name == name || strings.HasPrefix(model.Name, name+":") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) loadModelInventoryOnce(ctx context.Context) error {
|
||||||
|
if c.inventoryLoaded {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.apiClient.List(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.modelInventory = c.modelInventory[:0]
|
||||||
|
for _, model := range resp.Models {
|
||||||
|
c.modelInventory = append(c.modelInventory, ModelInfo{
|
||||||
|
Name: model.Name,
|
||||||
|
Remote: model.RemoteModel != "",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||||
|
if cloudDisabled {
|
||||||
|
c.modelInventory = filterCloudModels(c.modelInventory)
|
||||||
|
}
|
||||||
|
c.inventoryLoaded = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runIntegration(runner Runner, modelName string, args []string) error {
|
||||||
|
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", runner, modelName)
|
||||||
|
return runner.Run(modelName, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
func launchAfterConfiguration(name string, runner Runner, model string, req IntegrationLaunchRequest) error {
|
||||||
|
if req.ConfigureOnly {
|
||||||
|
launch, err := ConfirmPrompt(fmt.Sprintf("Launch %s now?", runner))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !launch {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := EnsureIntegrationInstalled(name, runner); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return runIntegration(runner, model, req.ExtraArgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadStoredIntegrationConfig(name string) (*config.IntegrationConfig, error) {
|
||||||
|
cfg, err := config.LoadIntegration(name)
|
||||||
|
if err == nil {
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
if !errors.Is(err, os.ErrNotExist) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
spec, specErr := LookupIntegrationSpec(name)
|
||||||
|
if specErr != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, alias := range spec.Aliases {
|
||||||
|
legacy, legacyErr := config.LoadIntegration(alias)
|
||||||
|
if legacyErr == nil {
|
||||||
|
migrateLegacyIntegrationConfig(spec.Name, legacy)
|
||||||
|
if migrated, migratedErr := config.LoadIntegration(spec.Name); migratedErr == nil {
|
||||||
|
return migrated, nil
|
||||||
|
}
|
||||||
|
return legacy, nil
|
||||||
|
}
|
||||||
|
if legacyErr != nil && !errors.Is(legacyErr, os.ErrNotExist) {
|
||||||
|
return nil, legacyErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func migrateLegacyIntegrationConfig(canonical string, legacy *config.IntegrationConfig) {
|
||||||
|
if legacy == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = config.SaveIntegration(canonical, append([]string(nil), legacy.Models...))
|
||||||
|
if len(legacy.Aliases) > 0 {
|
||||||
|
_ = config.SaveAliases(canonical, cloneAliases(legacy.Aliases))
|
||||||
|
}
|
||||||
|
if legacy.Onboarded {
|
||||||
|
_ = config.MarkIntegrationOnboarded(canonical)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func primaryModelFromConfig(cfg *config.IntegrationConfig) string {
|
||||||
|
if cfg == nil || len(cfg.Models) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return cfg.Models[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneAliases(aliases map[string]string) map[string]string {
|
||||||
|
if len(aliases) == 0 {
|
||||||
|
return make(map[string]string)
|
||||||
|
}
|
||||||
|
|
||||||
|
cloned := make(map[string]string, len(aliases))
|
||||||
|
for key, value := range aliases {
|
||||||
|
cloned[key] = value
|
||||||
|
}
|
||||||
|
return cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstModel(models []string) string {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return models[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func editorPreCheckedModels(saved *config.IntegrationConfig, override string) []string {
|
||||||
|
if override == "" {
|
||||||
|
if saved == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return append([]string(nil), saved.Models...)
|
||||||
|
}
|
||||||
|
return append([]string{override}, additionalSavedModels(saved, override)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func additionalSavedModels(saved *config.IntegrationConfig, exclude string) []string {
|
||||||
|
if saved == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var models []string
|
||||||
|
for _, model := range saved.Models {
|
||||||
|
if model != exclude {
|
||||||
|
models = append(models, model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return models
|
||||||
|
}
|
||||||
1729
cmd/launch/launch_test.go
Normal file
566
cmd/launch/models.go
Normal file
@@ -0,0 +1,566 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"runtime"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/cmd/config"
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
|
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||||
|
"github.com/ollama/ollama/internal/modelref"
|
||||||
|
"github.com/ollama/ollama/progress"
|
||||||
|
)
|
||||||
|
|
||||||
|
var recommendedModels = []ModelItem{
|
||||||
|
{Name: "kimi-k2.5:cloud", Description: "Multimodal reasoning with subagents", Recommended: true},
|
||||||
|
{Name: "qwen3.5:cloud", Description: "Reasoning, coding, and agentic tool use with vision", Recommended: true},
|
||||||
|
{Name: "glm-5:cloud", Description: "Reasoning and code generation", Recommended: true},
|
||||||
|
{Name: "minimax-m2.7:cloud", Description: "Fast, efficient coding and real-world productivity", Recommended: true},
|
||||||
|
{Name: "glm-4.7-flash", Description: "Reasoning and code generation locally", Recommended: true},
|
||||||
|
{Name: "qwen3.5", Description: "Reasoning, coding, and visual understanding locally", Recommended: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
var recommendedVRAM = map[string]string{
|
||||||
|
"glm-4.7-flash": "~25GB",
|
||||||
|
"qwen3.5": "~11GB",
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloudModelLimit holds context and output token limits for a cloud model.
|
||||||
|
type cloudModelLimit struct {
|
||||||
|
Context int
|
||||||
|
Output int
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloudModelLimits maps cloud model base names to their token limits.
|
||||||
|
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
|
||||||
|
var cloudModelLimits = map[string]cloudModelLimit{
|
||||||
|
"minimax-m2.7": {Context: 204_800, Output: 128_000},
|
||||||
|
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
|
||||||
|
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
|
||||||
|
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
|
||||||
|
"glm-4.6": {Context: 202_752, Output: 131_072},
|
||||||
|
"glm-4.7": {Context: 202_752, Output: 131_072},
|
||||||
|
"glm-5": {Context: 202_752, Output: 131_072},
|
||||||
|
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
|
||||||
|
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
|
||||||
|
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
|
||||||
|
"kimi-k2.5": {Context: 262_144, Output: 262_144},
|
||||||
|
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
|
||||||
|
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
|
||||||
|
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
|
||||||
|
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
|
||||||
|
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
|
||||||
|
"qwen3.5": {Context: 262_144, Output: 32_768},
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupCloudModelLimit returns the token limits for a cloud model.
|
||||||
|
// It normalizes explicit cloud source suffixes before checking the shared limit map.
|
||||||
|
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
||||||
|
base, stripped := modelref.StripCloudSourceTag(name)
|
||||||
|
if stripped {
|
||||||
|
if l, ok := cloudModelLimits[base]; ok {
|
||||||
|
return l, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cloudModelLimit{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// missingModelPolicy controls how model-not-found errors should be handled.
|
||||||
|
type missingModelPolicy int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// missingModelPromptPull prompts the user to download missing local models.
|
||||||
|
missingModelPromptPull missingModelPolicy = iota
|
||||||
|
// missingModelAutoPull downloads missing local models without prompting.
|
||||||
|
missingModelAutoPull
|
||||||
|
// missingModelFail returns an error for missing local models without prompting.
|
||||||
|
missingModelFail
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenBrowser opens the URL in the user's browser.
|
||||||
|
func OpenBrowser(url string) {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
_ = exec.Command("open", url).Start()
|
||||||
|
case "linux":
|
||||||
|
// Skip on headless systems where no display server is available
|
||||||
|
if os.Getenv("DISPLAY") == "" && os.Getenv("WAYLAND_DISPLAY") == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = exec.Command("xdg-open", url).Start()
|
||||||
|
case "windows":
|
||||||
|
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureAuth ensures the user is signed in before cloud-backed models run.
|
||||||
|
func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error {
|
||||||
|
var selectedCloudModels []string
|
||||||
|
for _, m := range selected {
|
||||||
|
if cloudModels[m] {
|
||||||
|
selectedCloudModels = append(selectedCloudModels, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(selectedCloudModels) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
|
||||||
|
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := client.Whoami(ctx)
|
||||||
|
if err == nil && user != nil && user.Name != "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var aErr api.AuthorizationError
|
||||||
|
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
modelList := strings.Join(selectedCloudModels, ", ")
|
||||||
|
|
||||||
|
if DefaultSignIn != nil {
|
||||||
|
_, err := DefaultSignIn(modelList, aErr.SigninURL)
|
||||||
|
if errors.Is(err, ErrCancelled) {
|
||||||
|
return ErrCancelled
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%s requires sign in", modelList)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
yes, err := ConfirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||||
|
if errors.Is(err, ErrCancelled) {
|
||||||
|
return ErrCancelled
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !yes {
|
||||||
|
return ErrCancelled
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||||
|
OpenBrowser(aErr.SigninURL)
|
||||||
|
|
||||||
|
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||||
|
frame := 0
|
||||||
|
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||||
|
|
||||||
|
ticker := time.NewTicker(200 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||||
|
return ctx.Err()
|
||||||
|
case <-ticker.C:
|
||||||
|
frame++
|
||||||
|
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||||
|
|
||||||
|
if frame%10 == 0 {
|
||||||
|
u, err := client.Whoami(ctx)
|
||||||
|
if err == nil && u != nil && u.Name != "" {
|
||||||
|
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// showOrPullWithPolicy checks if a model exists and applies the provided missing-model policy.
|
||||||
|
func showOrPullWithPolicy(ctx context.Context, client *api.Client, model string, policy missingModelPolicy, isCloudModel bool) error {
|
||||||
|
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
var statusErr api.StatusError
|
||||||
|
if !errors.As(err, &statusErr) || statusErr.StatusCode != http.StatusNotFound {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isCloudModel {
|
||||||
|
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
|
||||||
|
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
|
||||||
|
}
|
||||||
|
return fmt.Errorf("model %q not found", model)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch policy {
|
||||||
|
case missingModelAutoPull:
|
||||||
|
return pullMissingModel(ctx, client, model)
|
||||||
|
case missingModelFail:
|
||||||
|
return fmt.Errorf("model %q not found; run 'ollama pull %s' first, or use --yes to auto-pull", model, model)
|
||||||
|
default:
|
||||||
|
return confirmAndPull(ctx, client, model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func confirmAndPull(ctx context.Context, client *api.Client, model string) error {
|
||||||
|
if ok, err := ConfirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
|
||||||
|
return err
|
||||||
|
} else if !ok {
|
||||||
|
return errCancelled
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "\n")
|
||||||
|
return pullMissingModel(ctx, client, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func pullMissingModel(ctx context.Context, client *api.Client, model string) error {
|
||||||
|
if err := pullModel(ctx, client, model, false); err != nil {
|
||||||
|
return fmt.Errorf("failed to pull %s: %w", model, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepareEditorIntegration persists models and applies editor-managed config files.
|
||||||
|
func prepareEditorIntegration(name string, runner Runner, editor Editor, models []string) error {
|
||||||
|
if ok, err := confirmEditorEdit(runner, editor); err != nil {
|
||||||
|
return err
|
||||||
|
} else if !ok {
|
||||||
|
return errCancelled
|
||||||
|
}
|
||||||
|
if err := editor.Edit(models); err != nil {
|
||||||
|
return fmt.Errorf("setup failed: %w", err)
|
||||||
|
}
|
||||||
|
if err := config.SaveIntegration(name, models); err != nil {
|
||||||
|
return fmt.Errorf("failed to save: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func confirmEditorEdit(runner Runner, editor Editor) (bool, error) {
|
||||||
|
paths := editor.Paths()
|
||||||
|
if len(paths) == 0 {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", runner)
|
||||||
|
for _, path := range paths {
|
||||||
|
fmt.Fprintf(os.Stderr, " %s\n", path)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", fileutil.BackupDir())
|
||||||
|
|
||||||
|
return ConfirmPrompt("Proceed?")
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildModelList merges existing models with recommendations for selection UIs.
|
||||||
|
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
|
||||||
|
existingModels = make(map[string]bool)
|
||||||
|
cloudModels = make(map[string]bool)
|
||||||
|
recommended := make(map[string]bool)
|
||||||
|
var hasLocalModel, hasCloudModel bool
|
||||||
|
|
||||||
|
recDesc := make(map[string]string)
|
||||||
|
for _, rec := range recommendedModels {
|
||||||
|
recommended[rec.Name] = true
|
||||||
|
recDesc[rec.Name] = rec.Description
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range existing {
|
||||||
|
existingModels[m.Name] = true
|
||||||
|
if m.Remote {
|
||||||
|
cloudModels[m.Name] = true
|
||||||
|
hasCloudModel = true
|
||||||
|
} else {
|
||||||
|
hasLocalModel = true
|
||||||
|
}
|
||||||
|
displayName := strings.TrimSuffix(m.Name, ":latest")
|
||||||
|
existingModels[displayName] = true
|
||||||
|
item := ModelItem{Name: displayName, Recommended: recommended[displayName], Description: recDesc[displayName]}
|
||||||
|
items = append(items, item)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rec := range recommendedModels {
|
||||||
|
if existingModels[rec.Name] || existingModels[rec.Name+":latest"] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
items = append(items, rec)
|
||||||
|
if isCloudModelName(rec.Name) {
|
||||||
|
cloudModels[rec.Name] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
checked := make(map[string]bool, len(preChecked))
|
||||||
|
for _, n := range preChecked {
|
||||||
|
checked[n] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if current != "" {
|
||||||
|
matchedCurrent := false
|
||||||
|
for _, item := range items {
|
||||||
|
if item.Name == current {
|
||||||
|
current = item.Name
|
||||||
|
matchedCurrent = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !matchedCurrent {
|
||||||
|
for _, item := range items {
|
||||||
|
if strings.HasPrefix(item.Name, current+":") {
|
||||||
|
current = item.Name
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if checked[current] {
|
||||||
|
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
|
||||||
|
}
|
||||||
|
|
||||||
|
notInstalled := make(map[string]bool)
|
||||||
|
for i := range items {
|
||||||
|
if !existingModels[items[i].Name] && !cloudModels[items[i].Name] {
|
||||||
|
notInstalled[items[i].Name] = true
|
||||||
|
var parts []string
|
||||||
|
if items[i].Description != "" {
|
||||||
|
parts = append(parts, items[i].Description)
|
||||||
|
}
|
||||||
|
if vram := recommendedVRAM[items[i].Name]; vram != "" {
|
||||||
|
parts = append(parts, vram)
|
||||||
|
}
|
||||||
|
parts = append(parts, "(not downloaded)")
|
||||||
|
items[i].Description = strings.Join(parts, ", ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
recRank := make(map[string]int)
|
||||||
|
for i, rec := range recommendedModels {
|
||||||
|
recRank[rec.Name] = i + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
onlyLocal := hasLocalModel && !hasCloudModel
|
||||||
|
|
||||||
|
if hasLocalModel || hasCloudModel {
|
||||||
|
slices.SortStableFunc(items, func(a, b ModelItem) int {
|
||||||
|
ac, bc := checked[a.Name], checked[b.Name]
|
||||||
|
aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
|
||||||
|
aRec, bRec := recRank[a.Name] > 0, recRank[b.Name] > 0
|
||||||
|
aCloud, bCloud := cloudModels[a.Name], cloudModels[b.Name]
|
||||||
|
|
||||||
|
if ac != bc {
|
||||||
|
if ac {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if aRec != bRec {
|
||||||
|
if aRec {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if aRec && bRec {
|
||||||
|
if aCloud != bCloud {
|
||||||
|
if onlyLocal {
|
||||||
|
if aCloud {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if aCloud {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return recRank[a.Name] - recRank[b.Name]
|
||||||
|
}
|
||||||
|
if aNew != bNew {
|
||||||
|
if aNew {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return items, preChecked, existingModels, cloudModels
|
||||||
|
}
|
||||||
|
|
||||||
|
// isCloudModelName reports whether the model name has an explicit cloud source.
|
||||||
|
func isCloudModelName(name string) bool {
|
||||||
|
return modelref.HasExplicitCloudSource(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterCloudModels drops remote-only models from the given inventory.
|
||||||
|
func filterCloudModels(existing []modelInfo) []modelInfo {
|
||||||
|
filtered := existing[:0]
|
||||||
|
for _, m := range existing {
|
||||||
|
if !m.Remote {
|
||||||
|
filtered = append(filtered, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterCloudItems removes cloud models from selection items.
|
||||||
|
func filterCloudItems(items []ModelItem) []ModelItem {
|
||||||
|
filtered := items[:0]
|
||||||
|
for _, item := range items {
|
||||||
|
if !isCloudModelName(item.Name) {
|
||||||
|
filtered = append(filtered, item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
|
||||||
|
if client == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return resp.RemoteModel != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloudStatusDisabled returns whether cloud usage is currently disabled.
|
||||||
|
func cloudStatusDisabled(ctx context.Context, client *api.Client) (disabled bool, known bool) {
|
||||||
|
status, err := client.CloudStatusExperimental(ctx)
|
||||||
|
if err != nil {
|
||||||
|
var statusErr api.StatusError
|
||||||
|
if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
return status.Cloud.Disabled, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(ParthSareen): make this controllable on an integration level as well
|
||||||
|
const recommendedContextLength = 64000
|
||||||
|
|
||||||
|
func hasLocalModel(models []string) bool {
|
||||||
|
for _, m := range models {
|
||||||
|
if !isCloudModelName(m) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func lowContextLength(ctx context.Context, client *api.Client, models []string) error {
|
||||||
|
if !hasLocalModel(models) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
status, err := client.CloudStatusExperimental(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil //nolint:nilerr // best-effort check; ignore if status endpoint is unavailable
|
||||||
|
}
|
||||||
|
serverCtx := status.ContextLength
|
||||||
|
if serverCtx == 0 {
|
||||||
|
return nil // couldn't determine context length, skip check
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range models {
|
||||||
|
if isCloudModelName(m) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// A Modelfile can override num_ctx, which takes precedence over the server default.
|
||||||
|
effectiveCtx := serverCtx
|
||||||
|
modelfileOverride := false
|
||||||
|
var info *api.ShowResponse
|
||||||
|
if info, err = client.Show(ctx, &api.ShowRequest{Model: m}); err == nil {
|
||||||
|
if numCtx := parseNumCtx(info.Parameters); numCtx > 0 {
|
||||||
|
effectiveCtx = numCtx
|
||||||
|
modelfileOverride = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if effectiveCtx < recommendedContextLength {
|
||||||
|
fmt.Fprintf(os.Stderr, "\n%sWarning: context window is %d tokens (recommended: %d+)%s\n", ansiYellow, effectiveCtx, recommendedContextLength, ansiReset)
|
||||||
|
if modelfileOverride {
|
||||||
|
parentModel := info.Details.ParentModel
|
||||||
|
fmt.Fprintf(os.Stderr, "%sUse the model: %s and increase the context length to at least %d in Ollama App Settings.%s\n\n", ansiYellow, parentModel, recommendedContextLength, ansiReset)
|
||||||
|
} else {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
fmt.Fprintf(os.Stderr, "%sIncrease it in Ollama App Settings or with $env:OLLAMA_CONTEXT_LENGTH=%d; ollama serve%s\n\n", ansiYellow, recommendedContextLength, ansiReset)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "%sIncrease it in Ollama App Settings or with OLLAMA_CONTEXT_LENGTH=%d ollama serve%s\n\n", ansiYellow, recommendedContextLength, ansiReset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseNumCtx extracts num_ctx from the Show response Parameters string.
|
||||||
|
func parseNumCtx(parameters string) int {
|
||||||
|
for _, line := range strings.Split(parameters, "\n") {
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
if len(fields) == 2 && fields[0] == "num_ctx" {
|
||||||
|
if v, err := strconv.ParseFloat(fields[1], 64); err == nil {
|
||||||
|
return int(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(parthsareen): this duplicates the pull progress UI in cmd.PullHandler.
|
||||||
|
// Move the shared pull rendering to a small utility once the package boundary settles.
|
||||||
|
func pullModel(ctx context.Context, client *api.Client, model string, insecure bool) error {
|
||||||
|
p := progress.NewProgress(os.Stderr)
|
||||||
|
defer p.Stop()
|
||||||
|
|
||||||
|
bars := make(map[string]*progress.Bar)
|
||||||
|
var status string
|
||||||
|
var spinner *progress.Spinner
|
||||||
|
|
||||||
|
fn := func(resp api.ProgressResponse) error {
|
||||||
|
if resp.Digest != "" {
|
||||||
|
if resp.Completed == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if spinner != nil {
|
||||||
|
spinner.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
bar, ok := bars[resp.Digest]
|
||||||
|
if !ok {
|
||||||
|
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
if isDigest {
|
||||||
|
name = name[:min(12, len(name))]
|
||||||
|
}
|
||||||
|
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
|
||||||
|
bars[resp.Digest] = bar
|
||||||
|
p.Add(resp.Digest, bar)
|
||||||
|
}
|
||||||
|
|
||||||
|
bar.Set(resp.Completed)
|
||||||
|
} else if status != resp.Status {
|
||||||
|
if spinner != nil {
|
||||||
|
spinner.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
status = resp.Status
|
||||||
|
spinner = progress.NewSpinner(status)
|
||||||
|
p.Add(status, spinner)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
request := api.PullRequest{Name: model, Insecure: insecure}
|
||||||
|
return client.Pull(ctx, &request, fn)
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -14,7 +14,10 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/mod/semver"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
@@ -24,6 +27,9 @@ const defaultGatewayPort = 18789
|
|||||||
// Bound model capability probing so launch/config cannot hang on slow/unreachable API calls.
|
// Bound model capability probing so launch/config cannot hang on slow/unreachable API calls.
|
||||||
var openclawModelShowTimeout = 5 * time.Second
|
var openclawModelShowTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
// openclawFreshInstall is set to true when ensureOpenclawInstalled performs an install
|
||||||
|
var openclawFreshInstall bool
|
||||||
|
|
||||||
type Openclaw struct{}
|
type Openclaw struct{}
|
||||||
|
|
||||||
func (c *Openclaw) String() string { return "OpenClaw" }
|
func (c *Openclaw) String() string { return "OpenClaw" }
|
||||||
@@ -34,10 +40,7 @@ func (c *Openclaw) Run(model string, args []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
firstLaunch := true
|
firstLaunch := !c.onboarded()
|
||||||
if integrationConfig, err := loadIntegration("openclaw"); err == nil {
|
|
||||||
firstLaunch = !integrationConfig.Onboarded
|
|
||||||
}
|
|
||||||
|
|
||||||
if firstLaunch {
|
if firstLaunch {
|
||||||
fmt.Fprintf(os.Stderr, "\n%sSecurity%s\n\n", ansiBold, ansiReset)
|
fmt.Fprintf(os.Stderr, "\n%sSecurity%s\n\n", ansiBold, ansiReset)
|
||||||
@@ -45,28 +48,46 @@ func (c *Openclaw) Run(model string, args []string) error {
|
|||||||
fmt.Fprintf(os.Stderr, " A bad prompt can trick it into doing unsafe things.\n\n")
|
fmt.Fprintf(os.Stderr, " A bad prompt can trick it into doing unsafe things.\n\n")
|
||||||
fmt.Fprintf(os.Stderr, "%s Learn more: https://docs.openclaw.ai/gateway/security%s\n\n", ansiGray, ansiReset)
|
fmt.Fprintf(os.Stderr, "%s Learn more: https://docs.openclaw.ai/gateway/security%s\n\n", ansiGray, ansiReset)
|
||||||
|
|
||||||
ok, err := confirmPrompt("I understand the risks. Continue?")
|
ok, err := ConfirmPrompt("I understand the risks. Continue?")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if !c.onboarded() {
|
// Ensure the latest version is installed before onboarding so we get
|
||||||
|
// the newest wizard flags (e.g. --auth-choice ollama).
|
||||||
|
if !openclawFreshInstall {
|
||||||
|
update := exec.Command(bin, "update")
|
||||||
|
update.Stdout = os.Stdout
|
||||||
|
update.Stderr = os.Stderr
|
||||||
|
_ = update.Run() // best-effort; continue even if update fails
|
||||||
|
}
|
||||||
|
|
||||||
fmt.Fprintf(os.Stderr, "\n%sSetting up OpenClaw with Ollama...%s\n", ansiGreen, ansiReset)
|
fmt.Fprintf(os.Stderr, "\n%sSetting up OpenClaw with Ollama...%s\n", ansiGreen, ansiReset)
|
||||||
fmt.Fprintf(os.Stderr, "%s Model: %s%s\n\n", ansiGray, model, ansiReset)
|
fmt.Fprintf(os.Stderr, "%s Model: %s%s\n\n", ansiGray, model, ansiReset)
|
||||||
|
|
||||||
cmd := exec.Command(bin, "onboard",
|
onboardArgs := []string{
|
||||||
|
"onboard",
|
||||||
"--non-interactive",
|
"--non-interactive",
|
||||||
"--accept-risk",
|
"--accept-risk",
|
||||||
"--auth-choice", "skip",
|
"--auth-choice", "ollama",
|
||||||
"--gateway-token", "ollama",
|
"--custom-base-url", envconfig.Host().String(),
|
||||||
"--install-daemon",
|
"--custom-model-id", model,
|
||||||
"--skip-channels",
|
"--skip-channels",
|
||||||
"--skip-skills",
|
"--skip-skills",
|
||||||
)
|
}
|
||||||
|
if canInstallDaemon() {
|
||||||
|
onboardArgs = append(onboardArgs, "--install-daemon")
|
||||||
|
} else {
|
||||||
|
// When we can't install a daemon (e.g. no systemd, sudo dropped
|
||||||
|
// XDG_RUNTIME_DIR, or container environment), skip the gateway
|
||||||
|
// health check so non-interactive onboarding completes. The
|
||||||
|
// gateway is started as a foreground child process after onboarding.
|
||||||
|
onboardArgs = append(onboardArgs, "--skip-health")
|
||||||
|
}
|
||||||
|
cmd := exec.Command(bin, onboardArgs...)
|
||||||
cmd.Stdin = os.Stdin
|
cmd.Stdin = os.Stdin
|
||||||
cmd.Stdout = os.Stdout
|
cmd.Stdout = os.Stdout
|
||||||
cmd.Stderr = os.Stderr
|
cmd.Stderr = os.Stderr
|
||||||
@@ -75,25 +96,13 @@ func (c *Openclaw) Run(model string, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
patchDeviceScopes()
|
patchDeviceScopes()
|
||||||
|
|
||||||
// Onboarding overwrites openclaw.json, so re-apply the model config
|
|
||||||
// that Edit() wrote before Run() was called.
|
|
||||||
if err := c.Edit([]string{model}); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "%s Warning: could not re-apply model config: %v%s\n", ansiYellow, err, ansiReset)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasSuffix(model, ":cloud") || strings.HasSuffix(model, "-cloud") {
|
if ensureWebSearchPlugin() {
|
||||||
if ensureWebSearchPlugin() {
|
registerWebSearchPlugin()
|
||||||
registerWebSearchPlugin()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if firstLaunch {
|
fmt.Fprintf(os.Stderr, "\n%sStarting your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
|
||||||
fmt.Fprintf(os.Stderr, "\n%sPreparing your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
|
|
||||||
} else {
|
|
||||||
fmt.Fprintf(os.Stderr, "\n%sStarting your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
|
|
||||||
}
|
|
||||||
|
|
||||||
// When extra args are passed through, run exactly what the user asked for
|
// When extra args are passed through, run exactly what the user asked for
|
||||||
// after setup and skip the built-in gateway+TUI convenience flow.
|
// after setup and skip the built-in gateway+TUI convenience flow.
|
||||||
@@ -106,11 +115,6 @@ func (c *Openclaw) Run(model string, args []string) error {
|
|||||||
if err := cmd.Run(); err != nil {
|
if err := cmd.Run(); err != nil {
|
||||||
return windowsHint(err)
|
return windowsHint(err)
|
||||||
}
|
}
|
||||||
if firstLaunch {
|
|
||||||
if err := integrationOnboarded("openclaw"); err != nil {
|
|
||||||
return fmt.Errorf("failed to save onboarding state: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -118,7 +122,7 @@ func (c *Openclaw) Run(model string, args []string) error {
|
|||||||
addr := fmt.Sprintf("localhost:%d", port)
|
addr := fmt.Sprintf("localhost:%d", port)
|
||||||
|
|
||||||
// If the gateway is already running (e.g. via the daemon), restart it
|
// If the gateway is already running (e.g. via the daemon), restart it
|
||||||
// so it picks up any config changes from Edit() above (model, provider, etc.).
|
// so it picks up any config changes (model, provider, etc.).
|
||||||
if portOpen(addr) {
|
if portOpen(addr) {
|
||||||
restart := exec.Command(bin, "daemon", "restart")
|
restart := exec.Command(bin, "daemon", "restart")
|
||||||
restart.Env = openclawEnv()
|
restart.Env = openclawEnv()
|
||||||
@@ -165,11 +169,6 @@ func (c *Openclaw) Run(model string, args []string) error {
|
|||||||
return windowsHint(err)
|
return windowsHint(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if firstLaunch {
|
|
||||||
if err := integrationOnboarded("openclaw"); err != nil {
|
|
||||||
return fmt.Errorf("failed to save onboarding state: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -409,6 +408,25 @@ func patchScopes(obj map[string]any, key string, required []string) bool {
|
|||||||
return added
|
return added
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// canInstallDaemon reports whether the openclaw daemon can be installed as a
|
||||||
|
// background service. Returns false on Linux when systemd is absent (e.g.
|
||||||
|
// containers) so that --install-daemon is omitted and the gateway is started
|
||||||
|
// as a foreground child process instead. Returns true in all other cases.
|
||||||
|
func canInstallDaemon() bool {
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// /run/systemd/system exists as a directory when systemd is the init system.
|
||||||
|
// This is absent in most containers.
|
||||||
|
fi, err := os.Stat("/run/systemd/system")
|
||||||
|
if err != nil || !fi.IsDir() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Even when systemd is the init system, user services require a user
|
||||||
|
// manager instance. XDG_RUNTIME_DIR being set is a prerequisite.
|
||||||
|
return os.Getenv("XDG_RUNTIME_DIR") != ""
|
||||||
|
}
|
||||||
|
|
||||||
func ensureOpenclawInstalled() (string, error) {
|
func ensureOpenclawInstalled() (string, error) {
|
||||||
if _, err := exec.LookPath("openclaw"); err == nil {
|
if _, err := exec.LookPath("openclaw"); err == nil {
|
||||||
return "openclaw", nil
|
return "openclaw", nil
|
||||||
@@ -417,16 +435,20 @@ func ensureOpenclawInstalled() (string, error) {
|
|||||||
return "clawdbot", nil
|
return "clawdbot", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := exec.LookPath("npm"); err != nil {
|
_, npmErr := exec.LookPath("npm")
|
||||||
return "", fmt.Errorf("openclaw is not installed and npm was not found\n\n" +
|
_, gitErr := exec.LookPath("git")
|
||||||
"Install Node.js first:\n" +
|
if npmErr != nil || gitErr != nil {
|
||||||
" https://nodejs.org/\n\n" +
|
var missing []string
|
||||||
"Then rerun:\n" +
|
if npmErr != nil {
|
||||||
" ollama launch\n" +
|
missing = append(missing, "npm (Node.js): https://nodejs.org/")
|
||||||
"and select OpenClaw")
|
}
|
||||||
|
if gitErr != nil {
|
||||||
|
missing = append(missing, "git: https://git-scm.com/")
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("openclaw is not installed and required dependencies are missing\n\nInstall the following first:\n %s", strings.Join(missing, "\n "))
|
||||||
}
|
}
|
||||||
|
|
||||||
ok, err := confirmPrompt("OpenClaw is not installed. Install with npm?")
|
ok, err := ConfirmPrompt("OpenClaw is not installed. Install with npm?")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -448,6 +470,7 @@ func ensureOpenclawInstalled() (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintf(os.Stderr, "%sOpenClaw installed successfully%s\n\n", ansiGreen, ansiReset)
|
fmt.Fprintf(os.Stderr, "%sOpenClaw installed successfully%s\n\n", ansiGreen, ansiReset)
|
||||||
|
openclawFreshInstall = true
|
||||||
return "openclaw", nil
|
return "openclaw", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -502,7 +525,7 @@ func (c *Openclaw) Edit(models []string) error {
|
|||||||
ollama = make(map[string]any)
|
ollama = make(map[string]any)
|
||||||
}
|
}
|
||||||
|
|
||||||
ollama["baseUrl"] = envconfig.Host().String() + "/v1"
|
ollama["baseUrl"] = envconfig.Host().String()
|
||||||
// needed to register provider
|
// needed to register provider
|
||||||
ollama["apiKey"] = "ollama-local"
|
ollama["apiKey"] = "ollama-local"
|
||||||
ollama["api"] = "ollama"
|
ollama["api"] = "ollama"
|
||||||
@@ -561,7 +584,7 @@ func (c *Openclaw) Edit(models []string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := writeWithBackup(configPath, data); err != nil {
|
if err := fileutil.WriteWithBackup(configPath, data); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -592,6 +615,8 @@ func clearSessionModelOverride(primary string) {
|
|||||||
if override, _ := sess["modelOverride"].(string); override != "" && override != primary {
|
if override, _ := sess["modelOverride"].(string); override != "" && override != primary {
|
||||||
delete(sess, "modelOverride")
|
delete(sess, "modelOverride")
|
||||||
delete(sess, "providerOverride")
|
delete(sess, "providerOverride")
|
||||||
|
}
|
||||||
|
if model, _ := sess["model"].(string); model != "" && model != primary {
|
||||||
sess["model"] = primary
|
sess["model"] = primary
|
||||||
changed = true
|
changed = true
|
||||||
}
|
}
|
||||||
@@ -606,11 +631,15 @@ func clearSessionModelOverride(primary string) {
|
|||||||
_ = os.WriteFile(path, out, 0o600)
|
_ = os.WriteFile(path, out, 0o600)
|
||||||
}
|
}
|
||||||
|
|
||||||
const webSearchNpmPackage = "@ollama/openclaw-web-search"
|
const (
|
||||||
|
webSearchNpmPackage = "@ollama/openclaw-web-search"
|
||||||
|
webSearchMinVersion = "0.2.1"
|
||||||
|
)
|
||||||
|
|
||||||
// ensureWebSearchPlugin installs the openclaw-web-search extension into the
|
// ensureWebSearchPlugin installs the openclaw-web-search extension into the
|
||||||
// user-level extensions directory (~/.openclaw/extensions/) if it isn't already
|
// user-level extensions directory (~/.openclaw/extensions/) if it isn't already
|
||||||
// present. Returns true if the extension is available.
|
// present, or re-installs if the installed version is older than webSearchMinVersion.
|
||||||
|
// Returns true if the extension is available.
|
||||||
func ensureWebSearchPlugin() bool {
|
func ensureWebSearchPlugin() bool {
|
||||||
home, err := os.UserHomeDir()
|
home, err := os.UserHomeDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -618,8 +647,8 @@ func ensureWebSearchPlugin() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pluginDir := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
|
pluginDir := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
|
||||||
if _, err := os.Stat(filepath.Join(pluginDir, "index.ts")); err == nil {
|
if webSearchPluginUpToDate(pluginDir) {
|
||||||
return true // already installed
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
npmBin, err := exec.LookPath("npm")
|
npmBin, err := exec.LookPath("npm")
|
||||||
@@ -653,6 +682,34 @@ func ensureWebSearchPlugin() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// webSearchPluginUpToDate returns true if the plugin is installed and its
|
||||||
|
// package.json version is >= webSearchMinVersion.
|
||||||
|
func webSearchPluginUpToDate(pluginDir string) bool {
|
||||||
|
data, err := os.ReadFile(filepath.Join(pluginDir, "package.json"))
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
var pkg struct {
|
||||||
|
Version string `json:"version"`
|
||||||
|
}
|
||||||
|
if json.Unmarshal(data, &pkg) != nil || pkg.Version == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return !versionLessThan(pkg.Version, webSearchMinVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
// versionLessThan compares two semver version strings (major.minor.patch).
|
||||||
|
// Inputs may omit the "v" prefix; it is added automatically for semver.Compare.
|
||||||
|
func versionLessThan(a, b string) bool {
|
||||||
|
if !strings.HasPrefix(a, "v") {
|
||||||
|
a = "v" + a
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(b, "v") {
|
||||||
|
b = "v" + b
|
||||||
|
}
|
||||||
|
return semver.Compare(a, b) < 0
|
||||||
|
}
|
||||||
|
|
||||||
// registerWebSearchPlugin adds plugins.entries.openclaw-web-search to the OpenClaw
|
// registerWebSearchPlugin adds plugins.entries.openclaw-web-search to the OpenClaw
|
||||||
// config so the gateway activates it on next start. Best-effort; silently returns
|
// config so the gateway activates it on next start. Best-effort; silently returns
|
||||||
// on any error.
|
// on any error.
|
||||||
@@ -679,23 +736,67 @@ func registerWebSearchPlugin() {
|
|||||||
if entries == nil {
|
if entries == nil {
|
||||||
entries = make(map[string]any)
|
entries = make(map[string]any)
|
||||||
}
|
}
|
||||||
if _, ok := entries["openclaw-web-search"]; ok {
|
|
||||||
return // already registered
|
|
||||||
}
|
|
||||||
entries["openclaw-web-search"] = map[string]any{"enabled": true}
|
entries["openclaw-web-search"] = map[string]any{"enabled": true}
|
||||||
plugins["entries"] = entries
|
plugins["entries"] = entries
|
||||||
|
|
||||||
|
// Pin trust so the gateway doesn't warn about untracked plugins.
|
||||||
|
allow, _ := plugins["allow"].([]any)
|
||||||
|
hasAllow := false
|
||||||
|
for _, v := range allow {
|
||||||
|
if s, ok := v.(string); ok && s == "openclaw-web-search" {
|
||||||
|
hasAllow = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasAllow {
|
||||||
|
allow = append(allow, "openclaw-web-search")
|
||||||
|
}
|
||||||
|
plugins["allow"] = allow
|
||||||
|
|
||||||
|
// Record install provenance so the loader can verify the plugin origin.
|
||||||
|
installs, _ := plugins["installs"].(map[string]any)
|
||||||
|
if installs == nil {
|
||||||
|
installs = make(map[string]any)
|
||||||
|
}
|
||||||
|
pluginDir := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
|
||||||
|
installs["openclaw-web-search"] = map[string]any{
|
||||||
|
"source": "npm",
|
||||||
|
"spec": webSearchNpmPackage,
|
||||||
|
"installPath": pluginDir,
|
||||||
|
}
|
||||||
|
plugins["installs"] = installs
|
||||||
|
|
||||||
config["plugins"] = plugins
|
config["plugins"] = plugins
|
||||||
|
|
||||||
// Disable the built-in web search since our plugin replaces it.
|
// Add plugin tools to tools.alsoAllow so they survive the coding profile's
|
||||||
|
// policy pipeline (which has an explicit allow list of core tools only).
|
||||||
tools, _ := config["tools"].(map[string]any)
|
tools, _ := config["tools"].(map[string]any)
|
||||||
if tools == nil {
|
if tools == nil {
|
||||||
tools = make(map[string]any)
|
tools = make(map[string]any)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
alsoAllow, _ := tools["alsoAllow"].([]any)
|
||||||
|
needed := []string{"ollama_web_search", "ollama_web_fetch"}
|
||||||
|
have := make(map[string]bool, len(alsoAllow))
|
||||||
|
for _, v := range alsoAllow {
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
have[s] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, name := range needed {
|
||||||
|
if !have[name] {
|
||||||
|
alsoAllow = append(alsoAllow, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tools["alsoAllow"] = alsoAllow
|
||||||
|
|
||||||
|
// Disable built-in web search/fetch since our plugin replaces them.
|
||||||
web, _ := tools["web"].(map[string]any)
|
web, _ := tools["web"].(map[string]any)
|
||||||
if web == nil {
|
if web == nil {
|
||||||
web = make(map[string]any)
|
web = make(map[string]any)
|
||||||
}
|
}
|
||||||
web["search"] = map[string]any{"enabled": false}
|
web["search"] = map[string]any{"enabled": false}
|
||||||
|
web["fetch"] = map[string]any{"enabled": false}
|
||||||
tools["web"] = web
|
tools["web"] = web
|
||||||
config["tools"] = tools
|
config["tools"] = tools
|
||||||
|
|
||||||
@@ -776,9 +877,9 @@ func (c *Openclaw) Models() []string {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := readJSONFile(filepath.Join(home, ".openclaw", "openclaw.json"))
|
config, err := fileutil.ReadJSON(filepath.Join(home, ".openclaw", "openclaw.json"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
config, err = readJSONFile(filepath.Join(home, ".clawdbot", "clawdbot.json"))
|
config, err = fileutil.ReadJSON(filepath.Join(home, ".clawdbot", "clawdbot.json"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@@ -82,78 +82,6 @@ func TestOpenclawRunPassthroughArgs(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOpenclawRunFirstLaunchPersistence(t *testing.T) {
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
t.Skip("uses a POSIX shell test binary")
|
|
||||||
}
|
|
||||||
|
|
||||||
oldHook := DefaultConfirmPrompt
|
|
||||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
defer func() { DefaultConfirmPrompt = oldHook }()
|
|
||||||
|
|
||||||
t.Run("success persists onboarding flag", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
t.Setenv("PATH", tmpDir)
|
|
||||||
|
|
||||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
|
||||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
// Mark OpenClaw onboarding complete so Run takes passthrough path directly.
|
|
||||||
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
|
|
||||||
"wizard": {"lastRunAt": "2026-01-01T00:00:00Z"}
|
|
||||||
}`), 0o644); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if err := os.WriteFile(filepath.Join(tmpDir, "openclaw"), []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c := &Openclaw{}
|
|
||||||
if err := c.Run("llama3.2", []string{"gateway", "--status"}); err != nil {
|
|
||||||
t.Fatalf("Run() error = %v", err)
|
|
||||||
}
|
|
||||||
integrationConfig, err := loadIntegration("openclaw")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("loadIntegration() error = %v", err)
|
|
||||||
}
|
|
||||||
if !integrationConfig.Onboarded {
|
|
||||||
t.Fatal("expected onboarding flag to be persisted after successful run")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("failure does not persist onboarding flag", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
t.Setenv("PATH", tmpDir)
|
|
||||||
|
|
||||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
|
||||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
|
|
||||||
"wizard": {"lastRunAt": "2026-01-01T00:00:00Z"}
|
|
||||||
}`), 0o644); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if err := os.WriteFile(filepath.Join(tmpDir, "openclaw"), []byte("#!/bin/sh\nexit 1\n"), 0o755); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c := &Openclaw{}
|
|
||||||
if err := c.Run("llama3.2", []string{"gateway", "--status"}); err == nil {
|
|
||||||
t.Fatal("expected run failure")
|
|
||||||
}
|
|
||||||
integrationConfig, err := loadIntegration("openclaw")
|
|
||||||
if err == nil && integrationConfig.Onboarded {
|
|
||||||
t.Fatal("expected onboarding flag to remain unset after failed run")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenclawEdit(t *testing.T) {
|
func TestOpenclawEdit(t *testing.T) {
|
||||||
c := &Openclaw{}
|
c := &Openclaw{}
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
@@ -589,7 +517,7 @@ const testOpenclawFixture = `{
|
|||||||
"providers": {
|
"providers": {
|
||||||
"anthropic": {"apiKey": "xxx"},
|
"anthropic": {"apiKey": "xxx"},
|
||||||
"ollama": {
|
"ollama": {
|
||||||
"baseUrl": "http://127.0.0.1:11434/v1",
|
"baseUrl": "http://127.0.0.1:11434",
|
||||||
"models": [{"id": "old-model", "customField": "preserved"}]
|
"models": [{"id": "old-model", "customField": "preserved"}]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1448,7 +1376,7 @@ func TestOpenclawModelConfig(t *testing.T) {
|
|||||||
// report it as a remote/cloud model
|
// report it as a remote/cloud model
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.URL.Path == "/api/show" {
|
if r.URL.Path == "/api/show" {
|
||||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"minimax-m2.5"}`)
|
fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"minimax-m2.7"}`)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.WriteHeader(http.StatusNotFound)
|
w.WriteHeader(http.StatusNotFound)
|
||||||
@@ -1458,7 +1386,7 @@ func TestOpenclawModelConfig(t *testing.T) {
|
|||||||
u, _ := url.Parse(srv.URL)
|
u, _ := url.Parse(srv.URL)
|
||||||
client := api.NewClient(u, srv.Client())
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
cfg, isCloud := openclawModelConfig(context.Background(), client, "minimax-m2.5:cloud")
|
cfg, isCloud := openclawModelConfig(context.Background(), client, "minimax-m2.7:cloud")
|
||||||
|
|
||||||
if !isCloud {
|
if !isCloud {
|
||||||
t.Error("expected isCloud = true for cloud model")
|
t.Error("expected isCloud = true for cloud model")
|
||||||
@@ -1528,7 +1456,7 @@ func TestIntegrationOnboarded(t *testing.T) {
|
|||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
integrationConfig, err := loadIntegration("openclaw")
|
integrationConfig, err := LoadIntegration("openclaw")
|
||||||
if err == nil && integrationConfig.Onboarded {
|
if err == nil && integrationConfig.Onboarded {
|
||||||
t.Error("expected false for fresh config")
|
t.Error("expected false for fresh config")
|
||||||
}
|
}
|
||||||
@@ -1542,7 +1470,7 @@ func TestIntegrationOnboarded(t *testing.T) {
|
|||||||
if err := integrationOnboarded("openclaw"); err != nil {
|
if err := integrationOnboarded("openclaw"); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
integrationConfig, err := loadIntegration("openclaw")
|
integrationConfig, err := LoadIntegration("openclaw")
|
||||||
if err != nil || !integrationConfig.Onboarded {
|
if err != nil || !integrationConfig.Onboarded {
|
||||||
t.Error("expected true after integrationOnboarded")
|
t.Error("expected true after integrationOnboarded")
|
||||||
}
|
}
|
||||||
@@ -1556,7 +1484,7 @@ func TestIntegrationOnboarded(t *testing.T) {
|
|||||||
if err := integrationOnboarded("OpenClaw"); err != nil {
|
if err := integrationOnboarded("OpenClaw"); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
integrationConfig, err := loadIntegration("openclaw")
|
integrationConfig, err := LoadIntegration("openclaw")
|
||||||
if err != nil || !integrationConfig.Onboarded {
|
if err != nil || !integrationConfig.Onboarded {
|
||||||
t.Error("expected true when set with different case")
|
t.Error("expected true when set with different case")
|
||||||
}
|
}
|
||||||
@@ -1575,7 +1503,7 @@ func TestIntegrationOnboarded(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Verify onboarded is set
|
// Verify onboarded is set
|
||||||
integrationConfig, err := loadIntegration("openclaw")
|
integrationConfig, err := LoadIntegration("openclaw")
|
||||||
if err != nil || !integrationConfig.Onboarded {
|
if err != nil || !integrationConfig.Onboarded {
|
||||||
t.Error("expected true after integrationOnboarded")
|
t.Error("expected true after integrationOnboarded")
|
||||||
}
|
}
|
||||||
@@ -1587,3 +1515,377 @@ func TestIntegrationOnboarded(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestVersionLessThan(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
a, b string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"0.1.7", "0.2.1", true},
|
||||||
|
{"0.2.0", "0.2.1", true},
|
||||||
|
{"0.2.1", "0.2.1", false},
|
||||||
|
{"0.2.2", "0.2.1", false},
|
||||||
|
{"1.0.0", "0.2.1", false},
|
||||||
|
{"0.2.1", "1.0.0", true},
|
||||||
|
{"v0.1.7", "0.2.1", true},
|
||||||
|
{"0.2.1", "v0.2.1", false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.a+"_vs_"+tt.b, func(t *testing.T) {
|
||||||
|
if got := versionLessThan(tt.a, tt.b); got != tt.want {
|
||||||
|
t.Errorf("versionLessThan(%q, %q) = %v, want %v", tt.a, tt.b, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebSearchPluginUpToDate(t *testing.T) {
|
||||||
|
t.Run("missing directory", func(t *testing.T) {
|
||||||
|
if webSearchPluginUpToDate(filepath.Join(t.TempDir(), "nonexistent")) {
|
||||||
|
t.Error("expected false for missing directory")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing package.json", func(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
if webSearchPluginUpToDate(dir) {
|
||||||
|
t.Error("expected false for missing package.json")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("old version", func(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, "package.json"), []byte(`{"version":"0.1.7"}`), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if webSearchPluginUpToDate(dir) {
|
||||||
|
t.Error("expected false for old version 0.1.7")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("exact minimum version", func(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, "package.json"), []byte(`{"version":"0.2.1"}`), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !webSearchPluginUpToDate(dir) {
|
||||||
|
t.Error("expected true for exact minimum version 0.2.1")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("newer version", func(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, "package.json"), []byte(`{"version":"1.0.0"}`), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !webSearchPluginUpToDate(dir) {
|
||||||
|
t.Error("expected true for newer version 1.0.0")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid json", func(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, "package.json"), []byte(`not json`), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if webSearchPluginUpToDate(dir) {
|
||||||
|
t.Error("expected false for invalid json")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty version", func(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, "package.json"), []byte(`{"version":""}`), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if webSearchPluginUpToDate(dir) {
|
||||||
|
t.Error("expected false for empty version")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegisterWebSearchPlugin(t *testing.T) {
|
||||||
|
home := t.TempDir()
|
||||||
|
setTestHome(t, home)
|
||||||
|
|
||||||
|
configDir := filepath.Join(home, ".openclaw")
|
||||||
|
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
configPath := filepath.Join(configDir, "openclaw.json")
|
||||||
|
|
||||||
|
t.Run("fresh config", func(t *testing.T) {
|
||||||
|
if err := os.WriteFile(configPath, []byte(`{}`), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
registerWebSearchPlugin()
|
||||||
|
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var config map[string]any
|
||||||
|
if err := json.Unmarshal(data, &config); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
plugins, _ := config["plugins"].(map[string]any)
|
||||||
|
if plugins == nil {
|
||||||
|
t.Fatal("plugins section missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check entries
|
||||||
|
entries, _ := plugins["entries"].(map[string]any)
|
||||||
|
entry, _ := entries["openclaw-web-search"].(map[string]any)
|
||||||
|
if enabled, _ := entry["enabled"].(bool); !enabled {
|
||||||
|
t.Error("expected entries.openclaw-web-search.enabled = true")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check allow list
|
||||||
|
allow, _ := plugins["allow"].([]any)
|
||||||
|
found := false
|
||||||
|
for _, v := range allow {
|
||||||
|
if s, ok := v.(string); ok && s == "openclaw-web-search" {
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Error("expected plugins.allow to contain openclaw-web-search")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check install provenance
|
||||||
|
installs, _ := plugins["installs"].(map[string]any)
|
||||||
|
record, _ := installs["openclaw-web-search"].(map[string]any)
|
||||||
|
if record == nil {
|
||||||
|
t.Fatal("expected plugins.installs.openclaw-web-search")
|
||||||
|
}
|
||||||
|
if source, _ := record["source"].(string); source != "npm" {
|
||||||
|
t.Errorf("install source = %q, want %q", source, "npm")
|
||||||
|
}
|
||||||
|
if spec, _ := record["spec"].(string); spec != webSearchNpmPackage {
|
||||||
|
t.Errorf("install spec = %q, want %q", spec, webSearchNpmPackage)
|
||||||
|
}
|
||||||
|
expectedPath := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
|
||||||
|
if installPath, _ := record["installPath"].(string); installPath != expectedPath {
|
||||||
|
t.Errorf("installPath = %q, want %q", installPath, expectedPath)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("idempotent", func(t *testing.T) {
|
||||||
|
if err := os.WriteFile(configPath, []byte(`{}`), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
registerWebSearchPlugin()
|
||||||
|
registerWebSearchPlugin()
|
||||||
|
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var config map[string]any
|
||||||
|
if err := json.Unmarshal(data, &config); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
plugins, _ := config["plugins"].(map[string]any)
|
||||||
|
allow, _ := plugins["allow"].([]any)
|
||||||
|
count := 0
|
||||||
|
for _, v := range allow {
|
||||||
|
if s, ok := v.(string); ok && s == "openclaw-web-search" {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if count != 1 {
|
||||||
|
t.Errorf("expected exactly 1 openclaw-web-search in allow, got %d", count)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserves existing config", func(t *testing.T) {
|
||||||
|
initial := map[string]any{
|
||||||
|
"plugins": map[string]any{
|
||||||
|
"allow": []any{"some-other-plugin"},
|
||||||
|
"entries": map[string]any{
|
||||||
|
"some-other-plugin": map[string]any{"enabled": true},
|
||||||
|
},
|
||||||
|
"installs": map[string]any{
|
||||||
|
"some-other-plugin": map[string]any{
|
||||||
|
"source": "npm",
|
||||||
|
"installPath": "/some/path",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"customField": "preserved",
|
||||||
|
}
|
||||||
|
data, _ := json.Marshal(initial)
|
||||||
|
if err := os.WriteFile(configPath, data, 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
registerWebSearchPlugin()
|
||||||
|
|
||||||
|
out, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var config map[string]any
|
||||||
|
if err := json.Unmarshal(out, &config); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config["customField"] != "preserved" {
|
||||||
|
t.Error("customField was not preserved")
|
||||||
|
}
|
||||||
|
|
||||||
|
plugins, _ := config["plugins"].(map[string]any)
|
||||||
|
entries, _ := plugins["entries"].(map[string]any)
|
||||||
|
if entries["some-other-plugin"] == nil {
|
||||||
|
t.Error("existing plugin entry was lost")
|
||||||
|
}
|
||||||
|
|
||||||
|
installs, _ := plugins["installs"].(map[string]any)
|
||||||
|
if installs["some-other-plugin"] == nil {
|
||||||
|
t.Error("existing install record was lost")
|
||||||
|
}
|
||||||
|
|
||||||
|
allow, _ := plugins["allow"].([]any)
|
||||||
|
hasOther, hasWebSearch := false, false
|
||||||
|
for _, v := range allow {
|
||||||
|
s, _ := v.(string)
|
||||||
|
if s == "some-other-plugin" {
|
||||||
|
hasOther = true
|
||||||
|
}
|
||||||
|
if s == "openclaw-web-search" {
|
||||||
|
hasWebSearch = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasOther {
|
||||||
|
t.Error("existing allow entry was lost")
|
||||||
|
}
|
||||||
|
if !hasWebSearch {
|
||||||
|
t.Error("openclaw-web-search not added to allow")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClearSessionModelOverride(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
sessionsDir := filepath.Join(tmpDir, ".openclaw", "agents", "main", "sessions")
|
||||||
|
sessionsPath := filepath.Join(sessionsDir, "sessions.json")
|
||||||
|
|
||||||
|
writeSessionsFile := func(t *testing.T, sessions map[string]map[string]any) {
|
||||||
|
t.Helper()
|
||||||
|
if err := os.MkdirAll(sessionsDir, 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(sessions)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(sessionsPath, data, 0o600); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
readSessionsFile := func(t *testing.T) map[string]map[string]any {
|
||||||
|
t.Helper()
|
||||||
|
data, err := os.ReadFile(sessionsPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading sessions file: %v", err)
|
||||||
|
}
|
||||||
|
var sessions map[string]map[string]any
|
||||||
|
if err := json.Unmarshal(data, &sessions); err != nil {
|
||||||
|
t.Fatalf("parsing sessions file: %v", err)
|
||||||
|
}
|
||||||
|
return sessions
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("clears modelOverride and updates model", func(t *testing.T) {
|
||||||
|
writeSessionsFile(t, map[string]map[string]any{
|
||||||
|
"sess1": {"model": "ollama/old-model", "modelOverride": "old-model", "providerOverride": "ollama"},
|
||||||
|
})
|
||||||
|
clearSessionModelOverride("new-model")
|
||||||
|
sessions := readSessionsFile(t)
|
||||||
|
sess := sessions["sess1"]
|
||||||
|
if _, ok := sess["modelOverride"]; ok {
|
||||||
|
t.Error("modelOverride should have been deleted")
|
||||||
|
}
|
||||||
|
if _, ok := sess["providerOverride"]; ok {
|
||||||
|
t.Error("providerOverride should have been deleted")
|
||||||
|
}
|
||||||
|
if sess["model"] != "new-model" {
|
||||||
|
t.Errorf("model = %q, want %q", sess["model"], "new-model")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("updates model field in sessions without modelOverride", func(t *testing.T) {
|
||||||
|
// This is the bug case: session has model pointing to old primary,
|
||||||
|
// but no explicit modelOverride. After changing primary, the session
|
||||||
|
// model field must also be updated.
|
||||||
|
writeSessionsFile(t, map[string]map[string]any{
|
||||||
|
"sess1": {"model": "ollama/old-model"},
|
||||||
|
})
|
||||||
|
clearSessionModelOverride("new-model")
|
||||||
|
sessions := readSessionsFile(t)
|
||||||
|
if sessions["sess1"]["model"] != "new-model" {
|
||||||
|
t.Errorf("model = %q, want %q", sessions["sess1"]["model"], "new-model")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("does not update session already using primary", func(t *testing.T) {
|
||||||
|
writeSessionsFile(t, map[string]map[string]any{
|
||||||
|
"sess1": {"model": "current-model"},
|
||||||
|
})
|
||||||
|
clearSessionModelOverride("current-model")
|
||||||
|
sessions := readSessionsFile(t)
|
||||||
|
if sessions["sess1"]["model"] != "current-model" {
|
||||||
|
t.Errorf("model = %q, want %q", sessions["sess1"]["model"], "current-model")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("does not update session with empty model field", func(t *testing.T) {
|
||||||
|
writeSessionsFile(t, map[string]map[string]any{
|
||||||
|
"sess1": {"other": "data"},
|
||||||
|
})
|
||||||
|
clearSessionModelOverride("new-model")
|
||||||
|
sessions := readSessionsFile(t)
|
||||||
|
if _, ok := sessions["sess1"]["model"]; ok {
|
||||||
|
t.Error("model field should not have been added to session with no model")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles multiple sessions mixed", func(t *testing.T) {
|
||||||
|
writeSessionsFile(t, map[string]map[string]any{
|
||||||
|
"with-override": {"model": "old", "modelOverride": "old", "providerOverride": "ollama"},
|
||||||
|
"without-override": {"model": "old"},
|
||||||
|
"already-current": {"model": "new-model"},
|
||||||
|
"no-model": {"other": "data"},
|
||||||
|
})
|
||||||
|
clearSessionModelOverride("new-model")
|
||||||
|
sessions := readSessionsFile(t)
|
||||||
|
|
||||||
|
if sessions["with-override"]["model"] != "new-model" {
|
||||||
|
t.Errorf("with-override model = %q, want %q", sessions["with-override"]["model"], "new-model")
|
||||||
|
}
|
||||||
|
if _, ok := sessions["with-override"]["modelOverride"]; ok {
|
||||||
|
t.Error("with-override: modelOverride should be deleted")
|
||||||
|
}
|
||||||
|
if sessions["without-override"]["model"] != "new-model" {
|
||||||
|
t.Errorf("without-override model = %q, want %q", sessions["without-override"]["model"], "new-model")
|
||||||
|
}
|
||||||
|
if sessions["already-current"]["model"] != "new-model" {
|
||||||
|
t.Errorf("already-current model = %q, want %q", sessions["already-current"]["model"], "new-model")
|
||||||
|
}
|
||||||
|
if _, ok := sessions["no-model"]["model"]; ok {
|
||||||
|
t.Error("no-model: model should not have been added")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no-op when sessions file missing", func(t *testing.T) {
|
||||||
|
os.RemoveAll(sessionsDir)
|
||||||
|
clearSessionModelOverride("new-model") // should not panic or error
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1,9 +1,7 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"maps"
|
"maps"
|
||||||
"os"
|
"os"
|
||||||
@@ -12,34 +10,13 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OpenCode implements Runner and Editor for OpenCode integration
|
// OpenCode implements Runner and Editor for OpenCode integration
|
||||||
type OpenCode struct{}
|
type OpenCode struct{}
|
||||||
|
|
||||||
// cloudModelLimit holds context and output token limits for a cloud model.
|
|
||||||
type cloudModelLimit struct {
|
|
||||||
Context int
|
|
||||||
Output int
|
|
||||||
}
|
|
||||||
|
|
||||||
// lookupCloudModelLimit returns the token limits for a cloud model.
|
|
||||||
// It tries the exact name first, then strips the ":cloud" suffix.
|
|
||||||
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
|
||||||
if l, ok := cloudModelLimits[name]; ok {
|
|
||||||
return l, true
|
|
||||||
}
|
|
||||||
base := strings.TrimSuffix(name, ":cloud")
|
|
||||||
if base != name {
|
|
||||||
if l, ok := cloudModelLimits[base]; ok {
|
|
||||||
return l, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return cloudModelLimit{}, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *OpenCode) String() string { return "OpenCode" }
|
func (o *OpenCode) String() string { return "OpenCode" }
|
||||||
|
|
||||||
func (o *OpenCode) Run(model string, args []string) error {
|
func (o *OpenCode) Run(model string, args []string) error {
|
||||||
@@ -47,25 +24,6 @@ func (o *OpenCode) Run(model string, args []string) error {
|
|||||||
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
|
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call Edit() to ensure config is up-to-date before launch
|
|
||||||
models := []string{model}
|
|
||||||
if config, err := loadIntegration("opencode"); err == nil && len(config.Models) > 0 {
|
|
||||||
models = config.Models
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
models, err = resolveEditorModels("opencode", models, func() ([]string, error) {
|
|
||||||
return selectModels(context.Background(), "opencode", "")
|
|
||||||
})
|
|
||||||
if errors.Is(err, errCancelled) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := o.Edit(models); err != nil {
|
|
||||||
return fmt.Errorf("setup failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command("opencode", args...)
|
cmd := exec.Command("opencode", args...)
|
||||||
cmd.Stdin = os.Stdin
|
cmd.Stdin = os.Stdin
|
||||||
cmd.Stdout = os.Stdout
|
cmd.Stdout = os.Stdout
|
||||||
@@ -122,13 +80,18 @@ func (o *OpenCode) Edit(modelList []string) error {
|
|||||||
if !ok {
|
if !ok {
|
||||||
ollama = map[string]any{
|
ollama = map[string]any{
|
||||||
"npm": "@ai-sdk/openai-compatible",
|
"npm": "@ai-sdk/openai-compatible",
|
||||||
"name": "Ollama (local)",
|
"name": "Ollama",
|
||||||
"options": map[string]any{
|
"options": map[string]any{
|
||||||
"baseURL": envconfig.Host().String() + "/v1",
|
"baseURL": envconfig.Host().String() + "/v1",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Migrate legacy provider name
|
||||||
|
if name, _ := ollama["name"].(string); name == "Ollama (local)" {
|
||||||
|
ollama["name"] = "Ollama"
|
||||||
|
}
|
||||||
|
|
||||||
models, ok := ollama["models"].(map[string]any)
|
models, ok := ollama["models"].(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
models = make(map[string]any)
|
models = make(map[string]any)
|
||||||
@@ -147,8 +110,6 @@ func (o *OpenCode) Edit(modelList []string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
client, _ := api.ClientFromEnvironment()
|
|
||||||
|
|
||||||
for _, model := range modelList {
|
for _, model := range modelList {
|
||||||
if existing, ok := models[model].(map[string]any); ok {
|
if existing, ok := models[model].(map[string]any); ok {
|
||||||
// migrate existing models without _launch marker
|
// migrate existing models without _launch marker
|
||||||
@@ -158,7 +119,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
|||||||
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if isCloudModel(context.Background(), client, model) {
|
if isCloudModelName(model) {
|
||||||
if l, ok := lookupCloudModelLimit(model); ok {
|
if l, ok := lookupCloudModelLimit(model); ok {
|
||||||
existing["limit"] = map[string]any{
|
existing["limit"] = map[string]any{
|
||||||
"context": l.Context,
|
"context": l.Context,
|
||||||
@@ -172,7 +133,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
|||||||
"name": model,
|
"name": model,
|
||||||
"_launch": true,
|
"_launch": true,
|
||||||
}
|
}
|
||||||
if isCloudModel(context.Background(), client, model) {
|
if isCloudModelName(model) {
|
||||||
if l, ok := lookupCloudModelLimit(model); ok {
|
if l, ok := lookupCloudModelLimit(model); ok {
|
||||||
entry["limit"] = map[string]any{
|
entry["limit"] = map[string]any{
|
||||||
"context": l.Context,
|
"context": l.Context,
|
||||||
@@ -191,7 +152,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := writeWithBackup(configPath, configData); err != nil {
|
if err := fileutil.WriteWithBackup(configPath, configData); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -243,7 +204,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return writeWithBackup(statePath, stateData)
|
return fileutil.WriteWithBackup(statePath, stateData)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *OpenCode) Models() []string {
|
func (o *OpenCode) Models() []string {
|
||||||
@@ -251,7 +212,7 @@ func (o *OpenCode) Models() []string {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
config, err := readJSONFile(filepath.Join(home, ".config", "opencode", "opencode.json"))
|
config, err := fileutil.ReadJSON(filepath.Join(home, ".config", "opencode", "opencode.json"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -1,8 +1,10 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -232,6 +234,44 @@ func TestOpenCodeEdit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("migrate Ollama (local) provider name", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"name":"Ollama (local)","npm":"@ai-sdk/openai-compatible","options":{"baseURL":"http://localhost:11434/v1"}}}}`), 0o644)
|
||||||
|
|
||||||
|
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := os.ReadFile(configPath)
|
||||||
|
var cfg map[string]any
|
||||||
|
json.Unmarshal(data, &cfg)
|
||||||
|
provider := cfg["provider"].(map[string]any)
|
||||||
|
ollama := provider["ollama"].(map[string]any)
|
||||||
|
if ollama["name"] != "Ollama" {
|
||||||
|
t.Errorf("provider name not migrated: got %q, want %q", ollama["name"], "Ollama")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserve custom provider name", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"name":"My Custom Ollama","npm":"@ai-sdk/openai-compatible","options":{"baseURL":"http://localhost:11434/v1"}}}}`), 0o644)
|
||||||
|
|
||||||
|
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := os.ReadFile(configPath)
|
||||||
|
var cfg map[string]any
|
||||||
|
json.Unmarshal(data, &cfg)
|
||||||
|
provider := cfg["provider"].(map[string]any)
|
||||||
|
ollama := provider["ollama"].(map[string]any)
|
||||||
|
if ollama["name"] != "My Custom Ollama" {
|
||||||
|
t.Errorf("custom provider name was changed: got %q, want %q", ollama["name"], "My Custom Ollama")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
|
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
|
||||||
cleanup()
|
cleanup()
|
||||||
os.MkdirAll(configDir, 0o755)
|
os.MkdirAll(configDir, 0o755)
|
||||||
@@ -619,6 +659,54 @@ func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenCodeEdit_BackfillsCloudModelLimitOnExistingEntry(t *testing.T) {
|
||||||
|
o := &OpenCode{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/api/show" {
|
||||||
|
fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"glm-5"}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||||
|
configPath := filepath.Join(configDir, "opencode.json")
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
os.WriteFile(configPath, []byte(`{
|
||||||
|
"provider": {
|
||||||
|
"ollama": {
|
||||||
|
"models": {
|
||||||
|
"glm-5:cloud": {
|
||||||
|
"name": "glm-5:cloud",
|
||||||
|
"_launch": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`), 0o644)
|
||||||
|
|
||||||
|
if err := o.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := readOpenCodeModel(t, configPath, "glm-5:cloud")
|
||||||
|
limit, ok := entry["limit"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("cloud model limit was not added on re-edit")
|
||||||
|
}
|
||||||
|
if limit["context"] != float64(202_752) {
|
||||||
|
t.Errorf("context = %v, want 202752", limit["context"])
|
||||||
|
}
|
||||||
|
if limit["output"] != float64(131_072) {
|
||||||
|
t.Errorf("output = %v, want 131072", limit["output"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestLookupCloudModelLimit(t *testing.T) {
|
func TestLookupCloudModelLimit(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -626,13 +714,19 @@ func TestLookupCloudModelLimit(t *testing.T) {
|
|||||||
wantContext int
|
wantContext int
|
||||||
wantOutput int
|
wantOutput int
|
||||||
}{
|
}{
|
||||||
{"glm-4.7", true, 202_752, 131_072},
|
{"glm-4.7", false, 0, 0},
|
||||||
{"glm-4.7:cloud", true, 202_752, 131_072},
|
{"glm-4.7:cloud", true, 202_752, 131_072},
|
||||||
{"kimi-k2.5", true, 262_144, 262_144},
|
{"glm-5:cloud", true, 202_752, 131_072},
|
||||||
|
{"gpt-oss:120b-cloud", true, 131_072, 131_072},
|
||||||
|
{"gpt-oss:20b-cloud", true, 131_072, 131_072},
|
||||||
|
{"kimi-k2.5", false, 0, 0},
|
||||||
{"kimi-k2.5:cloud", true, 262_144, 262_144},
|
{"kimi-k2.5:cloud", true, 262_144, 262_144},
|
||||||
{"deepseek-v3.2", true, 163_840, 65_536},
|
{"deepseek-v3.2", false, 0, 0},
|
||||||
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
|
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
|
||||||
{"qwen3-coder:480b", true, 262_144, 65_536},
|
{"qwen3.5", false, 0, 0},
|
||||||
|
{"qwen3.5:cloud", true, 262_144, 32_768},
|
||||||
|
{"qwen3-coder:480b", false, 0, 0},
|
||||||
|
{"qwen3-coder:480b:cloud", true, 262_144, 65_536},
|
||||||
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
|
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
|
||||||
{"llama3.2", false, 0, 0},
|
{"llama3.2", false, 0, 0},
|
||||||
{"unknown-model:cloud", false, 0, 0},
|
{"unknown-model:cloud", false, 0, 0},
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
@@ -26,15 +27,6 @@ func (p *Pi) Run(model string, args []string) error {
|
|||||||
return fmt.Errorf("pi is not installed, install with: npm install -g @mariozechner/pi-coding-agent")
|
return fmt.Errorf("pi is not installed, install with: npm install -g @mariozechner/pi-coding-agent")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call Edit() to ensure config is up-to-date before launch
|
|
||||||
models := []string{model}
|
|
||||||
if config, err := loadIntegration("pi"); err == nil && len(config.Models) > 0 {
|
|
||||||
models = config.Models
|
|
||||||
}
|
|
||||||
if err := p.Edit(models); err != nil {
|
|
||||||
return fmt.Errorf("setup failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command("pi", args...)
|
cmd := exec.Command("pi", args...)
|
||||||
cmd.Stdin = os.Stdin
|
cmd.Stdin = os.Stdin
|
||||||
cmd.Stdout = os.Stdout
|
cmd.Stdout = os.Stdout
|
||||||
@@ -107,7 +99,8 @@ func (p *Pi) Edit(models []string) error {
|
|||||||
|
|
||||||
// Build new models list:
|
// Build new models list:
|
||||||
// 1. Keep user-managed models (no _launch marker) - untouched
|
// 1. Keep user-managed models (no _launch marker) - untouched
|
||||||
// 2. Keep ollama-managed models (_launch marker) that are still selected
|
// 2. Keep ollama-managed models (_launch marker) that are still selected,
|
||||||
|
// except stale cloud entries that should be rebuilt below
|
||||||
// 3. Add new ollama-managed models
|
// 3. Add new ollama-managed models
|
||||||
var newModels []any
|
var newModels []any
|
||||||
for _, m := range existingModels {
|
for _, m := range existingModels {
|
||||||
@@ -117,7 +110,13 @@ func (p *Pi) Edit(models []string) error {
|
|||||||
if !isPiOllamaModel(modelObj) {
|
if !isPiOllamaModel(modelObj) {
|
||||||
newModels = append(newModels, m)
|
newModels = append(newModels, m)
|
||||||
} else if selectedSet[id] {
|
} else if selectedSet[id] {
|
||||||
// Ollama-managed and still selected - keep it
|
// Rebuild stale managed cloud entries so createConfig refreshes
|
||||||
|
// the whole entry instead of patching it in place.
|
||||||
|
if !hasContextWindow(modelObj) {
|
||||||
|
if _, ok := lookupCloudModelLimit(id); ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
newModels = append(newModels, m)
|
newModels = append(newModels, m)
|
||||||
selectedSet[id] = false
|
selectedSet[id] = false
|
||||||
}
|
}
|
||||||
@@ -142,7 +141,7 @@ func (p *Pi) Edit(models []string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := writeWithBackup(configPath, configData); err != nil {
|
if err := fileutil.WriteWithBackup(configPath, configData); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -160,7 +159,7 @@ func (p *Pi) Edit(models []string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return writeWithBackup(settingsPath, settingsData)
|
return fileutil.WriteWithBackup(settingsPath, settingsData)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Pi) Models() []string {
|
func (p *Pi) Models() []string {
|
||||||
@@ -170,7 +169,7 @@ func (p *Pi) Models() []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
configPath := filepath.Join(home, ".pi", "agent", "models.json")
|
configPath := filepath.Join(home, ".pi", "agent", "models.json")
|
||||||
config, err := readJSONFile(configPath)
|
config, err := fileutil.ReadJSON(configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -199,15 +198,38 @@ func isPiOllamaModel(cfg map[string]any) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hasContextWindow(cfg map[string]any) bool {
|
||||||
|
switch v := cfg["contextWindow"].(type) {
|
||||||
|
case float64:
|
||||||
|
return v > 0
|
||||||
|
case int:
|
||||||
|
return v > 0
|
||||||
|
case int64:
|
||||||
|
return v > 0
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// createConfig builds Pi model config with capability detection
|
// createConfig builds Pi model config with capability detection
|
||||||
func createConfig(ctx context.Context, client *api.Client, modelID string) map[string]any {
|
func createConfig(ctx context.Context, client *api.Client, modelID string) map[string]any {
|
||||||
cfg := map[string]any{
|
cfg := map[string]any{
|
||||||
"id": modelID,
|
"id": modelID,
|
||||||
"_launch": true,
|
"_launch": true,
|
||||||
}
|
}
|
||||||
|
if l, ok := lookupCloudModelLimit(modelID); ok {
|
||||||
|
cfg["contextWindow"] = l.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
applyCloudContextFallback := func() {
|
||||||
|
if l, ok := lookupCloudModelLimit(modelID); ok {
|
||||||
|
cfg["contextWindow"] = l.Context
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID})
|
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
applyCloudContextFallback()
|
||||||
return cfg
|
return cfg
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -223,15 +245,21 @@ func createConfig(ctx context.Context, client *api.Client, modelID string) map[s
|
|||||||
cfg["reasoning"] = true
|
cfg["reasoning"] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract context window from ModelInfo
|
// Extract context window from ModelInfo. For known cloud models, the
|
||||||
|
// pre-filled shared limit remains unless the server provides a positive value.
|
||||||
|
hasContextWindow := false
|
||||||
for key, val := range resp.ModelInfo {
|
for key, val := range resp.ModelInfo {
|
||||||
if strings.HasSuffix(key, ".context_length") {
|
if strings.HasSuffix(key, ".context_length") {
|
||||||
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
|
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
|
||||||
cfg["contextWindow"] = int(ctxLen)
|
cfg["contextWindow"] = int(ctxLen)
|
||||||
|
hasContextWindow = true
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if !hasContextWindow {
|
||||||
|
applyCloudContextFallback()
|
||||||
|
}
|
||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -192,6 +192,48 @@ func TestPiEdit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("rebuilds stale existing managed cloud model", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
|
||||||
|
existingConfig := `{
|
||||||
|
"providers": {
|
||||||
|
"ollama": {
|
||||||
|
"baseUrl": "http://localhost:11434/v1",
|
||||||
|
"api": "openai-completions",
|
||||||
|
"apiKey": "ollama",
|
||||||
|
"models": [
|
||||||
|
{"id": "glm-5:cloud", "_launch": true, "legacyField": "stale"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := pi.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||||
|
t.Fatalf("Edit() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := readConfig()
|
||||||
|
providers := cfg["providers"].(map[string]any)
|
||||||
|
ollama := providers["ollama"].(map[string]any)
|
||||||
|
modelsArray := ollama["models"].([]any)
|
||||||
|
modelEntry := modelsArray[0].(map[string]any)
|
||||||
|
|
||||||
|
if modelEntry["contextWindow"] != float64(202_752) {
|
||||||
|
t.Errorf("contextWindow = %v, want 202752", modelEntry["contextWindow"])
|
||||||
|
}
|
||||||
|
input, ok := modelEntry["input"].([]any)
|
||||||
|
if !ok || len(input) != 1 || input[0] != "text" {
|
||||||
|
t.Errorf("input = %v, want [text]", modelEntry["input"])
|
||||||
|
}
|
||||||
|
if _, ok := modelEntry["legacyField"]; ok {
|
||||||
|
t.Error("legacyField should be removed when stale managed cloud entry is rebuilt")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("replaces old models with new ones", func(t *testing.T) {
|
t.Run("replaces old models with new ones", func(t *testing.T) {
|
||||||
cleanup()
|
cleanup()
|
||||||
os.MkdirAll(configDir, 0o755)
|
os.MkdirAll(configDir, 0o755)
|
||||||
@@ -798,6 +840,59 @@ func TestCreateConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("cloud model falls back to hardcoded context when show fails", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, _ := url.Parse(srv.URL)
|
||||||
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
|
cfg := createConfig(context.Background(), client, "kimi-k2.5:cloud")
|
||||||
|
|
||||||
|
if cfg["contextWindow"] != 262_144 {
|
||||||
|
t.Errorf("contextWindow = %v, want 262144", cfg["contextWindow"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cloud model falls back to hardcoded context when show omits model info", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/api/show" {
|
||||||
|
fmt.Fprintf(w, `{"capabilities":[],"model_info":{}}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, _ := url.Parse(srv.URL)
|
||||||
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
|
cfg := createConfig(context.Background(), client, "glm-5:cloud")
|
||||||
|
|
||||||
|
if cfg["contextWindow"] != 202_752 {
|
||||||
|
t.Errorf("contextWindow = %v, want 202752", cfg["contextWindow"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cloud model with dash suffix falls back to hardcoded context", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, _ := url.Parse(srv.URL)
|
||||||
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
|
cfg := createConfig(context.Background(), client, "gpt-oss:120b-cloud")
|
||||||
|
|
||||||
|
if cfg["contextWindow"] != 131_072 {
|
||||||
|
t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
|
||||||
|
}
|
||||||
|
})
|
||||||
t.Run("skips zero context length", func(t *testing.T) {
|
t.Run("skips zero context length", func(t *testing.T) {
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.URL.Path == "/api/show" {
|
if r.URL.Path == "/api/show" {
|
||||||
369
cmd/launch/registry.go
Normal file
@@ -0,0 +1,369 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IntegrationInstallSpec describes how launcher should detect and guide installation.
|
||||||
|
type IntegrationInstallSpec struct {
|
||||||
|
CheckInstalled func() bool
|
||||||
|
EnsureInstalled func() error
|
||||||
|
URL string
|
||||||
|
Command []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntegrationSpec is the canonical registry entry for one integration.
|
||||||
|
type IntegrationSpec struct {
|
||||||
|
Name string
|
||||||
|
Runner Runner
|
||||||
|
Aliases []string
|
||||||
|
Hidden bool
|
||||||
|
Description string
|
||||||
|
Install IntegrationInstallSpec
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntegrationInfo contains display information about a registered integration.
|
||||||
|
type IntegrationInfo struct {
|
||||||
|
Name string
|
||||||
|
DisplayName string
|
||||||
|
Description string
|
||||||
|
}
|
||||||
|
|
||||||
|
var launcherIntegrationOrder = []string{"opencode", "droid", "pi"}
|
||||||
|
|
||||||
|
var integrationSpecs = []*IntegrationSpec{
|
||||||
|
{
|
||||||
|
Name: "claude",
|
||||||
|
Runner: &Claude{},
|
||||||
|
Description: "Anthropic's coding tool with subagents",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := (&Claude{}).findPath()
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
URL: "https://code.claude.com/docs/en/quickstart",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "cline",
|
||||||
|
Runner: &Cline{},
|
||||||
|
Description: "Autonomous coding agent with parallel execution",
|
||||||
|
Hidden: true,
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := exec.LookPath("cline")
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
Command: []string{"npm", "install", "-g", "cline"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "codex",
|
||||||
|
Runner: &Codex{},
|
||||||
|
Description: "OpenAI's open-source coding agent",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := exec.LookPath("codex")
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
URL: "https://developers.openai.com/codex/cli/",
|
||||||
|
Command: []string{"npm", "install", "-g", "@openai/codex"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "droid",
|
||||||
|
Runner: &Droid{},
|
||||||
|
Description: "Factory's coding agent across terminal and IDEs",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := exec.LookPath("droid")
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
URL: "https://docs.factory.ai/cli/getting-started/quickstart",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "opencode",
|
||||||
|
Runner: &OpenCode{},
|
||||||
|
Description: "Anomaly's open-source coding agent",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := exec.LookPath("opencode")
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
URL: "https://opencode.ai",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "openclaw",
|
||||||
|
Runner: &Openclaw{},
|
||||||
|
Aliases: []string{"clawdbot", "moltbot"},
|
||||||
|
Description: "Personal AI with 100+ skills",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
if _, err := exec.LookPath("openclaw"); err == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if _, err := exec.LookPath("clawdbot"); err == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
},
|
||||||
|
EnsureInstalled: func() error {
|
||||||
|
_, err := ensureOpenclawInstalled()
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
URL: "https://docs.openclaw.ai",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "pi",
|
||||||
|
Runner: &Pi{},
|
||||||
|
Description: "Minimal AI agent toolkit with plugin support",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := exec.LookPath("pi")
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
Command: []string{"npm", "install", "-g", "@mariozechner/pi-coding-agent"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "vscode",
|
||||||
|
Runner: &VSCode{},
|
||||||
|
Aliases: []string{"code"},
|
||||||
|
Description: "Microsoft's open-source AI code editor",
|
||||||
|
Hidden: true,
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
return (&VSCode{}).findBinary() != ""
|
||||||
|
},
|
||||||
|
URL: "https://code.visualstudio.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var integrationSpecsByName map[string]*IntegrationSpec
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rebuildIntegrationSpecIndexes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func hyperlink(url, text string) string {
|
||||||
|
return fmt.Sprintf("\033]8;;%s\033\\%s\033]8;;\033\\", url, text)
|
||||||
|
}
|
||||||
|
|
||||||
|
func rebuildIntegrationSpecIndexes() {
|
||||||
|
integrationSpecsByName = make(map[string]*IntegrationSpec, len(integrationSpecs))
|
||||||
|
|
||||||
|
canonical := make(map[string]bool, len(integrationSpecs))
|
||||||
|
for _, spec := range integrationSpecs {
|
||||||
|
key := strings.ToLower(spec.Name)
|
||||||
|
if key == "" {
|
||||||
|
panic("launch: integration spec missing name")
|
||||||
|
}
|
||||||
|
if canonical[key] {
|
||||||
|
panic(fmt.Sprintf("launch: duplicate integration name %q", key))
|
||||||
|
}
|
||||||
|
canonical[key] = true
|
||||||
|
integrationSpecsByName[key] = spec
|
||||||
|
}
|
||||||
|
|
||||||
|
seenAliases := make(map[string]string)
|
||||||
|
for _, spec := range integrationSpecs {
|
||||||
|
for _, alias := range spec.Aliases {
|
||||||
|
key := strings.ToLower(alias)
|
||||||
|
if key == "" {
|
||||||
|
panic(fmt.Sprintf("launch: integration %q has empty alias", spec.Name))
|
||||||
|
}
|
||||||
|
if canonical[key] {
|
||||||
|
panic(fmt.Sprintf("launch: alias %q collides with canonical integration name", key))
|
||||||
|
}
|
||||||
|
if owner, exists := seenAliases[key]; exists {
|
||||||
|
panic(fmt.Sprintf("launch: alias %q collides between %q and %q", key, owner, spec.Name))
|
||||||
|
}
|
||||||
|
seenAliases[key] = spec.Name
|
||||||
|
integrationSpecsByName[key] = spec
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
orderSeen := make(map[string]bool, len(launcherIntegrationOrder))
|
||||||
|
for _, name := range launcherIntegrationOrder {
|
||||||
|
key := strings.ToLower(name)
|
||||||
|
if orderSeen[key] {
|
||||||
|
panic(fmt.Sprintf("launch: duplicate launcher order entry %q", key))
|
||||||
|
}
|
||||||
|
orderSeen[key] = true
|
||||||
|
|
||||||
|
spec, ok := integrationSpecsByName[key]
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Sprintf("launch: unknown launcher order entry %q", key))
|
||||||
|
}
|
||||||
|
if spec.Name != key {
|
||||||
|
panic(fmt.Sprintf("launch: launcher order entry %q must use canonical name, not alias", key))
|
||||||
|
}
|
||||||
|
if spec.Hidden {
|
||||||
|
panic(fmt.Sprintf("launch: hidden integration %q cannot appear in launcher order", key))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LookupIntegrationSpec resolves either a canonical integration name or alias to its spec.
|
||||||
|
func LookupIntegrationSpec(name string) (*IntegrationSpec, error) {
|
||||||
|
spec, ok := integrationSpecsByName[strings.ToLower(name)]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unknown integration: %s", name)
|
||||||
|
}
|
||||||
|
return spec, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LookupIntegration resolves a registry name to the canonical key and runner.
|
||||||
|
func LookupIntegration(name string) (string, Runner, error) {
|
||||||
|
spec, err := LookupIntegrationSpec(name)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
return spec.Name, spec.Runner, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListVisibleIntegrationSpecs returns the canonical integrations that should appear in interactive UIs.
|
||||||
|
func ListVisibleIntegrationSpecs() []IntegrationSpec {
|
||||||
|
visible := make([]IntegrationSpec, 0, len(integrationSpecs))
|
||||||
|
for _, spec := range integrationSpecs {
|
||||||
|
if spec.Hidden {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
visible = append(visible, *spec)
|
||||||
|
}
|
||||||
|
|
||||||
|
orderRank := make(map[string]int, len(launcherIntegrationOrder))
|
||||||
|
for i, name := range launcherIntegrationOrder {
|
||||||
|
orderRank[name] = i + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
slices.SortFunc(visible, func(a, b IntegrationSpec) int {
|
||||||
|
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
|
||||||
|
if aRank > 0 && bRank > 0 {
|
||||||
|
return aRank - bRank
|
||||||
|
}
|
||||||
|
if aRank > 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if bRank > 0 {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return strings.Compare(a.Name, b.Name)
|
||||||
|
})
|
||||||
|
|
||||||
|
return visible
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListIntegrationInfos returns the registered integrations in launcher display order.
|
||||||
|
func ListIntegrationInfos() []IntegrationInfo {
|
||||||
|
visible := ListVisibleIntegrationSpecs()
|
||||||
|
infos := make([]IntegrationInfo, 0, len(visible))
|
||||||
|
for _, spec := range visible {
|
||||||
|
infos = append(infos, IntegrationInfo{
|
||||||
|
Name: spec.Name,
|
||||||
|
DisplayName: spec.Runner.String(),
|
||||||
|
Description: spec.Description,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return infos
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntegrationSelectionItems returns the sorted integration items shown by launcher selection UIs.
|
||||||
|
func IntegrationSelectionItems() ([]ModelItem, error) {
|
||||||
|
visible := ListVisibleIntegrationSpecs()
|
||||||
|
if len(visible) == 0 {
|
||||||
|
return nil, fmt.Errorf("no integrations available")
|
||||||
|
}
|
||||||
|
|
||||||
|
items := make([]ModelItem, 0, len(visible))
|
||||||
|
for _, spec := range visible {
|
||||||
|
description := spec.Runner.String()
|
||||||
|
if conn, err := loadStoredIntegrationConfig(spec.Name); err == nil && len(conn.Models) > 0 {
|
||||||
|
description = fmt.Sprintf("%s (%s)", spec.Runner.String(), conn.Models[0])
|
||||||
|
}
|
||||||
|
items = append(items, ModelItem{Name: spec.Name, Description: description})
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsIntegrationInstalled checks if an integration binary is installed.
|
||||||
|
func IsIntegrationInstalled(name string) bool {
|
||||||
|
integration, err := integrationFor(name)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Ollama couldn't find integration %q, so it'll show up as not installed.\n", name)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return integration.installed
|
||||||
|
}
|
||||||
|
|
||||||
|
// integration is resolved registry metadata used by launcher state and install checks.
|
||||||
|
// It combines immutable registry spec data with computed runtime traits.
|
||||||
|
type integration struct {
|
||||||
|
spec *IntegrationSpec
|
||||||
|
installed bool
|
||||||
|
autoInstallable bool
|
||||||
|
editor bool
|
||||||
|
installHint string
|
||||||
|
}
|
||||||
|
|
||||||
|
// integrationFor resolves an integration name into the canonical spec plus
|
||||||
|
// derived launcher/install traits used across registry and launch flows.
|
||||||
|
func integrationFor(name string) (integration, error) {
|
||||||
|
spec, err := LookupIntegrationSpec(name)
|
||||||
|
if err != nil {
|
||||||
|
return integration{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
installed := true
|
||||||
|
if spec.Install.CheckInstalled != nil {
|
||||||
|
installed = spec.Install.CheckInstalled()
|
||||||
|
}
|
||||||
|
|
||||||
|
_, editor := spec.Runner.(Editor)
|
||||||
|
hint := ""
|
||||||
|
if spec.Install.URL != "" {
|
||||||
|
hint = "Install from " + hyperlink(spec.Install.URL, spec.Install.URL)
|
||||||
|
} else if len(spec.Install.Command) > 0 {
|
||||||
|
hint = "Install with: " + strings.Join(spec.Install.Command, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
return integration{
|
||||||
|
spec: spec,
|
||||||
|
installed: installed,
|
||||||
|
autoInstallable: spec.Install.EnsureInstalled != nil,
|
||||||
|
editor: editor,
|
||||||
|
installHint: hint,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnsureIntegrationInstalled installs auto-installable integrations when missing.
|
||||||
|
func EnsureIntegrationInstalled(name string, runner Runner) error {
|
||||||
|
integration, err := integrationFor(name)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%s is not installed", runner)
|
||||||
|
}
|
||||||
|
|
||||||
|
if integration.installed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if integration.autoInstallable {
|
||||||
|
return integration.spec.Install.EnsureInstalled()
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case integration.spec.Install.URL != "":
|
||||||
|
return fmt.Errorf("%s is not installed, install from %s", integration.spec.Name, integration.spec.Install.URL)
|
||||||
|
case len(integration.spec.Install.Command) > 0:
|
||||||
|
return fmt.Errorf("%s is not installed, install with: %s", integration.spec.Name, strings.Join(integration.spec.Install.Command, " "))
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%s is not installed", runner)
|
||||||
|
}
|
||||||
|
}
|
||||||
21
cmd/launch/registry_test_helpers_test.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// OverrideIntegration replaces one registry entry's runner for tests and returns a restore function.
|
||||||
|
func OverrideIntegration(name string, runner Runner) func() {
|
||||||
|
spec, err := LookupIntegrationSpec(name)
|
||||||
|
if err != nil {
|
||||||
|
key := strings.ToLower(name)
|
||||||
|
integrationSpecsByName[key] = &IntegrationSpec{Name: key, Runner: runner}
|
||||||
|
return func() {
|
||||||
|
delete(integrationSpecsByName, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
original := spec.Runner
|
||||||
|
spec.Runner = runner
|
||||||
|
return func() {
|
||||||
|
spec.Runner = original
|
||||||
|
}
|
||||||
|
}
|
||||||
68
cmd/launch/runner_exec_only_test.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
binary string
|
||||||
|
runner Runner
|
||||||
|
checkPath func(home string) string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "droid",
|
||||||
|
binary: "droid",
|
||||||
|
runner: &Droid{},
|
||||||
|
checkPath: func(home string) string {
|
||||||
|
return filepath.Join(home, ".factory", "settings.json")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "opencode",
|
||||||
|
binary: "opencode",
|
||||||
|
runner: &OpenCode{},
|
||||||
|
checkPath: func(home string) string {
|
||||||
|
return filepath.Join(home, ".config", "opencode", "opencode.json")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cline",
|
||||||
|
binary: "cline",
|
||||||
|
runner: &Cline{},
|
||||||
|
checkPath: func(home string) string {
|
||||||
|
return filepath.Join(home, ".cline", "data", "globalState.json")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "pi",
|
||||||
|
binary: "pi",
|
||||||
|
runner: &Pi{},
|
||||||
|
checkPath: func(home string) string {
|
||||||
|
return filepath.Join(home, ".pi", "agent", "models.json")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
home := t.TempDir()
|
||||||
|
setTestHome(t, home)
|
||||||
|
|
||||||
|
binDir := t.TempDir()
|
||||||
|
writeFakeBinary(t, binDir, tt.binary)
|
||||||
|
t.Setenv("PATH", binDir)
|
||||||
|
|
||||||
|
configPath := tt.checkPath(home)
|
||||||
|
if err := tt.runner.Run("llama3.2", nil); err != nil {
|
||||||
|
t.Fatalf("Run returned error: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(configPath); !os.IsNotExist(err) {
|
||||||
|
t.Fatalf("expected Run to leave %s untouched, got err=%v", configPath, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
103
cmd/launch/selector_hooks.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"golang.org/x/term"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ANSI escape sequences for terminal formatting.
|
||||||
|
const (
|
||||||
|
ansiBold = "\033[1m"
|
||||||
|
ansiReset = "\033[0m"
|
||||||
|
ansiGray = "\033[37m"
|
||||||
|
ansiGreen = "\033[32m"
|
||||||
|
ansiYellow = "\033[33m"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrCancelled is returned when the user cancels a selection.
|
||||||
|
var ErrCancelled = errors.New("cancelled")
|
||||||
|
|
||||||
|
// errCancelled is kept as an internal alias for existing call sites.
|
||||||
|
var errCancelled = ErrCancelled
|
||||||
|
|
||||||
|
// DefaultConfirmPrompt provides a TUI-based confirmation prompt.
|
||||||
|
// When set, ConfirmPrompt delegates to it instead of using raw terminal I/O.
|
||||||
|
var DefaultConfirmPrompt func(prompt string) (bool, error)
|
||||||
|
|
||||||
|
// SingleSelector is a function type for single item selection.
|
||||||
|
// current is the name of the previously selected item to highlight; empty means no pre-selection.
|
||||||
|
type SingleSelector func(title string, items []ModelItem, current string) (string, error)
|
||||||
|
|
||||||
|
// MultiSelector is a function type for multi item selection.
|
||||||
|
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
|
||||||
|
|
||||||
|
// DefaultSingleSelector is the default single-select implementation.
|
||||||
|
var DefaultSingleSelector SingleSelector
|
||||||
|
|
||||||
|
// DefaultMultiSelector is the default multi-select implementation.
|
||||||
|
var DefaultMultiSelector MultiSelector
|
||||||
|
|
||||||
|
// DefaultSignIn provides a TUI-based sign-in flow.
|
||||||
|
// When set, ensureAuth uses it instead of plain text prompts.
|
||||||
|
// Returns the signed-in username or an error.
|
||||||
|
var DefaultSignIn func(modelName, signInURL string) (string, error)
|
||||||
|
|
||||||
|
type launchConfirmPolicy struct {
|
||||||
|
yes bool
|
||||||
|
requireYesMessage bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var currentLaunchConfirmPolicy launchConfirmPolicy
|
||||||
|
|
||||||
|
func withLaunchConfirmPolicy(policy launchConfirmPolicy) func() {
|
||||||
|
old := currentLaunchConfirmPolicy
|
||||||
|
currentLaunchConfirmPolicy = policy
|
||||||
|
return func() {
|
||||||
|
currentLaunchConfirmPolicy = old
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfirmPrompt is the shared confirmation gate for launch flows (integration
|
||||||
|
// edits, missing-model pulls, sign-in prompts, OpenClaw install/security, etc).
|
||||||
|
// Behavior is controlled by currentLaunchConfirmPolicy, typically scoped by
|
||||||
|
// withLaunchConfirmPolicy in LaunchCmd (e.g. auto-approve with --yes).
|
||||||
|
func ConfirmPrompt(prompt string) (bool, error) {
|
||||||
|
if currentLaunchConfirmPolicy.yes {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
if currentLaunchConfirmPolicy.requireYesMessage {
|
||||||
|
return false, fmt.Errorf("%s requires confirmation; re-run with --yes to continue", prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
if DefaultConfirmPrompt != nil {
|
||||||
|
return DefaultConfirmPrompt(prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
fd := int(os.Stdin.Fd())
|
||||||
|
oldState, err := term.MakeRaw(fd)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
defer term.Restore(fd, oldState)
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
|
||||||
|
|
||||||
|
buf := make([]byte, 1)
|
||||||
|
for {
|
||||||
|
if _, err := os.Stdin.Read(buf); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch buf[0] {
|
||||||
|
case 'Y', 'y', 13:
|
||||||
|
fmt.Fprintf(os.Stderr, "yes\r\n")
|
||||||
|
return true, nil
|
||||||
|
case 'N', 'n', 27, 3:
|
||||||
|
fmt.Fprintf(os.Stderr, "no\r\n")
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
76
cmd/launch/selector_test.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestErrCancelled(t *testing.T) {
|
||||||
|
t.Run("NotNil", func(t *testing.T) {
|
||||||
|
if errCancelled == nil {
|
||||||
|
t.Error("errCancelled should not be nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Message", func(t *testing.T) {
|
||||||
|
if errCancelled.Error() != "cancelled" {
|
||||||
|
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithLaunchConfirmPolicy_ScopesAndRestores(t *testing.T) {
|
||||||
|
oldPolicy := currentLaunchConfirmPolicy
|
||||||
|
oldHook := DefaultConfirmPrompt
|
||||||
|
t.Cleanup(func() {
|
||||||
|
currentLaunchConfirmPolicy = oldPolicy
|
||||||
|
DefaultConfirmPrompt = oldHook
|
||||||
|
})
|
||||||
|
|
||||||
|
currentLaunchConfirmPolicy = launchConfirmPolicy{}
|
||||||
|
var hookCalls int
|
||||||
|
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||||
|
hookCalls++
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
restoreOuter := withLaunchConfirmPolicy(launchConfirmPolicy{requireYesMessage: true})
|
||||||
|
restoreInner := withLaunchConfirmPolicy(launchConfirmPolicy{yes: true})
|
||||||
|
|
||||||
|
ok, err := ConfirmPrompt("test prompt")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected --yes policy to allow prompt, got error: %v", err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected --yes policy to auto-accept prompt")
|
||||||
|
}
|
||||||
|
if hookCalls != 0 {
|
||||||
|
t.Fatalf("expected --yes to skip hook, got %d hook calls", hookCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
restoreInner()
|
||||||
|
|
||||||
|
_, err = ConfirmPrompt("test prompt")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected requireYesMessage policy to block prompt")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "re-run with --yes") {
|
||||||
|
t.Fatalf("expected actionable --yes error, got: %v", err)
|
||||||
|
}
|
||||||
|
if hookCalls != 0 {
|
||||||
|
t.Fatalf("expected blocking policy to skip hook, got %d hook calls", hookCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
restoreOuter()
|
||||||
|
|
||||||
|
ok, err = ConfirmPrompt("test prompt")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected restored default behavior to use hook, got error: %v", err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected hook to return true")
|
||||||
|
}
|
||||||
|
if hookCalls != 1 {
|
||||||
|
t.Fatalf("expected one hook call after restore, got %d", hookCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
82
cmd/launch/test_config_helpers_test.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/cmd/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
integrations map[string]Runner
|
||||||
|
integrationAliases map[string]bool
|
||||||
|
integrationOrder = launcherIntegrationOrder
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
integrations = buildTestIntegrations()
|
||||||
|
integrationAliases = buildTestIntegrationAliases()
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildTestIntegrations() map[string]Runner {
|
||||||
|
result := make(map[string]Runner, len(integrationSpecsByName))
|
||||||
|
for name, spec := range integrationSpecsByName {
|
||||||
|
result[strings.ToLower(name)] = spec.Runner
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildTestIntegrationAliases() map[string]bool {
|
||||||
|
result := make(map[string]bool)
|
||||||
|
for _, spec := range integrationSpecs {
|
||||||
|
for _, alias := range spec.Aliases {
|
||||||
|
result[strings.ToLower(alias)] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func setTestHome(t *testing.T, dir string) {
|
||||||
|
t.Helper()
|
||||||
|
setLaunchTestHome(t, dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SaveIntegration(appName string, models []string) error {
|
||||||
|
return config.SaveIntegration(appName, models)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadIntegration(appName string) (*config.IntegrationConfig, error) {
|
||||||
|
return config.LoadIntegration(appName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SaveAliases(appName string, aliases map[string]string) error {
|
||||||
|
return config.SaveAliases(appName, aliases)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LastModel() string {
|
||||||
|
return config.LastModel()
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetLastModel(model string) error {
|
||||||
|
return config.SetLastModel(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LastSelection() string {
|
||||||
|
return config.LastSelection()
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetLastSelection(selection string) error {
|
||||||
|
return config.SetLastSelection(selection)
|
||||||
|
}
|
||||||
|
|
||||||
|
func IntegrationModel(appName string) string {
|
||||||
|
return config.IntegrationModel(appName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func IntegrationModels(appName string) []string {
|
||||||
|
return config.IntegrationModels(appName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func integrationOnboarded(appName string) error {
|
||||||
|
return config.MarkIntegrationOnboarded(appName)
|
||||||
|
}
|
||||||
591
cmd/launch/vscode.go
Normal file
@@ -0,0 +1,591 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
)
|
||||||
|
|
||||||
|
// VSCode implements Runner and Editor for Visual Studio Code integration.
|
||||||
|
type VSCode struct{}
|
||||||
|
|
||||||
|
func (v *VSCode) String() string { return "Visual Studio Code" }
|
||||||
|
|
||||||
|
// findBinary returns the path/command to launch VS Code, or "" if not found.
|
||||||
|
// It checks platform-specific locations only.
|
||||||
|
func (v *VSCode) findBinary() string {
|
||||||
|
var candidates []string
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
candidates = []string{
|
||||||
|
"/Applications/Visual Studio Code.app",
|
||||||
|
}
|
||||||
|
case "windows":
|
||||||
|
if localAppData := os.Getenv("LOCALAPPDATA"); localAppData != "" {
|
||||||
|
candidates = append(candidates, filepath.Join(localAppData, "Programs", "Microsoft VS Code", "bin", "code.cmd"))
|
||||||
|
}
|
||||||
|
default: // linux
|
||||||
|
candidates = []string{
|
||||||
|
"/usr/bin/code",
|
||||||
|
"/snap/bin/code",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, c := range candidates {
|
||||||
|
if _, err := os.Stat(c); err == nil {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsRunning reports whether VS Code is currently running.
|
||||||
|
// Each platform uses a pattern specific enough to avoid matching Cursor or
|
||||||
|
// other VS Code forks.
|
||||||
|
func (v *VSCode) IsRunning() bool {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
out, err := exec.Command("pgrep", "-f", "Visual Studio Code.app/Contents/MacOS/Code").Output()
|
||||||
|
return err == nil && len(out) > 0
|
||||||
|
case "windows":
|
||||||
|
// Match VS Code by executable path to avoid matching Cursor or other forks.
|
||||||
|
out, err := exec.Command("powershell", "-NoProfile", "-Command",
|
||||||
|
`Get-Process Code -ErrorAction SilentlyContinue | Where-Object { $_.Path -like '*Microsoft VS Code*' } | Select-Object -First 1`).Output()
|
||||||
|
return err == nil && len(strings.TrimSpace(string(out))) > 0
|
||||||
|
default:
|
||||||
|
// Match VS Code specifically by its install path to avoid matching
|
||||||
|
// Cursor (/cursor/) or other forks.
|
||||||
|
for _, pattern := range []string{"/usr/share/code/", "/snap/code/"} {
|
||||||
|
out, err := exec.Command("pgrep", "-f", pattern).Output()
|
||||||
|
if err == nil && len(out) > 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Quit gracefully quits VS Code and waits for it to exit so that it flushes
|
||||||
|
// its in-memory state back to the database.
|
||||||
|
func (v *VSCode) Quit() {
|
||||||
|
if !v.IsRunning() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
_ = exec.Command("osascript", "-e", `quit app "Visual Studio Code"`).Run()
|
||||||
|
case "windows":
|
||||||
|
// Kill VS Code by executable path to avoid killing Cursor or other forks.
|
||||||
|
_ = exec.Command("powershell", "-NoProfile", "-Command",
|
||||||
|
`Get-Process Code -ErrorAction SilentlyContinue | Where-Object { $_.Path -like '*Microsoft VS Code*' } | Stop-Process -Force`).Run()
|
||||||
|
default:
|
||||||
|
for _, pattern := range []string{"/usr/share/code/", "/snap/code/"} {
|
||||||
|
_ = exec.Command("pkill", "-f", pattern).Run()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Wait for the process to fully exit and flush its state to disk
|
||||||
|
// TODO(hoyyeva): update spinner to use bubble tea
|
||||||
|
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||||
|
frame := 0
|
||||||
|
fmt.Fprintf(os.Stderr, "\033[90mRestarting VS Code... %s\033[0m", spinnerFrames[0])
|
||||||
|
|
||||||
|
ticker := time.NewTicker(200 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for range 150 { // 150 ticks × 200ms = 30s timeout
|
||||||
|
<-ticker.C
|
||||||
|
frame++
|
||||||
|
fmt.Fprintf(os.Stderr, "\r\033[90mRestarting VS Code... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||||
|
|
||||||
|
if frame%5 == 0 { // check every ~1s
|
||||||
|
if !v.IsRunning() {
|
||||||
|
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||||
|
// Give VS Code a moment to finish writing its state DB
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
minCopilotChatVersion = "0.41.0"
|
||||||
|
minVSCodeVersion = "1.113"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (v *VSCode) Run(model string, args []string) error {
|
||||||
|
v.checkVSCodeVersion()
|
||||||
|
v.checkCopilotChatVersion()
|
||||||
|
|
||||||
|
// Get all configured models (saved by the launcher framework before Run is called)
|
||||||
|
models := []string{model}
|
||||||
|
if cfg, err := loadStoredIntegrationConfig("vscode"); err == nil && len(cfg.Models) > 0 {
|
||||||
|
models = cfg.Models
|
||||||
|
}
|
||||||
|
|
||||||
|
// VS Code discovers models from ollama ls. Cloud models that pass Show
|
||||||
|
// (the server knows about them) but aren't in ls need to be pulled to
|
||||||
|
// register them so VS Code can find them.
|
||||||
|
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||||
|
v.ensureModelsRegistered(context.Background(), client, models)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warn if the default model doesn't support tool calling
|
||||||
|
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||||
|
if resp, err := client.Show(context.Background(), &api.ShowRequest{Model: models[0]}); err == nil {
|
||||||
|
hasTools := false
|
||||||
|
for _, c := range resp.Capabilities {
|
||||||
|
if c == "tools" {
|
||||||
|
hasTools = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasTools {
|
||||||
|
fmt.Fprintf(os.Stderr, "Note: %s does not support tool calling and may not appear in the Copilot Chat model picker.\n", models[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
v.printModelAccessTip()
|
||||||
|
|
||||||
|
if v.IsRunning() {
|
||||||
|
restart, err := ConfirmPrompt("Restart VS Code?")
|
||||||
|
if err != nil {
|
||||||
|
restart = false
|
||||||
|
}
|
||||||
|
if restart {
|
||||||
|
v.Quit()
|
||||||
|
if err := v.ShowInModelPicker(models); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s Warning: could not update VS Code model picker: %v%s\n", ansiYellow, err, ansiReset)
|
||||||
|
}
|
||||||
|
v.FocusVSCode()
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "\nTo get the latest model configuration, restart VS Code when you're ready.\n")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := v.ShowInModelPicker(models); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s Warning: could not update VS Code model picker: %v%s\n", ansiYellow, err, ansiReset)
|
||||||
|
}
|
||||||
|
v.FocusVSCode()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureModelsRegistered pulls models that the server knows about (Show succeeds)
|
||||||
|
// but aren't in ollama ls yet. This is needed for cloud models so that VS Code
|
||||||
|
// can discover them from the Ollama API.
|
||||||
|
func (v *VSCode) ensureModelsRegistered(ctx context.Context, client *api.Client, models []string) {
|
||||||
|
listed, err := client.List(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
registered := make(map[string]bool, len(listed.Models))
|
||||||
|
for _, m := range listed.Models {
|
||||||
|
registered[m.Name] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, model := range models {
|
||||||
|
if registered[model] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Also check without :latest suffix
|
||||||
|
if !strings.Contains(model, ":") && registered[model+":latest"] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := pullModel(ctx, client, model, false); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s Warning: could not register model %s: %v%s\n", ansiYellow, model, err, ansiReset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FocusVSCode brings VS Code to the foreground.
|
||||||
|
func (v *VSCode) FocusVSCode() {
|
||||||
|
binary := v.findBinary()
|
||||||
|
if binary == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if runtime.GOOS == "darwin" && strings.HasSuffix(binary, ".app") {
|
||||||
|
_ = exec.Command("open", "-a", binary).Run()
|
||||||
|
} else {
|
||||||
|
_ = exec.Command(binary).Start()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// printModelAccessTip shows instructions for finding Ollama models in VS Code.
|
||||||
|
func (v *VSCode) printModelAccessTip() {
|
||||||
|
fmt.Fprintf(os.Stderr, "\nTip: To use Ollama models, open Copilot Chat and click the model picker.\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " If you don't see your models, click \"Other models\" to find them.\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *VSCode) Paths() []string {
|
||||||
|
if p := v.chatLanguageModelsPath(); fileExists(p) {
|
||||||
|
return []string{p}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *VSCode) Edit(models []string) error {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write chatLanguageModels.json with Ollama vendor entry
|
||||||
|
clmPath := v.chatLanguageModelsPath()
|
||||||
|
if err := os.MkdirAll(filepath.Dir(clmPath), 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var entries []map[string]any
|
||||||
|
if data, err := os.ReadFile(clmPath); err == nil {
|
||||||
|
_ = json.Unmarshal(data, &entries)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove any existing Ollama entries, preserve others
|
||||||
|
filtered := make([]map[string]any, 0, len(entries))
|
||||||
|
for _, entry := range entries {
|
||||||
|
if vendor, _ := entry["vendor"].(string); vendor != "ollama" {
|
||||||
|
filtered = append(filtered, entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new Ollama entry
|
||||||
|
filtered = append(filtered, map[string]any{
|
||||||
|
"vendor": "ollama",
|
||||||
|
"name": "Ollama",
|
||||||
|
"url": envconfig.Host().String(),
|
||||||
|
})
|
||||||
|
|
||||||
|
data, err := json.MarshalIndent(filtered, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := fileutil.WriteWithBackup(clmPath, data); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up legacy settings from older Ollama integrations
|
||||||
|
v.updateSettings()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *VSCode) Models() []string {
|
||||||
|
if !v.hasOllamaVendor() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if cfg, err := loadStoredIntegrationConfig("vscode"); err == nil {
|
||||||
|
return cfg.Models
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasOllamaVendor checks if chatLanguageModels.json contains an Ollama vendor entry.
|
||||||
|
func (v *VSCode) hasOllamaVendor() bool {
|
||||||
|
data, err := os.ReadFile(v.chatLanguageModelsPath())
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var entries []map[string]any
|
||||||
|
if err := json.Unmarshal(data, &entries); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, entry := range entries {
|
||||||
|
if vendor, _ := entry["vendor"].(string); vendor == "ollama" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *VSCode) chatLanguageModelsPath() string {
|
||||||
|
return v.vscodePath("chatLanguageModels.json")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *VSCode) settingsPath() string {
|
||||||
|
return v.vscodePath("settings.json")
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateSettings cleans up legacy settings from older Ollama integrations.
|
||||||
|
func (v *VSCode) updateSettings() {
|
||||||
|
settingsPath := v.settingsPath()
|
||||||
|
data, err := os.ReadFile(settingsPath)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var settings map[string]any
|
||||||
|
if err := json.Unmarshal(data, &settings); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
changed := false
|
||||||
|
for _, key := range []string{"github.copilot.chat.byok.ollamaEndpoint", "ollama.launch.configured"} {
|
||||||
|
if _, ok := settings[key]; ok {
|
||||||
|
delete(settings, key)
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !changed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, err := json.MarshalIndent(settings, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = fileutil.WriteWithBackup(settingsPath, updated)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *VSCode) statePath() string {
|
||||||
|
return v.vscodePath("globalStorage", "state.vscdb")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShowInModelPicker ensures the given models are visible in VS Code's Copilot
|
||||||
|
// Chat model picker. It sets the configured models to true in the picker
|
||||||
|
// preferences so they appear in the dropdown. Models use the VS Code identifier
|
||||||
|
// format "ollama/Ollama/<name>".
|
||||||
|
func (v *VSCode) ShowInModelPicker(models []string) error {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dbPath := v.statePath()
|
||||||
|
needsCreate := !fileExists(dbPath)
|
||||||
|
if needsCreate {
|
||||||
|
if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil {
|
||||||
|
return fmt.Errorf("creating state directory: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := sql.Open("sqlite3", dbPath+"?_busy_timeout=5000")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("opening state database: %w", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
// Create the table if this is a fresh DB. Schema must match what VS Code creates.
|
||||||
|
if needsCreate {
|
||||||
|
if _, err := db.Exec("CREATE TABLE ItemTable (key TEXT UNIQUE ON CONFLICT REPLACE, value BLOB)"); err != nil {
|
||||||
|
return fmt.Errorf("initializing state database: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read existing preferences
|
||||||
|
prefs := make(map[string]bool)
|
||||||
|
var prefsJSON string
|
||||||
|
if err := db.QueryRow("SELECT value FROM ItemTable WHERE key = 'chatModelPickerPreferences'").Scan(&prefsJSON); err == nil {
|
||||||
|
_ = json.Unmarshal([]byte(prefsJSON), &prefs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build name→ID map from VS Code's cached model list.
|
||||||
|
// VS Code uses numeric IDs like "ollama/Ollama/4", not "ollama/Ollama/kimi-k2.5:cloud".
|
||||||
|
nameToID := make(map[string]string)
|
||||||
|
var cacheJSON string
|
||||||
|
if err := db.QueryRow("SELECT value FROM ItemTable WHERE key = 'chat.cachedLanguageModels.v2'").Scan(&cacheJSON); err == nil {
|
||||||
|
var cached []map[string]any
|
||||||
|
if json.Unmarshal([]byte(cacheJSON), &cached) == nil {
|
||||||
|
for _, entry := range cached {
|
||||||
|
meta, _ := entry["metadata"].(map[string]any)
|
||||||
|
if meta == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if vendor, _ := meta["vendor"].(string); vendor == "ollama" {
|
||||||
|
name, _ := meta["name"].(string)
|
||||||
|
id, _ := entry["identifier"].(string)
|
||||||
|
if name != "" && id != "" {
|
||||||
|
nameToID[name] = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ollama config is authoritative: always show configured models,
|
||||||
|
// hide Ollama models that are no longer in the config.
|
||||||
|
configuredIDs := make(map[string]bool)
|
||||||
|
for _, m := range models {
|
||||||
|
for _, id := range v.modelVSCodeIDs(m, nameToID) {
|
||||||
|
prefs[id] = true
|
||||||
|
configuredIDs[id] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for id := range prefs {
|
||||||
|
if strings.HasPrefix(id, "ollama/") && !configuredIDs[id] {
|
||||||
|
prefs[id] = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := json.Marshal(prefs)
|
||||||
|
if _, err = db.Exec("INSERT OR REPLACE INTO ItemTable (key, value) VALUES ('chatModelPickerPreferences', ?)", string(data)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelVSCodeIDs returns all possible VS Code picker IDs for a model name.
|
||||||
|
func (v *VSCode) modelVSCodeIDs(model string, nameToID map[string]string) []string {
|
||||||
|
var ids []string
|
||||||
|
if id, ok := nameToID[model]; ok {
|
||||||
|
ids = append(ids, id)
|
||||||
|
} else if !strings.Contains(model, ":") {
|
||||||
|
if id, ok := nameToID[model+":latest"]; ok {
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ids = append(ids, "ollama/Ollama/"+model)
|
||||||
|
if !strings.Contains(model, ":") {
|
||||||
|
ids = append(ids, "ollama/Ollama/"+model+":latest")
|
||||||
|
}
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *VSCode) vscodePath(parts ...string) string {
|
||||||
|
home, _ := os.UserHomeDir()
|
||||||
|
var base string
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
base = filepath.Join(home, "Library", "Application Support", "Code", "User")
|
||||||
|
case "windows":
|
||||||
|
base = filepath.Join(os.Getenv("APPDATA"), "Code", "User")
|
||||||
|
default:
|
||||||
|
base = filepath.Join(home, ".config", "Code", "User")
|
||||||
|
}
|
||||||
|
return filepath.Join(append([]string{base}, parts...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkVSCodeVersion warns if VS Code is older than minVSCodeVersion.
|
||||||
|
func (v *VSCode) checkVSCodeVersion() {
|
||||||
|
codeCLI := v.findCodeCLI()
|
||||||
|
if codeCLI == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := exec.Command(codeCLI, "--version").Output()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// "code --version" outputs: version\ncommit\narch
|
||||||
|
lines := strings.Split(strings.TrimSpace(string(out)), "\n")
|
||||||
|
if len(lines) == 0 || lines[0] == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
version := strings.TrimSpace(lines[0])
|
||||||
|
|
||||||
|
if compareVersions(version, minVSCodeVersion) < 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "\n%sWarning: VS Code version (%s) is older than the recommended version (%s)%s\n", ansiYellow, version, minVSCodeVersion, ansiReset)
|
||||||
|
fmt.Fprintf(os.Stderr, "Please update VS Code to the latest version.\n\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkCopilotChatVersion warns if the GitHub Copilot Chat extension is
|
||||||
|
// missing or older than minCopilotChatVersion.
|
||||||
|
func (v *VSCode) checkCopilotChatVersion() {
|
||||||
|
codeCLI := v.findCodeCLI()
|
||||||
|
if codeCLI == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := exec.Command(codeCLI, "--list-extensions", "--show-versions").Output()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
installed, version := parseCopilotChatVersion(string(out))
|
||||||
|
if !installed {
|
||||||
|
fmt.Fprintf(os.Stderr, "\n%sWarning: GitHub Copilot Chat extension is not installed%s\n", ansiYellow, ansiReset)
|
||||||
|
fmt.Fprintf(os.Stderr, "Install it in VS Code: Extensions → search \"GitHub Copilot Chat\" → Install\n\n")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if compareVersions(version, minCopilotChatVersion) < 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "\n%sWarning: GitHub Copilot Chat extension version (%s) is older than the recommended version (%s)%s\n", ansiYellow, version, minCopilotChatVersion, ansiReset)
|
||||||
|
fmt.Fprintf(os.Stderr, "Please update it in VS Code: Extensions → search \"GitHub Copilot Chat\" → Update\n\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// findCodeCLI returns the path to the VS Code CLI for querying extensions.
|
||||||
|
// On macOS, findBinary may return an .app bundle which can't run --list-extensions,
|
||||||
|
// so this resolves to the actual CLI binary inside the bundle.
|
||||||
|
func (v *VSCode) findCodeCLI() string {
|
||||||
|
binary := v.findBinary()
|
||||||
|
if binary == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if runtime.GOOS == "darwin" && strings.HasSuffix(binary, ".app") {
|
||||||
|
bundleCLI := binary + "/Contents/Resources/app/bin/code"
|
||||||
|
if _, err := os.Stat(bundleCLI); err == nil {
|
||||||
|
return bundleCLI
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return binary
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseCopilotChatVersion extracts the version of the GitHub Copilot Chat
|
||||||
|
// extension from "code --list-extensions --show-versions" output.
|
||||||
|
func parseCopilotChatVersion(output string) (installed bool, version string) {
|
||||||
|
for _, line := range strings.Split(output, "\n") {
|
||||||
|
// Format: github.copilot-chat@0.40.1
|
||||||
|
if !strings.HasPrefix(strings.ToLower(line), "github.copilot-chat@") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parts := strings.SplitN(line, "@", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return true, strings.TrimSpace(parts[1])
|
||||||
|
}
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// compareVersions compares two dot-separated version strings.
|
||||||
|
// Returns -1 if a < b, 0 if a == b, 1 if a > b.
|
||||||
|
func compareVersions(a, b string) int {
|
||||||
|
aParts := strings.Split(a, ".")
|
||||||
|
bParts := strings.Split(b, ".")
|
||||||
|
|
||||||
|
maxLen := len(aParts)
|
||||||
|
if len(bParts) > maxLen {
|
||||||
|
maxLen = len(bParts)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range maxLen {
|
||||||
|
var aNum, bNum int
|
||||||
|
if i < len(aParts) {
|
||||||
|
aNum, _ = strconv.Atoi(aParts[i])
|
||||||
|
}
|
||||||
|
if i < len(bParts) {
|
||||||
|
bNum, _ = strconv.Atoi(bParts[i])
|
||||||
|
}
|
||||||
|
if aNum < bNum {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if aNum > bNum {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileExists(path string) bool {
|
||||||
|
_, err := os.Stat(path)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
486
cmd/launch/vscode_test.go
Normal file
@@ -0,0 +1,486 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestVSCodeIntegration(t *testing.T) {
|
||||||
|
v := &VSCode{}
|
||||||
|
|
||||||
|
t.Run("String", func(t *testing.T) {
|
||||||
|
if got := v.String(); got != "Visual Studio Code" {
|
||||||
|
t.Errorf("String() = %q, want %q", got, "Visual Studio Code")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("implements Runner", func(t *testing.T) {
|
||||||
|
var _ Runner = v
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("implements Editor", func(t *testing.T) {
|
||||||
|
var _ Editor = v
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVSCodeEdit(t *testing.T) {
|
||||||
|
v := &VSCode{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("XDG_CONFIG_HOME", "")
|
||||||
|
clmPath := testVSCodePath(t, tmpDir, "chatLanguageModels.json")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setup string // initial chatLanguageModels.json content, empty means no file
|
||||||
|
models []string
|
||||||
|
validate func(t *testing.T, data []byte)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "fresh install",
|
||||||
|
models: []string{"llama3.2"},
|
||||||
|
validate: func(t *testing.T, data []byte) {
|
||||||
|
assertOllamaVendorConfigured(t, data)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "preserve other vendor entries",
|
||||||
|
setup: `[{"vendor": "azure", "name": "Azure", "url": "https://example.com"}]`,
|
||||||
|
models: []string{"llama3.2"},
|
||||||
|
validate: func(t *testing.T, data []byte) {
|
||||||
|
var entries []map[string]any
|
||||||
|
json.Unmarshal(data, &entries)
|
||||||
|
if len(entries) != 2 {
|
||||||
|
t.Errorf("expected 2 entries, got %d", len(entries))
|
||||||
|
}
|
||||||
|
// Check Azure entry preserved
|
||||||
|
found := false
|
||||||
|
for _, e := range entries {
|
||||||
|
if v, _ := e["vendor"].(string); v == "azure" {
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Error("azure vendor entry was not preserved")
|
||||||
|
}
|
||||||
|
assertOllamaVendorConfigured(t, data)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "update existing ollama entry",
|
||||||
|
setup: `[{"vendor": "ollama", "name": "Ollama", "url": "http://old:11434"}]`,
|
||||||
|
models: []string{"llama3.2"},
|
||||||
|
validate: func(t *testing.T, data []byte) {
|
||||||
|
assertOllamaVendorConfigured(t, data)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty models is no-op",
|
||||||
|
setup: `[{"vendor": "azure", "name": "Azure"}]`,
|
||||||
|
models: []string{},
|
||||||
|
validate: func(t *testing.T, data []byte) {
|
||||||
|
if string(data) != `[{"vendor": "azure", "name": "Azure"}]` {
|
||||||
|
t.Error("empty models should not modify file")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "corrupted JSON treated as empty",
|
||||||
|
setup: `{corrupted json`,
|
||||||
|
models: []string{"llama3.2"},
|
||||||
|
validate: func(t *testing.T, data []byte) {
|
||||||
|
var entries []map[string]any
|
||||||
|
if err := json.Unmarshal(data, &entries); err != nil {
|
||||||
|
t.Errorf("result is not valid JSON: %v", err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
os.RemoveAll(filepath.Dir(clmPath))
|
||||||
|
|
||||||
|
if tt.setup != "" {
|
||||||
|
os.MkdirAll(filepath.Dir(clmPath), 0o755)
|
||||||
|
os.WriteFile(clmPath, []byte(tt.setup), 0o644)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := v.Edit(tt.models); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := os.ReadFile(clmPath)
|
||||||
|
tt.validate(t, data)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVSCodeEditCleansUpOldSettings(t *testing.T) {
|
||||||
|
v := &VSCode{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("XDG_CONFIG_HOME", "")
|
||||||
|
settingsPath := testVSCodePath(t, tmpDir, "settings.json")
|
||||||
|
|
||||||
|
// Create settings.json with old byok setting
|
||||||
|
os.MkdirAll(filepath.Dir(settingsPath), 0o755)
|
||||||
|
os.WriteFile(settingsPath, []byte(`{"github.copilot.chat.byok.ollamaEndpoint": "http://old:11434", "ollama.launch.configured": true, "editor.fontSize": 14}`), 0o644)
|
||||||
|
|
||||||
|
if err := v.Edit([]string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify old settings were removed
|
||||||
|
data, err := os.ReadFile(settingsPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var settings map[string]any
|
||||||
|
json.Unmarshal(data, &settings)
|
||||||
|
if _, ok := settings["github.copilot.chat.byok.ollamaEndpoint"]; ok {
|
||||||
|
t.Error("github.copilot.chat.byok.ollamaEndpoint should have been removed")
|
||||||
|
}
|
||||||
|
if _, ok := settings["ollama.launch.configured"]; ok {
|
||||||
|
t.Error("ollama.launch.configured should have been removed")
|
||||||
|
}
|
||||||
|
if settings["editor.fontSize"] != float64(14) {
|
||||||
|
t.Error("editor.fontSize should have been preserved")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVSCodePaths(t *testing.T) {
|
||||||
|
v := &VSCode{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("XDG_CONFIG_HOME", "")
|
||||||
|
clmPath := testVSCodePath(t, tmpDir, "chatLanguageModels.json")
|
||||||
|
|
||||||
|
t.Run("no file returns nil", func(t *testing.T) {
|
||||||
|
os.Remove(clmPath)
|
||||||
|
if paths := v.Paths(); paths != nil {
|
||||||
|
t.Errorf("expected nil, got %v", paths)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("existing file returns path", func(t *testing.T) {
|
||||||
|
os.MkdirAll(filepath.Dir(clmPath), 0o755)
|
||||||
|
os.WriteFile(clmPath, []byte(`[]`), 0o644)
|
||||||
|
|
||||||
|
if paths := v.Paths(); len(paths) != 1 {
|
||||||
|
t.Errorf("expected 1 path, got %d", len(paths))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// testVSCodePath returns the expected VS Code config path for the given file in tests.
|
||||||
|
func testVSCodePath(t *testing.T, tmpDir, filename string) string {
|
||||||
|
t.Helper()
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
return filepath.Join(tmpDir, "Library", "Application Support", "Code", "User", filename)
|
||||||
|
case "windows":
|
||||||
|
t.Setenv("APPDATA", tmpDir)
|
||||||
|
return filepath.Join(tmpDir, "Code", "User", filename)
|
||||||
|
default:
|
||||||
|
return filepath.Join(tmpDir, ".config", "Code", "User", filename)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertOllamaVendorConfigured(t *testing.T, data []byte) {
|
||||||
|
t.Helper()
|
||||||
|
var entries []map[string]any
|
||||||
|
if err := json.Unmarshal(data, &entries); err != nil {
|
||||||
|
t.Fatalf("invalid JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, entry := range entries {
|
||||||
|
if vendor, _ := entry["vendor"].(string); vendor == "ollama" {
|
||||||
|
if name, _ := entry["name"].(string); name != "Ollama" {
|
||||||
|
t.Errorf("expected name \"Ollama\", got %q", name)
|
||||||
|
}
|
||||||
|
if url, _ := entry["url"].(string); url == "" {
|
||||||
|
t.Error("url not set")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Error("no ollama vendor entry found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShowInModelPicker(t *testing.T) {
|
||||||
|
v := &VSCode{}
|
||||||
|
|
||||||
|
// helper to create a state DB with optional seed data
|
||||||
|
setupDB := func(t *testing.T, tmpDir string, seedPrefs map[string]bool, seedCache []map[string]any) string {
|
||||||
|
t.Helper()
|
||||||
|
dbDir := filepath.Join(tmpDir, "globalStorage")
|
||||||
|
os.MkdirAll(dbDir, 0o755)
|
||||||
|
dbPath := filepath.Join(dbDir, "state.vscdb")
|
||||||
|
|
||||||
|
db, err := sql.Open("sqlite3", dbPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
if _, err := db.Exec("CREATE TABLE ItemTable (key TEXT UNIQUE ON CONFLICT REPLACE, value BLOB)"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if seedPrefs != nil {
|
||||||
|
data, _ := json.Marshal(seedPrefs)
|
||||||
|
db.Exec("INSERT INTO ItemTable (key, value) VALUES ('chatModelPickerPreferences', ?)", string(data))
|
||||||
|
}
|
||||||
|
if seedCache != nil {
|
||||||
|
data, _ := json.Marshal(seedCache)
|
||||||
|
db.Exec("INSERT INTO ItemTable (key, value) VALUES ('chat.cachedLanguageModels.v2', ?)", string(data))
|
||||||
|
}
|
||||||
|
return dbPath
|
||||||
|
}
|
||||||
|
|
||||||
|
// helper to read prefs back from DB
|
||||||
|
readPrefs := func(t *testing.T, dbPath string) map[string]bool {
|
||||||
|
t.Helper()
|
||||||
|
db, err := sql.Open("sqlite3", dbPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
var raw string
|
||||||
|
if err := db.QueryRow("SELECT value FROM ItemTable WHERE key = 'chatModelPickerPreferences'").Scan(&raw); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
prefs := make(map[string]bool)
|
||||||
|
json.Unmarshal([]byte(raw), &prefs)
|
||||||
|
return prefs
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("fresh DB creates table and shows models", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("XDG_CONFIG_HOME", "")
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Setenv("APPDATA", tmpDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := v.ShowInModelPicker([]string{"llama3.2"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dbPath := testVSCodePath(t, tmpDir, filepath.Join("globalStorage", "state.vscdb"))
|
||||||
|
prefs := readPrefs(t, dbPath)
|
||||||
|
if !prefs["ollama/Ollama/llama3.2"] {
|
||||||
|
t.Error("expected llama3.2 to be shown")
|
||||||
|
}
|
||||||
|
if !prefs["ollama/Ollama/llama3.2:latest"] {
|
||||||
|
t.Error("expected llama3.2:latest to be shown")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("configured models are shown", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("XDG_CONFIG_HOME", "")
|
||||||
|
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), nil, nil)
|
||||||
|
|
||||||
|
err := v.ShowInModelPicker([]string{"llama3.2", "qwen3:8b"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prefs := readPrefs(t, dbPath)
|
||||||
|
if !prefs["ollama/Ollama/llama3.2"] {
|
||||||
|
t.Error("expected llama3.2 to be shown")
|
||||||
|
}
|
||||||
|
if !prefs["ollama/Ollama/qwen3:8b"] {
|
||||||
|
t.Error("expected qwen3:8b to be shown")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("removed models are hidden", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("XDG_CONFIG_HOME", "")
|
||||||
|
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), map[string]bool{
|
||||||
|
"ollama/Ollama/llama3.2": true,
|
||||||
|
"ollama/Ollama/llama3.2:latest": true,
|
||||||
|
"ollama/Ollama/mistral": true,
|
||||||
|
"ollama/Ollama/mistral:latest": true,
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
// Only configure llama3.2 — mistral should get hidden
|
||||||
|
err := v.ShowInModelPicker([]string{"llama3.2"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prefs := readPrefs(t, dbPath)
|
||||||
|
if !prefs["ollama/Ollama/llama3.2"] {
|
||||||
|
t.Error("expected llama3.2 to stay shown")
|
||||||
|
}
|
||||||
|
if prefs["ollama/Ollama/mistral"] {
|
||||||
|
t.Error("expected mistral to be hidden")
|
||||||
|
}
|
||||||
|
if prefs["ollama/Ollama/mistral:latest"] {
|
||||||
|
t.Error("expected mistral:latest to be hidden")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-ollama prefs are preserved", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("XDG_CONFIG_HOME", "")
|
||||||
|
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), map[string]bool{
|
||||||
|
"copilot/gpt-4o": true,
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
err := v.ShowInModelPicker([]string{"llama3.2"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prefs := readPrefs(t, dbPath)
|
||||||
|
if !prefs["copilot/gpt-4o"] {
|
||||||
|
t.Error("expected copilot/gpt-4o to stay shown")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("uses cached numeric IDs when available", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("XDG_CONFIG_HOME", "")
|
||||||
|
cache := []map[string]any{
|
||||||
|
{
|
||||||
|
"identifier": "ollama/Ollama/4",
|
||||||
|
"metadata": map[string]any{"vendor": "ollama", "name": "llama3.2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), nil, cache)
|
||||||
|
|
||||||
|
err := v.ShowInModelPicker([]string{"llama3.2"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prefs := readPrefs(t, dbPath)
|
||||||
|
if !prefs["ollama/Ollama/4"] {
|
||||||
|
t.Error("expected numeric ID ollama/Ollama/4 to be shown")
|
||||||
|
}
|
||||||
|
// Name-based fallback should also be set
|
||||||
|
if !prefs["ollama/Ollama/llama3.2"] {
|
||||||
|
t.Error("expected name-based ID to also be shown")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty models is no-op", func(t *testing.T) {
|
||||||
|
err := v.ShowInModelPicker([]string{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("previously hidden model is re-shown when configured", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("XDG_CONFIG_HOME", "")
|
||||||
|
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), map[string]bool{
|
||||||
|
"ollama/Ollama/llama3.2": false,
|
||||||
|
"ollama/Ollama/llama3.2:latest": false,
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
// Ollama config is authoritative — should override the hidden state
|
||||||
|
err := v.ShowInModelPicker([]string{"llama3.2"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prefs := readPrefs(t, dbPath)
|
||||||
|
if !prefs["ollama/Ollama/llama3.2"] {
|
||||||
|
t.Error("expected llama3.2 to be re-shown")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseCopilotChatVersion(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
output string
|
||||||
|
wantInstalled bool
|
||||||
|
wantVersion string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "found among other extensions",
|
||||||
|
output: "ms-python.python@2024.1.1\ngithub.copilot-chat@0.40.1\ngithub.copilot@1.200.0\n",
|
||||||
|
wantInstalled: true,
|
||||||
|
wantVersion: "0.40.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only extension",
|
||||||
|
output: "GitHub.copilot-chat@0.41.0\n",
|
||||||
|
wantInstalled: true,
|
||||||
|
wantVersion: "0.41.0",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not installed",
|
||||||
|
output: "ms-python.python@2024.1.1\ngithub.copilot@1.200.0\n",
|
||||||
|
wantInstalled: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty output",
|
||||||
|
output: "",
|
||||||
|
wantInstalled: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "case insensitive match",
|
||||||
|
output: "GitHub.Copilot-Chat@0.39.0\n",
|
||||||
|
wantInstalled: true,
|
||||||
|
wantVersion: "0.39.0",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
installed, version := parseCopilotChatVersion(tt.output)
|
||||||
|
if installed != tt.wantInstalled {
|
||||||
|
t.Errorf("installed = %v, want %v", installed, tt.wantInstalled)
|
||||||
|
}
|
||||||
|
if installed && version != tt.wantVersion {
|
||||||
|
t.Errorf("version = %q, want %q", version, tt.wantVersion)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompareVersions(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
a, b string
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{"0.40.1", "0.40.1", 0},
|
||||||
|
{"0.40.2", "0.40.1", 1},
|
||||||
|
{"0.40.0", "0.40.1", -1},
|
||||||
|
{"0.41.0", "0.40.1", 1},
|
||||||
|
{"0.39.9", "0.40.1", -1},
|
||||||
|
{"1.0.0", "0.40.1", 1},
|
||||||
|
{"0.40", "0.40.1", -1},
|
||||||
|
{"0.40.1.1", "0.40.1", 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.a+"_vs_"+tt.b, func(t *testing.T) {
|
||||||
|
got := compareVersions(tt.a, tt.b)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("compareVersions(%q, %q) = %d, want %d", tt.a, tt.b, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
"github.com/charmbracelet/lipgloss"
|
"github.com/charmbracelet/lipgloss"
|
||||||
"github.com/ollama/ollama/cmd/config"
|
"github.com/ollama/ollama/cmd/launch"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -64,8 +64,8 @@ type SelectItem struct {
|
|||||||
Recommended bool
|
Recommended bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConvertItems converts config.ModelItem slice to SelectItem slice.
|
// ConvertItems converts launch.ModelItem slice to SelectItem slice.
|
||||||
func ConvertItems(items []config.ModelItem) []SelectItem {
|
func ConvertItems(items []launch.ModelItem) []SelectItem {
|
||||||
out := make([]SelectItem, len(items))
|
out := make([]SelectItem, len(items))
|
||||||
for i, item := range items {
|
for i, item := range items {
|
||||||
out[i] = SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
|
out[i] = SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
|
||||||
@@ -101,6 +101,16 @@ type selectorModel struct {
|
|||||||
width int
|
width int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func selectorModelWithCurrent(title string, items []SelectItem, current string) selectorModel {
|
||||||
|
m := selectorModel{
|
||||||
|
title: title,
|
||||||
|
items: items,
|
||||||
|
cursor: cursorForCurrent(items, current),
|
||||||
|
}
|
||||||
|
m.updateScroll(m.otherStart())
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
func (m selectorModel) filteredItems() []SelectItem {
|
func (m selectorModel) filteredItems() []SelectItem {
|
||||||
if m.filter == "" {
|
if m.filter == "" {
|
||||||
return m.items
|
return m.items
|
||||||
@@ -232,6 +242,10 @@ func (m selectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||||||
m.cancelled = true
|
m.cancelled = true
|
||||||
return m, tea.Quit
|
return m, tea.Quit
|
||||||
|
|
||||||
|
case tea.KeyLeft:
|
||||||
|
m.cancelled = true
|
||||||
|
return m, tea.Quit
|
||||||
|
|
||||||
case tea.KeyEnter:
|
case tea.KeyEnter:
|
||||||
filtered := m.filteredItems()
|
filtered := m.filteredItems()
|
||||||
if len(filtered) > 0 && m.cursor < len(filtered) {
|
if len(filtered) > 0 && m.cursor < len(filtered) {
|
||||||
@@ -344,7 +358,7 @@ func (m selectorModel) renderContent() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.WriteString("\n")
|
s.WriteString("\n")
|
||||||
help := "↑/↓ navigate • enter select • esc cancel"
|
help := "↑/↓ navigate • enter select • ← back"
|
||||||
if m.helpText != "" {
|
if m.helpText != "" {
|
||||||
help = m.helpText
|
help = m.helpText
|
||||||
}
|
}
|
||||||
@@ -367,13 +381,24 @@ func (m selectorModel) View() string {
|
|||||||
|
|
||||||
// cursorForCurrent returns the item index matching current, or 0 if not found.
|
// cursorForCurrent returns the item index matching current, or 0 if not found.
|
||||||
func cursorForCurrent(items []SelectItem, current string) int {
|
func cursorForCurrent(items []SelectItem, current string) int {
|
||||||
if current != "" {
|
if current == "" {
|
||||||
for i, item := range items {
|
return 0
|
||||||
if item.Name == current || strings.HasPrefix(item.Name, current+":") || strings.HasPrefix(current, item.Name+":") {
|
}
|
||||||
return i
|
|
||||||
}
|
// Prefer exact name matches before tag-prefix fallback so "qwen3.5" does not
|
||||||
|
// incorrectly select "qwen3.5:cloud" (and vice versa) based on list order.
|
||||||
|
for i, item := range items {
|
||||||
|
if item.Name == current {
|
||||||
|
return i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for i, item := range items {
|
||||||
|
if strings.HasPrefix(item.Name, current+":") || strings.HasPrefix(current, item.Name+":") {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -382,11 +407,7 @@ func SelectSingle(title string, items []SelectItem, current string) (string, err
|
|||||||
return "", fmt.Errorf("no items to select from")
|
return "", fmt.Errorf("no items to select from")
|
||||||
}
|
}
|
||||||
|
|
||||||
m := selectorModel{
|
m := selectorModelWithCurrent(title, items, current)
|
||||||
title: title,
|
|
||||||
items: items,
|
|
||||||
cursor: cursorForCurrent(items, current),
|
|
||||||
}
|
|
||||||
|
|
||||||
p := tea.NewProgram(m)
|
p := tea.NewProgram(m)
|
||||||
finalModel, err := p.Run()
|
finalModel, err := p.Run()
|
||||||
@@ -523,6 +544,7 @@ func (m *multiSelectorModel) toggleItem() {
|
|||||||
origIdx := m.itemIndex[item.Name]
|
origIdx := m.itemIndex[item.Name]
|
||||||
|
|
||||||
if m.checked[origIdx] {
|
if m.checked[origIdx] {
|
||||||
|
wasDefault := len(m.checkOrder) > 0 && m.checkOrder[len(m.checkOrder)-1] == origIdx
|
||||||
delete(m.checked, origIdx)
|
delete(m.checked, origIdx)
|
||||||
for i, idx := range m.checkOrder {
|
for i, idx := range m.checkOrder {
|
||||||
if idx == origIdx {
|
if idx == origIdx {
|
||||||
@@ -530,6 +552,34 @@ func (m *multiSelectorModel) toggleItem() {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if wasDefault {
|
||||||
|
// When removing the default, pick the nearest checked model above it
|
||||||
|
// (or below if none above) so default fallback follows list order.
|
||||||
|
newDefault := -1
|
||||||
|
for i := origIdx - 1; i >= 0; i-- {
|
||||||
|
if m.checked[i] {
|
||||||
|
newDefault = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if newDefault == -1 {
|
||||||
|
for i := origIdx + 1; i < len(m.items); i++ {
|
||||||
|
if m.checked[i] {
|
||||||
|
newDefault = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if newDefault != -1 {
|
||||||
|
for i, idx := range m.checkOrder {
|
||||||
|
if idx == newDefault {
|
||||||
|
m.checkOrder = append(m.checkOrder[:i], m.checkOrder[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.checkOrder = append(m.checkOrder, newDefault)
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
m.checked[origIdx] = true
|
m.checked[origIdx] = true
|
||||||
m.checkOrder = append(m.checkOrder, origIdx)
|
m.checkOrder = append(m.checkOrder, origIdx)
|
||||||
@@ -562,6 +612,10 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||||||
m.cancelled = true
|
m.cancelled = true
|
||||||
return m, tea.Quit
|
return m, tea.Quit
|
||||||
|
|
||||||
|
case tea.KeyLeft:
|
||||||
|
m.cancelled = true
|
||||||
|
return m, tea.Quit
|
||||||
|
|
||||||
case tea.KeyTab:
|
case tea.KeyTab:
|
||||||
m.multi = !m.multi
|
m.multi = !m.multi
|
||||||
|
|
||||||
@@ -764,7 +818,7 @@ func (m multiSelectorModel) View() string {
|
|||||||
s.WriteString("\n")
|
s.WriteString("\n")
|
||||||
|
|
||||||
if !m.multi {
|
if !m.multi {
|
||||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • tab add multiple • esc cancel"))
|
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • tab add multiple • ← back"))
|
||||||
} else {
|
} else {
|
||||||
count := m.selectedCount()
|
count := m.selectedCount()
|
||||||
if count == 0 {
|
if count == 0 {
|
||||||
@@ -773,7 +827,7 @@ func (m multiSelectorModel) View() string {
|
|||||||
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
|
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
|
||||||
}
|
}
|
||||||
s.WriteString("\n\n")
|
s.WriteString("\n\n")
|
||||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • tab select single • enter confirm • esc cancel"))
|
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • tab select single • enter confirm • ← back"))
|
||||||
}
|
}
|
||||||
|
|
||||||
result := s.String()
|
result := s.String()
|
||||||
|
|||||||
@@ -216,6 +216,41 @@ func TestUpdateScroll(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSelectorModelWithCurrent_ScrollsToCurrentInMoreSection(t *testing.T) {
|
||||||
|
m := selectorModelWithCurrent("Pick:", mixedItems(), "other-10")
|
||||||
|
|
||||||
|
if m.cursor != 11 {
|
||||||
|
t.Fatalf("cursor = %d, want 11", m.cursor)
|
||||||
|
}
|
||||||
|
if m.scrollOffset == 0 {
|
||||||
|
t.Fatal("scrollOffset should move to reveal current item in More section")
|
||||||
|
}
|
||||||
|
|
||||||
|
content := m.renderContent()
|
||||||
|
if !strings.Contains(content, "▸ other-10") {
|
||||||
|
t.Fatalf("expected current item to be visible and highlighted\n%s", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectorModelWithCurrent_HighlightsExactLocalWhenCloudVariantExists(t *testing.T) {
|
||||||
|
m := selectorModelWithCurrent("Pick:", []SelectItem{
|
||||||
|
{Name: "qwen3.5:cloud", Recommended: true},
|
||||||
|
{Name: "qwen3.5", Recommended: true},
|
||||||
|
}, "qwen3.5")
|
||||||
|
|
||||||
|
if m.cursor != 1 {
|
||||||
|
t.Fatalf("cursor = %d, want 1", m.cursor)
|
||||||
|
}
|
||||||
|
|
||||||
|
content := m.renderContent()
|
||||||
|
if !strings.Contains(content, "▸ qwen3.5") {
|
||||||
|
t.Fatalf("expected local qwen3.5 to be highlighted\n%s", content)
|
||||||
|
}
|
||||||
|
if strings.Contains(content, "▸ qwen3.5:cloud") {
|
||||||
|
t.Fatalf("did not expect cloud qwen3.5:cloud to be highlighted\n%s", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRenderContent_SectionHeaders(t *testing.T) {
|
func TestRenderContent_SectionHeaders(t *testing.T) {
|
||||||
m := selectorModel{
|
m := selectorModel{
|
||||||
title: "Pick:",
|
title: "Pick:",
|
||||||
@@ -418,6 +453,28 @@ func TestCursorForCurrent(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCursorForCurrent_PrefersExactLocalOverCloudPrefix(t *testing.T) {
|
||||||
|
testItems := []SelectItem{
|
||||||
|
{Name: "qwen3.5:cloud", Recommended: true},
|
||||||
|
{Name: "qwen3.5", Recommended: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := cursorForCurrent(testItems, "qwen3.5"); got != 1 {
|
||||||
|
t.Errorf("cursorForCurrent(%q) = %d, want %d", "qwen3.5", got, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCursorForCurrent_PrefersExactCloudOverLocalPrefix(t *testing.T) {
|
||||||
|
testItems := []SelectItem{
|
||||||
|
{Name: "qwen3.5", Recommended: true},
|
||||||
|
{Name: "qwen3.5:cloud", Recommended: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := cursorForCurrent(testItems, "qwen3.5:cloud"); got != 1 {
|
||||||
|
t.Errorf("cursorForCurrent(%q) = %d, want %d", "qwen3.5:cloud", got, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// --- ReorderItems ---
|
// --- ReorderItems ---
|
||||||
|
|
||||||
func TestReorderItems(t *testing.T) {
|
func TestReorderItems(t *testing.T) {
|
||||||
@@ -725,6 +782,9 @@ func TestMulti_MultiModeHelpText(t *testing.T) {
|
|||||||
if !strings.Contains(content, "tab select single") {
|
if !strings.Contains(content, "tab select single") {
|
||||||
t.Error("multi mode should show 'tab select single' in help")
|
t.Error("multi mode should show 'tab select single' in help")
|
||||||
}
|
}
|
||||||
|
if !strings.Contains(content, "← back") {
|
||||||
|
t.Error("multi mode should show '← back' in help")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- preChecked initialization order ---
|
// --- preChecked initialization order ---
|
||||||
@@ -783,6 +843,74 @@ func TestMulti_LastCheckedIsDefault(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMulti_UncheckingDefaultFallsBackToNearestCheckedAbove(t *testing.T) {
|
||||||
|
// Default is "b", and checked models are "a", "b", "c".
|
||||||
|
// Unticking default should make "a" (the nearest checked item above) default.
|
||||||
|
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"b", "c", "a"})
|
||||||
|
m.multi = true
|
||||||
|
m.cursor = 1 // "b"
|
||||||
|
m.toggleItem()
|
||||||
|
|
||||||
|
lastIdx := m.checkOrder[len(m.checkOrder)-1]
|
||||||
|
if m.items[lastIdx].Name != "a" {
|
||||||
|
t.Fatalf("expected default to fall back to 'a', got %q", m.items[lastIdx].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMulti_UncheckingTopDefaultFallsBackToNearestCheckedBelow(t *testing.T) {
|
||||||
|
// Default is top item "a". With no checked item above, fallback should pick
|
||||||
|
// the nearest checked item below ("b").
|
||||||
|
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "c", "b"})
|
||||||
|
m.multi = true
|
||||||
|
m.cursor = 0 // "a"
|
||||||
|
m.toggleItem()
|
||||||
|
|
||||||
|
lastIdx := m.checkOrder[len(m.checkOrder)-1]
|
||||||
|
if m.items[lastIdx].Name != "b" {
|
||||||
|
t.Fatalf("expected default to fall back to 'b', got %q", m.items[lastIdx].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Left arrow back navigation ---
|
||||||
|
|
||||||
|
func TestSelectorLeftArrowCancelsWhenNoFilter(t *testing.T) {
|
||||||
|
m := selectorModelWithCurrent("Pick:", items("a", "b", "c"), "")
|
||||||
|
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyLeft})
|
||||||
|
got := updated.(selectorModel)
|
||||||
|
if !got.cancelled {
|
||||||
|
t.Error("left arrow with empty filter should cancel (go back)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectorLeftArrowCancelsWhenFiltering(t *testing.T) {
|
||||||
|
m := selectorModelWithCurrent("Pick:", items("a", "b", "c"), "")
|
||||||
|
m.filter = "a"
|
||||||
|
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyLeft})
|
||||||
|
got := updated.(selectorModel)
|
||||||
|
if !got.cancelled {
|
||||||
|
t.Error("left arrow with active filter should still cancel (go back)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultiSelectorLeftArrowCancelsWhenNoFilter(t *testing.T) {
|
||||||
|
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||||
|
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyLeft})
|
||||||
|
got := updated.(multiSelectorModel)
|
||||||
|
if !got.cancelled {
|
||||||
|
t.Error("left arrow with empty filter should cancel (go back)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultiSelectorLeftArrowCancelsWhenFiltering(t *testing.T) {
|
||||||
|
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||||
|
m.filter = "a"
|
||||||
|
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyLeft})
|
||||||
|
got := updated.(multiSelectorModel)
|
||||||
|
if !got.cancelled {
|
||||||
|
t.Error("left arrow with active filter should still cancel (go back)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Key message helpers for testing
|
// Key message helpers for testing
|
||||||
|
|
||||||
type keyType = int
|
type keyType = int
|
||||||
|
|||||||
@@ -1,15 +1,24 @@
|
|||||||
package tui
|
package tui
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
"github.com/charmbracelet/lipgloss"
|
"github.com/charmbracelet/lipgloss"
|
||||||
"github.com/ollama/ollama/cmd/config"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/cmd/launch"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type signInTickMsg struct{}
|
||||||
|
|
||||||
|
type signInCheckMsg struct {
|
||||||
|
signedIn bool
|
||||||
|
userName string
|
||||||
|
}
|
||||||
|
|
||||||
type signInModel struct {
|
type signInModel struct {
|
||||||
modelName string
|
modelName string
|
||||||
signInURL string
|
signInURL string
|
||||||
@@ -88,11 +97,8 @@ func renderSignIn(modelName, signInURL string, spinner, width int) string {
|
|||||||
|
|
||||||
fmt.Fprintf(&s, "To use %s, please sign in.\n\n", selectorSelectedItemStyle.Render(modelName))
|
fmt.Fprintf(&s, "To use %s, please sign in.\n\n", selectorSelectedItemStyle.Render(modelName))
|
||||||
|
|
||||||
// Wrap in OSC 8 hyperlink so the entire URL is clickable even when wrapped.
|
|
||||||
// Padding is outside the hyperlink so spaces don't get underlined.
|
|
||||||
link := fmt.Sprintf("\033]8;;%s\033\\%s\033]8;;\033\\", signInURL, urlColor.Render(signInURL))
|
|
||||||
s.WriteString("Navigate to:\n")
|
s.WriteString("Navigate to:\n")
|
||||||
s.WriteString(urlWrap.Render(link))
|
s.WriteString(urlWrap.Render(urlColor.Render(signInURL)))
|
||||||
s.WriteString("\n\n")
|
s.WriteString("\n\n")
|
||||||
|
|
||||||
s.WriteString(lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).Render(
|
s.WriteString(lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).Render(
|
||||||
@@ -104,9 +110,21 @@ func renderSignIn(modelName, signInURL string, spinner, width int) string {
|
|||||||
return lipgloss.NewStyle().PaddingLeft(2).Render(s.String())
|
return lipgloss.NewStyle().PaddingLeft(2).Render(s.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func checkSignIn() tea.Msg {
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
return signInCheckMsg{signedIn: false}
|
||||||
|
}
|
||||||
|
user, err := client.Whoami(context.Background())
|
||||||
|
if err == nil && user != nil && user.Name != "" {
|
||||||
|
return signInCheckMsg{signedIn: true, userName: user.Name}
|
||||||
|
}
|
||||||
|
return signInCheckMsg{signedIn: false}
|
||||||
|
}
|
||||||
|
|
||||||
// RunSignIn shows a bubbletea sign-in dialog and polls until the user signs in or cancels.
|
// RunSignIn shows a bubbletea sign-in dialog and polls until the user signs in or cancels.
|
||||||
func RunSignIn(modelName, signInURL string) (string, error) {
|
func RunSignIn(modelName, signInURL string) (string, error) {
|
||||||
config.OpenBrowser(signInURL)
|
launch.OpenBrowser(signInURL)
|
||||||
|
|
||||||
m := signInModel{
|
m := signInModel{
|
||||||
modelName: modelName,
|
modelName: modelName,
|
||||||
|
|||||||
@@ -25,22 +25,6 @@ func TestRenderSignIn_ContainsURL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRenderSignIn_OSC8Hyperlink(t *testing.T) {
|
|
||||||
url := "https://ollama.com/connect?key=abc123"
|
|
||||||
got := renderSignIn("test:cloud", url, 0, 120)
|
|
||||||
|
|
||||||
// Should contain OSC 8 open sequence with the URL
|
|
||||||
osc8Open := "\033]8;;" + url + "\033\\"
|
|
||||||
if !strings.Contains(got, osc8Open) {
|
|
||||||
t.Error("should contain OSC 8 open sequence with URL")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should contain OSC 8 close sequence
|
|
||||||
osc8Close := "\033]8;;\033\\"
|
|
||||||
if !strings.Contains(got, osc8Close) {
|
|
||||||
t.Error("should contain OSC 8 close sequence")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRenderSignIn_ContainsSpinner(t *testing.T) {
|
func TestRenderSignIn_ContainsSpinner(t *testing.T) {
|
||||||
got := renderSignIn("test:cloud", "https://example.com", 0, 80)
|
got := renderSignIn("test:cloud", "https://example.com", 0, 80)
|
||||||
|
|||||||
841
cmd/tui/tui.go
@@ -1,16 +1,11 @@
|
|||||||
package tui
|
package tui
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
"github.com/charmbracelet/lipgloss"
|
"github.com/charmbracelet/lipgloss"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/cmd/launch"
|
||||||
"github.com/ollama/ollama/cmd/config"
|
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -45,30 +40,24 @@ var (
|
|||||||
type menuItem struct {
|
type menuItem struct {
|
||||||
title string
|
title string
|
||||||
description string
|
description string
|
||||||
integration string // integration name for loading model config, empty if not an integration
|
integration string
|
||||||
isRunModel bool
|
isRunModel bool
|
||||||
isOthers bool
|
isOthers bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var mainMenuItems = []menuItem{
|
var mainMenuItems = []menuItem{
|
||||||
{
|
{
|
||||||
title: "Run a model",
|
title: "Chat with a model",
|
||||||
description: "Start an interactive chat with a model",
|
description: "Start an interactive chat with a model",
|
||||||
isRunModel: true,
|
isRunModel: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
title: "Launch Claude Code",
|
|
||||||
description: "Agentic coding across large codebases",
|
|
||||||
integration: "claude",
|
integration: "claude",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
title: "Launch Codex",
|
|
||||||
description: "OpenAI's open-source coding agent",
|
|
||||||
integration: "codex",
|
integration: "codex",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
title: "Launch OpenClaw",
|
|
||||||
description: "Personal AI with 100+ skills",
|
|
||||||
integration: "openclaw",
|
integration: "openclaw",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -79,277 +68,106 @@ var othersMenuItem = menuItem{
|
|||||||
isOthers: true,
|
isOthers: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
// getOtherIntegrations dynamically builds the "Others" list from the integration
|
type model struct {
|
||||||
// registry, excluding any integrations already present in the pinned mainMenuItems.
|
state *launch.LauncherState
|
||||||
func getOtherIntegrations() []menuItem {
|
items []menuItem
|
||||||
pinned := map[string]bool{
|
cursor int
|
||||||
"run": true, // not an integration but in the pinned list
|
showOthers bool
|
||||||
|
width int
|
||||||
|
quitting bool
|
||||||
|
selected bool
|
||||||
|
action TUIAction
|
||||||
|
}
|
||||||
|
|
||||||
|
func newModel(state *launch.LauncherState) model {
|
||||||
|
m := model{
|
||||||
|
state: state,
|
||||||
}
|
}
|
||||||
|
m.showOthers = shouldExpandOthers(state)
|
||||||
|
m.items = buildMenuItems(state, m.showOthers)
|
||||||
|
m.cursor = initialCursor(state, m.items)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldExpandOthers(state *launch.LauncherState) bool {
|
||||||
|
if state == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, item := range otherIntegrationItems(state) {
|
||||||
|
if item.integration == state.LastSelection {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildMenuItems(state *launch.LauncherState, showOthers bool) []menuItem {
|
||||||
|
items := make([]menuItem, 0, len(mainMenuItems)+1)
|
||||||
for _, item := range mainMenuItems {
|
for _, item := range mainMenuItems {
|
||||||
if item.integration != "" {
|
if item.integration == "" {
|
||||||
pinned[item.integration] = true
|
items = append(items, item)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if integrationState, ok := state.Integrations[item.integration]; ok {
|
||||||
|
items = append(items, integrationMenuItem(integrationState))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var others []menuItem
|
if showOthers {
|
||||||
for _, info := range config.ListIntegrationInfos() {
|
items = append(items, otherIntegrationItems(state)...)
|
||||||
|
} else {
|
||||||
|
items = append(items, othersMenuItem)
|
||||||
|
}
|
||||||
|
|
||||||
|
return items
|
||||||
|
}
|
||||||
|
|
||||||
|
func integrationMenuItem(state launch.LauncherIntegrationState) menuItem {
|
||||||
|
description := state.Description
|
||||||
|
if description == "" {
|
||||||
|
description = "Open " + state.DisplayName + " integration"
|
||||||
|
}
|
||||||
|
return menuItem{
|
||||||
|
title: "Launch " + state.DisplayName,
|
||||||
|
description: description,
|
||||||
|
integration: state.Name,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func otherIntegrationItems(state *launch.LauncherState) []menuItem {
|
||||||
|
pinned := map[string]bool{
|
||||||
|
"claude": true,
|
||||||
|
"codex": true,
|
||||||
|
"openclaw": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
var items []menuItem
|
||||||
|
for _, info := range launch.ListIntegrationInfos() {
|
||||||
if pinned[info.Name] {
|
if pinned[info.Name] {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
desc := info.Description
|
integrationState, ok := state.Integrations[info.Name]
|
||||||
if desc == "" {
|
if !ok {
|
||||||
desc = "Open " + info.DisplayName + " integration"
|
|
||||||
}
|
|
||||||
others = append(others, menuItem{
|
|
||||||
title: "Launch " + info.DisplayName,
|
|
||||||
description: desc,
|
|
||||||
integration: info.Name,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return others
|
|
||||||
}
|
|
||||||
|
|
||||||
type model struct {
|
|
||||||
items []menuItem
|
|
||||||
cursor int
|
|
||||||
quitting bool
|
|
||||||
selected bool
|
|
||||||
changeModel bool
|
|
||||||
changeModels []string // multi-select result for Editor integrations
|
|
||||||
showOthers bool
|
|
||||||
availableModels map[string]bool
|
|
||||||
err error
|
|
||||||
|
|
||||||
showingModal bool
|
|
||||||
modalSelector selectorModel
|
|
||||||
modalItems []SelectItem
|
|
||||||
|
|
||||||
showingMultiModal bool
|
|
||||||
multiModalSelector multiSelectorModel
|
|
||||||
|
|
||||||
showingSignIn bool
|
|
||||||
signInURL string
|
|
||||||
signInModel string
|
|
||||||
signInSpinner int
|
|
||||||
signInFromModal bool // true if sign-in was triggered from modal (not main menu)
|
|
||||||
|
|
||||||
width int // terminal width from WindowSizeMsg
|
|
||||||
statusMsg string // temporary status message shown near help text
|
|
||||||
}
|
|
||||||
|
|
||||||
type signInTickMsg struct{}
|
|
||||||
|
|
||||||
type signInCheckMsg struct {
|
|
||||||
signedIn bool
|
|
||||||
userName string
|
|
||||||
}
|
|
||||||
|
|
||||||
type clearStatusMsg struct{}
|
|
||||||
|
|
||||||
func (m *model) modelExists(name string) bool {
|
|
||||||
if m.availableModels == nil || name == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if m.availableModels[name] {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// Check for prefix match (e.g., "llama2" matches "llama2:latest")
|
|
||||||
for modelName := range m.availableModels {
|
|
||||||
if strings.HasPrefix(modelName, name+":") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) buildModalItems() []SelectItem {
|
|
||||||
modelItems, _ := config.GetModelItems(context.Background())
|
|
||||||
return ReorderItems(ConvertItems(modelItems))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) openModelModal(currentModel string) {
|
|
||||||
m.modalItems = m.buildModalItems()
|
|
||||||
cursor := 0
|
|
||||||
if currentModel != "" {
|
|
||||||
for i, item := range m.modalItems {
|
|
||||||
if item.Name == currentModel || strings.HasPrefix(item.Name, currentModel+":") || strings.HasPrefix(currentModel, item.Name+":") {
|
|
||||||
cursor = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
m.modalSelector = selectorModel{
|
|
||||||
title: "Select model:",
|
|
||||||
items: m.modalItems,
|
|
||||||
cursor: cursor,
|
|
||||||
helpText: "↑/↓ navigate • enter select • ← back",
|
|
||||||
}
|
|
||||||
m.modalSelector.updateScroll(m.modalSelector.otherStart())
|
|
||||||
m.showingModal = true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) openMultiModelModal(integration string) {
|
|
||||||
items := m.buildModalItems()
|
|
||||||
var preChecked []string
|
|
||||||
if models := config.IntegrationModels(integration); len(models) > 0 {
|
|
||||||
preChecked = models
|
|
||||||
}
|
|
||||||
m.multiModalSelector = newMultiSelectorModel("Select models:", items, preChecked)
|
|
||||||
// Set cursor to the first pre-checked (last used) model
|
|
||||||
if len(preChecked) > 0 {
|
|
||||||
for i, item := range items {
|
|
||||||
if item.Name == preChecked[0] {
|
|
||||||
m.multiModalSelector.cursor = i
|
|
||||||
m.multiModalSelector.updateScroll(m.multiModalSelector.otherStart())
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
m.showingMultiModal = true
|
|
||||||
}
|
|
||||||
|
|
||||||
func isCloudModel(name string) bool {
|
|
||||||
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
|
|
||||||
}
|
|
||||||
|
|
||||||
func cloudStatusDisabled(client *api.Client) bool {
|
|
||||||
status, err := client.CloudStatusExperimental(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return status.Cloud.Disabled
|
|
||||||
}
|
|
||||||
|
|
||||||
func cloudModelDisabled(name string) bool {
|
|
||||||
if !isCloudModel(name) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
client, err := api.ClientFromEnvironment()
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return cloudStatusDisabled(client)
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkCloudSignIn checks if a cloud model needs sign-in.
|
|
||||||
// Returns a command to start sign-in if needed, or nil if already signed in.
|
|
||||||
func (m *model) checkCloudSignIn(modelName string, fromModal bool) tea.Cmd {
|
|
||||||
if modelName == "" || !isCloudModel(modelName) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
client, err := api.ClientFromEnvironment()
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if cloudStatusDisabled(client) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
user, err := client.Whoami(context.Background())
|
|
||||||
if err == nil && user != nil && user.Name != "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var aErr api.AuthorizationError
|
|
||||||
if errors.As(err, &aErr) && aErr.SigninURL != "" {
|
|
||||||
return m.startSignIn(modelName, aErr.SigninURL, fromModal)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// startSignIn initiates the sign-in flow for a cloud model.
|
|
||||||
// fromModal indicates if this was triggered from the model picker modal.
|
|
||||||
func (m *model) startSignIn(modelName, signInURL string, fromModal bool) tea.Cmd {
|
|
||||||
m.showingModal = false
|
|
||||||
m.showingSignIn = true
|
|
||||||
m.signInURL = signInURL
|
|
||||||
m.signInModel = modelName
|
|
||||||
m.signInSpinner = 0
|
|
||||||
m.signInFromModal = fromModal
|
|
||||||
|
|
||||||
config.OpenBrowser(signInURL)
|
|
||||||
|
|
||||||
return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
|
||||||
return signInTickMsg{}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkSignIn() tea.Msg {
|
|
||||||
client, err := api.ClientFromEnvironment()
|
|
||||||
if err != nil {
|
|
||||||
return signInCheckMsg{signedIn: false}
|
|
||||||
}
|
|
||||||
user, err := client.Whoami(context.Background())
|
|
||||||
if err == nil && user != nil && user.Name != "" {
|
|
||||||
return signInCheckMsg{signedIn: true, userName: user.Name}
|
|
||||||
}
|
|
||||||
return signInCheckMsg{signedIn: false}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) loadAvailableModels() {
|
|
||||||
m.availableModels = make(map[string]bool)
|
|
||||||
client, err := api.ClientFromEnvironment()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
models, err := client.List(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
cloudDisabled := cloudStatusDisabled(client)
|
|
||||||
for _, mdl := range models.Models {
|
|
||||||
if cloudDisabled && mdl.RemoteModel != "" {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
m.availableModels[mdl.Name] = true
|
items = append(items, integrationMenuItem(integrationState))
|
||||||
}
|
}
|
||||||
|
return items
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *model) buildItems() {
|
func initialCursor(state *launch.LauncherState, items []menuItem) int {
|
||||||
others := getOtherIntegrations()
|
if state == nil || state.LastSelection == "" {
|
||||||
m.items = make([]menuItem, 0, len(mainMenuItems)+1+len(others))
|
return 0
|
||||||
m.items = append(m.items, mainMenuItems...)
|
|
||||||
|
|
||||||
if m.showOthers {
|
|
||||||
m.items = append(m.items, others...)
|
|
||||||
} else {
|
|
||||||
m.items = append(m.items, othersMenuItem)
|
|
||||||
}
|
}
|
||||||
}
|
for i, item := range items {
|
||||||
|
if state.LastSelection == "run" && item.isRunModel {
|
||||||
func isOthersIntegration(name string) bool {
|
return i
|
||||||
for _, item := range getOtherIntegrations() {
|
}
|
||||||
if item.integration == name {
|
if item.integration == state.LastSelection {
|
||||||
return true
|
return i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return 0
|
||||||
}
|
|
||||||
|
|
||||||
func initialModel() model {
|
|
||||||
m := model{
|
|
||||||
cursor: 0,
|
|
||||||
}
|
|
||||||
m.loadAvailableModels()
|
|
||||||
|
|
||||||
lastSelection := config.LastSelection()
|
|
||||||
if isOthersIntegration(lastSelection) {
|
|
||||||
m.showOthers = true
|
|
||||||
}
|
|
||||||
|
|
||||||
m.buildItems()
|
|
||||||
|
|
||||||
if lastSelection != "" {
|
|
||||||
for i, item := range m.items {
|
|
||||||
if lastSelection == "run" && item.isRunModel {
|
|
||||||
m.cursor = i
|
|
||||||
break
|
|
||||||
} else if item.integration == lastSelection {
|
|
||||||
m.cursor = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return m
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m model) Init() tea.Cmd {
|
func (m model) Init() tea.Cmd {
|
||||||
@@ -357,143 +175,11 @@ func (m model) Init() tea.Cmd {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
if wmsg, ok := msg.(tea.WindowSizeMsg); ok {
|
|
||||||
wasSet := m.width > 0
|
|
||||||
m.width = wmsg.Width
|
|
||||||
if wasSet {
|
|
||||||
return m, tea.EnterAltScreen
|
|
||||||
}
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := msg.(clearStatusMsg); ok {
|
|
||||||
m.statusMsg = ""
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.showingSignIn {
|
|
||||||
switch msg := msg.(type) {
|
|
||||||
case tea.KeyMsg:
|
|
||||||
switch msg.Type {
|
|
||||||
case tea.KeyCtrlC, tea.KeyEsc:
|
|
||||||
m.showingSignIn = false
|
|
||||||
if m.signInFromModal {
|
|
||||||
m.showingModal = true
|
|
||||||
}
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
case signInTickMsg:
|
|
||||||
m.signInSpinner++
|
|
||||||
// Check sign-in status every 5th tick (~1 second)
|
|
||||||
if m.signInSpinner%5 == 0 {
|
|
||||||
return m, tea.Batch(
|
|
||||||
tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
|
||||||
return signInTickMsg{}
|
|
||||||
}),
|
|
||||||
checkSignIn,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
return m, tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
|
||||||
return signInTickMsg{}
|
|
||||||
})
|
|
||||||
|
|
||||||
case signInCheckMsg:
|
|
||||||
if msg.signedIn {
|
|
||||||
if m.signInFromModal {
|
|
||||||
m.modalSelector.selected = m.signInModel
|
|
||||||
m.changeModel = true
|
|
||||||
} else {
|
|
||||||
m.selected = true
|
|
||||||
}
|
|
||||||
m.quitting = true
|
|
||||||
return m, tea.Quit
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.showingMultiModal {
|
|
||||||
switch msg := msg.(type) {
|
|
||||||
case tea.KeyMsg:
|
|
||||||
if msg.Type == tea.KeyLeft {
|
|
||||||
m.showingMultiModal = false
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
updated, cmd := m.multiModalSelector.Update(msg)
|
|
||||||
m.multiModalSelector = updated.(multiSelectorModel)
|
|
||||||
|
|
||||||
if m.multiModalSelector.cancelled {
|
|
||||||
m.showingMultiModal = false
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
if m.multiModalSelector.confirmed {
|
|
||||||
var selected []string
|
|
||||||
if m.multiModalSelector.singleAdd != "" {
|
|
||||||
// Single-add mode: prepend picked model, keep existing deduped
|
|
||||||
selected = []string{m.multiModalSelector.singleAdd}
|
|
||||||
for _, name := range config.IntegrationModels(m.items[m.cursor].integration) {
|
|
||||||
if name != m.multiModalSelector.singleAdd {
|
|
||||||
selected = append(selected, name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Last checked is default (first in result)
|
|
||||||
co := m.multiModalSelector.checkOrder
|
|
||||||
last := co[len(co)-1]
|
|
||||||
selected = []string{m.multiModalSelector.items[last].Name}
|
|
||||||
for _, idx := range co {
|
|
||||||
if idx != last {
|
|
||||||
selected = append(selected, m.multiModalSelector.items[idx].Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(selected) > 0 {
|
|
||||||
m.changeModels = selected
|
|
||||||
m.changeModel = true
|
|
||||||
m.quitting = true
|
|
||||||
return m, tea.Quit
|
|
||||||
}
|
|
||||||
m.multiModalSelector.confirmed = false
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
return m, cmd
|
|
||||||
}
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.showingModal {
|
|
||||||
switch msg := msg.(type) {
|
|
||||||
case tea.KeyMsg:
|
|
||||||
switch msg.Type {
|
|
||||||
case tea.KeyCtrlC, tea.KeyEsc, tea.KeyLeft:
|
|
||||||
m.showingModal = false
|
|
||||||
return m, nil
|
|
||||||
|
|
||||||
case tea.KeyEnter:
|
|
||||||
filtered := m.modalSelector.filteredItems()
|
|
||||||
if len(filtered) > 0 && m.modalSelector.cursor < len(filtered) {
|
|
||||||
m.modalSelector.selected = filtered[m.modalSelector.cursor].Name
|
|
||||||
}
|
|
||||||
if m.modalSelector.selected != "" {
|
|
||||||
if cmd := m.checkCloudSignIn(m.modalSelector.selected, true); cmd != nil {
|
|
||||||
return m, cmd
|
|
||||||
}
|
|
||||||
m.changeModel = true
|
|
||||||
m.quitting = true
|
|
||||||
return m, tea.Quit
|
|
||||||
}
|
|
||||||
return m, nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
// Delegate navigation (up/down/pgup/pgdown/filter/backspace) to selectorModel
|
|
||||||
m.modalSelector.updateNavigation(msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
switch msg := msg.(type) {
|
switch msg := msg.(type) {
|
||||||
|
case tea.WindowSizeMsg:
|
||||||
|
m.width = msg.Width
|
||||||
|
return m, nil
|
||||||
|
|
||||||
case tea.KeyMsg:
|
case tea.KeyMsg:
|
||||||
switch msg.String() {
|
switch msg.String() {
|
||||||
case "ctrl+c", "q", "esc":
|
case "ctrl+c", "q", "esc":
|
||||||
@@ -504,162 +190,78 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||||||
if m.cursor > 0 {
|
if m.cursor > 0 {
|
||||||
m.cursor--
|
m.cursor--
|
||||||
}
|
}
|
||||||
// Auto-collapse "Others" when cursor moves back into pinned items
|
|
||||||
if m.showOthers && m.cursor < len(mainMenuItems) {
|
if m.showOthers && m.cursor < len(mainMenuItems) {
|
||||||
m.showOthers = false
|
m.showOthers = false
|
||||||
m.buildItems()
|
m.items = buildMenuItems(m.state, false)
|
||||||
|
m.cursor = min(m.cursor, len(m.items)-1)
|
||||||
}
|
}
|
||||||
|
return m, nil
|
||||||
|
|
||||||
case "down", "j":
|
case "down", "j":
|
||||||
if m.cursor < len(m.items)-1 {
|
if m.cursor < len(m.items)-1 {
|
||||||
m.cursor++
|
m.cursor++
|
||||||
}
|
}
|
||||||
// Auto-expand "Others..." when cursor lands on it
|
|
||||||
if m.cursor < len(m.items) && m.items[m.cursor].isOthers && !m.showOthers {
|
if m.cursor < len(m.items) && m.items[m.cursor].isOthers && !m.showOthers {
|
||||||
m.showOthers = true
|
m.showOthers = true
|
||||||
m.buildItems()
|
m.items = buildMenuItems(m.state, true)
|
||||||
// cursor now points at the first "other" integration
|
|
||||||
}
|
}
|
||||||
|
return m, nil
|
||||||
|
|
||||||
case "enter", " ":
|
case "enter", " ":
|
||||||
item := m.items[m.cursor]
|
if m.selectableItem(m.items[m.cursor]) {
|
||||||
|
m.selected = true
|
||||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) && !config.AutoInstallable(item.integration) {
|
m.action = actionForMenuItem(m.items[m.cursor], false)
|
||||||
return m, nil
|
m.quitting = true
|
||||||
|
return m, tea.Quit
|
||||||
}
|
}
|
||||||
|
return m, nil
|
||||||
var configuredModel string
|
|
||||||
if item.isRunModel {
|
|
||||||
configuredModel = config.LastModel()
|
|
||||||
} else if item.integration != "" {
|
|
||||||
configuredModel = config.IntegrationModel(item.integration)
|
|
||||||
}
|
|
||||||
if cmd := m.checkCloudSignIn(configuredModel, false); cmd != nil {
|
|
||||||
return m, cmd
|
|
||||||
}
|
|
||||||
|
|
||||||
if configuredModel != "" && isCloudModel(configuredModel) && cloudModelDisabled(configuredModel) {
|
|
||||||
if item.integration != "" && config.IsEditorIntegration(item.integration) {
|
|
||||||
m.openMultiModelModal(item.integration)
|
|
||||||
} else {
|
|
||||||
m.openModelModal(configuredModel)
|
|
||||||
}
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
m.selected = true
|
|
||||||
m.quitting = true
|
|
||||||
return m, tea.Quit
|
|
||||||
|
|
||||||
case "right", "l":
|
case "right", "l":
|
||||||
item := m.items[m.cursor]
|
item := m.items[m.cursor]
|
||||||
if item.integration != "" || item.isRunModel {
|
if item.isRunModel || m.changeableItem(item) {
|
||||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
|
m.selected = true
|
||||||
if config.AutoInstallable(item.integration) {
|
m.action = actionForMenuItem(item, true)
|
||||||
// Auto-installable: select to trigger install flow
|
m.quitting = true
|
||||||
m.selected = true
|
return m, tea.Quit
|
||||||
m.quitting = true
|
|
||||||
return m, tea.Quit
|
|
||||||
}
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
if item.integration != "" && config.IsEditorIntegration(item.integration) {
|
|
||||||
m.openMultiModelModal(item.integration)
|
|
||||||
} else {
|
|
||||||
var currentModel string
|
|
||||||
if item.isRunModel {
|
|
||||||
currentModel = config.LastModel()
|
|
||||||
} else if item.integration != "" {
|
|
||||||
currentModel = config.IntegrationModel(item.integration)
|
|
||||||
}
|
|
||||||
m.openModelModal(currentModel)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
return m, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m model) selectableItem(item menuItem) bool {
|
||||||
|
if item.isRunModel {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if item.integration == "" || item.isOthers {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
state, ok := m.state.Integrations[item.integration]
|
||||||
|
return ok && state.Selectable
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m model) changeableItem(item menuItem) bool {
|
||||||
|
if item.integration == "" || item.isOthers {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
state, ok := m.state.Integrations[item.integration]
|
||||||
|
return ok && state.Changeable
|
||||||
|
}
|
||||||
|
|
||||||
func (m model) View() string {
|
func (m model) View() string {
|
||||||
if m.quitting {
|
if m.quitting {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.showingSignIn {
|
|
||||||
return m.renderSignInDialog()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.showingMultiModal {
|
|
||||||
return m.multiModalSelector.View()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.showingModal {
|
|
||||||
return m.renderModal()
|
|
||||||
}
|
|
||||||
|
|
||||||
s := selectorTitleStyle.Render("Ollama "+versionStyle.Render(version.Version)) + "\n\n"
|
s := selectorTitleStyle.Render("Ollama "+versionStyle.Render(version.Version)) + "\n\n"
|
||||||
|
|
||||||
for i, item := range m.items {
|
for i, item := range m.items {
|
||||||
cursor := ""
|
s += m.renderMenuItem(i, item)
|
||||||
style := menuItemStyle
|
|
||||||
isInstalled := true
|
|
||||||
|
|
||||||
if item.integration != "" {
|
|
||||||
isInstalled = config.IsIntegrationInstalled(item.integration)
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.cursor == i {
|
|
||||||
cursor = "▸ "
|
|
||||||
if isInstalled {
|
|
||||||
style = menuSelectedItemStyle
|
|
||||||
} else {
|
|
||||||
style = greyedSelectedStyle
|
|
||||||
}
|
|
||||||
} else if !isInstalled && item.integration != "" {
|
|
||||||
style = greyedStyle
|
|
||||||
}
|
|
||||||
|
|
||||||
title := item.title
|
|
||||||
var modelSuffix string
|
|
||||||
if item.integration != "" {
|
|
||||||
if !isInstalled {
|
|
||||||
if config.AutoInstallable(item.integration) {
|
|
||||||
title += " " + notInstalledStyle.Render("(install)")
|
|
||||||
} else {
|
|
||||||
title += " " + notInstalledStyle.Render("(not installed)")
|
|
||||||
}
|
|
||||||
} else if m.cursor == i {
|
|
||||||
if mdl := config.IntegrationModel(item.integration); mdl != "" && m.modelExists(mdl) {
|
|
||||||
modelSuffix = " " + modelStyle.Render("("+mdl+")")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if item.isRunModel && m.cursor == i {
|
|
||||||
if mdl := config.LastModel(); mdl != "" && m.modelExists(mdl) {
|
|
||||||
modelSuffix = " " + modelStyle.Render("("+mdl+")")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s += style.Render(cursor+title) + modelSuffix + "\n"
|
|
||||||
|
|
||||||
desc := item.description
|
|
||||||
if !isInstalled && item.integration != "" && m.cursor == i {
|
|
||||||
if config.AutoInstallable(item.integration) {
|
|
||||||
desc = "Press enter to install"
|
|
||||||
} else if hint := config.IntegrationInstallHint(item.integration); hint != "" {
|
|
||||||
desc = hint
|
|
||||||
} else {
|
|
||||||
desc = "not installed"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s += menuDescStyle.Render(desc) + "\n\n"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.statusMsg != "" {
|
s += "\n" + selectorHelpStyle.Render("↑/↓ navigate • enter launch • → configure • esc quit")
|
||||||
s += "\n" + lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "124", Dark: "210"}).Render(m.statusMsg) + "\n"
|
|
||||||
}
|
|
||||||
|
|
||||||
s += "\n" + selectorHelpStyle.Render("↑/↓ navigate • enter launch • → change model • esc quit")
|
|
||||||
|
|
||||||
if m.width > 0 {
|
if m.width > 0 {
|
||||||
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
|
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
|
||||||
@@ -667,80 +269,125 @@ func (m model) View() string {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m model) renderModal() string {
|
func (m model) renderMenuItem(index int, item menuItem) string {
|
||||||
modalStyle := lipgloss.NewStyle().
|
cursor := ""
|
||||||
PaddingBottom(1).
|
style := menuItemStyle
|
||||||
PaddingRight(2)
|
title := item.title
|
||||||
|
description := item.description
|
||||||
|
modelSuffix := ""
|
||||||
|
|
||||||
s := modalStyle.Render(m.modalSelector.renderContent())
|
if m.cursor == index {
|
||||||
if m.width > 0 {
|
cursor = "▸ "
|
||||||
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m model) renderSignInDialog() string {
|
|
||||||
return renderSignIn(m.signInModel, m.signInURL, m.signInSpinner, m.width)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Selection int
|
|
||||||
|
|
||||||
const (
|
|
||||||
SelectionNone Selection = iota
|
|
||||||
SelectionRunModel
|
|
||||||
SelectionChangeRunModel
|
|
||||||
SelectionIntegration // Generic integration selection
|
|
||||||
SelectionChangeIntegration // Generic change model for integration
|
|
||||||
)
|
|
||||||
|
|
||||||
type Result struct {
|
|
||||||
Selection Selection
|
|
||||||
Integration string // integration name if applicable
|
|
||||||
Model string // model name if selected from single-select modal
|
|
||||||
Models []string // models selected from multi-select modal (Editor integrations)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Run() (Result, error) {
|
|
||||||
m := initialModel()
|
|
||||||
p := tea.NewProgram(m)
|
|
||||||
|
|
||||||
finalModel, err := p.Run()
|
|
||||||
if err != nil {
|
|
||||||
return Result{Selection: SelectionNone}, fmt.Errorf("error running TUI: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fm := finalModel.(model)
|
|
||||||
if fm.err != nil {
|
|
||||||
return Result{Selection: SelectionNone}, fm.err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !fm.selected && !fm.changeModel {
|
|
||||||
return Result{Selection: SelectionNone}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
item := fm.items[fm.cursor]
|
|
||||||
|
|
||||||
if fm.changeModel {
|
|
||||||
if item.isRunModel {
|
|
||||||
return Result{
|
|
||||||
Selection: SelectionChangeRunModel,
|
|
||||||
Model: fm.modalSelector.selected,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
return Result{
|
|
||||||
Selection: SelectionChangeIntegration,
|
|
||||||
Integration: item.integration,
|
|
||||||
Model: fm.modalSelector.selected,
|
|
||||||
Models: fm.changeModels,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if item.isRunModel {
|
if item.isRunModel {
|
||||||
return Result{Selection: SelectionRunModel}, nil
|
if m.cursor == index && m.state.RunModel != "" {
|
||||||
|
modelSuffix = " " + modelStyle.Render("("+m.state.RunModel+")")
|
||||||
|
}
|
||||||
|
if m.cursor == index {
|
||||||
|
style = menuSelectedItemStyle
|
||||||
|
}
|
||||||
|
} else if item.isOthers {
|
||||||
|
if m.cursor == index {
|
||||||
|
style = menuSelectedItemStyle
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
integrationState := m.state.Integrations[item.integration]
|
||||||
|
if !integrationState.Selectable {
|
||||||
|
if m.cursor == index {
|
||||||
|
style = greyedSelectedStyle
|
||||||
|
} else {
|
||||||
|
style = greyedStyle
|
||||||
|
}
|
||||||
|
} else if m.cursor == index {
|
||||||
|
style = menuSelectedItemStyle
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.cursor == index && integrationState.CurrentModel != "" {
|
||||||
|
modelSuffix = " " + modelStyle.Render("("+integrationState.CurrentModel+")")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !integrationState.Installed {
|
||||||
|
if integrationState.AutoInstallable {
|
||||||
|
title += " " + notInstalledStyle.Render("(install)")
|
||||||
|
} else {
|
||||||
|
title += " " + notInstalledStyle.Render("(not installed)")
|
||||||
|
}
|
||||||
|
if m.cursor == index {
|
||||||
|
if integrationState.AutoInstallable {
|
||||||
|
description = "Press enter to install"
|
||||||
|
} else if integrationState.InstallHint != "" {
|
||||||
|
description = integrationState.InstallHint
|
||||||
|
} else {
|
||||||
|
description = "not installed"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return Result{
|
return style.Render(cursor+title) + modelSuffix + "\n" + menuDescStyle.Render(description) + "\n\n"
|
||||||
Selection: SelectionIntegration,
|
}
|
||||||
Integration: item.integration,
|
|
||||||
}, nil
|
type TUIActionKind int
|
||||||
|
|
||||||
|
const (
|
||||||
|
TUIActionNone TUIActionKind = iota
|
||||||
|
TUIActionRunModel
|
||||||
|
TUIActionLaunchIntegration
|
||||||
|
)
|
||||||
|
|
||||||
|
type TUIAction struct {
|
||||||
|
Kind TUIActionKind
|
||||||
|
Integration string
|
||||||
|
ForceConfigure bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a TUIAction) LastSelection() string {
|
||||||
|
switch a.Kind {
|
||||||
|
case TUIActionRunModel:
|
||||||
|
return "run"
|
||||||
|
case TUIActionLaunchIntegration:
|
||||||
|
return a.Integration
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a TUIAction) RunModelRequest() launch.RunModelRequest {
|
||||||
|
return launch.RunModelRequest{ForcePicker: a.ForceConfigure}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a TUIAction) IntegrationLaunchRequest() launch.IntegrationLaunchRequest {
|
||||||
|
return launch.IntegrationLaunchRequest{
|
||||||
|
Name: a.Integration,
|
||||||
|
ForceConfigure: a.ForceConfigure,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func actionForMenuItem(item menuItem, forceConfigure bool) TUIAction {
|
||||||
|
switch {
|
||||||
|
case item.isRunModel:
|
||||||
|
return TUIAction{Kind: TUIActionRunModel, ForceConfigure: forceConfigure}
|
||||||
|
case item.integration != "":
|
||||||
|
return TUIAction{Kind: TUIActionLaunchIntegration, Integration: item.integration, ForceConfigure: forceConfigure}
|
||||||
|
default:
|
||||||
|
return TUIAction{Kind: TUIActionNone}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func RunMenu(state *launch.LauncherState) (TUIAction, error) {
|
||||||
|
menu := newModel(state)
|
||||||
|
program := tea.NewProgram(menu)
|
||||||
|
|
||||||
|
finalModel, err := program.Run()
|
||||||
|
if err != nil {
|
||||||
|
return TUIAction{Kind: TUIActionNone}, fmt.Errorf("error running TUI: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
finalMenu := finalModel.(model)
|
||||||
|
if !finalMenu.selected {
|
||||||
|
return TUIAction{Kind: TUIActionNone}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return finalMenu.action, nil
|
||||||
}
|
}
|
||||||
|
|||||||
178
cmd/tui/tui_test.go
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
package tui
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
"github.com/ollama/ollama/cmd/launch"
|
||||||
|
)
|
||||||
|
|
||||||
|
func launcherTestState() *launch.LauncherState {
|
||||||
|
return &launch.LauncherState{
|
||||||
|
LastSelection: "run",
|
||||||
|
RunModel: "qwen3:8b",
|
||||||
|
Integrations: map[string]launch.LauncherIntegrationState{
|
||||||
|
"claude": {
|
||||||
|
Name: "claude",
|
||||||
|
DisplayName: "Claude Code",
|
||||||
|
Description: "Anthropic's coding tool with subagents",
|
||||||
|
Selectable: true,
|
||||||
|
Changeable: true,
|
||||||
|
CurrentModel: "glm-5:cloud",
|
||||||
|
},
|
||||||
|
"codex": {
|
||||||
|
Name: "codex",
|
||||||
|
DisplayName: "Codex",
|
||||||
|
Description: "OpenAI's open-source coding agent",
|
||||||
|
Selectable: true,
|
||||||
|
Changeable: true,
|
||||||
|
},
|
||||||
|
"openclaw": {
|
||||||
|
Name: "openclaw",
|
||||||
|
DisplayName: "OpenClaw",
|
||||||
|
Description: "Personal AI with 100+ skills",
|
||||||
|
Selectable: true,
|
||||||
|
Changeable: true,
|
||||||
|
AutoInstallable: true,
|
||||||
|
},
|
||||||
|
"droid": {
|
||||||
|
Name: "droid",
|
||||||
|
DisplayName: "Droid",
|
||||||
|
Description: "Factory's coding agent across terminal and IDEs",
|
||||||
|
Selectable: true,
|
||||||
|
Changeable: true,
|
||||||
|
},
|
||||||
|
"pi": {
|
||||||
|
Name: "pi",
|
||||||
|
DisplayName: "Pi",
|
||||||
|
Description: "Minimal AI agent toolkit with plugin support",
|
||||||
|
Selectable: true,
|
||||||
|
Changeable: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMenuRendersPinnedItemsAndMore(t *testing.T) {
|
||||||
|
view := newModel(launcherTestState()).View()
|
||||||
|
for _, want := range []string{"Chat with a model", "Launch Claude Code", "Launch Codex", "Launch OpenClaw", "More..."} {
|
||||||
|
if !strings.Contains(view, want) {
|
||||||
|
t.Fatalf("expected menu view to contain %q\n%s", want, view)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMenuExpandsOthersFromLastSelection(t *testing.T) {
|
||||||
|
state := launcherTestState()
|
||||||
|
state.LastSelection = "pi"
|
||||||
|
|
||||||
|
menu := newModel(state)
|
||||||
|
if !menu.showOthers {
|
||||||
|
t.Fatal("expected others section to expand when last selection is in the overflow list")
|
||||||
|
}
|
||||||
|
view := menu.View()
|
||||||
|
if !strings.Contains(view, "Launch Pi") {
|
||||||
|
t.Fatalf("expected expanded view to contain overflow integration\n%s", view)
|
||||||
|
}
|
||||||
|
if strings.Contains(view, "More...") {
|
||||||
|
t.Fatalf("expected expanded view to replace More... item\n%s", view)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMenuEnterOnRunSelectsRun(t *testing.T) {
|
||||||
|
menu := newModel(launcherTestState())
|
||||||
|
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||||
|
got := updated.(model)
|
||||||
|
want := TUIAction{Kind: TUIActionRunModel}
|
||||||
|
if !got.selected || got.action != want {
|
||||||
|
t.Fatalf("expected enter on run to select run action, got selected=%v action=%v", got.selected, got.action)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMenuRightOnRunSelectsChangeRun(t *testing.T) {
|
||||||
|
menu := newModel(launcherTestState())
|
||||||
|
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
|
||||||
|
got := updated.(model)
|
||||||
|
want := TUIAction{Kind: TUIActionRunModel, ForceConfigure: true}
|
||||||
|
if !got.selected || got.action != want {
|
||||||
|
t.Fatalf("expected right on run to select change-run action, got selected=%v action=%v", got.selected, got.action)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMenuEnterOnIntegrationSelectsLaunch(t *testing.T) {
|
||||||
|
menu := newModel(launcherTestState())
|
||||||
|
menu.cursor = 1
|
||||||
|
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||||
|
got := updated.(model)
|
||||||
|
want := TUIAction{Kind: TUIActionLaunchIntegration, Integration: "claude"}
|
||||||
|
if !got.selected || got.action != want {
|
||||||
|
t.Fatalf("expected enter on integration to launch, got selected=%v action=%v", got.selected, got.action)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMenuRightOnIntegrationSelectsConfigure(t *testing.T) {
|
||||||
|
menu := newModel(launcherTestState())
|
||||||
|
menu.cursor = 1
|
||||||
|
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
|
||||||
|
got := updated.(model)
|
||||||
|
want := TUIAction{Kind: TUIActionLaunchIntegration, Integration: "claude", ForceConfigure: true}
|
||||||
|
if !got.selected || got.action != want {
|
||||||
|
t.Fatalf("expected right on integration to configure, got selected=%v action=%v", got.selected, got.action)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMenuIgnoresDisabledActions(t *testing.T) {
|
||||||
|
state := launcherTestState()
|
||||||
|
claude := state.Integrations["claude"]
|
||||||
|
claude.Selectable = false
|
||||||
|
claude.Changeable = false
|
||||||
|
state.Integrations["claude"] = claude
|
||||||
|
|
||||||
|
menu := newModel(state)
|
||||||
|
menu.cursor = 1
|
||||||
|
|
||||||
|
updatedEnter, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||||
|
if updatedEnter.(model).selected {
|
||||||
|
t.Fatal("expected non-selectable integration to ignore enter")
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedRight, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
|
||||||
|
if updatedRight.(model).selected {
|
||||||
|
t.Fatal("expected non-changeable integration to ignore right")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMenuShowsCurrentModelSuffixes(t *testing.T) {
|
||||||
|
menu := newModel(launcherTestState())
|
||||||
|
runView := menu.View()
|
||||||
|
if !strings.Contains(runView, "(qwen3:8b)") {
|
||||||
|
t.Fatalf("expected run row to show current model suffix\n%s", runView)
|
||||||
|
}
|
||||||
|
|
||||||
|
menu.cursor = 1
|
||||||
|
integrationView := menu.View()
|
||||||
|
if !strings.Contains(integrationView, "(glm-5:cloud)") {
|
||||||
|
t.Fatalf("expected integration row to show current model suffix\n%s", integrationView)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMenuShowsInstallStatusAndHint(t *testing.T) {
|
||||||
|
state := launcherTestState()
|
||||||
|
codex := state.Integrations["codex"]
|
||||||
|
codex.Installed = false
|
||||||
|
codex.Selectable = false
|
||||||
|
codex.Changeable = false
|
||||||
|
codex.InstallHint = "Install from https://example.com/codex"
|
||||||
|
state.Integrations["codex"] = codex
|
||||||
|
|
||||||
|
menu := newModel(state)
|
||||||
|
menu.cursor = 2
|
||||||
|
view := menu.View()
|
||||||
|
if !strings.Contains(view, "(not installed)") {
|
||||||
|
t.Fatalf("expected not-installed marker\n%s", view)
|
||||||
|
}
|
||||||
|
if !strings.Contains(view, codex.InstallHint) {
|
||||||
|
t.Fatalf("expected install hint in description\n%s", view)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,7 +12,6 @@ To use Ollama with tools that expect the Anthropic API (like Claude Code), set t
|
|||||||
|
|
||||||
```shell
|
```shell
|
||||||
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
|
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
|
||||||
export ANTHROPIC_API_KEY="" # required but ignored
|
|
||||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -269,7 +268,7 @@ ollama launch claude --config
|
|||||||
Set the environment variables and run Claude Code:
|
Set the environment variables and run Claude Code:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder
|
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 claude --model qwen3-coder
|
||||||
```
|
```
|
||||||
|
|
||||||
Or set the environment variables in your shell profile:
|
Or set the environment variables in your shell profile:
|
||||||
@@ -277,7 +276,6 @@ Or set the environment variables in your shell profile:
|
|||||||
```shell
|
```shell
|
||||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||||
export ANTHROPIC_API_KEY=""
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Then run Claude Code with any Ollama model:
|
Then run Claude Code with any Ollama model:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ Ollama provides compatibility with parts of the [OpenAI API](https://platform.op
|
|||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
### Simple `v1/chat/completions` example
|
### Simple `/v1/chat/completions` example
|
||||||
|
|
||||||
<CodeGroup dropdown>
|
<CodeGroup dropdown>
|
||||||
|
|
||||||
@@ -57,7 +57,7 @@ curl -X POST http://localhost:11434/v1/chat/completions \
|
|||||||
|
|
||||||
</CodeGroup>
|
</CodeGroup>
|
||||||
|
|
||||||
### Simple `v1/responses` example
|
### Simple `/v1/responses` example
|
||||||
|
|
||||||
<CodeGroup dropdown>
|
<CodeGroup dropdown>
|
||||||
|
|
||||||
@@ -103,7 +103,7 @@ curl -X POST http://localhost:11434/v1/responses \
|
|||||||
|
|
||||||
</CodeGroup>
|
</CodeGroup>
|
||||||
|
|
||||||
### v1/chat/completions with vision example
|
### `/v1/chat/completions` with vision example
|
||||||
|
|
||||||
<CodeGroup dropdown>
|
<CodeGroup dropdown>
|
||||||
|
|
||||||
@@ -184,6 +184,7 @@ curl -X POST http://localhost:11434/v1/chat/completions \
|
|||||||
- [x] Reproducible outputs
|
- [x] Reproducible outputs
|
||||||
- [x] Vision
|
- [x] Vision
|
||||||
- [x] Tools
|
- [x] Tools
|
||||||
|
- [x] Reasoning/thinking control (for thinking models)
|
||||||
- [ ] Logprobs
|
- [ ] Logprobs
|
||||||
|
|
||||||
#### Supported request fields
|
#### Supported request fields
|
||||||
@@ -207,6 +208,9 @@ curl -X POST http://localhost:11434/v1/chat/completions \
|
|||||||
- [x] `top_p`
|
- [x] `top_p`
|
||||||
- [x] `max_tokens`
|
- [x] `max_tokens`
|
||||||
- [x] `tools`
|
- [x] `tools`
|
||||||
|
- [x] `reasoning_effort` (`"high"`, `"medium"`, `"low"`, `"none"`)
|
||||||
|
- [x] `reasoning`
|
||||||
|
- [x] `effort` (`"high"`, `"medium"`, `"low"`, `"none"`)
|
||||||
- [ ] `tool_choice`
|
- [ ] `tool_choice`
|
||||||
- [ ] `logit_bias`
|
- [ ] `logit_bias`
|
||||||
- [ ] `user`
|
- [ ] `user`
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ Configure and launch external applications to use Ollama models. This provides a
|
|||||||
- **OpenCode** - Open-source coding assistant
|
- **OpenCode** - Open-source coding assistant
|
||||||
- **Claude Code** - Anthropic's agentic coding tool
|
- **Claude Code** - Anthropic's agentic coding tool
|
||||||
- **Codex** - OpenAI's coding assistant
|
- **Codex** - OpenAI's coding assistant
|
||||||
|
- **VS Code** - Microsoft's IDE with built-in AI chat
|
||||||
- **Droid** - Factory's AI coding agent
|
- **Droid** - Factory's AI coding agent
|
||||||
|
|
||||||
#### Examples
|
#### Examples
|
||||||
@@ -40,7 +41,7 @@ ollama launch claude
|
|||||||
Launch with a specific model:
|
Launch with a specific model:
|
||||||
|
|
||||||
```
|
```
|
||||||
ollama launch claude --model qwen3-coder
|
ollama launch claude --model qwen3.5
|
||||||
```
|
```
|
||||||
|
|
||||||
Configure without launching:
|
Configure without launching:
|
||||||
|
|||||||
@@ -51,6 +51,9 @@ Install prerequisites:
|
|||||||
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
|
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
|
||||||
- (Optional) VULKAN GPU support
|
- (Optional) VULKAN GPU support
|
||||||
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
|
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
|
||||||
|
- (Optional) MLX engine support
|
||||||
|
- [CUDA 13+ SDK](https://developer.nvidia.com/cuda-downloads)
|
||||||
|
- [cuDNN 9+](https://developer.nvidia.com/cudnn)
|
||||||
|
|
||||||
Then, configure and build the project:
|
Then, configure and build the project:
|
||||||
|
|
||||||
@@ -101,6 +104,10 @@ Install prerequisites:
|
|||||||
- (Optional) VULKAN GPU support
|
- (Optional) VULKAN GPU support
|
||||||
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
|
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
|
||||||
- Or install via package manager: `sudo apt install vulkan-sdk` (Ubuntu/Debian) or `sudo dnf install vulkan-sdk` (Fedora/CentOS)
|
- Or install via package manager: `sudo apt install vulkan-sdk` (Ubuntu/Debian) or `sudo dnf install vulkan-sdk` (Fedora/CentOS)
|
||||||
|
- (Optional) MLX engine support
|
||||||
|
- [CUDA 13+ SDK](https://developer.nvidia.com/cuda-downloads)
|
||||||
|
- [cuDNN 9+](https://developer.nvidia.com/cudnn)
|
||||||
|
- OpenBLAS/LAPACK: `sudo apt install libopenblas-dev liblapack-dev liblapacke-dev` (Ubuntu/Debian)
|
||||||
> [!IMPORTANT]
|
> [!IMPORTANT]
|
||||||
> Ensure prerequisites are in `PATH` before running CMake.
|
> Ensure prerequisites are in `PATH` before running CMake.
|
||||||
|
|
||||||
@@ -118,6 +125,67 @@ Lastly, run Ollama:
|
|||||||
go run . serve
|
go run . serve
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## MLX Engine (Optional)
|
||||||
|
|
||||||
|
The MLX engine enables running safetensor based models. It requires building the [MLX](https://github.com/ml-explore/mlx) and [MLX-C](https://github.com/ml-explore/mlx-c) shared libraries separately via CMake. On MacOS, MLX leverages the Metal library to run on the GPU, and on Windows and Linux, runs on NVIDIA GPUs via CUDA v13.
|
||||||
|
|
||||||
|
### macOS (Apple Silicon)
|
||||||
|
|
||||||
|
Requires the Metal toolchain. Install [Xcode](https://developer.apple.com/xcode/) first, then:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
xcodebuild -downloadComponent MetalToolchain
|
||||||
|
```
|
||||||
|
|
||||||
|
Verify it's installed correctly (should print "no input files"):
|
||||||
|
|
||||||
|
```shell
|
||||||
|
xcrun metal
|
||||||
|
```
|
||||||
|
|
||||||
|
Then build:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
cmake -B build --preset MLX
|
||||||
|
cmake --build build --preset MLX --parallel
|
||||||
|
cmake --install build --component MLX
|
||||||
|
```
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Without the Metal toolchain, cmake will silently complete with Metal disabled. Check the cmake output for `Setting MLX_BUILD_METAL=OFF` which indicates the toolchain is missing.
|
||||||
|
|
||||||
|
### Windows / Linux (CUDA)
|
||||||
|
|
||||||
|
Requires CUDA 13+ and [cuDNN](https://developer.nvidia.com/cudnn) 9+.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
cmake -B build --preset "MLX CUDA 13"
|
||||||
|
cmake --build build --target mlx --target mlxc --config Release --parallel
|
||||||
|
cmake --install build --component MLX --strip
|
||||||
|
```
|
||||||
|
|
||||||
|
### Local MLX source overrides
|
||||||
|
|
||||||
|
To build against a local checkout of MLX and/or MLX-C (useful for development), set environment variables before running CMake:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
export OLLAMA_MLX_SOURCE=/path/to/mlx
|
||||||
|
export OLLAMA_MLX_C_SOURCE=/path/to/mlx-c
|
||||||
|
```
|
||||||
|
|
||||||
|
For example, using the helper scripts with local mlx and mlx-c repos:
|
||||||
|
```shell
|
||||||
|
OLLAMA_MLX_SOURCE=../mlx OLLAMA_MLX_C_SOURCE=../mlx-c ./scripts/build_linux.sh
|
||||||
|
|
||||||
|
OLLAMA_MLX_SOURCE=../mlx OLLAMA_MLX_C_SOURCE=../mlx-c ./scripts/build_darwin.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
```powershell
|
||||||
|
$env:OLLAMA_MLX_SOURCE="../mlx"
|
||||||
|
$env:OLLAMA_MLX_C_SOURCE="../mlx-c"
|
||||||
|
./scripts/build_darwin.ps1
|
||||||
|
```
|
||||||
|
|
||||||
## Docker
|
## Docker
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
|||||||
@@ -127,6 +127,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"group": "IDEs & Editors",
|
"group": "IDEs & Editors",
|
||||||
|
"expanded": true,
|
||||||
"pages": [
|
"pages": [
|
||||||
"/integrations/cline",
|
"/integrations/cline",
|
||||||
"/integrations/jetbrains",
|
"/integrations/jetbrains",
|
||||||
@@ -160,6 +161,12 @@
|
|||||||
"group": "More information",
|
"group": "More information",
|
||||||
"pages": [
|
"pages": [
|
||||||
"/cli",
|
"/cli",
|
||||||
|
{
|
||||||
|
"group": "Assistant Sandboxing",
|
||||||
|
"pages": [
|
||||||
|
"/integrations/nemoclaw"
|
||||||
|
]
|
||||||
|
},
|
||||||
"/modelfile",
|
"/modelfile",
|
||||||
"/context-length",
|
"/context-length",
|
||||||
"/linux",
|
"/linux",
|
||||||
|
|||||||
33
docs/gpu.mdx
@@ -61,11 +61,17 @@ Ollama supports the following AMD GPUs via the ROCm library:
|
|||||||
|
|
||||||
### Linux Support
|
### Linux Support
|
||||||
|
|
||||||
| Family | Cards and accelerators |
|
Ollama requires the AMD ROCm v7 driver on Linux. You can install or upgrade
|
||||||
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
using the `amdgpu-install` utility from
|
||||||
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` |
|
[AMD's ROCm documentation](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/).
|
||||||
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `SSG` |
|
|
||||||
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` |
|
| Family | Cards and accelerators |
|
||||||
|
| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| AMD Radeon RX | `9070 XT` `9070 GRE` `9070` `9060 XT` `9060 XT LP` `9060` `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7700` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `5700 XT` `5700` `5600 XT` `5500 XT` |
|
||||||
|
| AMD Radeon AI PRO | `R9700` `R9600D` |
|
||||||
|
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` |
|
||||||
|
| AMD Ryzen AI | `Ryzen AI Max+ 395` `Ryzen AI Max 390` `Ryzen AI Max 385` `Ryzen AI 9 HX 475` `Ryzen AI 9 HX 470` `Ryzen AI 9 465` `Ryzen AI 9 HX 375` `Ryzen AI 9 HX 370` `Ryzen AI 9 365` |
|
||||||
|
| AMD Instinct | `MI350X` `MI300X` `MI300A` `MI250X` `MI250` `MI210` `MI100` |
|
||||||
|
|
||||||
### Windows Support
|
### Windows Support
|
||||||
|
|
||||||
@@ -97,17 +103,20 @@ This table shows some example GPUs that map to these LLVM targets:
|
|||||||
| **LLVM Target** | **An Example GPU** |
|
| **LLVM Target** | **An Example GPU** |
|
||||||
|-----------------|---------------------|
|
|-----------------|---------------------|
|
||||||
| gfx908 | Radeon Instinct MI100 |
|
| gfx908 | Radeon Instinct MI100 |
|
||||||
| gfx90a | Radeon Instinct MI210 |
|
| gfx90a | Radeon Instinct MI210/MI250 |
|
||||||
| gfx940 | Radeon Instinct MI300 |
|
| gfx942 | Radeon Instinct MI300X/MI300A |
|
||||||
| gfx941 | |
|
| gfx950 | Radeon Instinct MI350X |
|
||||||
| gfx942 | |
|
| gfx1010 | Radeon RX 5700 XT |
|
||||||
|
| gfx1012 | Radeon RX 5500 XT |
|
||||||
| gfx1030 | Radeon PRO V620 |
|
| gfx1030 | Radeon PRO V620 |
|
||||||
| gfx1100 | Radeon PRO W7900 |
|
| gfx1100 | Radeon PRO W7900 |
|
||||||
| gfx1101 | Radeon PRO W7700 |
|
| gfx1101 | Radeon PRO W7700 |
|
||||||
| gfx1102 | Radeon RX 7600 |
|
| gfx1102 | Radeon RX 7600 |
|
||||||
|
| gfx1103 | Radeon 780M |
|
||||||
AMD is working on enhancing ROCm v6 to broaden support for families of GPUs in a
|
| gfx1150 | Ryzen AI 9 HX 375 |
|
||||||
future release which should increase support for more GPUs.
|
| gfx1151 | Ryzen AI Max+ 395 |
|
||||||
|
| gfx1200 | Radeon RX 9070 |
|
||||||
|
| gfx1201 | Radeon RX 9070 XT |
|
||||||
|
|
||||||
Reach out on [Discord](https://discord.gg/ollama) or file an
|
Reach out on [Discord](https://discord.gg/ollama) or file an
|
||||||
[issue](https://github.com/ollama/ollama/issues) for additional help.
|
[issue](https://github.com/ollama/ollama/issues) for additional help.
|
||||||
|
|||||||
BIN
docs/images/local.png
Normal file
|
After Width: | Height: | Size: 29 KiB |
BIN
docs/images/vscode-add-ollama.png
Normal file
|
After Width: | Height: | Size: 64 KiB |
|
Before Width: | Height: | Size: 77 KiB |
|
Before Width: | Height: | Size: 56 KiB |
BIN
docs/images/vscode-other-models.png
Normal file
|
After Width: | Height: | Size: 52 KiB |
BIN
docs/images/vscode-unhide.png
Normal file
|
After Width: | Height: | Size: 67 KiB |
BIN
docs/images/vscode.png
Normal file
|
After Width: | Height: | Size: 2.7 MiB |
@@ -4,7 +4,7 @@ title: Claude Code
|
|||||||
|
|
||||||
Claude Code is Anthropic's agentic coding tool that can read, modify, and execute code in your working directory.
|
Claude Code is Anthropic's agentic coding tool that can read, modify, and execute code in your working directory.
|
||||||
|
|
||||||
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `glm-4.7`, `qwen3-coder`, `gpt-oss`.
|
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `qwen3.5`, `glm-5:cloud`, `kimi-k2.5:cloud`.
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
@@ -32,13 +32,83 @@ irm https://claude.ai/install.ps1 | iex
|
|||||||
ollama launch claude
|
ollama launch claude
|
||||||
```
|
```
|
||||||
|
|
||||||
To configure without launching:
|
### Run directly with a model
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
ollama launch claude --config
|
ollama launch claude --model kimi-k2.5:cloud
|
||||||
```
|
```
|
||||||
|
|
||||||
### Manual setup
|
## Recommended Models
|
||||||
|
|
||||||
|
- `kimi-k2.5:cloud`
|
||||||
|
- `glm-5:cloud`
|
||||||
|
- `minimax-m2.7:cloud`
|
||||||
|
- `qwen3.5:cloud`
|
||||||
|
- `glm-4.7-flash`
|
||||||
|
- `qwen3.5`
|
||||||
|
|
||||||
|
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
||||||
|
|
||||||
|
## Non-interactive (headless) mode
|
||||||
|
|
||||||
|
Run Claude Code without interaction for use in Docker, CI/CD, or scripts:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama launch claude --model kimi-k2.5:cloud --yes -- -p "how does this repository work?"
|
||||||
|
```
|
||||||
|
|
||||||
|
The `--yes` flag auto-pulls the model, skips selectors, and requires `--model` to be specified. Arguments after `--` are passed directly to Claude Code.
|
||||||
|
|
||||||
|
## Web search
|
||||||
|
|
||||||
|
Claude Code can search the web through Ollama's web search API. See the [web search documentation](/capabilities/web-search) for setup and usage.
|
||||||
|
|
||||||
|
## Scheduled Tasks with `/loop`
|
||||||
|
|
||||||
|
The `/loop` command runs a prompt or slash command on a recurring schedule inside Claude Code. This is useful for automating repetitive tasks like checking PRs, running research, or setting reminders.
|
||||||
|
|
||||||
|
```
|
||||||
|
/loop <interval> <prompt or /command>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
|
**Check in on your PRs**
|
||||||
|
|
||||||
|
```
|
||||||
|
/loop 30m Check my open PRs and summarize their status
|
||||||
|
```
|
||||||
|
|
||||||
|
**Automate research tasks**
|
||||||
|
|
||||||
|
```
|
||||||
|
/loop 1h Research the latest AI news and summarize key developments
|
||||||
|
```
|
||||||
|
|
||||||
|
**Automate bug reporting and triaging**
|
||||||
|
|
||||||
|
```
|
||||||
|
/loop 15m Check for new GitHub issues and triage by priority
|
||||||
|
```
|
||||||
|
|
||||||
|
**Set reminders**
|
||||||
|
|
||||||
|
```
|
||||||
|
/loop 1h Remind me to review the deploy status
|
||||||
|
```
|
||||||
|
|
||||||
|
## Telegram
|
||||||
|
|
||||||
|
Chat with Claude Code from Telegram by connecting a bot to your session. Install the [Telegram plugin](https://github.com/anthropics/claude-plugins-official), create a bot via [@BotFather](https://t.me/BotFather), then launch with the channel flag:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama launch claude -- --channels plugin:telegram@claude-plugins-official
|
||||||
|
```
|
||||||
|
|
||||||
|
Claude Code will prompt for permission on most actions. To allow the bot to work autonomously, configure [permission rules](https://code.claude.com/docs/en/permissions) or pass `--dangerously-skip-permissions` in isolated environments.
|
||||||
|
|
||||||
|
See the [plugin README](https://github.com/anthropics/claude-plugins-official/tree/main/external_plugins/telegram) for full setup instructions including pairing and access control.
|
||||||
|
|
||||||
|
## Manual setup
|
||||||
|
|
||||||
Claude Code connects to Ollama using the Anthropic-compatible API.
|
Claude Code connects to Ollama using the Anthropic-compatible API.
|
||||||
|
|
||||||
@@ -53,23 +123,14 @@ export ANTHROPIC_BASE_URL=http://localhost:11434
|
|||||||
2. Run Claude Code with an Ollama model:
|
2. Run Claude Code with an Ollama model:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
claude --model gpt-oss:20b
|
claude --model qwen3.5
|
||||||
```
|
```
|
||||||
|
|
||||||
Or run with environment variables inline:
|
Or run with environment variables inline:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder
|
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model glm-5:cloud
|
||||||
```
|
```
|
||||||
|
|
||||||
**Note:** Claude Code requires a large context window. We recommend at least 64k tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
|
**Note:** Claude Code requires a large context window. We recommend at least 64k tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
|
||||||
|
|
||||||
## Recommended Models
|
|
||||||
|
|
||||||
- `qwen3-coder`
|
|
||||||
- `glm-4.7`
|
|
||||||
- `gpt-oss:20b`
|
|
||||||
- `gpt-oss:120b`
|
|
||||||
|
|
||||||
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
|
||||||
|
|
||||||
|
|||||||
67
docs/integrations/nemoclaw.mdx
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
---
|
||||||
|
title: NemoClaw
|
||||||
|
---
|
||||||
|
|
||||||
|
NemoClaw is NVIDIA's open source security stack for [OpenClaw](/integrations/openclaw). It wraps OpenClaw with the NVIDIA OpenShell runtime to provide kernel-level sandboxing, network policy controls, and audit trails for AI agents.
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
Pull a model:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ollama pull nemotron-3-nano:30b
|
||||||
|
```
|
||||||
|
|
||||||
|
Run the installer:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -fsSL https://www.nvidia.com/nemoclaw.sh | \
|
||||||
|
NEMOCLAW_NON_INTERACTIVE=1 \
|
||||||
|
NEMOCLAW_PROVIDER=ollama \
|
||||||
|
NEMOCLAW_MODEL=nemotron-3-nano:30b \
|
||||||
|
bash
|
||||||
|
```
|
||||||
|
|
||||||
|
Connect to your sandbox:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
nemoclaw my-assistant connect
|
||||||
|
```
|
||||||
|
|
||||||
|
Open the TUI:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
openclaw tui
|
||||||
|
```
|
||||||
|
|
||||||
|
<Note>Ollama support in NemoClaw is still experimental.</Note>
|
||||||
|
|
||||||
|
## Platform support
|
||||||
|
|
||||||
|
| Platform | Runtime | Status |
|
||||||
|
|----------|---------|--------|
|
||||||
|
| Linux (Ubuntu 22.04+) | Docker | Primary |
|
||||||
|
| macOS (Apple Silicon) | Colima or Docker Desktop | Supported |
|
||||||
|
| Windows | WSL2 with Docker Desktop | Supported |
|
||||||
|
|
||||||
|
CMD and PowerShell are not supported on Windows — WSL2 is required.
|
||||||
|
|
||||||
|
<Note>Ollama must be installed and running before the installer runs. When running inside WSL2 or a container, ensure Ollama is reachable from the sandbox (e.g. `OLLAMA_HOST=0.0.0.0`).</Note>
|
||||||
|
|
||||||
|
## System requirements
|
||||||
|
|
||||||
|
- CPU: 4 vCPU minimum
|
||||||
|
- RAM: 8 GB minimum (16 GB recommended)
|
||||||
|
- Disk: 20 GB free (40 GB recommended for local models)
|
||||||
|
- Node.js 20+ and npm 10+
|
||||||
|
- Container runtime (Docker preferred)
|
||||||
|
|
||||||
|
## Recommended models
|
||||||
|
|
||||||
|
- `nemotron-3-super:cloud` — Strong reasoning and coding
|
||||||
|
- `qwen3.5:cloud` — 397B; reasoning and code generation
|
||||||
|
- `nemotron-3-nano:30b` — Recommended local model; fits in 24 GB VRAM
|
||||||
|
- `qwen3.5:27b` — Fast local reasoning (~18 GB VRAM)
|
||||||
|
- `glm-4.7-flash` — Reasoning and code generation (~25 GB VRAM)
|
||||||
|
|
||||||
|
More models at [ollama.com/search](https://ollama.com/search).
|
||||||
@@ -15,13 +15,29 @@ Ollama handles everything automatically:
|
|||||||
1. **Install** — If OpenClaw isn't installed, Ollama prompts to install it via npm
|
1. **Install** — If OpenClaw isn't installed, Ollama prompts to install it via npm
|
||||||
2. **Security** — On the first launch, a security notice explains the risks of tool access
|
2. **Security** — On the first launch, a security notice explains the risks of tool access
|
||||||
3. **Model** — Pick a model from the selector (local or cloud)
|
3. **Model** — Pick a model from the selector (local or cloud)
|
||||||
4. **Onboarding** — Ollama configures the provider, installs the gateway daemon, and sets your model as the primary
|
4. **Onboarding** — Ollama configures the provider, installs the gateway daemon, sets your model as the primary, and installs the web search and fetch plugin
|
||||||
5. **Gateway** — Starts in the background and opens the OpenClaw TUI
|
5. **Gateway** — Starts in the background and opens the OpenClaw TUI
|
||||||
|
|
||||||
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens if using local models. See [Context length](/context-length) for more information.</Note>
|
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens if using local models. See [Context length](/context-length) for more information.</Note>
|
||||||
|
|
||||||
<Note>Previously known as Clawdbot. `ollama launch clawdbot` still works as an alias.</Note>
|
<Note>Previously known as Clawdbot. `ollama launch clawdbot` still works as an alias.</Note>
|
||||||
|
|
||||||
|
## Web search and fetch
|
||||||
|
|
||||||
|
OpenClaw ships with a web search and fetch plugin that gives local or cloud models the ability to search the web and extract readable page content.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ollama launch openclaw
|
||||||
|
```
|
||||||
|
|
||||||
|
Web search and fetch is enabled automatically when launching OpenClaw through Ollama. To install the plugin directly:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
openclaw plugins install @ollama/openclaw-web-search
|
||||||
|
```
|
||||||
|
|
||||||
|
<Note>Web search for local models requires `ollama signin`.</Note>
|
||||||
|
|
||||||
## Configure without launching
|
## Configure without launching
|
||||||
|
|
||||||
To change the model without starting the gateway and TUI:
|
To change the model without starting the gateway and TUI:
|
||||||
@@ -43,7 +59,7 @@ If the gateway is already running, it restarts automatically to pick up the new
|
|||||||
**Cloud models**:
|
**Cloud models**:
|
||||||
|
|
||||||
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
|
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
|
||||||
- `minimax-m2.5:cloud` — Fast, efficient coding and real-world productivity
|
- `minimax-m2.7:cloud` — Fast, efficient coding and real-world productivity
|
||||||
- `glm-5:cloud` — Reasoning and code generation
|
- `glm-5:cloud` — Reasoning and code generation
|
||||||
|
|
||||||
**Local models:**
|
**Local models:**
|
||||||
@@ -52,6 +68,16 @@ If the gateway is already running, it restarts automatically to pick up the new
|
|||||||
|
|
||||||
More models at [ollama.com/search](https://ollama.com/search?c=cloud).
|
More models at [ollama.com/search](https://ollama.com/search?c=cloud).
|
||||||
|
|
||||||
|
## Non-interactive (headless) mode
|
||||||
|
|
||||||
|
Run OpenClaw without interaction for use in Docker, CI/CD, or scripts:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ollama launch openclaw --model kimi-k2.5:cloud --yes
|
||||||
|
```
|
||||||
|
|
||||||
|
The `--yes` flag auto-pulls the model, skips selectors, and requires `--model` to be specified.
|
||||||
|
|
||||||
## Connect messaging apps
|
## Connect messaging apps
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -2,33 +2,84 @@
|
|||||||
title: VS Code
|
title: VS Code
|
||||||
---
|
---
|
||||||
|
|
||||||
## Install
|
VS Code includes built-in AI chat through GitHub Copilot Chat. Ollama models can be used directly in the Copilot Chat model picker.
|
||||||
|
|
||||||
Install [VS Code](https://code.visualstudio.com/download).
|
|
||||||
|
|
||||||
## Usage with Ollama
|

|
||||||
|
|
||||||
1. Open Copilot side bar found in top right window
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
- Ollama v0.18.3+
|
||||||
|
- [VS Code 1.113+](https://code.visualstudio.com/download)
|
||||||
|
- [GitHub Copilot Chat extension 0.41.0+](https://marketplace.visualstudio.com/items?itemName=GitHub.copilot-chat)
|
||||||
|
|
||||||
|
<Note> VS Code requires you to be logged in to use its model selector, even for custom models. This doesn't require a paid GitHub Copilot account; GitHub Copilot Free will enable model selection for custom models.</Note>
|
||||||
|
|
||||||
|
## Quick setup
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama launch vscode
|
||||||
|
```
|
||||||
|
|
||||||
|
Recommended models will be shown after running the command. See the latest models at [ollama.com](https://ollama.com/search?c=tools).
|
||||||
|
|
||||||
|
Make sure **Local** is selected at the bottom of the Copilot Chat panel to use your Ollama models.
|
||||||
|
<div style={{ display: "flex", justifyContent: "center" }}>
|
||||||
|
<img
|
||||||
|
src="/images/local.png"
|
||||||
|
alt="Ollama Local Models"
|
||||||
|
width="60%"
|
||||||
|
style={{ borderRadius: "4px", marginTop: "10px", marginBottom: "10px" }}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
## Run directly with a model
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama launch vscode --model qwen3.5:cloud
|
||||||
|
```
|
||||||
|
Cloud models are also available at [ollama.com](https://ollama.com/search?c=cloud).
|
||||||
|
|
||||||
|
## Manual setup
|
||||||
|
|
||||||
|
To configure Ollama manually without `ollama launch`:
|
||||||
|
|
||||||
|
1. Open the **Copilot Chat** side bar from the top right corner
|
||||||
<div style={{ display: "flex", justifyContent: "center" }}>
|
<div style={{ display: "flex", justifyContent: "center" }}>
|
||||||
<img
|
<img
|
||||||
src="/images/vscode-sidebar.png"
|
src="/images/vscode-sidebar.png"
|
||||||
alt="VS Code chat Sidebar"
|
alt="VS Code chat Sidebar"
|
||||||
width="75%"
|
width="75%"
|
||||||
|
style={{ borderRadius: "4px" }}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
2. Select the model dropdown > **Manage models**
|
2. Click the **settings gear icon** (<Icon icon="gear" />) to bring up the Language Models window
|
||||||
<div style={{ display: "flex", justifyContent: "center" }}>
|
<div style={{ display: "flex", justifyContent: "center" }}>
|
||||||
<img
|
<img
|
||||||
src="/images/vscode-models.png"
|
src="/images/vscode-other-models.png"
|
||||||
alt="VS Code model picker"
|
alt="VS Code model picker"
|
||||||
width="75%"
|
width="75%"
|
||||||
|
style={{ borderRadius: "4px" }}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
3. Enter **Ollama** under **Provider Dropdown** and select desired models (e.g `qwen3, qwen3-coder:480b-cloud`)
|
3. Click **Add Models** and select **Ollama** to load all your Ollama models into VS Code
|
||||||
<div style={{ display: "flex", justifyContent: "center" }}>
|
<div style={{ display: "flex", justifyContent: "center" }}>
|
||||||
<img
|
<img
|
||||||
src="/images/vscode-model-options.png"
|
src="/images/vscode-add-ollama.png"
|
||||||
alt="VS Code model options dropdown"
|
alt="VS Code model options dropdown to add ollama models"
|
||||||
width="75%"
|
width="75%"
|
||||||
|
style={{ borderRadius: "4px" }}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
4. Click the **Unhide** button in the model picker to show your Ollama models
|
||||||
|
<div style={{ display: "flex", justifyContent: "center" }}>
|
||||||
|
<img
|
||||||
|
src="/images/vscode-unhide.png"
|
||||||
|
alt="VS Code unhide models button"
|
||||||
|
width="75%"
|
||||||
|
style={{ borderRadius: "4px" }}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ nvidia-smi
|
|||||||
|
|
||||||
### Install AMD ROCm drivers (optional)
|
### Install AMD ROCm drivers (optional)
|
||||||
|
|
||||||
[Download and Install](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/tutorial/quick-start.html) ROCm v6.
|
[Download and Install](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/tutorial/quick-start.html) ROCm v7.
|
||||||
|
|
||||||
### Start Ollama
|
### Start Ollama
|
||||||
|
|
||||||
|
|||||||
@@ -114,6 +114,25 @@ If you are experiencing problems getting Ollama to correctly discover or use you
|
|||||||
- `OLLAMA_DEBUG=1` During GPU discovery additional information will be reported
|
- `OLLAMA_DEBUG=1` During GPU discovery additional information will be reported
|
||||||
- Check dmesg for any errors from amdgpu or kfd drivers `sudo dmesg | grep -i amdgpu` and `sudo dmesg | grep -i kfd`
|
- Check dmesg for any errors from amdgpu or kfd drivers `sudo dmesg | grep -i amdgpu` and `sudo dmesg | grep -i kfd`
|
||||||
|
|
||||||
|
### AMD Driver Version Mismatch
|
||||||
|
|
||||||
|
If your AMD GPU is not detected on Linux and the server logs contain messages like:
|
||||||
|
|
||||||
|
```
|
||||||
|
msg="failure during GPU discovery" ... error="failed to finish discovery before timeout"
|
||||||
|
msg="bootstrap discovery took" duration=30s ...
|
||||||
|
```
|
||||||
|
|
||||||
|
This typically means the system's AMD GPU driver is too old. Ollama bundles
|
||||||
|
ROCm 7 linux libraries which require a compatible ROCm 7 kernel driver. If the
|
||||||
|
system is running an older driver (ROCm 6.x or earlier), GPU initialization
|
||||||
|
will hang during device discovery and eventually time out, causing Ollama to
|
||||||
|
fall back to CPU.
|
||||||
|
|
||||||
|
To resolve this, upgrade to the ROCm v7 driver using the `amdgpu-install`
|
||||||
|
utility from [AMD's ROCm documentation](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/).
|
||||||
|
After upgrading, reboot and restart Ollama.
|
||||||
|
|
||||||
## Multiple AMD GPUs
|
## Multiple AMD GPUs
|
||||||
|
|
||||||
If you experience gibberish responses when models load across multiple AMD GPUs on Linux, see the following guide.
|
If you experience gibberish responses when models load across multiple AMD GPUs on Linux, see the following guide.
|
||||||
|
|||||||
@@ -80,9 +80,13 @@ help you keep up to date.
|
|||||||
|
|
||||||
If you'd like to install or integrate Ollama as a service, a standalone
|
If you'd like to install or integrate Ollama as a service, a standalone
|
||||||
`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI
|
`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI
|
||||||
and GPU library dependencies for Nvidia. If you have an AMD GPU, also download
|
and GPU library dependencies for Nvidia. Depending on your hardware, you may also
|
||||||
and extract the additional ROCm package `ollama-windows-amd64-rocm.zip` into the
|
need to download and extract additional packages into the same directory:
|
||||||
same directory. This allows for embedding Ollama in existing applications, or
|
|
||||||
|
- **AMD GPU**: `ollama-windows-amd64-rocm.zip`
|
||||||
|
- **MLX (CUDA)**: `ollama-windows-amd64-mlx.zip`
|
||||||
|
|
||||||
|
This allows for embedding Ollama in existing applications, or
|
||||||
running it as a system service via `ollama serve` with tools such as
|
running it as a system service via `ollama serve` with tools such as
|
||||||
[NSSM](https://nssm.cc/).
|
[NSSM](https://nssm.cc/).
|
||||||
|
|
||||||
|
|||||||
@@ -59,6 +59,29 @@ func Host() *url.URL {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConnectableHost returns Host() with unspecified bind addresses (0.0.0.0, ::)
|
||||||
|
// replaced by the corresponding loopback address (127.0.0.1, ::1).
|
||||||
|
// Unspecified addresses are valid for binding a server socket but not for
|
||||||
|
// connecting as a client, which fails on Windows.
|
||||||
|
func ConnectableHost() *url.URL {
|
||||||
|
u := Host()
|
||||||
|
host, port, err := net.SplitHostPort(u.Host)
|
||||||
|
if err != nil {
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
if ip := net.ParseIP(host); ip != nil && ip.IsUnspecified() {
|
||||||
|
if ip.To4() != nil {
|
||||||
|
host = "127.0.0.1"
|
||||||
|
} else {
|
||||||
|
host = "::1"
|
||||||
|
}
|
||||||
|
u.Host = net.JoinHostPort(host, port)
|
||||||
|
}
|
||||||
|
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// AllowedOrigins returns a list of allowed origins. AllowedOrigins can be configured via the OLLAMA_ORIGINS environment variable.
|
// AllowedOrigins returns a list of allowed origins. AllowedOrigins can be configured via the OLLAMA_ORIGINS environment variable.
|
||||||
func AllowedOrigins() (origins []string) {
|
func AllowedOrigins() (origins []string) {
|
||||||
if s := Var("OLLAMA_ORIGINS"); s != "" {
|
if s := Var("OLLAMA_ORIGINS"); s != "" {
|
||||||
@@ -191,6 +214,8 @@ func LogLevel() slog.Level {
|
|||||||
var (
|
var (
|
||||||
// FlashAttention enables the experimental flash attention feature.
|
// FlashAttention enables the experimental flash attention feature.
|
||||||
FlashAttention = BoolWithDefault("OLLAMA_FLASH_ATTENTION")
|
FlashAttention = BoolWithDefault("OLLAMA_FLASH_ATTENTION")
|
||||||
|
// DebugLogRequests logs inference requests to disk for replay/debugging.
|
||||||
|
DebugLogRequests = Bool("OLLAMA_DEBUG_LOG_REQUESTS")
|
||||||
// KvCacheType is the quantization type for the K/V cache.
|
// KvCacheType is the quantization type for the K/V cache.
|
||||||
KvCacheType = String("OLLAMA_KV_CACHE_TYPE")
|
KvCacheType = String("OLLAMA_KV_CACHE_TYPE")
|
||||||
// NoHistory disables readline history.
|
// NoHistory disables readline history.
|
||||||
@@ -279,28 +304,29 @@ type EnvVar struct {
|
|||||||
|
|
||||||
func AsMap() map[string]EnvVar {
|
func AsMap() map[string]EnvVar {
|
||||||
ret := map[string]EnvVar{
|
ret := map[string]EnvVar{
|
||||||
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
||||||
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(false), "Enabled flash attention"},
|
"OLLAMA_DEBUG_LOG_REQUESTS": {"OLLAMA_DEBUG_LOG_REQUESTS", DebugLogRequests(), "Log inference request bodies and replay curl commands to a temp directory"},
|
||||||
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
|
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(false), "Enabled flash attention"},
|
||||||
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
|
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
|
||||||
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
|
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
|
||||||
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
|
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
|
||||||
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"},
|
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
|
||||||
"OLLAMA_LOAD_TIMEOUT": {"OLLAMA_LOAD_TIMEOUT", LoadTimeout(), "How long to allow model loads to stall before giving up (default \"5m\")"},
|
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"},
|
||||||
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
|
"OLLAMA_LOAD_TIMEOUT": {"OLLAMA_LOAD_TIMEOUT", LoadTimeout(), "How long to allow model loads to stall before giving up (default \"5m\")"},
|
||||||
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
|
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
|
||||||
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
|
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
|
||||||
"OLLAMA_NO_CLOUD": {"OLLAMA_NO_CLOUD", NoCloud(), "Disable Ollama cloud features (remote inference and web search)"},
|
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
|
||||||
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
|
"OLLAMA_NO_CLOUD": {"OLLAMA_NO_CLOUD", NoCloud(), "Disable Ollama cloud features (remote inference and web search)"},
|
||||||
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
|
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
|
||||||
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
|
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
|
||||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
|
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
|
||||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
|
||||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
|
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||||
"OLLAMA_EDITOR": {"OLLAMA_EDITOR", Editor(), "Path to editor for interactive prompt editing (Ctrl+G)"},
|
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
|
||||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
"OLLAMA_EDITOR": {"OLLAMA_EDITOR", Editor(), "Path to editor for interactive prompt editing (Ctrl+G)"},
|
||||||
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
|
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||||
|
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
|
||||||
|
|
||||||
// Informational
|
// Informational
|
||||||
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
||||||
|
|||||||
@@ -52,6 +52,37 @@ func TestHost(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnectableHost(t *testing.T) {
|
||||||
|
cases := map[string]struct {
|
||||||
|
value string
|
||||||
|
expect string
|
||||||
|
}{
|
||||||
|
"empty": {"", "http://127.0.0.1:11434"},
|
||||||
|
"localhost": {"127.0.0.1", "http://127.0.0.1:11434"},
|
||||||
|
"localhost and port": {"127.0.0.1:1234", "http://127.0.0.1:1234"},
|
||||||
|
"ipv4 unspecified": {"0.0.0.0", "http://127.0.0.1:11434"},
|
||||||
|
"ipv4 unspecified + port": {"0.0.0.0:1234", "http://127.0.0.1:1234"},
|
||||||
|
"ipv6 unspecified": {"[::]", "http://[::1]:11434"},
|
||||||
|
"ipv6 unspecified + port": {"[::]:1234", "http://[::1]:1234"},
|
||||||
|
"ipv6 localhost": {"[::1]", "http://[::1]:11434"},
|
||||||
|
"ipv6 localhost + port": {"[::1]:1234", "http://[::1]:1234"},
|
||||||
|
"specific address": {"192.168.1.5", "http://192.168.1.5:11434"},
|
||||||
|
"specific address + port": {"192.168.1.5:8080", "http://192.168.1.5:8080"},
|
||||||
|
"hostname": {"example.com", "http://example.com:11434"},
|
||||||
|
"hostname and port": {"example.com:1234", "http://example.com:1234"},
|
||||||
|
"https unspecified + port": {"https://0.0.0.0:4321", "https://127.0.0.1:4321"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tt := range cases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_HOST", tt.value)
|
||||||
|
if host := ConnectableHost(); host.String() != tt.expect {
|
||||||
|
t.Errorf("%s: expected %s, got %s", name, tt.expect, host.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestOrigins(t *testing.T) {
|
func TestOrigins(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
value string
|
value string
|
||||||
|
|||||||
@@ -874,7 +874,7 @@ func (f GGML) SupportsFlashAttention() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if slices.Contains([]string{"gemma2"}, arch) {
|
if slices.Contains([]string{"gemma2", "grok"}, arch) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,4 +14,15 @@ The integration tests have 2 modes of operating.
|
|||||||
> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree.
|
> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree.
|
||||||
|
|
||||||
|
|
||||||
Many tests use a default small model suitable to run on many systems. You can override this default model by setting `OLLAMA_TEST_DEFAULT_MODEL`
|
## Testing a New Model
|
||||||
|
|
||||||
|
When implementing new model architecture, use `OLLAMA_TEST_MODEL` to run the
|
||||||
|
integration suite against your model.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build the binary first
|
||||||
|
go build .
|
||||||
|
|
||||||
|
# Run integration tests against it
|
||||||
|
OLLAMA_TEST_MODEL=mymodel go test -tags integration -v -count 1 -timeout 15m ./integration/
|
||||||
|
```
|
||||||
|
|||||||
@@ -48,9 +48,7 @@ func TestAPIGenerate(t *testing.T) {
|
|||||||
|
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
pullOrSkip(ctx, t, client, req.Model)
|
||||||
t.Fatalf("pull failed %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -151,7 +149,11 @@ func TestAPIGenerate(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate PS while we're at it...
|
// Validate PS while we're at it — skip for local-only models
|
||||||
|
// which may lack metadata fields like family, parameter_size, etc.
|
||||||
|
if testModel != "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
resp, err := client.ListRunning(ctx)
|
resp, err := client.ListRunning(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("list models API error: %s", err)
|
t.Fatalf("list models API error: %s", err)
|
||||||
@@ -208,9 +210,7 @@ func TestAPIChat(t *testing.T) {
|
|||||||
|
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
pullOrSkip(ctx, t, client, req.Model)
|
||||||
t.Fatalf("pull failed %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -311,6 +311,9 @@ func TestAPIChat(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAPIListModels(t *testing.T) {
|
func TestAPIListModels(t *testing.T) {
|
||||||
|
if testModel != "" {
|
||||||
|
t.Skip("skipping metadata test with model override")
|
||||||
|
}
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
@@ -361,6 +364,9 @@ func verifyModelDetails(t *testing.T, details api.ModelDetails) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAPIShowModel(t *testing.T) {
|
func TestAPIShowModel(t *testing.T) {
|
||||||
|
if testModel != "" {
|
||||||
|
t.Skip("skipping metadata test with model override")
|
||||||
|
}
|
||||||
modelName := "llama3.2"
|
modelName := "llama3.2"
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -400,6 +406,10 @@ func TestAPIShowModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAPIGenerateLogprobs(t *testing.T) {
|
func TestAPIGenerateLogprobs(t *testing.T) {
|
||||||
|
if testModel != "" {
|
||||||
|
// Logprobs requires runner support (e.g. llama.cpp has it, MLX does not).
|
||||||
|
t.Skip("logprobs not supported by all runners")
|
||||||
|
}
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@@ -513,6 +523,10 @@ func TestAPIGenerateLogprobs(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAPIChatLogprobs(t *testing.T) {
|
func TestAPIChatLogprobs(t *testing.T) {
|
||||||
|
if testModel != "" {
|
||||||
|
// Logprobs requires runner support (e.g. llama.cpp has it, MLX does not).
|
||||||
|
t.Skip("logprobs not supported by all runners")
|
||||||
|
}
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,9 @@ func TestBlueSky(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnicode(t *testing.T) {
|
func TestUnicode(t *testing.T) {
|
||||||
|
if testModel != "" {
|
||||||
|
t.Skip("uses hardcoded model, not applicable with model override")
|
||||||
|
}
|
||||||
skipUnderMinVRAM(t, 6)
|
skipUnderMinVRAM(t, 6)
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -59,9 +62,7 @@ func TestUnicode(t *testing.T) {
|
|||||||
}
|
}
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
pullOrSkip(ctx, t, client, req.Model)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
slog.Info("loading", "model", req.Model)
|
slog.Info("loading", "model", req.Model)
|
||||||
err := client.Generate(ctx, &api.GenerateRequest{Model: req.Model}, func(response api.GenerateResponse) error { return nil })
|
err := client.Generate(ctx, &api.GenerateRequest{Model: req.Model}, func(response api.GenerateResponse) error { return nil })
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -81,6 +82,9 @@ func TestUnicode(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestExtendedUnicodeOutput(t *testing.T) {
|
func TestExtendedUnicodeOutput(t *testing.T) {
|
||||||
|
if testModel != "" {
|
||||||
|
t.Skip("uses hardcoded model, not applicable with model override")
|
||||||
|
}
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
// Set up the test data
|
// Set up the test data
|
||||||
@@ -100,9 +104,7 @@ func TestExtendedUnicodeOutput(t *testing.T) {
|
|||||||
}
|
}
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
pullOrSkip(ctx, t, client, req.Model)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
DoChat(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second)
|
DoChat(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -148,15 +150,16 @@ func TestUnicodeModelDir(t *testing.T) {
|
|||||||
// TestNumPredict verifies that when num_predict is set, the model generates
|
// TestNumPredict verifies that when num_predict is set, the model generates
|
||||||
// exactly that many tokens. It uses logprobs to count the actual tokens output.
|
// exactly that many tokens. It uses logprobs to count the actual tokens output.
|
||||||
func TestNumPredict(t *testing.T) {
|
func TestNumPredict(t *testing.T) {
|
||||||
|
if testModel != "" {
|
||||||
|
t.Skip("uses hardcoded model, not applicable with model override")
|
||||||
|
}
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
if err := PullIfMissing(ctx, client, "qwen3:0.6b"); err != nil {
|
pullOrSkip(ctx, t, client, "qwen3:0.6b")
|
||||||
t.Fatalf("failed to pull model: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
Model: "qwen3:0.6b",
|
Model: "qwen3:0.6b",
|
||||||
|
|||||||
@@ -67,6 +67,9 @@ func TestConcurrentChat(t *testing.T) {
|
|||||||
// Stress the scheduler and attempt to load more models than will fit to cause thrashing
|
// Stress the scheduler and attempt to load more models than will fit to cause thrashing
|
||||||
// This test will always load at least 2 models even on CPU based systems
|
// This test will always load at least 2 models even on CPU based systems
|
||||||
func TestMultiModelStress(t *testing.T) {
|
func TestMultiModelStress(t *testing.T) {
|
||||||
|
if testModel != "" {
|
||||||
|
t.Skip("uses hardcoded models, not applicable with model override")
|
||||||
|
}
|
||||||
s := os.Getenv("OLLAMA_MAX_VRAM")
|
s := os.Getenv("OLLAMA_MAX_VRAM")
|
||||||
if s == "" {
|
if s == "" {
|
||||||
s = "0"
|
s = "0"
|
||||||
@@ -114,9 +117,7 @@ func TestMultiModelStress(t *testing.T) {
|
|||||||
|
|
||||||
// Make sure all the models are pulled before we get started
|
// Make sure all the models are pulled before we get started
|
||||||
for _, model := range chosenModels {
|
for _, model := range chosenModels {
|
||||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
pullOrSkip(ctx, t, client, model)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine how many models we can load in parallel before we exceed VRAM
|
// Determine how many models we can load in parallel before we exceed VRAM
|
||||||
|
|||||||
@@ -38,9 +38,7 @@ func TestLongInputContext(t *testing.T) {
|
|||||||
}
|
}
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
pullOrSkip(ctx, t, client, req.Model)
|
||||||
t.Fatalf("PullIfMissing failed: %v", err)
|
|
||||||
}
|
|
||||||
DoChat(ctx, t, client, req, []string{"russia", "german", "france", "england", "austria", "prussia", "europe", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
|
DoChat(ctx, t, client, req, []string{"russia", "german", "france", "england", "austria", "prussia", "europe", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,14 +68,15 @@ func TestContextExhaustion(t *testing.T) {
|
|||||||
}
|
}
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
pullOrSkip(ctx, t, client, req.Model)
|
||||||
t.Fatalf("PullIfMissing failed: %v", err)
|
|
||||||
}
|
|
||||||
DoChat(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water", "time", "travel", "world"}, 120*time.Second, 10*time.Second)
|
DoChat(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water", "time", "travel", "world"}, 120*time.Second, 10*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send multiple generate requests with prior context and ensure the response is coherant and expected
|
// Send multiple generate requests with prior context and ensure the response is coherant and expected
|
||||||
func TestParallelGenerateWithHistory(t *testing.T) {
|
func TestParallelGenerateWithHistory(t *testing.T) {
|
||||||
|
if testModel != "" {
|
||||||
|
t.Skip("uses hardcoded model, not applicable with model override")
|
||||||
|
}
|
||||||
modelName := "gpt-oss:20b"
|
modelName := "gpt-oss:20b"
|
||||||
req, resp := GenerateRequests()
|
req, resp := GenerateRequests()
|
||||||
numParallel := 2
|
numParallel := 2
|
||||||
@@ -133,6 +132,12 @@ func TestParallelGenerateWithHistory(t *testing.T) {
|
|||||||
|
|
||||||
// Send generate requests with prior context and ensure the response is coherant and expected
|
// Send generate requests with prior context and ensure the response is coherant and expected
|
||||||
func TestGenerateWithHistory(t *testing.T) {
|
func TestGenerateWithHistory(t *testing.T) {
|
||||||
|
if testModel != "" {
|
||||||
|
// The Generate API's Context field (token array continuation) is not
|
||||||
|
// supported by all runners (e.g. MLX). Chat history works; this is
|
||||||
|
// the only generate-specific continuation path.
|
||||||
|
t.Skip("generate context continuation not supported by all runners")
|
||||||
|
}
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
Model: smol,
|
Model: smol,
|
||||||
Prompt: rainbowPrompt,
|
Prompt: rainbowPrompt,
|
||||||
@@ -173,6 +178,9 @@ func TestGenerateWithHistory(t *testing.T) {
|
|||||||
|
|
||||||
// Send multiple chat requests with prior context and ensure the response is coherant and expected
|
// Send multiple chat requests with prior context and ensure the response is coherant and expected
|
||||||
func TestParallelChatWithHistory(t *testing.T) {
|
func TestParallelChatWithHistory(t *testing.T) {
|
||||||
|
if testModel != "" {
|
||||||
|
t.Skip("uses hardcoded model, not applicable with model override")
|
||||||
|
}
|
||||||
modelName := "gpt-oss:20b"
|
modelName := "gpt-oss:20b"
|
||||||
req, resp := ChatRequests()
|
req, resp := ChatRequests()
|
||||||
numParallel := 2
|
numParallel := 2
|
||||||
|
|||||||
@@ -78,8 +78,11 @@ func TestEmbedCosineDistanceCorrelation(t *testing.T) {
|
|||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
for _, model := range libraryEmbedModels {
|
for _, model := range testModels(libraryEmbedModels) {
|
||||||
t.Run(model, func(t *testing.T) {
|
t.Run(model, func(t *testing.T) {
|
||||||
|
if testModel != "" {
|
||||||
|
requireCapability(ctx, t, client, model, "embedding")
|
||||||
|
}
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
a string
|
a string
|
||||||
b string
|
b string
|
||||||
@@ -145,6 +148,9 @@ func TestEmbedCosineDistanceCorrelation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAllMiniLMEmbeddings(t *testing.T) {
|
func TestAllMiniLMEmbeddings(t *testing.T) {
|
||||||
|
if testModel != "" {
|
||||||
|
t.Skip("uses hardcoded model, not applicable with model override")
|
||||||
|
}
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
@@ -175,6 +181,9 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAllMiniLMEmbed(t *testing.T) {
|
func TestAllMiniLMEmbed(t *testing.T) {
|
||||||
|
if testModel != "" {
|
||||||
|
t.Skip("uses hardcoded model, not applicable with model override")
|
||||||
|
}
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
@@ -212,6 +221,9 @@ func TestAllMiniLMEmbed(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAllMiniLMBatchEmbed(t *testing.T) {
|
func TestAllMiniLMBatchEmbed(t *testing.T) {
|
||||||
|
if testModel != "" {
|
||||||
|
t.Skip("uses hardcoded model, not applicable with model override")
|
||||||
|
}
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
@@ -259,6 +271,9 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||||
|
if testModel != "" {
|
||||||
|
t.Skip("uses hardcoded model, not applicable with model override")
|
||||||
|
}
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
@@ -397,21 +412,13 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||||||
|
|
||||||
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
|
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
pullOrSkip(ctx, t, client, req.Model)
|
||||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return client.Embeddings(ctx, &req)
|
return client.Embeddings(ctx, &req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
pullOrSkip(ctx, t, client, req.Model)
|
||||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return client.Embed(ctx, &req)
|
return client.Embed(ctx, &req)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -426,9 +433,12 @@ func TestEmbedTruncation(t *testing.T) {
|
|||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
for _, model := range libraryEmbedModels {
|
for _, model := range testModels(libraryEmbedModels) {
|
||||||
model := model
|
model := model
|
||||||
t.Run(model, func(t *testing.T) {
|
t.Run(model, func(t *testing.T) {
|
||||||
|
if testModel != "" {
|
||||||
|
requireCapability(ctx, t, client, model, "embedding")
|
||||||
|
}
|
||||||
// Check if we're running out of time (reserve 20s for current model)
|
// Check if we're running out of time (reserve 20s for current model)
|
||||||
if deadline, ok := t.Deadline(); ok && time.Until(deadline) < 20*time.Second {
|
if deadline, ok := t.Deadline(); ok && time.Until(deadline) < 20*time.Second {
|
||||||
t.Skip("skipping remaining tests to avoid timeout")
|
t.Skip("skipping remaining tests to avoid timeout")
|
||||||
@@ -494,9 +504,12 @@ func TestEmbedLargeInput(t *testing.T) {
|
|||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
for _, model := range libraryEmbedModels {
|
for _, model := range testModels(libraryEmbedModels) {
|
||||||
model := model
|
model := model
|
||||||
t.Run(model, func(t *testing.T) {
|
t.Run(model, func(t *testing.T) {
|
||||||
|
if testModel != "" {
|
||||||
|
requireCapability(ctx, t, client, model, "embedding")
|
||||||
|
}
|
||||||
mctx, mcancel := context.WithTimeout(ctx, 2*time.Minute)
|
mctx, mcancel := context.WithTimeout(ctx, 2*time.Minute)
|
||||||
defer mcancel()
|
defer mcancel()
|
||||||
|
|
||||||
@@ -559,9 +572,12 @@ func TestEmbedStatusCode(t *testing.T) {
|
|||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
for _, model := range libraryEmbedModels {
|
for _, model := range testModels(libraryEmbedModels) {
|
||||||
model := model
|
model := model
|
||||||
t.Run(model, func(t *testing.T) {
|
t.Run(model, func(t *testing.T) {
|
||||||
|
if testModel != "" {
|
||||||
|
requireCapability(ctx, t, client, model, "embedding")
|
||||||
|
}
|
||||||
// Check if we're running out of time (reserve 20s for current model)
|
// Check if we're running out of time (reserve 20s for current model)
|
||||||
if deadline, ok := t.Deadline(); ok && time.Until(deadline) < 20*time.Second {
|
if deadline, ok := t.Deadline(); ok && time.Until(deadline) < 20*time.Second {
|
||||||
t.Skip("skipping remaining tests to avoid timeout")
|
t.Skip("skipping remaining tests to avoid timeout")
|
||||||
@@ -571,9 +587,7 @@ func TestEmbedStatusCode(t *testing.T) {
|
|||||||
defer mcancel()
|
defer mcancel()
|
||||||
|
|
||||||
// Pull the model if needed
|
// Pull the model if needed
|
||||||
if err := PullIfMissing(mctx, client, model); err != nil {
|
pullOrSkip(mctx, t, client, model)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("truncation error status code", func(t *testing.T) {
|
t.Run("truncation error status code", func(t *testing.T) {
|
||||||
truncFalse := false
|
truncFalse := false
|
||||||
|
|||||||
@@ -14,6 +14,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestImageGeneration(t *testing.T) {
|
func TestImageGeneration(t *testing.T) {
|
||||||
|
if testModel != "" {
|
||||||
|
t.Skip("uses hardcoded models, not applicable with model override")
|
||||||
|
}
|
||||||
skipUnderMinVRAM(t, 8)
|
skipUnderMinVRAM(t, 8)
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
@@ -41,12 +44,8 @@ func TestImageGeneration(t *testing.T) {
|
|||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
// Pull both models
|
// Pull both models
|
||||||
if err := PullIfMissing(ctx, client, tc.imageGenModel); err != nil {
|
pullOrSkip(ctx, t, client, tc.imageGenModel)
|
||||||
t.Fatalf("failed to pull image gen model: %v", err)
|
pullOrSkip(ctx, t, client, tc.visionModel)
|
||||||
}
|
|
||||||
if err := PullIfMissing(ctx, client, tc.visionModel); err != nil {
|
|
||||||
t.Fatalf("failed to pull vision model: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate the image
|
// Generate the image
|
||||||
t.Logf("Generating image with prompt: %s", tc.prompt)
|
t.Logf("Generating image with prompt: %s", tc.prompt)
|
||||||
|
|||||||
@@ -24,15 +24,12 @@ func TestLibraryModelsChat(t *testing.T) {
|
|||||||
defer cleanup()
|
defer cleanup()
|
||||||
targetArch := os.Getenv("OLLAMA_TEST_ARCHITECTURE")
|
targetArch := os.Getenv("OLLAMA_TEST_ARCHITECTURE")
|
||||||
|
|
||||||
chatModels := libraryChatModels
|
for _, model := range testModels(libraryChatModels) {
|
||||||
for _, model := range chatModels {
|
|
||||||
t.Run(model, func(t *testing.T) {
|
t.Run(model, func(t *testing.T) {
|
||||||
if time.Now().Sub(started) > softTimeout {
|
if time.Now().Sub(started) > softTimeout {
|
||||||
t.Skip("skipping remaining tests to avoid excessive runtime")
|
t.Skip("skipping remaining tests to avoid excessive runtime")
|
||||||
}
|
}
|
||||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
pullOrSkip(ctx, t, client, model)
|
||||||
t.Fatalf("pull failed %s", err)
|
|
||||||
}
|
|
||||||
if targetArch != "" {
|
if targetArch != "" {
|
||||||
resp, err := client.Show(ctx, &api.ShowRequest{Name: model})
|
resp, err := client.Show(ctx, &api.ShowRequest{Name: model})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -13,39 +13,35 @@ import (
|
|||||||
|
|
||||||
func TestVisionModels(t *testing.T) {
|
func TestVisionModels(t *testing.T) {
|
||||||
skipUnderMinVRAM(t, 6)
|
skipUnderMinVRAM(t, 6)
|
||||||
type testCase struct {
|
|
||||||
model string
|
defaultVisionModels := []string{
|
||||||
}
|
"qwen2.5vl",
|
||||||
testCases := []testCase{
|
"llama3.2-vision",
|
||||||
{
|
"gemma3",
|
||||||
model: "qwen2.5vl",
|
"qwen3-vl:8b",
|
||||||
},
|
"qwen3-vl:30b",
|
||||||
{
|
"ministral-3",
|
||||||
model: "llama3.2-vision",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
model: "gemma3",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
model: "qwen3-vl:8b",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Qwen 3 VL mixture of experts
|
|
||||||
model: "qwen3-vl:30b",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
model: "ministral-3",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, v := range testCases {
|
for _, model := range testModels(defaultVisionModels) {
|
||||||
t.Run(v.model, func(t *testing.T) {
|
t.Run(model, func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
if testModel != "" {
|
||||||
|
requireCapability(ctx, t, client, model, "vision")
|
||||||
|
}
|
||||||
|
|
||||||
|
pullOrSkip(ctx, t, client, model)
|
||||||
|
|
||||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
req := api.ChatRequest{
|
req := api.ChatRequest{
|
||||||
Model: v.model,
|
Model: model,
|
||||||
Messages: []api.Message{
|
Messages: []api.Message{
|
||||||
{
|
{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
@@ -61,16 +57,7 @@ func TestVisionModels(t *testing.T) {
|
|||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
|
||||||
defer cancel()
|
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
|
||||||
|
|
||||||
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
|
|
||||||
resp := "the ollam"
|
|
||||||
defer cleanup()
|
|
||||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
// Preload to skip if we're less than 80% on GPU to avoid extremely slow tests
|
// Preload to skip if we're less than 80% on GPU to avoid extremely slow tests
|
||||||
err = client.Generate(ctx, &api.GenerateRequest{Model: req.Model}, func(response api.GenerateResponse) error { return nil })
|
err = client.Generate(ctx, &api.GenerateRequest{Model: req.Model}, func(response api.GenerateResponse) error { return nil })
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -78,13 +65,17 @@ func TestVisionModels(t *testing.T) {
|
|||||||
}
|
}
|
||||||
skipIfNotGPULoaded(ctx, t, client, req.Model, 80)
|
skipIfNotGPULoaded(ctx, t, client, req.Model, 80)
|
||||||
|
|
||||||
|
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
|
||||||
// llava models on CPU can be quite slow to start
|
// llava models on CPU can be quite slow to start
|
||||||
DoChat(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
DoChat(ctx, t, client, req, []string{"the ollam"}, 240*time.Second, 30*time.Second)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIntegrationSplitBatch(t *testing.T) {
|
func TestIntegrationSplitBatch(t *testing.T) {
|
||||||
|
if testModel != "" {
|
||||||
|
t.Skip("uses hardcoded model, not applicable with model override")
|
||||||
|
}
|
||||||
skipUnderMinVRAM(t, 6)
|
skipUnderMinVRAM(t, 6)
|
||||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -111,9 +102,7 @@ func TestIntegrationSplitBatch(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
pullOrSkip(ctx, t, client, req.Model)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
// llava models on CPU can be quite slow to start,
|
// llava models on CPU can be quite slow to start,
|
||||||
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
|
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,9 +45,7 @@ func TestMaxQueue(t *testing.T) {
|
|||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
pullOrSkip(ctx, t, client, req.Model)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Context for the worker threads so we can shut them down
|
// Context for the worker threads so we can shut them down
|
||||||
// embedCtx, embedCancel := context.WithCancel(ctx)
|
// embedCtx, embedCancel := context.WithCancel(ctx)
|
||||||
|
|||||||
@@ -46,14 +46,12 @@ func TestModelsChat(t *testing.T) {
|
|||||||
chatModels = append(ollamaEngineChatModels, llamaRunnerChatModels...)
|
chatModels = append(ollamaEngineChatModels, llamaRunnerChatModels...)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, model := range chatModels {
|
for _, model := range testModels(chatModels) {
|
||||||
t.Run(model, func(t *testing.T) {
|
t.Run(model, func(t *testing.T) {
|
||||||
if time.Now().Sub(started) > softTimeout {
|
if time.Now().Sub(started) > softTimeout {
|
||||||
t.Skip("skipping remaining tests to avoid excessive runtime")
|
t.Skip("skipping remaining tests to avoid excessive runtime")
|
||||||
}
|
}
|
||||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
pullOrSkip(ctx, t, client, model)
|
||||||
t.Fatalf("pull failed %s", err)
|
|
||||||
}
|
|
||||||
if maxVram > 0 {
|
if maxVram > 0 {
|
||||||
resp, err := client.List(ctx)
|
resp, err := client.List(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -133,14 +131,15 @@ func TestModelsEmbed(t *testing.T) {
|
|||||||
t.Fatalf("failed to load test data: %s", err)
|
t.Fatalf("failed to load test data: %s", err)
|
||||||
}
|
}
|
||||||
for model, expected := range testCase {
|
for model, expected := range testCase {
|
||||||
|
if testModel != "" && model != testModel {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
t.Run(model, func(t *testing.T) {
|
t.Run(model, func(t *testing.T) {
|
||||||
if time.Now().Sub(started) > softTimeout {
|
if time.Now().Sub(started) > softTimeout {
|
||||||
t.Skip("skipping remaining tests to avoid excessive runtime")
|
t.Skip("skipping remaining tests to avoid excessive runtime")
|
||||||
}
|
}
|
||||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
pullOrSkip(ctx, t, client, model)
|
||||||
t.Fatalf("pull failed %s", err)
|
|
||||||
}
|
|
||||||
if maxVram > 0 {
|
if maxVram > 0 {
|
||||||
resp, err := client.List(ctx)
|
resp, err := client.List(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -87,9 +87,7 @@ func doModelPerfTest(t *testing.T, chatModels []string) {
|
|||||||
if time.Now().Sub(started) > softTimeout {
|
if time.Now().Sub(started) > softTimeout {
|
||||||
t.Skip("skipping remaining tests to avoid excessive runtime")
|
t.Skip("skipping remaining tests to avoid excessive runtime")
|
||||||
}
|
}
|
||||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
pullOrSkip(ctx, t, client, model)
|
||||||
t.Fatalf("pull failed %s", err)
|
|
||||||
}
|
|
||||||
var maxContext int
|
var maxContext int
|
||||||
|
|
||||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
|
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
|
||||||
|
|||||||
@@ -33,9 +33,7 @@ func TestQuantization(t *testing.T) {
|
|||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
for _, base := range sourceModels {
|
for _, base := range sourceModels {
|
||||||
if err := PullIfMissing(ctx, client, base); err != nil {
|
pullOrSkip(ctx, t, client, base)
|
||||||
t.Fatalf("pull failed %s", err)
|
|
||||||
}
|
|
||||||
for _, quant := range quantizations {
|
for _, quant := range quantizations {
|
||||||
newName := fmt.Sprintf("%s__%s", base, quant)
|
newName := fmt.Sprintf("%s__%s", base, quant)
|
||||||
t.Run(newName, func(t *testing.T) {
|
t.Run(newName, func(t *testing.T) {
|
||||||
|
|||||||
523
integration/tools_stress_test.go
Normal file
@@ -0,0 +1,523 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestAPIToolCallingStress tests tool calling with complex, agent-style prompts
|
||||||
|
// that include large system messages, multiple tools, and multi-turn conversations.
|
||||||
|
// This catches cache corruption and parser bugs that simple tool tests miss.
|
||||||
|
func TestAPIToolCallingStress(t *testing.T) {
|
||||||
|
initialTimeout := 120 * time.Second
|
||||||
|
streamTimeout := 120 * time.Second
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
minVRAM := map[string]uint64{
|
||||||
|
"qwen3-vl": 16,
|
||||||
|
"gpt-oss:20b": 16,
|
||||||
|
"gpt-oss:120b": 70,
|
||||||
|
"qwen3": 6,
|
||||||
|
"llama3.1": 8,
|
||||||
|
"llama3.2": 4,
|
||||||
|
"mistral": 6,
|
||||||
|
"qwen2.5": 6,
|
||||||
|
"qwen2": 6,
|
||||||
|
"ministral-3": 20,
|
||||||
|
"mistral-nemo": 9,
|
||||||
|
"mistral-small": 16,
|
||||||
|
"mixtral:8x22b": 80,
|
||||||
|
"qwq": 20,
|
||||||
|
"granite3.3": 7,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Models that don't reliably produce tool calls with complex/multi-tool prompts.
|
||||||
|
// The stress test uses a large system prompt with many tools, simulating coding agents.
|
||||||
|
// Some models are too small, too slow, or not designed for this use case.
|
||||||
|
skipModels := map[string]string{
|
||||||
|
"lfm2.5-thinking": "returns text instead of tool calls with complex system prompts",
|
||||||
|
"qwen3-vl": "vision model, extremely slow with complex tool prompts",
|
||||||
|
"llama3.2": "3B model too small for reliable multi-tool agent prompts",
|
||||||
|
"mistral": "7B v0.3 returns text instead of tool calls with complex prompts",
|
||||||
|
"mixtral:8x22b": "returns text instead of tool calls with complex prompts",
|
||||||
|
"qwen2": "returns text instead of tool calls with complex prompts",
|
||||||
|
"granite3.3": "returns text instead of tool calls with complex prompts",
|
||||||
|
}
|
||||||
|
|
||||||
|
models := testModels(libraryToolsModels)
|
||||||
|
|
||||||
|
for _, model := range models {
|
||||||
|
t.Run(model, func(t *testing.T) {
|
||||||
|
// Skip known-bad models unless explicitly requested via env var
|
||||||
|
if reason, ok := skipModels[model]; ok && testModel == "" {
|
||||||
|
t.Skipf("skipping: %s", reason)
|
||||||
|
}
|
||||||
|
if testModel != "" {
|
||||||
|
requireCapability(ctx, t, client, model, "tools")
|
||||||
|
}
|
||||||
|
if v, ok := minVRAM[model]; ok {
|
||||||
|
skipUnderMinVRAM(t, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
pullOrSkip(ctx, t, client, model)
|
||||||
|
|
||||||
|
tools := stressTestTools()
|
||||||
|
|
||||||
|
// Large system prompt that mimics real coding agents (opencode, Claude Code, etc.)
|
||||||
|
// This is intentionally very long (~5000+ tokens) to match the prompt sizes that
|
||||||
|
// real coding agents send. The combination of a large system prompt, many tools,
|
||||||
|
// and thinking mode is what triggers failures in some models.
|
||||||
|
systemPrompt := stressTestSystemPrompt()
|
||||||
|
|
||||||
|
// Test 1: First request (fresh prompt processing)
|
||||||
|
// Use a direct prompt that tells the model exactly what tool to use,
|
||||||
|
// reducing the chance it asks for clarification instead.
|
||||||
|
t.Run("first_request", func(t *testing.T) {
|
||||||
|
testToolCall(t, ctx, client, model, systemPrompt, tools,
|
||||||
|
"Run git diff main to review the code changes on the current branch.",
|
||||||
|
initialTimeout, streamTimeout)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test 2: Repeat with same prompt (tests cache reuse)
|
||||||
|
t.Run("cached_request", func(t *testing.T) {
|
||||||
|
testToolCall(t, ctx, client, model, systemPrompt, tools,
|
||||||
|
"Run git diff main to review the code changes on the current branch.",
|
||||||
|
initialTimeout, streamTimeout)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test 3: Different user message (partial cache hit)
|
||||||
|
t.Run("different_user_message", func(t *testing.T) {
|
||||||
|
testToolCall(t, ctx, client, model, systemPrompt, tools,
|
||||||
|
"Read the file at ./go.mod and tell me what dependencies we have.",
|
||||||
|
initialTimeout, streamTimeout)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test 4: Multi-turn with tool response
|
||||||
|
t.Run("multi_turn", func(t *testing.T) {
|
||||||
|
testToolCallMultiTurn(t, ctx, client, model, systemPrompt, tools,
|
||||||
|
initialTimeout, streamTimeout)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTool(name, description string, required []string, props map[string]api.ToolProperty) api.Tool {
|
||||||
|
return api.Tool{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: name,
|
||||||
|
Description: description,
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Required: required,
|
||||||
|
Properties: testPropsMap(props),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// stressTestTools returns a set of tools matching the scale and verbosity of
|
||||||
|
// real coding agent tool definitions (opencode, Claude Code, etc.). The tool
|
||||||
|
// descriptions are intentionally verbose to match real-world prompt sizes.
|
||||||
|
func stressTestTools() []api.Tool {
|
||||||
|
return []api.Tool{
|
||||||
|
newTool("bash", "Executes a given bash command in a persistent shell session with optional timeout, ensuring proper handling and security measures. All commands run in the working directory by default. Before executing the command, verify that the parent directory exists. Always quote file paths that contain spaces with double quotes. After ensuring proper quoting, execute the command and capture the output. Avoid using bash with find, grep, cat, head, tail, sed, awk, or echo commands unless explicitly instructed. Instead, always prefer using the dedicated tools for these commands. When issuing multiple commands, if they are independent and can run in parallel, make multiple tool calls in a single message.",
|
||||||
|
[]string{"command"},
|
||||||
|
map[string]api.ToolProperty{
|
||||||
|
"command": {Type: api.PropertyType{"string"}, Description: "The bash command to execute"},
|
||||||
|
"description": {Type: api.PropertyType{"string"}, Description: "Short description of what this command does in 5-10 words"},
|
||||||
|
"timeout": {Type: api.PropertyType{"number"}, Description: "Optional timeout in milliseconds. If not specified, commands will time out after 120000ms (2 minutes)"},
|
||||||
|
}),
|
||||||
|
newTool("read", "Read a file or directory from the local filesystem. If the path does not exist, an error is returned. By default, this tool returns up to 2000 lines from the start of the file. The offset parameter is the line number to start from (1-indexed). To read later sections, call this tool again with a larger offset. Use the grep tool to find specific content in large files or files with long lines. If you are unsure of the correct file path, use the glob tool to look up filenames by glob pattern. Contents are returned with each line prefixed by its line number. Any line longer than 2000 characters is truncated. Call this tool in parallel when you know there are multiple files you want to read. Avoid tiny repeated slices (30 line chunks). If you need more context, read a larger window. This tool can read image files and PDFs and return them as file attachments.",
|
||||||
|
[]string{"path"},
|
||||||
|
map[string]api.ToolProperty{
|
||||||
|
"path": {Type: api.PropertyType{"string"}, Description: "The absolute path to the file to read"},
|
||||||
|
"offset": {Type: api.PropertyType{"number"}, Description: "Line number to start reading from (1-indexed)"},
|
||||||
|
"limit": {Type: api.PropertyType{"number"}, Description: "Maximum number of lines to read"},
|
||||||
|
}),
|
||||||
|
newTool("glob", "Fast file pattern matching tool that works with any codebase size. Supports glob patterns like '**/*.js' or 'src/**/*.ts'. Returns matching file paths sorted by modification time. Use this tool when you need to find files by name patterns. When you are doing an open-ended search that may require multiple rounds of globbing and grepping, use the task tool instead. You have the capability to call multiple tools in a single response. It is always better to speculatively perform multiple searches as a batch that are potentially useful.",
|
||||||
|
[]string{"pattern"},
|
||||||
|
map[string]api.ToolProperty{
|
||||||
|
"pattern": {Type: api.PropertyType{"string"}, Description: "The glob pattern to match files against"},
|
||||||
|
"path": {Type: api.PropertyType{"string"}, Description: "The directory to search in"},
|
||||||
|
}),
|
||||||
|
newTool("grep", "Fast content search tool that works with any codebase size. Searches file contents using regular expressions. Supports full regex syntax (eg. 'log.*Error', 'function\\s+\\w+'). Filter files by pattern with the include parameter (eg. '*.js', '*.{ts,tsx}'). Returns file paths and line numbers with at least one match sorted by modification time. Use this tool when you need to find files containing specific patterns. If you need to identify or count the number of matches within files, use the bash tool with rg (ripgrep) directly. When you are doing an open-ended search that may require multiple rounds of globbing and grepping, use the task tool instead.",
|
||||||
|
[]string{"pattern"},
|
||||||
|
map[string]api.ToolProperty{
|
||||||
|
"pattern": {Type: api.PropertyType{"string"}, Description: "The regex pattern to search for in file contents"},
|
||||||
|
"path": {Type: api.PropertyType{"string"}, Description: "The directory to search in"},
|
||||||
|
"include": {Type: api.PropertyType{"string"}, Description: "File pattern to include (eg. '*.js', '*.{ts,tsx}')"},
|
||||||
|
}),
|
||||||
|
newTool("edit", "Performs exact string replacements in files. You must use your read tool at least once in the conversation before editing. This tool will error if you attempt an edit without reading the file. When editing text from read tool output, ensure you preserve the exact indentation (tabs/spaces) as it appears after the line number prefix. Always prefer editing existing files in the codebase. Never write new files unless explicitly required. Only use emojis if the user explicitly requests it. The edit will fail if oldString is not found in the file. The edit will fail if oldString is found multiple times in the file. Use replaceAll for replacing and renaming strings across the file.",
|
||||||
|
[]string{"path", "old_string", "new_string"},
|
||||||
|
map[string]api.ToolProperty{
|
||||||
|
"path": {Type: api.PropertyType{"string"}, Description: "The absolute path to the file to modify"},
|
||||||
|
"old_string": {Type: api.PropertyType{"string"}, Description: "The text to replace (must be unique in the file)"},
|
||||||
|
"new_string": {Type: api.PropertyType{"string"}, Description: "The replacement text"},
|
||||||
|
}),
|
||||||
|
newTool("write", "Writes a file to the local filesystem. This tool will overwrite the existing file if there is one at the provided path. If this is an existing file, you must use the read tool first to read the file contents. This tool will fail if you did not read the file first. Always prefer editing existing files in the codebase. Never write new files unless explicitly required. Never proactively create documentation files or README files. Only create documentation files if explicitly requested by the user.",
|
||||||
|
[]string{"path", "content"},
|
||||||
|
map[string]api.ToolProperty{
|
||||||
|
"path": {Type: api.PropertyType{"string"}, Description: "The absolute path to the file to write"},
|
||||||
|
"content": {Type: api.PropertyType{"string"}, Description: "The content to write to the file"},
|
||||||
|
}),
|
||||||
|
newTool("question", "Use this tool when you need to ask the user questions during execution. This allows you to gather user preferences or requirements, clarify ambiguous instructions, get decisions on implementation choices as you work, and offer choices to the user about what direction to take. When custom is enabled (default), a 'Type your own answer' option is added automatically. Answers are returned as arrays of labels. Set multiple to true to allow selecting more than one answer. If you recommend a specific option, make that the first option in the list and add '(Recommended)' at the end of the label.",
|
||||||
|
[]string{"questions"},
|
||||||
|
map[string]api.ToolProperty{
|
||||||
|
"questions": {Type: api.PropertyType{"string"}, Description: "The question to ask the user"},
|
||||||
|
}),
|
||||||
|
newTool("task", "Launch a new agent to handle complex, multistep tasks autonomously. Available agent types: general (general-purpose agent for researching complex questions and executing multi-step tasks, use this to execute multiple units of work in parallel) and explore (fast agent specialized for exploring codebases, use this when you need to quickly find files by patterns, search code for keywords, or answer questions about the codebase). Launch multiple agents concurrently whenever possible to maximize performance. When the agent is done, it will return a single message back to you. Each agent invocation starts with a fresh context unless you provide task_id to resume the same subagent session.",
|
||||||
|
[]string{"description", "prompt", "subagent_type"},
|
||||||
|
map[string]api.ToolProperty{
|
||||||
|
"description": {Type: api.PropertyType{"string"}, Description: "A short (3-5 word) description of the task"},
|
||||||
|
"prompt": {Type: api.PropertyType{"string"}, Description: "The task for the agent to perform"},
|
||||||
|
"subagent_type": {Type: api.PropertyType{"string"}, Description: "The type of specialized agent to use (general or explore)"},
|
||||||
|
}),
|
||||||
|
newTool("webfetch", "Fetches content from a specified URL. Takes a URL and optional format as input. Fetches the URL content, converts to requested format (markdown by default). Returns the content in the specified format. Use this tool when you need to retrieve and analyze web content. The URL must be a fully-formed valid URL. HTTP URLs will be automatically upgraded to HTTPS. Format options: markdown (default), text, or html. This tool is read-only and does not modify any files. Results may be summarized if the content is very large.",
|
||||||
|
[]string{"url", "format"},
|
||||||
|
map[string]api.ToolProperty{
|
||||||
|
"url": {Type: api.PropertyType{"string"}, Description: "The URL to fetch content from"},
|
||||||
|
"format": {Type: api.PropertyType{"string"}, Description: "Output format: markdown (default), text, or html"},
|
||||||
|
}),
|
||||||
|
newTool("todowrite", "Use this tool to create and manage a structured task list for your current coding session. This helps you track progress, organize complex tasks, and demonstrate thoroughness to the user. Use this tool proactively when handling complex multistep tasks, non-trivial and complex tasks, when the user explicitly requests a todo list, when the user provides multiple tasks, after receiving new instructions, and after completing a task. Do not use this tool when there is only a single straightforward task, the task is trivial, the task can be completed in less than 3 steps, or the task is purely conversational.",
|
||||||
|
[]string{"todos"},
|
||||||
|
map[string]api.ToolProperty{
|
||||||
|
"todos": {Type: api.PropertyType{"string"}, Description: "JSON array of todo items with id, title, and status fields"},
|
||||||
|
}),
|
||||||
|
newTool("skill", "Load a specialized skill that provides domain-specific instructions and workflows. Skills contain curated prompts and tool configurations for specific tasks like code review, testing, deployment, and documentation. Use this tool when the user's request matches an available skill description.",
|
||||||
|
[]string{"name"},
|
||||||
|
map[string]api.ToolProperty{
|
||||||
|
"name": {Type: api.PropertyType{"string"}, Description: "The name of the skill to load"},
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// stressTestSystemPrompt returns a system prompt that matches the scale and
|
||||||
|
// content of real coding agent system prompts (~5000+ tokens). This is based
|
||||||
|
// on actual prompts captured from opencode sessions. The prompt size combined
|
||||||
|
// with many tool declarations is what pushes models past their effective
|
||||||
|
// context handling and triggers tag leakage / broken tool calls.
|
||||||
|
func stressTestSystemPrompt() string {
|
||||||
|
return `You are opencode, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user.
|
||||||
|
|
||||||
|
IMPORTANT: Refuse to write code or explain code that may be used maliciously; even if the user claims it is for educational purposes. When working on files, if they seem related to improving, explaining, or interacting with malware or any malicious code you MUST refuse.
|
||||||
|
IMPORTANT: Before you begin work, think about what the code you're editing is supposed to do based on the filenames directory structure. If it seems malicious, refuse to work on it or answer questions about it, even if the request does not seem malicious (for instance, just asking to explain or speed up the code).
|
||||||
|
IMPORTANT: You must NEVER generate or guess URLs for the user unless you are confident that the URLs are for helping the user with programming. You may use URLs provided by the user in their messages or local files.
|
||||||
|
|
||||||
|
If the user asks for help or wants to give feedback inform them of the following:
|
||||||
|
- /help: Get help with using opencode
|
||||||
|
- To give feedback, users should report the issue at https://github.com/sampleorg/opencode/issues
|
||||||
|
|
||||||
|
# Tone and style
|
||||||
|
You should be concise, direct, and to the point. When you run a non-trivial bash command, you should explain what the command does and why you are running it, to make sure the user understands what you are doing (this is especially important when you are running a command that will make changes to the user's system).
|
||||||
|
Remember that your output will be displayed on a command line interface. Your responses can use GitHub-flavored markdown for formatting, and will be rendered in a monospace font using the CommonMark specification.
|
||||||
|
Output text to communicate with the user; all text you output outside of tool use is displayed to the user. Only use tools to complete tasks. Never use tools like Bash or code comments as means to communicate with the user during the session.
|
||||||
|
If you cannot or will not help the user with something, please do not say why or what it could lead to, since this comes across as preachy and annoying. Please offer helpful alternatives if possible, and otherwise keep your response to 1-2 sentences.
|
||||||
|
Only use emojis if the user explicitly requests it. Avoid using emojis in all communication unless asked.
|
||||||
|
IMPORTANT: You should minimize output tokens as much as possible while maintaining helpfulness, quality, and accuracy. Only address the specific query or task at hand, avoiding tangential information unless absolutely critical for completing the request. If you can answer in 1-3 sentences or a short paragraph, please do.
|
||||||
|
IMPORTANT: You should NOT answer with unnecessary preamble or postamble (such as explaining your code or summarizing your action), unless the user asks you to.
|
||||||
|
IMPORTANT: Keep your responses short, since they will be displayed on a command line interface. You MUST answer concisely with fewer than 4 lines (not including tool use or code generation), unless user asks for detail. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...". Here are some examples to demonstrate appropriate verbosity:
|
||||||
|
|
||||||
|
user: 2 + 2
|
||||||
|
assistant: 4
|
||||||
|
|
||||||
|
user: what is 2+2?
|
||||||
|
assistant: 4
|
||||||
|
|
||||||
|
user: is 11 a prime number?
|
||||||
|
assistant: Yes
|
||||||
|
|
||||||
|
user: what command should I run to list files in the current directory?
|
||||||
|
assistant: ls
|
||||||
|
|
||||||
|
user: what command should I run to watch files in the current directory?
|
||||||
|
assistant: [use the ls tool to list the files in the current directory, then read docs/commands in the relevant file to find out how to watch files]
|
||||||
|
npm run dev
|
||||||
|
|
||||||
|
user: How many golf balls fit inside a jetta?
|
||||||
|
assistant: 150000
|
||||||
|
|
||||||
|
user: what files are in the directory src/?
|
||||||
|
assistant: [runs ls and sees foo.c, bar.c, baz.c]
|
||||||
|
user: which file contains the implementation of foo?
|
||||||
|
assistant: src/foo.c
|
||||||
|
|
||||||
|
user: write tests for new feature
|
||||||
|
assistant: [uses grep and glob search tools to find where similar tests are defined, uses concurrent read file tool use blocks in one tool call to read relevant files at the same time, uses edit file tool to write new tests]
|
||||||
|
|
||||||
|
# Proactiveness
|
||||||
|
You are allowed to be proactive, but only when the user asks you to do something. You should strive to strike a balance between:
|
||||||
|
1. Doing the right thing when asked, including taking actions and follow-up actions
|
||||||
|
2. Not surprising the user with actions you take without asking
|
||||||
|
For example, if the user asks you how to approach something, you should do your best to answer their question first, and not immediately jump into taking actions.
|
||||||
|
3. Do not add additional code explanation summary unless requested by the user. After working on a file, just stop, rather than providing an explanation of what you did.
|
||||||
|
|
||||||
|
# Following conventions
|
||||||
|
When making changes to files, first understand the file's code conventions. Mimic code style, use existing libraries and utilities, and follow existing patterns.
|
||||||
|
- NEVER assume that a given library is available, even if it is well known. Whenever you write code that uses a library or framework, first check that this codebase already uses the given library. For example, you might look at neighboring files, or check the package.json (or cargo.toml, and so on depending on the language).
|
||||||
|
- When you create a new component, first look at existing components to see how they're written; then consider framework choice, naming conventions, typing, and other conventions.
|
||||||
|
- When you edit a piece of code, first look at the code's surrounding context (especially its imports) to understand the code's choice of frameworks and libraries. Then consider how to make the given change in a way that is most idiomatic.
|
||||||
|
- Always follow security best practices. Never introduce code that exposes or logs secrets and keys. Never commit secrets or keys to the repository.
|
||||||
|
|
||||||
|
# Code style
|
||||||
|
- IMPORTANT: DO NOT ADD ANY COMMENTS unless asked
|
||||||
|
|
||||||
|
# Doing tasks
|
||||||
|
The user will primarily request you perform software engineering tasks. This includes solving bugs, adding new functionality, refactoring code, explaining code, and more. For these tasks the following steps are recommended:
|
||||||
|
- Use the available search tools to understand the codebase and the user's query. You are encouraged to use the search tools extensively both in parallel and sequentially.
|
||||||
|
- Implement the solution using all tools available to you
|
||||||
|
- Verify the solution if possible with tests. NEVER assume specific test framework or test script. Check the README or search codebase to determine the testing approach.
|
||||||
|
- VERY IMPORTANT: When you have completed a task, you MUST run the lint and typecheck commands (e.g. npm run lint, npm run typecheck, ruff, etc.) with Bash if they were provided to you to ensure your code is correct. If you are unable to find the correct command, ask the user for the command to run and if they supply it, proactively suggest writing it to AGENTS.md so that you will know to run it next time.
|
||||||
|
NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTANT to only commit when explicitly asked, otherwise the user will feel that you are being too proactive.
|
||||||
|
|
||||||
|
# Tool usage policy
|
||||||
|
- When doing file search, prefer to use the Task tool in order to reduce context usage.
|
||||||
|
- You have the capability to call multiple tools in a single response. When multiple independent pieces of information are requested, batch your tool calls together for optimal performance. When making multiple bash tool calls, you MUST send a single message with multiple tools calls to run the calls in parallel.
|
||||||
|
|
||||||
|
You MUST answer concisely with fewer than 4 lines of text (not including tool use or code generation), unless user asks for detail.
|
||||||
|
|
||||||
|
# Code References
|
||||||
|
When referencing specific functions or pieces of code include the pattern file_path:line_number to allow the user to easily navigate to the source code location.
|
||||||
|
|
||||||
|
# Git workflow
|
||||||
|
When working with git:
|
||||||
|
- Create descriptive commit messages that explain WHY not just WHAT
|
||||||
|
- Use conventional commit format: feat:, fix:, refactor:, docs:, test:, chore:
|
||||||
|
- Check git status before and after operations
|
||||||
|
- Never force push to main/master
|
||||||
|
- Review diffs before committing
|
||||||
|
- NEVER update the git config
|
||||||
|
- NEVER run destructive/irreversible git commands unless the user explicitly requests them
|
||||||
|
- NEVER skip hooks (--no-verify, --no-gpg-sign, etc) unless the user explicitly requests it
|
||||||
|
- Avoid git commit --amend unless explicitly requested by the user
|
||||||
|
- NEVER commit changes unless the user explicitly asks you to
|
||||||
|
|
||||||
|
# Safety
|
||||||
|
- Never delete files without confirmation
|
||||||
|
- Never run destructive commands (rm -rf, DROP TABLE, etc.) without confirmation
|
||||||
|
- Always validate inputs before using them in shell commands
|
||||||
|
- Be careful with environment variables and secrets
|
||||||
|
- Do not expose API keys, passwords, or tokens in code or logs
|
||||||
|
|
||||||
|
# Environment
|
||||||
|
Working directory: /Users/test/code/myproject
|
||||||
|
Platform: darwin
|
||||||
|
Shell: zsh
|
||||||
|
Is directory a git repo: yes
|
||||||
|
The project uses Go 1.22 with modules. Run tests with 'go test ./...' and build with 'go build ./...'.
|
||||||
|
The CI pipeline runs golangci-lint, go vet, and go test with race detector enabled.
|
||||||
|
|
||||||
|
# User instructions
|
||||||
|
Never use cd to change into the repo root or any other directory in Bash commands. The working directory is always the repo root — use relative paths directly.
|
||||||
|
Never use heredoc-style inline bash or python scripts in Bash tool calls. Instead, write the script to an ephemeral file under ./.tmp/ in the repo, then run it as a separate command.`
|
||||||
|
}
|
||||||
|
|
||||||
|
// validStressTools is the set of tool names used in the stress test.
|
||||||
|
var validStressTools = map[string]bool{
|
||||||
|
"bash": true, "read": true, "glob": true, "grep": true,
|
||||||
|
"edit": true, "write": true, "question": true, "task": true,
|
||||||
|
"webfetch": true, "todowrite": true, "skill": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
func testToolCall(t *testing.T, ctx context.Context, client *api.Client, model, systemPrompt string, tools []api.Tool, userMessage string, initialTimeout, streamTimeout time.Duration) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
req := api.ChatRequest{
|
||||||
|
Model: model,
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "system", Content: systemPrompt},
|
||||||
|
{Role: "user", Content: userMessage},
|
||||||
|
},
|
||||||
|
Tools: tools,
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 0,
|
||||||
|
"num_ctx": contextLength(16384),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
stallTimer := time.NewTimer(initialTimeout)
|
||||||
|
var gotToolCall bool
|
||||||
|
var lastToolCall api.ToolCall
|
||||||
|
var allContent string
|
||||||
|
|
||||||
|
fn := func(response api.ChatResponse) error {
|
||||||
|
if len(response.Message.ToolCalls) > 0 {
|
||||||
|
gotToolCall = true
|
||||||
|
lastToolCall = response.Message.ToolCalls[len(response.Message.ToolCalls)-1]
|
||||||
|
}
|
||||||
|
allContent += response.Message.Content
|
||||||
|
if !stallTimer.Reset(streamTimeout) {
|
||||||
|
return fmt.Errorf("stall detected while streaming")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
stream := true
|
||||||
|
req.Stream = &stream
|
||||||
|
done := make(chan int)
|
||||||
|
var genErr error
|
||||||
|
go func() {
|
||||||
|
genErr = client.Chat(ctx, &req, fn)
|
||||||
|
done <- 0
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-stallTimer.C:
|
||||||
|
t.Fatalf("chat stalled after %s", initialTimeout)
|
||||||
|
case <-done:
|
||||||
|
if genErr != nil {
|
||||||
|
t.Fatalf("chat failed: %v", genErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for leaked special tags in content — these should never
|
||||||
|
// appear in user-visible output regardless of model quality.
|
||||||
|
checkNoLeakedTags(t, allContent)
|
||||||
|
|
||||||
|
// The model must produce either a tool call or a text response.
|
||||||
|
// A text response (e.g. asking for clarification) is legitimate.
|
||||||
|
// Empty output with no tool call indicates a parser or model failure
|
||||||
|
// (e.g. malformed tool call that gets dropped).
|
||||||
|
if !gotToolCall && allContent == "" {
|
||||||
|
t.Fatal("model produced neither a tool call nor text content")
|
||||||
|
}
|
||||||
|
if gotToolCall {
|
||||||
|
if !validStressTools[lastToolCall.Function.Name] {
|
||||||
|
t.Errorf("unexpected tool: %q", lastToolCall.Function.Name)
|
||||||
|
}
|
||||||
|
argsJSON, _ := json.Marshal(lastToolCall.Function.Arguments)
|
||||||
|
t.Logf("tool call: %s(%s)", lastToolCall.Function.Name, string(argsJSON))
|
||||||
|
} else {
|
||||||
|
t.Logf("text response (no tool call): %q", truncate(allContent, 200))
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Fatal("context cancelled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testToolCallMultiTurn(t *testing.T, ctx context.Context, client *api.Client, model, systemPrompt string, tools []api.Tool, initialTimeout, streamTimeout time.Duration) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
req := api.ChatRequest{
|
||||||
|
Model: model,
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "system", Content: systemPrompt},
|
||||||
|
{Role: "user", Content: "What files are in the current directory?"},
|
||||||
|
{Role: "assistant", Content: "", ToolCalls: []api.ToolCall{{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "bash",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{},
|
||||||
|
},
|
||||||
|
}}},
|
||||||
|
{Role: "tool", Content: "go.mod\ngo.sum\nmain.go\nREADME.md\n"},
|
||||||
|
// The model should now respond with content or another tool call
|
||||||
|
},
|
||||||
|
Tools: tools,
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 0,
|
||||||
|
"num_ctx": contextLength(16384),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// For the tool response arguments, set the command
|
||||||
|
req.Messages[2].ToolCalls[0].Function.Arguments.Set("command", "ls")
|
||||||
|
|
||||||
|
stallTimer := time.NewTimer(initialTimeout)
|
||||||
|
var gotResponse bool
|
||||||
|
var allContent string
|
||||||
|
var gotToolCall bool
|
||||||
|
|
||||||
|
fn := func(response api.ChatResponse) error {
|
||||||
|
if response.Message.Content != "" {
|
||||||
|
gotResponse = true
|
||||||
|
allContent += response.Message.Content
|
||||||
|
}
|
||||||
|
if len(response.Message.ToolCalls) > 0 {
|
||||||
|
gotToolCall = true
|
||||||
|
gotResponse = true
|
||||||
|
}
|
||||||
|
if !stallTimer.Reset(streamTimeout) {
|
||||||
|
return fmt.Errorf("stall detected")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
stream := true
|
||||||
|
req.Stream = &stream
|
||||||
|
done := make(chan int)
|
||||||
|
var genErr error
|
||||||
|
go func() {
|
||||||
|
genErr = client.Chat(ctx, &req, fn)
|
||||||
|
done <- 0
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-stallTimer.C:
|
||||||
|
t.Fatalf("chat stalled after %s", initialTimeout)
|
||||||
|
case <-done:
|
||||||
|
if genErr != nil {
|
||||||
|
t.Fatalf("chat failed: %v", genErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkNoLeakedTags(t, allContent)
|
||||||
|
|
||||||
|
if !gotResponse {
|
||||||
|
t.Fatal("expected response (content or tool call), got nothing")
|
||||||
|
}
|
||||||
|
if gotToolCall {
|
||||||
|
t.Log("multi-turn: got follow-up tool call")
|
||||||
|
} else {
|
||||||
|
t.Logf("multi-turn: got content response: %q", truncate(allContent, 200))
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Fatal("context cancelled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkNoLeakedTags verifies that model-internal special tags do not appear in
|
||||||
|
// user-visible content. These tags should be consumed by the parser and never
|
||||||
|
// passed through. If they appear, either the parser has a bug or the model is
|
||||||
|
// generating malformed output that the parser fails to handle.
|
||||||
|
func checkNoLeakedTags(t *testing.T, content string) {
|
||||||
|
t.Helper()
|
||||||
|
leakedTags := []string{
|
||||||
|
"<|channel>", "<channel|>",
|
||||||
|
"<|tool_call>", "<tool_call|>",
|
||||||
|
"<|tool>", "<tool|>",
|
||||||
|
"<|turn>", "<turn|>",
|
||||||
|
}
|
||||||
|
for _, tag := range leakedTags {
|
||||||
|
if strings.Contains(content, tag) {
|
||||||
|
t.Errorf("leaked special tag %q in content: %q", tag, truncate(content, 300))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func contextLength(defaultVal int) int {
|
||||||
|
if s := os.Getenv("OLLAMA_CONTEXT_LENGTH"); s != "" {
|
||||||
|
if n, err := strconv.Atoi(s); err == nil {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultVal
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncate(s string, n int) string {
|
||||||
|
if len(s) <= n {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return s[:n] + "..."
|
||||||
|
}
|
||||||
@@ -47,15 +47,18 @@ func TestAPIToolCalling(t *testing.T) {
|
|||||||
"granite3.3": 7,
|
"granite3.3": 7,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, model := range libraryToolsModels {
|
models := testModels(libraryToolsModels)
|
||||||
|
|
||||||
|
for _, model := range models {
|
||||||
t.Run(model, func(t *testing.T) {
|
t.Run(model, func(t *testing.T) {
|
||||||
|
if testModel != "" {
|
||||||
|
requireCapability(ctx, t, client, model, "tools")
|
||||||
|
}
|
||||||
if v, ok := minVRAM[model]; ok {
|
if v, ok := minVRAM[model]; ok {
|
||||||
skipUnderMinVRAM(t, v)
|
skipUnderMinVRAM(t, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
pullOrSkip(ctx, t, client, model)
|
||||||
t.Fatalf("pull failed %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tools := []api.Tool{
|
tools := []api.Tool{
|
||||||
{
|
{
|
||||||
|
|||||||