mirror of
https://github.com/ollama/ollama.git
synced 2026-04-19 03:54:21 +02:00
Compare commits
203 Commits
brucemacd/
...
pdevine/qw
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
578c32e42e | ||
|
|
a10d2625ca | ||
|
|
b960d769ad | ||
|
|
455a6099d1 | ||
|
|
7e6e8377eb | ||
|
|
126d8db7f3 | ||
|
|
3f3a24b418 | ||
|
|
96e36c0d90 | ||
|
|
6f8ddbb26b | ||
|
|
b5e7888414 | ||
|
|
eab4d22269 | ||
|
|
5759c2d2d2 | ||
|
|
42b1c2642b | ||
|
|
727d69ddf3 | ||
|
|
f622b0c5fc | ||
|
|
5d0000634c | ||
|
|
676d9845ba | ||
|
|
e37a9b4c01 | ||
|
|
d727aacd04 | ||
|
|
fa69b833cd | ||
|
|
bbbad97686 | ||
|
|
bcf6d55b54 | ||
|
|
810d4f9c22 | ||
|
|
856c047a6c | ||
|
|
79c1e93c00 | ||
|
|
f8b657c967 | ||
|
|
10fefe0d57 | ||
|
|
2f9a68f9e9 | ||
|
|
3980c0217d | ||
|
|
870599f5da | ||
|
|
abf8e8e9c8 | ||
|
|
f3f31a8192 | ||
|
|
9e7ba835da | ||
|
|
347f17b8d1 | ||
|
|
081b9eb423 | ||
|
|
bb867c6fdb | ||
|
|
81f4506a61 | ||
|
|
76925f1284 | ||
|
|
f676231de9 | ||
|
|
af5f7c0a9e | ||
|
|
a6b27d776b | ||
|
|
539741199e | ||
|
|
8f45236d09 | ||
|
|
97013a190c | ||
|
|
c222735c02 | ||
|
|
87d21c7fc0 | ||
|
|
54e05172a0 | ||
|
|
464186e995 | ||
|
|
8c4d5d6c2f | ||
|
|
bc72b14016 | ||
|
|
61086083eb | ||
|
|
62d1f01ab4 | ||
|
|
10e51c5177 | ||
|
|
3e06bde643 | ||
|
|
6be2de8214 | ||
|
|
ebb1b9ec14 | ||
|
|
d126467d5d | ||
|
|
afb4c62fbf | ||
|
|
e790dc435b | ||
|
|
288077c3a3 | ||
|
|
4425c54eda | ||
|
|
778899a5d2 | ||
|
|
4eab60c1e2 | ||
|
|
1af850e6e3 | ||
|
|
9b0c7cc7b9 | ||
|
|
6928630601 | ||
|
|
9896e3627f | ||
|
|
15732f0ea7 | ||
|
|
562c76d7cc | ||
|
|
122c68c151 | ||
|
|
82848a7806 | ||
|
|
39982a954e | ||
|
|
e9f6ea232f | ||
|
|
110eff01a9 | ||
|
|
799e51d419 | ||
|
|
e8fcb29586 | ||
|
|
97d2f05a6d | ||
|
|
8207e55ec7 | ||
|
|
ad16bffc7d | ||
|
|
c1e3ef4bcc | ||
|
|
a3093cd5e5 | ||
|
|
23d4cad1a2 | ||
|
|
86513cb697 | ||
|
|
3490e9590b | ||
|
|
8da09b1e7e | ||
|
|
a60b9adcce | ||
|
|
a16f96658b | ||
|
|
18ab09b431 | ||
|
|
638faeac54 | ||
|
|
dd5eb6337d | ||
|
|
79917cf80b | ||
|
|
cc90a035a0 | ||
|
|
d98dda4676 | ||
|
|
d69ddc1edc | ||
|
|
9bf41969f0 | ||
|
|
0f23b7bff5 | ||
|
|
4e57d2094e | ||
|
|
7f9efd53df | ||
|
|
da70c3222e | ||
|
|
9d902d63ce | ||
|
|
f4f0a4a471 | ||
|
|
3323c1d319 | ||
|
|
f20dc6b698 | ||
|
|
4b2ac1f369 | ||
|
|
8daf47fb3a | ||
|
|
6c980579cd | ||
|
|
5c73c4e2ee | ||
|
|
5daf59cc66 | ||
|
|
0ade9205cc | ||
|
|
06edabdde1 | ||
|
|
8b4e5a82a8 | ||
|
|
3445223311 | ||
|
|
fa6c0127e6 | ||
|
|
97323d1c68 | ||
|
|
458dd1b9d9 | ||
|
|
9d02d1d767 | ||
|
|
1a636fb47a | ||
|
|
0759fface9 | ||
|
|
325b72bc31 | ||
|
|
f01a9a7859 | ||
|
|
9aefd2dfee | ||
|
|
d07e4a1dd3 | ||
|
|
8a257ec00a | ||
|
|
2f4de1acf7 | ||
|
|
ec95c45f70 | ||
|
|
3a88f7eb20 | ||
|
|
0d5da826d4 | ||
|
|
9b795698b8 | ||
|
|
041fb77639 | ||
|
|
8224cce583 | ||
|
|
d18dcd7775 | ||
|
|
5f5ef20131 | ||
|
|
f0a07a353b | ||
|
|
948de6bbd2 | ||
|
|
598b74d42c | ||
|
|
935a48ed1a | ||
|
|
de39e24bf7 | ||
|
|
519b11eba1 | ||
|
|
379fd64fa8 | ||
|
|
59c019a6fb | ||
|
|
fad3bcccb2 | ||
|
|
bd6697ad95 | ||
|
|
f8dc7c9f54 | ||
|
|
4a3741129d | ||
|
|
77ba9404ac | ||
|
|
0aaf6119ec | ||
|
|
f08427c138 | ||
|
|
2dbb000908 | ||
|
|
c980e19995 | ||
|
|
6162374ca9 | ||
|
|
44bdd9a2ef | ||
|
|
db493d6e5e | ||
|
|
75695f16a5 | ||
|
|
a0407d07fa | ||
|
|
9ec733e527 | ||
|
|
5ef04dab52 | ||
|
|
aea316f1e9 | ||
|
|
235ba3df5c | ||
|
|
099a0f18ef | ||
|
|
fff696ee31 | ||
|
|
2e3ce6eab3 | ||
|
|
9e2003f88a | ||
|
|
42e1d49fbe | ||
|
|
814630ca60 | ||
|
|
87cf187774 | ||
|
|
6ddd8862cd | ||
|
|
f1373193dc | ||
|
|
8a4b77f9da | ||
|
|
5f53fe7884 | ||
|
|
7ab4ca0e7f | ||
|
|
e36f389e82 | ||
|
|
c61023f554 | ||
|
|
d25535c3f3 | ||
|
|
c323161f24 | ||
|
|
255579aaa7 | ||
|
|
f7102ba826 | ||
|
|
cefabd79a8 | ||
|
|
df70249520 | ||
|
|
77eb2ca619 | ||
|
|
ee25219edd | ||
|
|
b1fccabb34 | ||
|
|
a6355329bf | ||
|
|
0398b24b42 | ||
|
|
75b1dddf91 | ||
|
|
e1e80ffc3e | ||
|
|
71896485fd | ||
|
|
ef00199fb4 | ||
|
|
8f4a008139 | ||
|
|
d8cc798c2b | ||
|
|
6582f6da5c | ||
|
|
0334ffa625 | ||
|
|
d11fbd2c60 | ||
|
|
6a7c3f188e | ||
|
|
427e2c962a | ||
|
|
27db7f806f | ||
|
|
3590fbfa76 | ||
|
|
cd0094f772 | ||
|
|
06bc8e6712 | ||
|
|
fc5f9bb448 | ||
|
|
a0740f7ef7 | ||
|
|
a0923cbdd0 | ||
|
|
f92e362b2e | ||
|
|
aa23d8ecd2 |
70
.github/workflows/release.yaml
vendored
70
.github/workflows/release.yaml
vendored
@@ -117,6 +117,25 @@ jobs:
|
|||||||
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
||||||
flags: ''
|
flags: ''
|
||||||
runner_dir: 'vulkan'
|
runner_dir: 'vulkan'
|
||||||
|
- os: windows
|
||||||
|
arch: amd64
|
||||||
|
preset: 'MLX CUDA 13'
|
||||||
|
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
|
||||||
|
cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip
|
||||||
|
cuda-components:
|
||||||
|
- '"cudart"'
|
||||||
|
- '"nvcc"'
|
||||||
|
- '"cublas"'
|
||||||
|
- '"cublas_dev"'
|
||||||
|
- '"cufft"'
|
||||||
|
- '"cufft_dev"'
|
||||||
|
- '"nvrtc"'
|
||||||
|
- '"nvrtc_dev"'
|
||||||
|
- '"crt"'
|
||||||
|
- '"nvvm"'
|
||||||
|
- '"nvptxcompiler"'
|
||||||
|
cuda-version: '13.0'
|
||||||
|
flags: ''
|
||||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||||
environment: release
|
environment: release
|
||||||
env:
|
env:
|
||||||
@@ -125,8 +144,10 @@ jobs:
|
|||||||
- name: Install system dependencies
|
- name: Install system dependencies
|
||||||
run: |
|
run: |
|
||||||
choco install -y --no-progress ccache ninja
|
choco install -y --no-progress ccache ninja
|
||||||
|
if (Get-Command ccache -ErrorAction SilentlyContinue) {
|
||||||
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
||||||
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan')
|
}
|
||||||
|
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan') || startsWith(matrix.preset, 'MLX ')
|
||||||
id: cache-install
|
id: cache-install
|
||||||
uses: actions/cache/restore@v4
|
uses: actions/cache/restore@v4
|
||||||
with:
|
with:
|
||||||
@@ -134,8 +155,9 @@ jobs:
|
|||||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||||
C:\Program Files\AMD\ROCm
|
C:\Program Files\AMD\ROCm
|
||||||
C:\VulkanSDK
|
C:\VulkanSDK
|
||||||
key: ${{ matrix.install }}
|
C:\Program Files\NVIDIA\CUDNN
|
||||||
- if: startsWith(matrix.preset, 'CUDA ')
|
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||||
|
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'MLX ')
|
||||||
name: Install CUDA ${{ matrix.cuda-version }}
|
name: Install CUDA ${{ matrix.cuda-version }}
|
||||||
run: |
|
run: |
|
||||||
$ErrorActionPreference = "Stop"
|
$ErrorActionPreference = "Stop"
|
||||||
@@ -179,6 +201,23 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
- if: startsWith(matrix.preset, 'MLX ')
|
||||||
|
name: Install cuDNN for MLX
|
||||||
|
run: |
|
||||||
|
$ErrorActionPreference = "Stop"
|
||||||
|
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
|
||||||
|
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||||
|
Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip"
|
||||||
|
Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted
|
||||||
|
$cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName
|
||||||
|
New-Item -ItemType Directory -Force -Path $cudnnRoot
|
||||||
|
Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse
|
||||||
|
}
|
||||||
|
|
||||||
|
echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||||
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
||||||
uses: actions/cache/save@v4
|
uses: actions/cache/save@v4
|
||||||
with:
|
with:
|
||||||
@@ -186,7 +225,8 @@ jobs:
|
|||||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||||
C:\Program Files\AMD\ROCm
|
C:\Program Files\AMD\ROCm
|
||||||
C:\VulkanSDK
|
C:\VulkanSDK
|
||||||
key: ${{ matrix.install }}
|
C:\Program Files\NVIDIA\CUDNN
|
||||||
|
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/cache@v4
|
- uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
@@ -198,7 +238,7 @@ jobs:
|
|||||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}"
|
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}"
|
||||||
cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}"
|
cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}"
|
||||||
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
|
cmake --install build --component "${{ startsWith(matrix.preset, 'MLX ') && 'MLX' || startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
|
||||||
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
|
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
|
||||||
env:
|
env:
|
||||||
CMAKE_GENERATOR: Ninja
|
CMAKE_GENERATOR: Ninja
|
||||||
@@ -337,6 +377,7 @@ jobs:
|
|||||||
name: bundles-windows
|
name: bundles-windows
|
||||||
path: |
|
path: |
|
||||||
dist/*.zip
|
dist/*.zip
|
||||||
|
dist/*.ps1
|
||||||
dist/OllamaSetup.exe
|
dist/OllamaSetup.exe
|
||||||
|
|
||||||
linux-build:
|
linux-build:
|
||||||
@@ -514,6 +555,9 @@ jobs:
|
|||||||
- name: Log dist contents
|
- name: Log dist contents
|
||||||
run: |
|
run: |
|
||||||
ls -l dist/
|
ls -l dist/
|
||||||
|
- name: Copy install scripts to dist
|
||||||
|
run: |
|
||||||
|
cp scripts/install.sh dist/install.sh
|
||||||
- name: Generate checksum file
|
- name: Generate checksum file
|
||||||
run: find . -type f -not -name 'sha256sum.txt' | xargs sha256sum | tee sha256sum.txt
|
run: find . -type f -not -name 'sha256sum.txt' | xargs sha256sum | tee sha256sum.txt
|
||||||
working-directory: dist
|
working-directory: dist
|
||||||
@@ -536,14 +580,22 @@ jobs:
|
|||||||
- name: Upload release artifacts
|
- name: Upload release artifacts
|
||||||
run: |
|
run: |
|
||||||
pids=()
|
pids=()
|
||||||
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg ; do
|
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg dist/*.ps1 dist/*.sh ; do
|
||||||
echo "Uploading $payload"
|
echo "Uploading $payload"
|
||||||
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
|
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
|
||||||
pids[$!]=$!
|
pids+=($!)
|
||||||
sleep 1
|
sleep 1
|
||||||
done
|
done
|
||||||
echo "Waiting for uploads to complete"
|
echo "Waiting for uploads to complete"
|
||||||
for pid in "${pids[*]}"; do
|
failed=0
|
||||||
wait $pid
|
for pid in "${pids[@]}"; do
|
||||||
|
if ! wait $pid; then
|
||||||
|
echo "::error::Upload failed (pid $pid)"
|
||||||
|
failed=1
|
||||||
|
fi
|
||||||
done
|
done
|
||||||
|
if [ $failed -ne 0 ]; then
|
||||||
|
echo "One or more uploads failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
echo "done"
|
echo "done"
|
||||||
|
|||||||
22
.github/workflows/test-install.yaml
vendored
Normal file
22
.github/workflows/test-install.yaml
vendored
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
name: test-install
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- 'scripts/install.sh'
|
||||||
|
- '.github/workflows/test-install.yaml'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest, macos-latest]
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Run install script
|
||||||
|
run: sh ./scripts/install.sh
|
||||||
|
env:
|
||||||
|
OLLAMA_NO_START: 1 # do not start app
|
||||||
|
- name: Verify ollama is available
|
||||||
|
run: ollama --version
|
||||||
62
.github/workflows/test.yaml
vendored
62
.github/workflows/test.yaml
vendored
@@ -37,7 +37,7 @@ jobs:
|
|||||||
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
|
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
|
||||||
}
|
}
|
||||||
|
|
||||||
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
|
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*' '.github/**/*') | tee -a $GITHUB_OUTPUT
|
||||||
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
|
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
|
||||||
|
|
||||||
linux:
|
linux:
|
||||||
@@ -51,7 +51,7 @@ jobs:
|
|||||||
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
|
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
|
||||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
||||||
- preset: ROCm
|
- preset: ROCm
|
||||||
container: rocm/dev-ubuntu-22.04:6.1.2
|
container: rocm/dev-ubuntu-22.04:7.2
|
||||||
extra-packages: rocm-libs
|
extra-packages: rocm-libs
|
||||||
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm'
|
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm'
|
||||||
- preset: Vulkan
|
- preset: Vulkan
|
||||||
@@ -60,6 +60,10 @@ jobs:
|
|||||||
mesa-vulkan-drivers vulkan-tools
|
mesa-vulkan-drivers vulkan-tools
|
||||||
libvulkan1 libvulkan-dev
|
libvulkan1 libvulkan-dev
|
||||||
vulkan-sdk cmake ccache g++ make
|
vulkan-sdk cmake ccache g++ make
|
||||||
|
- preset: 'MLX CUDA 13'
|
||||||
|
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
|
||||||
|
extra-packages: libcudnn9-dev-cuda-13 libopenblas-dev liblapack-dev liblapacke-dev git curl
|
||||||
|
flags: '-DCMAKE_CUDA_ARCHITECTURES=87 -DBLAS_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu -DLAPACK_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu'
|
||||||
runs-on: linux
|
runs-on: linux
|
||||||
container: ${{ matrix.container }}
|
container: ${{ matrix.container }}
|
||||||
steps:
|
steps:
|
||||||
@@ -76,6 +80,10 @@ jobs:
|
|||||||
$sudo apt-get update
|
$sudo apt-get update
|
||||||
fi
|
fi
|
||||||
$sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }}
|
$sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }}
|
||||||
|
# MLX requires CMake 3.25+, install from official releases
|
||||||
|
if [ "${{ matrix.preset }}" = "MLX CUDA 13" ]; then
|
||||||
|
curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.31.2/cmake-3.31.2-linux-$(uname -m).tar.gz | $sudo tar xz -C /usr/local --strip-components 1
|
||||||
|
fi
|
||||||
# Export VULKAN_SDK if provided by LunarG package (defensive)
|
# Export VULKAN_SDK if provided by LunarG package (defensive)
|
||||||
if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then
|
if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then
|
||||||
echo "VULKAN_SDK=/usr" >> $GITHUB_ENV
|
echo "VULKAN_SDK=/usr" >> $GITHUB_ENV
|
||||||
@@ -87,8 +95,8 @@ jobs:
|
|||||||
path: /github/home/.cache/ccache
|
path: /github/home/.cache/ccache
|
||||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
|
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
|
||||||
- run: |
|
- run: |
|
||||||
cmake --preset ${{ matrix.preset }} ${{ matrix.flags }}
|
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
|
||||||
cmake --build --preset ${{ matrix.preset }} --parallel
|
cmake --build --preset "${{ matrix.preset }}" --parallel
|
||||||
|
|
||||||
windows:
|
windows:
|
||||||
needs: [changes]
|
needs: [changes]
|
||||||
@@ -114,12 +122,31 @@ jobs:
|
|||||||
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||||
- preset: Vulkan
|
- preset: Vulkan
|
||||||
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
||||||
|
- preset: 'MLX CUDA 13'
|
||||||
|
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
|
||||||
|
cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip
|
||||||
|
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
|
||||||
|
cuda-components:
|
||||||
|
- '"cudart"'
|
||||||
|
- '"nvcc"'
|
||||||
|
- '"cublas"'
|
||||||
|
- '"cublas_dev"'
|
||||||
|
- '"cufft"'
|
||||||
|
- '"cufft_dev"'
|
||||||
|
- '"nvrtc"'
|
||||||
|
- '"nvrtc_dev"'
|
||||||
|
- '"crt"'
|
||||||
|
- '"nvvm"'
|
||||||
|
- '"nvptxcompiler"'
|
||||||
|
cuda-version: '13.0'
|
||||||
runs-on: windows
|
runs-on: windows
|
||||||
steps:
|
steps:
|
||||||
- run: |
|
- run: |
|
||||||
choco install -y --no-progress ccache ninja
|
choco install -y --no-progress ccache ninja
|
||||||
|
if (Get-Command ccache -ErrorAction SilentlyContinue) {
|
||||||
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
||||||
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan'
|
}
|
||||||
|
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan' || matrix.preset == 'MLX CUDA 13'
|
||||||
id: cache-install
|
id: cache-install
|
||||||
uses: actions/cache/restore@v4
|
uses: actions/cache/restore@v4
|
||||||
with:
|
with:
|
||||||
@@ -127,8 +154,9 @@ jobs:
|
|||||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||||
C:\Program Files\AMD\ROCm
|
C:\Program Files\AMD\ROCm
|
||||||
C:\VulkanSDK
|
C:\VulkanSDK
|
||||||
key: ${{ matrix.install }}
|
C:\Program Files\NVIDIA\CUDNN
|
||||||
- if: matrix.preset == 'CUDA'
|
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||||
|
- if: matrix.preset == 'CUDA' || matrix.preset == 'MLX CUDA 13'
|
||||||
name: Install CUDA ${{ matrix.cuda-version }}
|
name: Install CUDA ${{ matrix.cuda-version }}
|
||||||
run: |
|
run: |
|
||||||
$ErrorActionPreference = "Stop"
|
$ErrorActionPreference = "Stop"
|
||||||
@@ -168,6 +196,23 @@ jobs:
|
|||||||
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
|
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
|
||||||
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||||
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
|
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
|
||||||
|
- if: matrix.preset == 'MLX CUDA 13'
|
||||||
|
name: Install cuDNN for MLX
|
||||||
|
run: |
|
||||||
|
$ErrorActionPreference = "Stop"
|
||||||
|
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
|
||||||
|
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||||
|
Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip"
|
||||||
|
Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted
|
||||||
|
$cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName
|
||||||
|
New-Item -ItemType Directory -Force -Path $cudnnRoot
|
||||||
|
Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse
|
||||||
|
}
|
||||||
|
|
||||||
|
echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||||
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
||||||
uses: actions/cache/save@v4
|
uses: actions/cache/save@v4
|
||||||
with:
|
with:
|
||||||
@@ -175,7 +220,8 @@ jobs:
|
|||||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||||
C:\Program Files\AMD\ROCm
|
C:\Program Files\AMD\ROCm
|
||||||
C:\VulkanSDK
|
C:\VulkanSDK
|
||||||
key: ${{ matrix.install }}
|
C:\Program Files\NVIDIA\CUDNN
|
||||||
|
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/cache@v4
|
- uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
|
|||||||
110
CMakeLists.txt
110
CMakeLists.txt
@@ -64,10 +64,15 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR})
|
|||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
|
||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${OLLAMA_BUILD_DIR})
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${OLLAMA_BUILD_DIR})
|
||||||
|
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
# Store ggml include paths for use with target_include_directories later.
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include)
|
# We avoid global include_directories() to prevent polluting the include path
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
|
# for other projects like MLX (whose openblas dependency has its own common.h).
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
|
set(GGML_INCLUDE_DIRS
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx
|
||||||
|
)
|
||||||
|
|
||||||
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
|
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
|
||||||
|
|
||||||
@@ -87,6 +92,14 @@ if(NOT CPU_VARIANTS)
|
|||||||
set(CPU_VARIANTS "ggml-cpu")
|
set(CPU_VARIANTS "ggml-cpu")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# Apply ggml include directories to ggml targets only (not globally)
|
||||||
|
target_include_directories(ggml-base PRIVATE ${GGML_INCLUDE_DIRS})
|
||||||
|
foreach(variant ${CPU_VARIANTS})
|
||||||
|
if(TARGET ${variant})
|
||||||
|
target_include_directories(${variant} PRIVATE ${GGML_INCLUDE_DIRS})
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
|
||||||
install(TARGETS ggml-base ${CPU_VARIANTS}
|
install(TARGETS ggml-base ${CPU_VARIANTS}
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
PRE_EXCLUDE_REGEXES ".*"
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
@@ -103,6 +116,7 @@ if(CMAKE_CUDA_COMPILER)
|
|||||||
|
|
||||||
find_package(CUDAToolkit)
|
find_package(CUDAToolkit)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
||||||
|
target_include_directories(ggml-cuda PRIVATE ${GGML_INCLUDE_DIRS})
|
||||||
install(TARGETS ggml-cuda
|
install(TARGETS ggml-cuda
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||||
@@ -134,6 +148,7 @@ if(CMAKE_HIP_COMPILER)
|
|||||||
if(AMDGPU_TARGETS)
|
if(AMDGPU_TARGETS)
|
||||||
find_package(hip REQUIRED)
|
find_package(hip REQUIRED)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
||||||
|
target_include_directories(ggml-hip PRIVATE ${GGML_INCLUDE_DIRS})
|
||||||
|
|
||||||
if (WIN32)
|
if (WIN32)
|
||||||
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
|
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
|
||||||
@@ -148,7 +163,7 @@ if(CMAKE_HIP_COMPILER)
|
|||||||
)
|
)
|
||||||
install(RUNTIME_DEPENDENCY_SET rocm
|
install(RUNTIME_DEPENDENCY_SET rocm
|
||||||
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
|
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
|
||||||
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf
|
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register roctx64 rocroller drm drm_amdgpu numa elf
|
||||||
PRE_EXCLUDE_REGEXES ".*"
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
POST_EXCLUDE_REGEXES "system32"
|
POST_EXCLUDE_REGEXES "system32"
|
||||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
|
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
|
||||||
@@ -168,6 +183,7 @@ if(NOT APPLE)
|
|||||||
find_package(Vulkan)
|
find_package(Vulkan)
|
||||||
if(Vulkan_FOUND)
|
if(Vulkan_FOUND)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
||||||
|
target_include_directories(ggml-vulkan PRIVATE ${GGML_INCLUDE_DIRS})
|
||||||
install(TARGETS ggml-vulkan
|
install(TARGETS ggml-vulkan
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
PRE_INCLUDE_REGEXES vulkan
|
PRE_INCLUDE_REGEXES vulkan
|
||||||
@@ -179,18 +195,43 @@ if(NOT APPLE)
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
option(MLX_ENGINE "Enable MLX backend" OFF)
|
option(MLX_ENGINE "Enable MLX backend" OFF)
|
||||||
|
|
||||||
if(MLX_ENGINE)
|
if(MLX_ENGINE)
|
||||||
message(STATUS "Setting up MLX (this takes a while...)")
|
message(STATUS "Setting up MLX (this takes a while...)")
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx)
|
||||||
|
|
||||||
# Find CUDA toolkit if MLX is built with CUDA support
|
# Find CUDA toolkit if MLX is built with CUDA support
|
||||||
find_package(CUDAToolkit)
|
find_package(CUDAToolkit)
|
||||||
|
|
||||||
|
# Build list of directories for runtime dependency resolution
|
||||||
|
set(MLX_RUNTIME_DIRS ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR})
|
||||||
|
# Add cuDNN bin paths for DLLs (Windows MLX CUDA builds)
|
||||||
|
# CUDNN_ROOT_DIR is the standard CMake variable for cuDNN location
|
||||||
|
if(DEFINED ENV{CUDNN_ROOT_DIR})
|
||||||
|
# cuDNN 9.x has versioned subdirectories under bin/ (e.g., bin/13.0/)
|
||||||
|
file(GLOB CUDNN_BIN_SUBDIRS "$ENV{CUDNN_ROOT_DIR}/bin/*")
|
||||||
|
list(APPEND MLX_RUNTIME_DIRS ${CUDNN_BIN_SUBDIRS})
|
||||||
|
endif()
|
||||||
|
# Add build output directory and MLX dependency build directories
|
||||||
|
list(APPEND MLX_RUNTIME_DIRS ${OLLAMA_BUILD_DIR})
|
||||||
|
# OpenBLAS DLL location (pre-built zip extracts into openblas-src/bin/)
|
||||||
|
list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/openblas-src/bin)
|
||||||
|
# NCCL: on Linux, if real NCCL is found, cmake bundles libnccl.so via the
|
||||||
|
# regex below. If NCCL is not found, MLX links a static stub (OBJECT lib)
|
||||||
|
# so there is no runtime dependency. This path covers the stub build dir
|
||||||
|
# for windows so we include the DLL in our dependencies.
|
||||||
|
list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/distributed/nccl/nccl_stub-prefix/src/nccl_stub-build/Release)
|
||||||
|
|
||||||
|
# Base regexes for runtime dependencies (cross-platform)
|
||||||
|
set(MLX_INCLUDE_REGEXES cublas cublasLt cudart cufft nvrtc nvrtc-builtins cudnn nccl openblas gfortran)
|
||||||
|
# On Windows, also include dl.dll (dlfcn-win32 POSIX emulation layer)
|
||||||
|
if(WIN32)
|
||||||
|
list(APPEND MLX_INCLUDE_REGEXES "^dl\\.dll$")
|
||||||
|
endif()
|
||||||
|
|
||||||
install(TARGETS mlx mlxc
|
install(TARGETS mlx mlxc
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
DIRECTORIES ${MLX_RUNTIME_DIRS}
|
||||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
|
PRE_INCLUDE_REGEXES ${MLX_INCLUDE_REGEXES}
|
||||||
PRE_EXCLUDE_REGEXES ".*"
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||||
@@ -205,13 +246,54 @@ if(MLX_ENGINE)
|
|||||||
COMPONENT MLX)
|
COMPONENT MLX)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Manually install cudart and cublas since they might not be picked up as direct dependencies
|
# Install CCCL headers for NVRTC JIT compilation at runtime.
|
||||||
|
# MLX's own install rules use the default component so they get skipped by
|
||||||
|
# --component MLX. Headers are installed alongside libmlx in OLLAMA_INSTALL_DIR.
|
||||||
|
# On Linux, MLX's jit_module.cpp resolves CCCL via
|
||||||
|
# current_binary_dir().parent_path() / "include" / "cccl", so we create a
|
||||||
|
# symlink from lib/ollama/include -> ${OLLAMA_RUNNER_DIR}/include
|
||||||
|
# This will need refinement if we add multiple CUDA versions for MLX in the future.
|
||||||
|
if(EXISTS ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda)
|
||||||
|
install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda
|
||||||
|
DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl
|
||||||
|
COMPONENT MLX)
|
||||||
|
install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/nv
|
||||||
|
DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl
|
||||||
|
COMPONENT MLX)
|
||||||
|
if(NOT WIN32 AND NOT APPLE)
|
||||||
|
install(CODE "
|
||||||
|
set(_link \"${CMAKE_INSTALL_PREFIX}/lib/ollama/include\")
|
||||||
|
set(_target \"${OLLAMA_RUNNER_DIR}/include\")
|
||||||
|
if(NOT EXISTS \${_link})
|
||||||
|
execute_process(COMMAND \${CMAKE_COMMAND} -E create_symlink \${_target} \${_link})
|
||||||
|
endif()
|
||||||
|
" COMPONENT MLX)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# On Windows, explicitly install dl.dll (dlfcn-win32 POSIX dlopen emulation)
|
||||||
|
# RUNTIME_DEPENDENCIES auto-excludes it via POST_EXCLUDE_FILES_STRICT because
|
||||||
|
# dlfcn-win32 is a known CMake target with its own install rules (which install
|
||||||
|
# to the wrong destination). We must install it explicitly here.
|
||||||
|
if(WIN32)
|
||||||
|
install(FILES ${OLLAMA_BUILD_DIR}/dl.dll
|
||||||
|
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||||
|
COMPONENT MLX)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Manually install CUDA runtime libraries that MLX loads via dlopen
|
||||||
|
# (not detected by RUNTIME_DEPENDENCIES since they aren't link-time deps)
|
||||||
if(CUDAToolkit_FOUND)
|
if(CUDAToolkit_FOUND)
|
||||||
file(GLOB CUDART_LIBS
|
file(GLOB MLX_CUDA_LIBS
|
||||||
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
|
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
|
||||||
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
|
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*"
|
||||||
if(CUDART_LIBS)
|
"${CUDAToolkit_LIBRARY_DIR}/libcublasLt.so*"
|
||||||
install(FILES ${CUDART_LIBS}
|
"${CUDAToolkit_LIBRARY_DIR}/libnvrtc.so*"
|
||||||
|
"${CUDAToolkit_LIBRARY_DIR}/libnvrtc-builtins.so*"
|
||||||
|
"${CUDAToolkit_LIBRARY_DIR}/libcufft.so*"
|
||||||
|
"${CUDAToolkit_LIBRARY_DIR}/libcudnn.so*")
|
||||||
|
if(MLX_CUDA_LIBS)
|
||||||
|
install(FILES ${MLX_CUDA_LIBS}
|
||||||
DESTINATION ${OLLAMA_INSTALL_DIR}
|
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||||
COMPONENT MLX)
|
COMPONENT MLX)
|
||||||
endif()
|
endif()
|
||||||
|
|||||||
@@ -77,6 +77,15 @@
|
|||||||
"OLLAMA_RUNNER_DIR": "rocm"
|
"OLLAMA_RUNNER_DIR": "rocm"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "ROCm 7",
|
||||||
|
"inherits": [ "ROCm" ],
|
||||||
|
"cacheVariables": {
|
||||||
|
"CMAKE_HIP_FLAGS": "-parallel-jobs=4",
|
||||||
|
"AMDGPU_TARGETS": "gfx942;gfx950;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1151;gfx1200;gfx1201;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-",
|
||||||
|
"OLLAMA_RUNNER_DIR": "rocm"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "Vulkan",
|
"name": "Vulkan",
|
||||||
"inherits": [ "Default" ],
|
"inherits": [ "Default" ],
|
||||||
@@ -103,6 +112,7 @@
|
|||||||
"name": "MLX CUDA 13",
|
"name": "MLX CUDA 13",
|
||||||
"inherits": [ "MLX", "CUDA 13" ],
|
"inherits": [ "MLX", "CUDA 13" ],
|
||||||
"cacheVariables": {
|
"cacheVariables": {
|
||||||
|
"MLX_CUDA_ARCHITECTURES": "86;89;90;90a;100;103;75-virtual;80-virtual;110-virtual;120-virtual;121-virtual",
|
||||||
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
|
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -158,6 +168,11 @@
|
|||||||
"inherits": [ "ROCm" ],
|
"inherits": [ "ROCm" ],
|
||||||
"configurePreset": "ROCm 6"
|
"configurePreset": "ROCm 6"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "ROCm 7",
|
||||||
|
"inherits": [ "ROCm" ],
|
||||||
|
"configurePreset": "ROCm 7"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "Vulkan",
|
"name": "Vulkan",
|
||||||
"targets": [ "ggml-vulkan" ],
|
"targets": [ "ggml-vulkan" ],
|
||||||
|
|||||||
133
Dockerfile
133
Dockerfile
@@ -1,33 +1,23 @@
|
|||||||
# vim: filetype=dockerfile
|
# vim: filetype=dockerfile
|
||||||
|
|
||||||
ARG FLAVOR=${TARGETARCH}
|
ARG FLAVOR=${TARGETARCH}
|
||||||
ARG PARALLEL=8
|
|
||||||
|
|
||||||
ARG ROCMVERSION=6.3.3
|
ARG ROCMVERSION=7.2
|
||||||
ARG JETPACK5VERSION=r35.4.1
|
ARG JETPACK5VERSION=r35.4.1
|
||||||
ARG JETPACK6VERSION=r36.4.0
|
ARG JETPACK6VERSION=r36.4.0
|
||||||
ARG CMAKEVERSION=3.31.2
|
ARG CMAKEVERSION=3.31.2
|
||||||
|
ARG NINJAVERSION=1.12.1
|
||||||
ARG VULKANVERSION=1.4.321.1
|
ARG VULKANVERSION=1.4.321.1
|
||||||
|
|
||||||
# We require gcc v10 minimum. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
|
# Default empty stages for local MLX source overrides.
|
||||||
|
# Override with: docker build --build-context local-mlx=../mlx --build-context local-mlx-c=../mlx-c
|
||||||
|
FROM scratch AS local-mlx
|
||||||
|
FROM scratch AS local-mlx-c
|
||||||
|
|
||||||
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
|
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
|
||||||
RUN yum install -y yum-utils \
|
RUN dnf install -y yum-utils ccache gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ gcc-toolset-11-binutils \
|
||||||
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
|
|
||||||
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
|
|
||||||
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \
|
|
||||||
&& dnf install -y ccache \
|
|
||||||
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
||||||
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
|
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
||||||
ARG VULKANVERSION
|
|
||||||
RUN wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
|
|
||||||
&& tar xvf /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
|
|
||||||
&& dnf -y install ninja-build \
|
|
||||||
&& ln -s /usr/bin/python3 /usr/bin/python \
|
|
||||||
&& /${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \
|
|
||||||
&& /${VULKANVERSION}/vulkansdk -j 8 shaderc
|
|
||||||
RUN cp -r /${VULKANVERSION}/x86_64/include/* /usr/local/include/ \
|
|
||||||
&& cp -r /${VULKANVERSION}/x86_64/lib/* /usr/local/lib
|
|
||||||
ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH
|
|
||||||
|
|
||||||
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
|
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
|
||||||
# install epel-release for ccache
|
# install epel-release for ccache
|
||||||
@@ -38,100 +28,119 @@ ENV CC=clang CXX=clang++
|
|||||||
|
|
||||||
FROM base-${TARGETARCH} AS base
|
FROM base-${TARGETARCH} AS base
|
||||||
ARG CMAKEVERSION
|
ARG CMAKEVERSION
|
||||||
|
ARG NINJAVERSION
|
||||||
RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||||
|
RUN dnf install -y unzip \
|
||||||
|
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux$([ "$(uname -m)" = "aarch64" ] && echo "-aarch64").zip \
|
||||||
|
&& unzip /tmp/ninja.zip -d /usr/local/bin \
|
||||||
|
&& rm /tmp/ninja.zip
|
||||||
|
ENV CMAKE_GENERATOR=Ninja
|
||||||
ENV LDFLAGS=-s
|
ENV LDFLAGS=-s
|
||||||
|
|
||||||
FROM base AS cpu
|
FROM base AS cpu
|
||||||
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
|
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
|
||||||
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
||||||
ARG PARALLEL
|
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CPU' \
|
cmake --preset 'CPU' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'CPU' \
|
&& cmake --build --preset 'CPU' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CPU --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CPU --strip
|
||||||
|
|
||||||
FROM base AS cuda-11
|
FROM base AS cuda-11
|
||||||
ARG CUDA11VERSION=11.8
|
ARG CUDA11VERSION=11.8
|
||||||
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
||||||
ARG PARALLEL
|
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 11' \
|
cmake --preset 'CUDA 11' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
|
&& cmake --build --preset 'CUDA 11' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CUDA --strip
|
||||||
|
|
||||||
FROM base AS cuda-12
|
FROM base AS cuda-12
|
||||||
ARG CUDA12VERSION=12.8
|
ARG CUDA12VERSION=12.8
|
||||||
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
||||||
ARG PARALLEL
|
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 12' \
|
cmake --preset 'CUDA 12' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
|
&& cmake --build --preset 'CUDA 12' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CUDA --strip
|
||||||
|
|
||||||
|
|
||||||
FROM base AS cuda-13
|
FROM base AS cuda-13
|
||||||
ARG CUDA13VERSION=13.0
|
ARG CUDA13VERSION=13.0
|
||||||
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
||||||
ARG PARALLEL
|
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 13' \
|
cmake --preset 'CUDA 13' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
|
&& cmake --build --preset 'CUDA 13' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CUDA --strip
|
||||||
|
|
||||||
|
|
||||||
FROM base AS rocm-6
|
FROM base AS rocm-7
|
||||||
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
||||||
ARG PARALLEL
|
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'ROCm 6' \
|
cmake --preset 'ROCm 7' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
|
&& cmake --build --preset 'ROCm 7' -- -l $(nproc) \
|
||||||
&& cmake --install build --component HIP --strip --parallel ${PARALLEL}
|
&& cmake --install build --component HIP --strip
|
||||||
RUN rm -f dist/lib/ollama/rocm/rocblas/library/*gfx90[06]*
|
RUN rm -f dist/lib/ollama/rocm/rocblas/library/*gfx90[06]*
|
||||||
|
|
||||||
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
|
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
|
||||||
ARG CMAKEVERSION
|
ARG CMAKEVERSION
|
||||||
RUN apt-get update && apt-get install -y curl ccache \
|
ARG NINJAVERSION
|
||||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
RUN apt-get update && apt-get install -y curl ccache unzip \
|
||||||
|
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 \
|
||||||
|
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux-aarch64.zip \
|
||||||
|
&& unzip /tmp/ninja.zip -d /usr/local/bin \
|
||||||
|
&& rm /tmp/ninja.zip
|
||||||
|
ENV CMAKE_GENERATOR=Ninja
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
ARG PARALLEL
|
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'JetPack 5' \
|
cmake --preset 'JetPack 5' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 5' \
|
&& cmake --build --preset 'JetPack 5' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CUDA --strip
|
||||||
|
|
||||||
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
|
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
|
||||||
ARG CMAKEVERSION
|
ARG CMAKEVERSION
|
||||||
RUN apt-get update && apt-get install -y curl ccache \
|
ARG NINJAVERSION
|
||||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
RUN apt-get update && apt-get install -y curl ccache unzip \
|
||||||
|
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 \
|
||||||
|
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux-aarch64.zip \
|
||||||
|
&& unzip /tmp/ninja.zip -d /usr/local/bin \
|
||||||
|
&& rm /tmp/ninja.zip
|
||||||
|
ENV CMAKE_GENERATOR=Ninja
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
ARG PARALLEL
|
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'JetPack 6' \
|
cmake --preset 'JetPack 6' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \
|
&& cmake --build --preset 'JetPack 6' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CUDA --strip
|
||||||
|
|
||||||
FROM base AS vulkan
|
FROM base AS vulkan
|
||||||
|
ARG VULKANVERSION
|
||||||
|
RUN ln -s /usr/bin/python3 /usr/bin/python \
|
||||||
|
&& wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk.tar.xz \
|
||||||
|
&& tar xvf /tmp/vulkansdk.tar.xz -C /tmp \
|
||||||
|
&& /tmp/${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \
|
||||||
|
&& /tmp/${VULKANVERSION}/vulkansdk -j 8 shaderc \
|
||||||
|
&& cp -r /tmp/${VULKANVERSION}/x86_64/include/* /usr/local/include/ \
|
||||||
|
&& cp -r /tmp/${VULKANVERSION}/x86_64/lib/* /usr/local/lib \
|
||||||
|
&& cp -r /tmp/${VULKANVERSION}/x86_64/bin/* /usr/local/bin/ \
|
||||||
|
&& rm -rf /tmp/${VULKANVERSION} /tmp/vulkansdk.tar.xz
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'Vulkan' \
|
cmake --preset 'Vulkan' \
|
||||||
&& cmake --build --parallel --preset 'Vulkan' \
|
&& cmake --build --preset 'Vulkan' -- -l $(nproc) \
|
||||||
&& cmake --install build --component Vulkan --strip --parallel 8
|
&& cmake --install build --component Vulkan --strip
|
||||||
|
|
||||||
FROM base AS mlx
|
FROM base AS mlx
|
||||||
ARG CUDA13VERSION=13.0
|
ARG CUDA13VERSION=13.0
|
||||||
@@ -143,20 +152,27 @@ ENV PATH=/usr/local/cuda-13/bin:$PATH
|
|||||||
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
|
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
|
||||||
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
|
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
|
||||||
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
|
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
|
||||||
ARG PARALLEL
|
|
||||||
WORKDIR /go/src/github.com/ollama/ollama
|
WORKDIR /go/src/github.com/ollama/ollama
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
COPY x/ml/backend/mlx x/ml/backend/mlx
|
COPY x/imagegen/mlx x/imagegen/mlx
|
||||||
COPY go.mod go.sum .
|
COPY go.mod go.sum .
|
||||||
COPY MLX_VERSION .
|
COPY MLX_VERSION MLX_CORE_VERSION .
|
||||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||||
ENV PATH=/usr/local/go/bin:$PATH
|
ENV PATH=/usr/local/go/bin:$PATH
|
||||||
RUN go mod download
|
RUN go mod download
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
--mount=type=bind,from=local-mlx,target=/tmp/local-mlx \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
|
--mount=type=bind,from=local-mlx-c,target=/tmp/local-mlx-c \
|
||||||
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
|
if [ -f /tmp/local-mlx/CMakeLists.txt ]; then \
|
||||||
|
export OLLAMA_MLX_SOURCE=/tmp/local-mlx; \
|
||||||
|
fi \
|
||||||
|
&& if [ -f /tmp/local-mlx-c/CMakeLists.txt ]; then \
|
||||||
|
export OLLAMA_MLX_C_SOURCE=/tmp/local-mlx-c; \
|
||||||
|
fi \
|
||||||
|
&& cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
||||||
|
&& cmake --build --preset 'MLX CUDA 13' -- -l $(nproc) \
|
||||||
|
&& cmake --install build --component MLX --strip
|
||||||
|
|
||||||
FROM base AS build
|
FROM base AS build
|
||||||
WORKDIR /go/src/github.com/ollama/ollama
|
WORKDIR /go/src/github.com/ollama/ollama
|
||||||
@@ -165,16 +181,14 @@ RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-
|
|||||||
ENV PATH=/usr/local/go/bin:$PATH
|
ENV PATH=/usr/local/go/bin:$PATH
|
||||||
RUN go mod download
|
RUN go mod download
|
||||||
COPY . .
|
COPY . .
|
||||||
# Clone mlx-c headers for CGO (version from MLX_VERSION file)
|
|
||||||
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
|
|
||||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||||
ENV CGO_ENABLED=1
|
ENV CGO_ENABLED=1
|
||||||
ARG CGO_CFLAGS
|
ARG CGO_CFLAGS
|
||||||
ARG CGO_CXXFLAGS
|
ARG CGO_CXXFLAGS
|
||||||
ENV CGO_CFLAGS="${CGO_CFLAGS} -I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
|
ENV CGO_CFLAGS="${CGO_CFLAGS}"
|
||||||
ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}"
|
ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}"
|
||||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||||
go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .
|
go build -trimpath -buildmode=pie -o /bin/ollama .
|
||||||
|
|
||||||
FROM --platform=linux/amd64 scratch AS amd64
|
FROM --platform=linux/amd64 scratch AS amd64
|
||||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||||
@@ -191,10 +205,9 @@ COPY --from=jetpack-5 dist/lib/ollama/ /lib/ollama/
|
|||||||
COPY --from=jetpack-6 dist/lib/ollama/ /lib/ollama/
|
COPY --from=jetpack-6 dist/lib/ollama/ /lib/ollama/
|
||||||
|
|
||||||
FROM scratch AS rocm
|
FROM scratch AS rocm
|
||||||
COPY --from=rocm-6 dist/lib/ollama /lib/ollama
|
COPY --from=rocm-7 dist/lib/ollama /lib/ollama
|
||||||
|
|
||||||
FROM ${FLAVOR} AS archive
|
FROM ${FLAVOR} AS archive
|
||||||
ARG VULKANVERSION
|
|
||||||
COPY --from=cpu dist/lib/ollama /lib/ollama
|
COPY --from=cpu dist/lib/ollama /lib/ollama
|
||||||
COPY --from=build /bin/ollama /bin/ollama
|
COPY --from=build /bin/ollama /bin/ollama
|
||||||
|
|
||||||
|
|||||||
1
MLX_CORE_VERSION
Normal file
1
MLX_CORE_VERSION
Normal file
@@ -0,0 +1 @@
|
|||||||
|
v0.30.6
|
||||||
@@ -1 +1 @@
|
|||||||
v0.4.1
|
v0.5.0
|
||||||
|
|||||||
910
README.md
910
README.md
@@ -1,20 +1,30 @@
|
|||||||
<div align="center">
|
<p align="center">
|
||||||
<a href="https://ollama.com">
|
<a href="https://ollama.com">
|
||||||
<img alt="ollama" width="240" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
<img src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7" alt="ollama" width="200"/>
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</p>
|
||||||
|
|
||||||
# Ollama
|
# Ollama
|
||||||
|
|
||||||
Get up and running with large language models.
|
Start building with open models.
|
||||||
|
|
||||||
|
## Download
|
||||||
|
|
||||||
### macOS
|
### macOS
|
||||||
|
|
||||||
[Download](https://ollama.com/download/Ollama.dmg)
|
```shell
|
||||||
|
curl -fsSL https://ollama.com/install.sh | sh
|
||||||
|
```
|
||||||
|
|
||||||
|
or [download manually](https://ollama.com/download/Ollama.dmg)
|
||||||
|
|
||||||
### Windows
|
### Windows
|
||||||
|
|
||||||
[Download](https://ollama.com/download/OllamaSetup.exe)
|
```shell
|
||||||
|
irm https://ollama.com/install.ps1 | iex
|
||||||
|
```
|
||||||
|
|
||||||
|
or [download manually](https://ollama.com/download/OllamaSetup.exe)
|
||||||
|
|
||||||
### Linux
|
### Linux
|
||||||
|
|
||||||
@@ -36,647 +46,311 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
|
|||||||
### Community
|
### Community
|
||||||
|
|
||||||
- [Discord](https://discord.gg/ollama)
|
- [Discord](https://discord.gg/ollama)
|
||||||
|
- [𝕏 (Twitter)](https://x.com/ollama)
|
||||||
- [Reddit](https://reddit.com/r/ollama)
|
- [Reddit](https://reddit.com/r/ollama)
|
||||||
|
|
||||||
## Quickstart
|
## Get started
|
||||||
|
|
||||||
To run and chat with [Gemma 3](https://ollama.com/library/gemma3):
|
```
|
||||||
|
ollama
|
||||||
|
```
|
||||||
|
|
||||||
```shell
|
You'll be prompted to run a model or connect Ollama to your existing agents or applications such as `claude`, `codex`, `openclaw` and more.
|
||||||
|
|
||||||
|
### Coding
|
||||||
|
|
||||||
|
To launch a specific integration:
|
||||||
|
|
||||||
|
```
|
||||||
|
ollama launch claude
|
||||||
|
```
|
||||||
|
|
||||||
|
Supported integrations include [Claude Code](https://docs.ollama.com/integrations/claude-code), [Codex](https://docs.ollama.com/integrations/codex), [Droid](https://docs.ollama.com/integrations/droid), and [OpenCode](https://docs.ollama.com/integrations/opencode).
|
||||||
|
|
||||||
|
### AI assistant
|
||||||
|
|
||||||
|
Use [OpenClaw](https://docs.ollama.com/integrations/openclaw) to turn Ollama into a personal AI assistant across WhatsApp, Telegram, Slack, Discord, and more:
|
||||||
|
|
||||||
|
```
|
||||||
|
ollama launch openclaw
|
||||||
|
```
|
||||||
|
|
||||||
|
### Chat with a model
|
||||||
|
|
||||||
|
Run and chat with [Gemma 3](https://ollama.com/library/gemma3):
|
||||||
|
|
||||||
|
```
|
||||||
ollama run gemma3
|
ollama run gemma3
|
||||||
```
|
```
|
||||||
|
|
||||||
## Model library
|
See [ollama.com/library](https://ollama.com/library) for the full list.
|
||||||
|
|
||||||
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library "ollama model library")
|
See the [quickstart guide](https://docs.ollama.com/quickstart) for more details.
|
||||||
|
|
||||||
Here are some example models that can be downloaded:
|
|
||||||
|
|
||||||
| Model | Parameters | Size | Download |
|
|
||||||
| ------------------ | ---------- | ----- | -------------------------------- |
|
|
||||||
| Gemma 3 | 1B | 815MB | `ollama run gemma3:1b` |
|
|
||||||
| Gemma 3 | 4B | 3.3GB | `ollama run gemma3` |
|
|
||||||
| Gemma 3 | 12B | 8.1GB | `ollama run gemma3:12b` |
|
|
||||||
| Gemma 3 | 27B | 17GB | `ollama run gemma3:27b` |
|
|
||||||
| QwQ | 32B | 20GB | `ollama run qwq` |
|
|
||||||
| DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` |
|
|
||||||
| DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` |
|
|
||||||
| Llama 4 | 109B | 67GB | `ollama run llama4:scout` |
|
|
||||||
| Llama 4 | 400B | 245GB | `ollama run llama4:maverick` |
|
|
||||||
| Llama 3.3 | 70B | 43GB | `ollama run llama3.3` |
|
|
||||||
| Llama 3.2 | 3B | 2.0GB | `ollama run llama3.2` |
|
|
||||||
| Llama 3.2 | 1B | 1.3GB | `ollama run llama3.2:1b` |
|
|
||||||
| Llama 3.2 Vision | 11B | 7.9GB | `ollama run llama3.2-vision` |
|
|
||||||
| Llama 3.2 Vision | 90B | 55GB | `ollama run llama3.2-vision:90b` |
|
|
||||||
| Llama 3.1 | 8B | 4.7GB | `ollama run llama3.1` |
|
|
||||||
| Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` |
|
|
||||||
| Phi 4 | 14B | 9.1GB | `ollama run phi4` |
|
|
||||||
| Phi 4 Mini | 3.8B | 2.5GB | `ollama run phi4-mini` |
|
|
||||||
| Mistral | 7B | 4.1GB | `ollama run mistral` |
|
|
||||||
| Moondream 2 | 1.4B | 829MB | `ollama run moondream` |
|
|
||||||
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
|
|
||||||
| Starling | 7B | 4.1GB | `ollama run starling-lm` |
|
|
||||||
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
|
||||||
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
|
||||||
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
|
||||||
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
|
|
||||||
|
|
||||||
> [!NOTE]
|
|
||||||
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
|
||||||
|
|
||||||
## Customize a model
|
|
||||||
|
|
||||||
### Import from GGUF
|
|
||||||
|
|
||||||
Ollama supports importing GGUF models in the Modelfile:
|
|
||||||
|
|
||||||
1. Create a file named `Modelfile`, with a `FROM` instruction with the local filepath to the model you want to import.
|
|
||||||
|
|
||||||
```
|
|
||||||
FROM ./vicuna-33b.Q4_0.gguf
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Create the model in Ollama
|
|
||||||
|
|
||||||
```shell
|
|
||||||
ollama create example -f Modelfile
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Run the model
|
|
||||||
|
|
||||||
```shell
|
|
||||||
ollama run example
|
|
||||||
```
|
|
||||||
|
|
||||||
### Import from Safetensors
|
|
||||||
|
|
||||||
See the [guide](https://docs.ollama.com/import) on importing models for more information.
|
|
||||||
|
|
||||||
### Customize a prompt
|
|
||||||
|
|
||||||
Models from the Ollama library can be customized with a prompt. For example, to customize the `llama3.2` model:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
ollama pull llama3.2
|
|
||||||
```
|
|
||||||
|
|
||||||
Create a `Modelfile`:
|
|
||||||
|
|
||||||
```
|
|
||||||
FROM llama3.2
|
|
||||||
|
|
||||||
# set the temperature to 1 [higher is more creative, lower is more coherent]
|
|
||||||
PARAMETER temperature 1
|
|
||||||
|
|
||||||
# set the system message
|
|
||||||
SYSTEM """
|
|
||||||
You are Mario from Super Mario Bros. Answer as Mario, the assistant, only.
|
|
||||||
"""
|
|
||||||
```
|
|
||||||
|
|
||||||
Next, create and run the model:
|
|
||||||
|
|
||||||
```
|
|
||||||
ollama create mario -f ./Modelfile
|
|
||||||
ollama run mario
|
|
||||||
>>> hi
|
|
||||||
Hello! It's your friend Mario.
|
|
||||||
```
|
|
||||||
|
|
||||||
For more information on working with a Modelfile, see the [Modelfile](https://docs.ollama.com/modelfile) documentation.
|
|
||||||
|
|
||||||
## CLI Reference
|
|
||||||
|
|
||||||
### Create a model
|
|
||||||
|
|
||||||
`ollama create` is used to create a model from a Modelfile.
|
|
||||||
|
|
||||||
```shell
|
|
||||||
ollama create mymodel -f ./Modelfile
|
|
||||||
```
|
|
||||||
|
|
||||||
### Pull a model
|
|
||||||
|
|
||||||
```shell
|
|
||||||
ollama pull llama3.2
|
|
||||||
```
|
|
||||||
|
|
||||||
> This command can also be used to update a local model. Only the diff will be pulled.
|
|
||||||
|
|
||||||
### Remove a model
|
|
||||||
|
|
||||||
```shell
|
|
||||||
ollama rm llama3.2
|
|
||||||
```
|
|
||||||
|
|
||||||
### Copy a model
|
|
||||||
|
|
||||||
```shell
|
|
||||||
ollama cp llama3.2 my-model
|
|
||||||
```
|
|
||||||
|
|
||||||
### Multiline input
|
|
||||||
|
|
||||||
For multiline input, you can wrap text with `"""`:
|
|
||||||
|
|
||||||
```
|
|
||||||
>>> """Hello,
|
|
||||||
... world!
|
|
||||||
... """
|
|
||||||
I'm a basic program that prints the famous "Hello, world!" message to the console.
|
|
||||||
```
|
|
||||||
|
|
||||||
### Multimodal models
|
|
||||||
|
|
||||||
```
|
|
||||||
ollama run llava "What's in this image? /Users/jmorgan/Desktop/smile.png"
|
|
||||||
```
|
|
||||||
|
|
||||||
> **Output**: The image features a yellow smiley face, which is likely the central focus of the picture.
|
|
||||||
|
|
||||||
### Pass the prompt as an argument
|
|
||||||
|
|
||||||
```shell
|
|
||||||
ollama run llama3.2 "Summarize this file: $(cat README.md)"
|
|
||||||
```
|
|
||||||
|
|
||||||
> **Output**: Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
|
|
||||||
|
|
||||||
### Show model information
|
|
||||||
|
|
||||||
```shell
|
|
||||||
ollama show llama3.2
|
|
||||||
```
|
|
||||||
|
|
||||||
### List models on your computer
|
|
||||||
|
|
||||||
```shell
|
|
||||||
ollama list
|
|
||||||
```
|
|
||||||
|
|
||||||
### List which models are currently loaded
|
|
||||||
|
|
||||||
```shell
|
|
||||||
ollama ps
|
|
||||||
```
|
|
||||||
|
|
||||||
### Stop a model which is currently running
|
|
||||||
|
|
||||||
```shell
|
|
||||||
ollama stop llama3.2
|
|
||||||
```
|
|
||||||
|
|
||||||
### Generate embeddings from the CLI
|
|
||||||
|
|
||||||
```shell
|
|
||||||
ollama run embeddinggemma "Your text to embed"
|
|
||||||
```
|
|
||||||
|
|
||||||
You can also pipe text for scripted workflows:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
echo "Your text to embed" | ollama run embeddinggemma
|
|
||||||
```
|
|
||||||
|
|
||||||
### Start Ollama
|
|
||||||
|
|
||||||
`ollama serve` is used when you want to start ollama without running the desktop application.
|
|
||||||
|
|
||||||
## Building
|
|
||||||
|
|
||||||
See the [developer guide](https://github.com/ollama/ollama/blob/main/docs/development.md)
|
|
||||||
|
|
||||||
### Running local builds
|
|
||||||
|
|
||||||
Next, start the server:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
./ollama serve
|
|
||||||
```
|
|
||||||
|
|
||||||
Finally, in a separate shell, run a model:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
./ollama run llama3.2
|
|
||||||
```
|
|
||||||
|
|
||||||
## Building with MLX (experimental)
|
|
||||||
|
|
||||||
First build the MLX libraries:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
cmake --preset MLX
|
|
||||||
cmake --build --preset MLX --parallel
|
|
||||||
cmake --install build --component MLX
|
|
||||||
```
|
|
||||||
|
|
||||||
When building with the `-tags mlx` flag, the main `ollama` binary includes MLX support for experimental features like image generation:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
go build -tags mlx .
|
|
||||||
```
|
|
||||||
|
|
||||||
Finally, start the server:
|
|
||||||
|
|
||||||
```
|
|
||||||
./ollama serve
|
|
||||||
```
|
|
||||||
|
|
||||||
### Building MLX with CUDA
|
|
||||||
|
|
||||||
When building with CUDA, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with default architectures:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
cmake --preset 'MLX CUDA 13'
|
|
||||||
cmake --build --preset 'MLX CUDA 13' --parallel
|
|
||||||
cmake --install build --component MLX
|
|
||||||
```
|
|
||||||
|
|
||||||
## REST API
|
## REST API
|
||||||
|
|
||||||
Ollama has a REST API for running and managing models.
|
Ollama has a REST API for running and managing models.
|
||||||
|
|
||||||
### Generate a response
|
|
||||||
|
|
||||||
```shell
|
|
||||||
curl http://localhost:11434/api/generate -d '{
|
|
||||||
"model": "llama3.2",
|
|
||||||
"prompt":"Why is the sky blue?"
|
|
||||||
}'
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Chat with a model
|
|
||||||
|
|
||||||
```shell
|
|
||||||
curl http://localhost:11434/api/chat -d '{
|
curl http://localhost:11434/api/chat -d '{
|
||||||
"model": "llama3.2",
|
"model": "gemma3",
|
||||||
"messages": [
|
"messages": [{
|
||||||
{ "role": "user", "content": "why is the sky blue?" }
|
"role": "user",
|
||||||
]
|
"content": "Why is the sky blue?"
|
||||||
|
}],
|
||||||
|
"stream": false
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
See the [API documentation](./docs/api.md) for all endpoints.
|
See the [API documentation](https://docs.ollama.com/api) for all endpoints.
|
||||||
|
|
||||||
|
### Python
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install ollama
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
from ollama import chat
|
||||||
|
|
||||||
|
response = chat(model='gemma3', messages=[
|
||||||
|
{
|
||||||
|
'role': 'user',
|
||||||
|
'content': 'Why is the sky blue?',
|
||||||
|
},
|
||||||
|
])
|
||||||
|
print(response.message.content)
|
||||||
|
```
|
||||||
|
|
||||||
|
### JavaScript
|
||||||
|
|
||||||
|
```
|
||||||
|
npm i ollama
|
||||||
|
```
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
import ollama from "ollama";
|
||||||
|
|
||||||
|
const response = await ollama.chat({
|
||||||
|
model: "gemma3",
|
||||||
|
messages: [{ role: "user", content: "Why is the sky blue?" }],
|
||||||
|
});
|
||||||
|
console.log(response.message.content);
|
||||||
|
```
|
||||||
|
|
||||||
|
## Supported backends
|
||||||
|
|
||||||
|
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
|
||||||
|
- [CLI reference](https://docs.ollama.com/cli)
|
||||||
|
- [REST API reference](https://docs.ollama.com/api)
|
||||||
|
- [Importing models](https://docs.ollama.com/import)
|
||||||
|
- [Modelfile reference](https://docs.ollama.com/modelfile)
|
||||||
|
- [Building from source](https://github.com/ollama/ollama/blob/main/docs/development.md)
|
||||||
|
|
||||||
## Community Integrations
|
## Community Integrations
|
||||||
|
|
||||||
### Web & Desktop
|
> Want to add your project? Open a pull request.
|
||||||
|
|
||||||
- [Onyx](https://github.com/onyx-dot-app/onyx)
|
### Chat Interfaces
|
||||||
- [Open WebUI](https://github.com/open-webui/open-webui)
|
|
||||||
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
|
|
||||||
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
|
||||||
- [Hollama](https://github.com/fmaclen/hollama)
|
|
||||||
- [Lollms WebUI (Single user)](https://github.com/ParisNeo/lollms-webui)
|
|
||||||
- [Lollms (Multi users)](https://github.com/ParisNeo/lollms)
|
|
||||||
- [LibreChat](https://github.com/danny-avila/LibreChat)
|
|
||||||
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
|
||||||
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
|
||||||
- [AI-UI](https://github.com/bajahaw/ai-ui)
|
|
||||||
- [Saddle](https://github.com/jikkuatwork/saddle)
|
|
||||||
- [TagSpaces](https://www.tagspaces.org) (A platform for file-based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
|
|
||||||
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
|
||||||
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
|
|
||||||
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
|
|
||||||
- [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
|
|
||||||
- [Ollamac](https://github.com/kevinhermawan/Ollamac)
|
|
||||||
- [big-AGI](https://github.com/enricoros/big-AGI)
|
|
||||||
- [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
|
|
||||||
- [Amica](https://github.com/semperai/amica)
|
|
||||||
- [chatd](https://github.com/BruceMacD/chatd)
|
|
||||||
- [Ollama-SwiftUI](https://github.com/kghandour/Ollama-SwiftUI)
|
|
||||||
- [Dify.AI](https://github.com/langgenius/dify)
|
|
||||||
- [MindMac](https://mindmac.app)
|
|
||||||
- [NextJS Web Interface for Ollama](https://github.com/jakobhoeg/nextjs-ollama-llm-ui)
|
|
||||||
- [Msty](https://msty.app)
|
|
||||||
- [Chatbox](https://github.com/Bin-Huang/Chatbox)
|
|
||||||
- [WinForm Ollama Copilot](https://github.com/tgraupmann/WinForm_Ollama_Copilot)
|
|
||||||
- [NextChat](https://github.com/ChatGPTNextWeb/ChatGPT-Next-Web) with [Get Started Doc](https://docs.nextchat.dev/models/ollama)
|
|
||||||
- [Alpaca WebUI](https://github.com/mmo80/alpaca-webui)
|
|
||||||
- [OllamaGUI](https://github.com/enoch1118/ollamaGUI)
|
|
||||||
- [OpenAOE](https://github.com/InternLM/OpenAOE)
|
|
||||||
- [Odin Runes](https://github.com/leonid20000/OdinRunes)
|
|
||||||
- [LLM-X](https://github.com/mrdjohnson/llm-x) (Progressive Web App)
|
|
||||||
- [AnythingLLM (Docker + MacOs/Windows/Linux native app)](https://github.com/Mintplex-Labs/anything-llm)
|
|
||||||
- [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
|
|
||||||
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
|
|
||||||
- [IntelliBar](https://intellibar.app/) (AI-powered assistant for macOS)
|
|
||||||
- [Jirapt](https://github.com/AliAhmedNada/jirapt) (Jira Integration to generate issues, tasks, epics)
|
|
||||||
- [ojira](https://github.com/AliAhmedNada/ojira) (Jira chrome plugin to easily generate descriptions for tasks)
|
|
||||||
- [QA-Pilot](https://github.com/reid41/QA-Pilot) (Interactive chat tool that can leverage Ollama models for rapid understanding and navigation of GitHub code repositories)
|
|
||||||
- [ChatOllama](https://github.com/sugarforever/chat-ollama) (Open Source Chatbot based on Ollama with Knowledge Bases)
|
|
||||||
- [CRAG Ollama Chat](https://github.com/Nagi-ovo/CRAG-Ollama-Chat) (Simple Web Search with Corrective RAG)
|
|
||||||
- [RAGFlow](https://github.com/infiniflow/ragflow) (Open-source Retrieval-Augmented Generation engine based on deep document understanding)
|
|
||||||
- [StreamDeploy](https://github.com/StreamDeploy-DevRel/streamdeploy-llm-app-scaffold) (LLM Application Scaffold)
|
|
||||||
- [chat](https://github.com/swuecho/chat) (chat web app for teams)
|
|
||||||
- [Lobe Chat](https://github.com/lobehub/lobe-chat) with [Integrating Doc](https://lobehub.com/docs/self-hosting/examples/ollama)
|
|
||||||
- [Ollama RAG Chatbot](https://github.com/datvodinh/rag-chatbot.git) (Local Chat with multiple PDFs using Ollama and RAG)
|
|
||||||
- [BrainSoup](https://www.nurgo-software.com/products/brainsoup) (Flexible native client with RAG & multi-agent automation)
|
|
||||||
- [macai](https://github.com/Renset/macai) (macOS client for Ollama, ChatGPT, and other compatible API back-ends)
|
|
||||||
- [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama)
|
|
||||||
- [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models)
|
|
||||||
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
|
|
||||||
- [Casibase](https://casibase.org) (An open source AI knowledge base and dialogue system combining the latest RAG, SSO, ollama support, and multiple large language models.)
|
|
||||||
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
|
|
||||||
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
|
|
||||||
- [Shinkai Desktop](https://github.com/dcSpark/shinkai-apps) (Two click install Local AI using Ollama + Files + RAG)
|
|
||||||
- [AiLama](https://github.com/zeyoyt/ailama) (A Discord User App that allows you to interact with Ollama anywhere in Discord)
|
|
||||||
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
|
|
||||||
- [R2R](https://github.com/SciPhi-AI/R2R) (Open-source RAG engine)
|
|
||||||
- [Ollama-Kis](https://github.com/elearningshow/ollama-kis) (A simple easy-to-use GUI with sample custom LLM for Drivers Education)
|
|
||||||
- [OpenGPA](https://opengpa.org) (Open-source offline-first Enterprise Agentic Application)
|
|
||||||
- [Painting Droid](https://github.com/mateuszmigas/painting-droid) (Painting app with AI integrations)
|
|
||||||
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
|
|
||||||
- [AI Studio](https://github.com/MindWorkAI/AI-Studio)
|
|
||||||
- [Sidellama](https://github.com/gyopak/sidellama) (browser-based LLM client)
|
|
||||||
- [LLMStack](https://github.com/trypromptly/LLMStack) (No-code multi-agent framework to build LLM agents and workflows)
|
|
||||||
- [BoltAI for Mac](https://boltai.com) (AI Chat Client for Mac)
|
|
||||||
- [Harbor](https://github.com/av/harbor) (Containerized LLM Toolkit with Ollama as default backend)
|
|
||||||
- [PyGPT](https://github.com/szczyglis-dev/py-gpt) (AI desktop assistant for Linux, Windows, and Mac)
|
|
||||||
- [Alpaca](https://github.com/Jeffser/Alpaca) (An Ollama client application for Linux and macOS made with GTK4 and Adwaita)
|
|
||||||
- [AutoGPT](https://github.com/Significant-Gravitas/AutoGPT/blob/master/docs/content/platform/ollama.md) (AutoGPT Ollama integration)
|
|
||||||
- [Go-CREW](https://www.jonathanhecl.com/go-crew/) (Powerful Offline RAG in Golang)
|
|
||||||
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
|
|
||||||
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot, and Ollama4j
|
|
||||||
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
|
|
||||||
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VS Code extension for multi-file/whole-repo coding
|
|
||||||
- [Void](https://github.com/voideditor/void) (Open source AI code editor and Cursor alternative)
|
|
||||||
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
|
|
||||||
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
|
|
||||||
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
|
|
||||||
- [crewAI with Mesop](https://github.com/rapidarchitect/ollama-crew-mesop) (Mesop Web Interface to run crewAI with Ollama)
|
|
||||||
- [Tkinter-based client](https://github.com/chyok/ollama-gui) (Python tkinter-based Client for Ollama)
|
|
||||||
- [LLMChat](https://github.com/trendy-design/llmchat) (Privacy focused, 100% local, intuitive all-in-one chat interface)
|
|
||||||
- [Local Multimodal AI Chat](https://github.com/Leon-Sander/Local-Multimodal-AI-Chat) (Ollama-based LLM Chat with support for multiple features, including PDF RAG, voice chat, image-based interactions, and integration with OpenAI.)
|
|
||||||
- [ARGO](https://github.com/xark-argo/argo) (Locally download and run Ollama and Huggingface models with RAG and deep research on Mac/Windows/Linux)
|
|
||||||
- [OrionChat](https://github.com/EliasPereirah/OrionChat) - OrionChat is a web interface for chatting with different AI providers
|
|
||||||
- [G1](https://github.com/bklieger-groq/g1) (Prototype of using prompting strategies to improve the LLM's reasoning through o1-like reasoning chains.)
|
|
||||||
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
|
|
||||||
- [Promptery](https://github.com/promptery/promptery) (desktop client for Ollama.)
|
|
||||||
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
|
||||||
- [chat-ollama](https://github.com/annilq/chat-ollama) (a React Native client for Ollama)
|
|
||||||
- [SpaceLlama](https://github.com/tcsenpai/spacellama) (Firefox and Chrome extension to quickly summarize web pages with ollama in a sidebar)
|
|
||||||
- [YouLama](https://github.com/tcsenpai/youlama) (Webapp to quickly summarize any YouTube video, supporting Invidious as well)
|
|
||||||
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
|
|
||||||
- [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol)
|
|
||||||
- [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app)
|
|
||||||
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard, and said in the meetings)
|
|
||||||
- [Hexabot](https://github.com/hexastack/hexabot) (A conversational AI builder)
|
|
||||||
- [Reddit Rate](https://github.com/rapidarchitect/reddit_analyzer) (Search and Rate Reddit topics with a weighted summation)
|
|
||||||
- [OpenTalkGpt](https://github.com/adarshM84/OpenTalkGpt) (Chrome Extension to manage open-source models supported by Ollama, create custom models, and chat with models from a user-friendly UI)
|
|
||||||
- [VT](https://github.com/vinhnx/vt.ai) (A minimal multimodal AI chat app, with dynamic conversation routing. Supports local models via Ollama)
|
|
||||||
- [Nosia](https://github.com/nosia-ai/nosia) (Easy to install and use RAG platform based on Ollama)
|
|
||||||
- [Witsy](https://github.com/nbonamy/witsy) (An AI Desktop application available for Mac/Windows/Linux)
|
|
||||||
- [Abbey](https://github.com/US-Artificial-Intelligence/abbey) (A configurable AI interface server with notebooks, document storage, and YouTube support)
|
|
||||||
- [Minima](https://github.com/dmayboroda/minima) (RAG with on-premises or fully local workflow)
|
|
||||||
- [aidful-ollama-model-delete](https://github.com/AidfulAI/aidful-ollama-model-delete) (User interface for simplified model cleanup)
|
|
||||||
- [Perplexica](https://github.com/ItzCrazyKns/Perplexica) (An AI-powered search engine & an open-source alternative to Perplexity AI)
|
|
||||||
- [Ollama Chat WebUI for Docker ](https://github.com/oslook/ollama-webui) (Support for local docker deployment, lightweight ollama webui)
|
|
||||||
- [AI Toolkit for Visual Studio Code](https://aka.ms/ai-tooklit/ollama-docs) (Microsoft-official VS Code extension to chat, test, evaluate models with Ollama support, and use them in your AI applications.)
|
|
||||||
- [MinimalNextOllamaChat](https://github.com/anilkay/MinimalNextOllamaChat) (Minimal Web UI for Chat and Model Control)
|
|
||||||
- [Chipper](https://github.com/TilmanGriesel/chipper) AI interface for tinkerers (Ollama, Haystack RAG, Python)
|
|
||||||
- [ChibiChat](https://github.com/CosmicEventHorizon/ChibiChat) (Kotlin-based Android app to chat with Ollama and Koboldcpp API endpoints)
|
|
||||||
- [LocalLLM](https://github.com/qusaismael/localllm) (Minimal Web-App to run ollama models on it with a GUI)
|
|
||||||
- [Ollamazing](https://github.com/buiducnhat/ollamazing) (Web extension to run Ollama models)
|
|
||||||
- [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivalent endpoint with Ollama support for running locally)
|
|
||||||
- [AntSK](https://github.com/AIDotNet/AntSK) (Out-of-the-box & Adaptable RAG Chatbot)
|
|
||||||
- [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot)
|
|
||||||
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
|
|
||||||
- [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms)
|
|
||||||
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
|
|
||||||
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
|
|
||||||
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
|
||||||
- [Flufy](https://github.com/Aharon-Bensadoun/Flufy) (A beautiful chat interface for interacting with Ollama's API. Built with React, TypeScript, and Material-UI.)
|
|
||||||
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
|
|
||||||
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
|
|
||||||
- [Ollamb](https://github.com/hengkysteen/ollamb) (Simple yet rich in features, cross-platform built with Flutter and designed for Ollama. Try the [web demo](https://hengkysteen.github.io/demo/ollamb/).)
|
|
||||||
- [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama)
|
|
||||||
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
|
|
||||||
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
|
|
||||||
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
|
|
||||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
|
||||||
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
|
|
||||||
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
|
|
||||||
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
|
|
||||||
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
|
|
||||||
- [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.)
|
|
||||||
- [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models)
|
|
||||||
- [KDeps](https://github.com/kdeps/kdeps) (Kdeps is an offline-first AI framework for building Dockerized full-stack AI applications declaratively using Apple PKL and integrates APIs with Ollama on the backend.)
|
|
||||||
- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare)
|
|
||||||
- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads)
|
|
||||||
- [Hillnote](https://hillnote.com) (A Markdown-first workspace designed to supercharge your AI workflow. Create documents ready to integrate with Claude, ChatGPT, Gemini, Cursor, and more - all while keeping your work on your device.)
|
|
||||||
|
|
||||||
### Cloud
|
#### Web
|
||||||
|
|
||||||
|
- [Open WebUI](https://github.com/open-webui/open-webui) - Extensible, self-hosted AI interface
|
||||||
|
- [Onyx](https://github.com/onyx-dot-app/onyx) - Connected AI workspace
|
||||||
|
- [LibreChat](https://github.com/danny-avila/LibreChat) - Enhanced ChatGPT clone with multi-provider support
|
||||||
|
- [Lobe Chat](https://github.com/lobehub/lobe-chat) - Modern chat framework with plugin ecosystem ([docs](https://lobehub.com/docs/self-hosting/examples/ollama))
|
||||||
|
- [NextChat](https://github.com/ChatGPTNextWeb/ChatGPT-Next-Web) - Cross-platform ChatGPT UI ([docs](https://docs.nextchat.dev/models/ollama))
|
||||||
|
- [Perplexica](https://github.com/ItzCrazyKns/Perplexica) - AI-powered search engine, open-source Perplexity alternative
|
||||||
|
- [big-AGI](https://github.com/enricoros/big-AGI) - AI suite for professionals
|
||||||
|
- [Lollms WebUI](https://github.com/ParisNeo/lollms-webui) - Multi-model web interface
|
||||||
|
- [ChatOllama](https://github.com/sugarforever/chat-ollama) - Chatbot with knowledge bases
|
||||||
|
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt) - On-premise AI platform
|
||||||
|
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama) - ChatGPT-style web interface
|
||||||
|
- [Hollama](https://github.com/fmaclen/hollama) - Minimal web interface
|
||||||
|
- [Chatbox](https://github.com/Bin-Huang/Chatbox) - Desktop and web AI client
|
||||||
|
- [chat](https://github.com/swuecho/chat) - Chat web app for teams
|
||||||
|
- [Ollama RAG Chatbot](https://github.com/datvodinh/rag-chatbot.git) - Chat with multiple PDFs using RAG
|
||||||
|
- [Tkinter-based client](https://github.com/chyok/ollama-gui) - Python desktop client
|
||||||
|
|
||||||
|
#### Desktop
|
||||||
|
|
||||||
|
- [Dify.AI](https://github.com/langgenius/dify) - LLM app development platform
|
||||||
|
- [AnythingLLM](https://github.com/Mintplex-Labs/anything-llm) - All-in-one AI app for Mac, Windows, and Linux
|
||||||
|
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid) - Cross-platform mobile and desktop client
|
||||||
|
- [Witsy](https://github.com/nbonamy/witsy) - AI desktop app for Mac, Windows, and Linux
|
||||||
|
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) - Multi-provider desktop client
|
||||||
|
- [Ollama App](https://github.com/JHubi1/ollama-app) - Multi-platform client for desktop and mobile
|
||||||
|
- [PyGPT](https://github.com/szczyglis-dev/py-gpt) - AI desktop assistant for Linux, Windows, and Mac
|
||||||
|
- [Alpaca](https://github.com/Jeffser/Alpaca) - GTK4 client for Linux and macOS
|
||||||
|
- [SwiftChat](https://github.com/aws-samples/swift-chat) - Cross-platform including iOS, Android, and Apple Vision Pro
|
||||||
|
- [Enchanted](https://github.com/AugustDev/enchanted) - Native macOS and iOS client
|
||||||
|
- [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) - Multi-model desktop runner
|
||||||
|
- [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) - Evaluate and compare models
|
||||||
|
- [macai](https://github.com/Renset/macai) - macOS client for Ollama and ChatGPT
|
||||||
|
- [AI Studio](https://github.com/MindWorkAI/AI-Studio) - Multi-provider desktop IDE
|
||||||
|
- [Reins](https://github.com/ibrahimcetin/reins) - Parameter tuning and reasoning model support
|
||||||
|
- [ConfiChat](https://github.com/1runeberg/confichat) - Privacy-focused with optional encryption
|
||||||
|
- [LLocal.in](https://github.com/kartikm7/llocal) - Electron desktop client
|
||||||
|
- [MindMac](https://mindmac.app) - AI chat client for Mac
|
||||||
|
- [Msty](https://msty.app) - Multi-model desktop client
|
||||||
|
- [BoltAI for Mac](https://boltai.com) - AI chat client for Mac
|
||||||
|
- [IntelliBar](https://intellibar.app/) - AI-powered assistant for macOS
|
||||||
|
- [Kerlig AI](https://www.kerlig.com/) - AI writing assistant for macOS
|
||||||
|
- [Hillnote](https://hillnote.com) - Markdown-first AI workspace
|
||||||
|
- [Perfect Memory AI](https://www.perfectmemory.ai/) - Productivity AI personalized by screen and meeting history
|
||||||
|
|
||||||
|
#### Mobile
|
||||||
|
|
||||||
|
- [Ollama Android Chat](https://github.com/sunshine0523/OllamaServer) - One-click Ollama on Android
|
||||||
|
|
||||||
|
> SwiftChat, Enchanted, Maid, Ollama App, Reins, and ConfiChat listed above also support mobile platforms.
|
||||||
|
|
||||||
|
### Code Editors & Development
|
||||||
|
|
||||||
|
- [Cline](https://github.com/cline/cline) - VS Code extension for multi-file/whole-repo coding
|
||||||
|
- [Continue](https://github.com/continuedev/continue) - Open-source AI code assistant for any IDE
|
||||||
|
- [Void](https://github.com/voideditor/void) - Open source AI code editor, Cursor alternative
|
||||||
|
- [Copilot for Obsidian](https://github.com/logancyang/obsidian-copilot) - AI assistant for Obsidian
|
||||||
|
- [twinny](https://github.com/rjmacarthy/twinny) - Copilot and Copilot chat alternative
|
||||||
|
- [gptel Emacs client](https://github.com/karthink/gptel) - LLM client for Emacs
|
||||||
|
- [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) - Use Ollama as GitHub Copilot
|
||||||
|
- [Obsidian Local GPT](https://github.com/pfrankov/obsidian-local-gpt) - Local AI for Obsidian
|
||||||
|
- [Ellama Emacs client](https://github.com/s-kostyaev/ellama) - LLM tool for Emacs
|
||||||
|
- [orbiton](https://github.com/xyproto/orbiton) - Config-free text editor with Ollama tab completion
|
||||||
|
- [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) - Sublime Text 4 AI assistant
|
||||||
|
- [VT Code](https://github.com/vinhnx/vtcode) - Rust-based terminal coding agent with Tree-sitter
|
||||||
|
- [QodeAssist](https://github.com/Palm1r/QodeAssist) - AI coding assistant for Qt Creator
|
||||||
|
- [AI Toolkit for VS Code](https://aka.ms/ai-tooklit/ollama-docs) - Microsoft-official VS Code extension
|
||||||
|
- [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama) - Natural language interface for computers
|
||||||
|
|
||||||
|
### Libraries & SDKs
|
||||||
|
|
||||||
|
- [LiteLLM](https://github.com/BerriAI/litellm) - Unified API for 100+ LLM providers
|
||||||
|
- [Semantic Kernel](https://github.com/microsoft/semantic-kernel/tree/main/python/semantic_kernel/connectors/ai/ollama) - Microsoft AI orchestration SDK
|
||||||
|
- [LangChain4j](https://github.com/langchain4j/langchain4j) - Java LangChain ([example](https://github.com/langchain4j/langchain4j-examples/tree/main/ollama-examples/src/main/java))
|
||||||
|
- [LangChainGo](https://github.com/tmc/langchaingo/) - Go LangChain ([example](https://github.com/tmc/langchaingo/tree/main/examples/ollama-completion-example))
|
||||||
|
- [Spring AI](https://github.com/spring-projects/spring-ai) - Spring framework AI support ([docs](https://docs.spring.io/spring-ai/reference/api/chat/ollama-chat.html))
|
||||||
|
- [LangChain](https://python.langchain.com/docs/integrations/chat/ollama/) and [LangChain.js](https://js.langchain.com/docs/integrations/chat/ollama/) with [example](https://js.langchain.com/docs/tutorials/local_rag/)
|
||||||
|
- [Ollama for Ruby](https://github.com/crmne/ruby_llm) - Ruby LLM library
|
||||||
|
- [any-llm](https://github.com/mozilla-ai/any-llm) - Unified LLM interface by Mozilla
|
||||||
|
- [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp) - .NET SDK
|
||||||
|
- [LangChainRust](https://github.com/Abraxas-365/langchain-rust) - Rust LangChain ([example](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/llm_ollama.rs))
|
||||||
|
- [Agents-Flex for Java](https://github.com/agents-flex/agents-flex) - Java agent framework ([example](https://github.com/agents-flex/agents-flex/tree/main/agents-flex-llm/agents-flex-llm-ollama/src/test/java/com/agentsflex/llm/ollama))
|
||||||
|
- [Elixir LangChain](https://github.com/brainlid/langchain) - Elixir LangChain
|
||||||
|
- [Ollama-rs for Rust](https://github.com/pepperoni21/ollama-rs) - Rust SDK
|
||||||
|
- [LangChain for .NET](https://github.com/tryAGI/LangChain) - .NET LangChain ([example](https://github.com/tryAGI/LangChain/blob/main/examples/LangChain.Samples.OpenAI/Program.cs))
|
||||||
|
- [chromem-go](https://github.com/philippgille/chromem-go) - Go vector database with Ollama embeddings ([example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama))
|
||||||
|
- [LangChainDart](https://github.com/davidmigloz/langchain_dart) - Dart LangChain
|
||||||
|
- [LlmTornado](https://github.com/lofcz/llmtornado) - Unified C# interface for multiple inference APIs
|
||||||
|
- [Ollama4j for Java](https://github.com/ollama4j/ollama4j) - Java SDK
|
||||||
|
- [Ollama for Laravel](https://github.com/cloudstudio/ollama-laravel) - Laravel integration
|
||||||
|
- [Ollama for Swift](https://github.com/mattt/ollama-swift) - Swift SDK
|
||||||
|
- [LlamaIndex](https://docs.llamaindex.ai/en/stable/examples/llm/ollama/) and [LlamaIndexTS](https://ts.llamaindex.ai/modules/llms/available_llms/ollama) - Data framework for LLM apps
|
||||||
|
- [Haystack](https://github.com/deepset-ai/haystack-integrations/blob/main/integrations/ollama.md) - AI pipeline framework
|
||||||
|
- [Firebase Genkit](https://firebase.google.com/docs/genkit/plugins/ollama) - Google AI framework
|
||||||
|
- [Ollama-hpp for C++](https://github.com/jmont-dev/ollama-hpp) - C++ SDK
|
||||||
|
- [PromptingTools.jl](https://github.com/svilupp/PromptingTools.jl) - Julia LLM toolkit ([example](https://svilupp.github.io/PromptingTools.jl/dev/examples/working_with_ollama))
|
||||||
|
- [Ollama for R - rollama](https://github.com/JBGruber/rollama) - R SDK
|
||||||
|
- [Portkey](https://portkey.ai/docs/welcome/integration-guides/ollama) - AI gateway
|
||||||
|
- [Testcontainers](https://testcontainers.com/modules/ollama/) - Container-based testing
|
||||||
|
- [LLPhant](https://github.com/theodo-group/LLPhant?tab=readme-ov-file#ollama) - PHP AI framework
|
||||||
|
|
||||||
|
### Frameworks & Agents
|
||||||
|
|
||||||
|
- [AutoGPT](https://github.com/Significant-Gravitas/AutoGPT/blob/master/docs/content/platform/ollama.md) - Autonomous AI agent platform
|
||||||
|
- [crewAI](https://github.com/crewAIInc/crewAI) - Multi-agent orchestration framework
|
||||||
|
- [Strands Agents](https://github.com/strands-agents/sdk-python) - Model-driven agent building by AWS
|
||||||
|
- [Cheshire Cat](https://github.com/cheshire-cat-ai/core) - AI assistant framework
|
||||||
|
- [any-agent](https://github.com/mozilla-ai/any-agent) - Unified agent framework interface by Mozilla
|
||||||
|
- [Stakpak](https://github.com/stakpak/agent) - Open source DevOps agent
|
||||||
|
- [Hexabot](https://github.com/hexastack/hexabot) - Conversational AI builder
|
||||||
|
- [Neuro SAN](https://github.com/cognizant-ai-lab/neuro-san-studio) - Multi-agent orchestration ([docs](https://github.com/cognizant-ai-lab/neuro-san-studio/blob/main/docs/user_guide.md#ollama))
|
||||||
|
|
||||||
|
### RAG & Knowledge Bases
|
||||||
|
|
||||||
|
- [RAGFlow](https://github.com/infiniflow/ragflow) - RAG engine based on deep document understanding
|
||||||
|
- [R2R](https://github.com/SciPhi-AI/R2R) - Open-source RAG engine
|
||||||
|
- [MaxKB](https://github.com/1Panel-dev/MaxKB/) - Ready-to-use RAG chatbot
|
||||||
|
- [Minima](https://github.com/dmayboroda/minima) - On-premises or fully local RAG
|
||||||
|
- [Chipper](https://github.com/TilmanGriesel/chipper) - AI interface with Haystack RAG
|
||||||
|
- [ARGO](https://github.com/xark-argo/argo) - RAG and deep research on Mac/Windows/Linux
|
||||||
|
- [Archyve](https://github.com/nickthecook/archyve) - RAG-enabling document library
|
||||||
|
- [Casibase](https://casibase.org) - AI knowledge base with RAG and SSO
|
||||||
|
- [BrainSoup](https://www.nurgo-software.com/products/brainsoup) - Native client with RAG and multi-agent automation
|
||||||
|
|
||||||
|
### Bots & Messaging
|
||||||
|
|
||||||
|
- [LangBot](https://github.com/RockChinQ/LangBot) - Multi-platform messaging bots with agents and RAG
|
||||||
|
- [AstrBot](https://github.com/Soulter/AstrBot/) - Multi-platform chatbot with RAG and plugins
|
||||||
|
- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) - TypeScript Discord bot
|
||||||
|
- [Ollama Telegram Bot](https://github.com/ruecat/ollama-telegram) - Telegram bot
|
||||||
|
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) - Telegram bot for roleplay
|
||||||
|
|
||||||
|
### Terminal & CLI
|
||||||
|
|
||||||
|
- [aichat](https://github.com/sigoden/aichat) - All-in-one LLM CLI with Shell Assistant, RAG, and AI tools
|
||||||
|
- [oterm](https://github.com/ggozad/oterm) - Terminal client for Ollama
|
||||||
|
- [gollama](https://github.com/sammcj/gollama) - Go-based model manager for Ollama
|
||||||
|
- [tlm](https://github.com/yusufcanb/tlm) - Local shell copilot
|
||||||
|
- [tenere](https://github.com/pythops/tenere) - TUI for LLMs
|
||||||
|
- [ParLlama](https://github.com/paulrobello/parllama) - TUI for Ollama
|
||||||
|
- [llm-ollama](https://github.com/taketwo/llm-ollama) - Plugin for [Datasette's LLM CLI](https://llm.datasette.io/en/stable/)
|
||||||
|
- [ShellOracle](https://github.com/djcopley/ShellOracle) - Shell command suggestions
|
||||||
|
- [LLM-X](https://github.com/mrdjohnson/llm-x) - Progressive web app for LLMs
|
||||||
|
- [cmdh](https://github.com/pgibler/cmdh) - Natural language to shell commands
|
||||||
|
- [VT](https://github.com/vinhnx/vt.ai) - Minimal multimodal AI chat app
|
||||||
|
|
||||||
|
### Productivity & Apps
|
||||||
|
|
||||||
|
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) - AI collaborative workspace, self-hostable Notion alternative
|
||||||
|
- [Screenpipe](https://github.com/mediar-ai/screenpipe) - 24/7 screen and mic recording with AI-powered search
|
||||||
|
- [Vibe](https://github.com/thewh1teagle/vibe) - Transcribe and analyze meetings
|
||||||
|
- [Page Assist](https://github.com/n4ze3m/page-assist) - Chrome extension for AI-powered browsing
|
||||||
|
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) - Private, on-device browser AI assistant
|
||||||
|
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server) - Security proxy for Ollama
|
||||||
|
- [1Panel](https://github.com/1Panel-dev/1Panel/) - Web-based Linux server management
|
||||||
|
- [Writeopia](https://github.com/Writeopia/Writeopia) - Text editor with Ollama integration
|
||||||
|
- [QA-Pilot](https://github.com/reid41/QA-Pilot) - GitHub code repository understanding
|
||||||
|
- [Raycast extension](https://github.com/MassimilianoPasquini97/raycast_ollama) - Ollama in Raycast
|
||||||
|
- [Painting Droid](https://github.com/mateuszmigas/painting-droid) - Painting app with AI integrations
|
||||||
|
- [Serene Pub](https://github.com/doolijb/serene-pub) - AI roleplaying app
|
||||||
|
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) - Document management with Ollama workflows
|
||||||
|
- [TagSpaces](https://www.tagspaces.org) - File management with [AI tagging](https://docs.tagspaces.org/ai/)
|
||||||
|
|
||||||
|
### Observability & Monitoring
|
||||||
|
|
||||||
|
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) - Debug, evaluate, and monitor LLM applications
|
||||||
|
- [OpenLIT](https://github.com/openlit/openlit) - OpenTelemetry-native monitoring for Ollama and GPUs
|
||||||
|
- [Lunary](https://lunary.ai/docs/integrations/ollama) - LLM observability with analytics and PII masking
|
||||||
|
- [Langfuse](https://langfuse.com/docs/integrations/ollama) - Open source LLM observability
|
||||||
|
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) - AI observability and evaluation for agents
|
||||||
|
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) - Open source LLM observability
|
||||||
|
|
||||||
|
### Database & Embeddings
|
||||||
|
|
||||||
|
- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database ([guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md))
|
||||||
|
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) - Connect Ollama with 200+ data platforms
|
||||||
|
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) - Embeddable vector database for Go ([example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama))
|
||||||
|
- [Kangaroo](https://github.com/dbkangaroo/kangaroo) - AI-powered SQL client
|
||||||
|
|
||||||
|
### Infrastructure & Deployment
|
||||||
|
|
||||||
|
#### Cloud
|
||||||
|
|
||||||
- [Google Cloud](https://cloud.google.com/run/docs/tutorials/gpu-gemma2-with-ollama)
|
- [Google Cloud](https://cloud.google.com/run/docs/tutorials/gpu-gemma2-with-ollama)
|
||||||
- [Fly.io](https://fly.io/docs/python/do-more/add-ollama/)
|
- [Fly.io](https://fly.io/docs/python/do-more/add-ollama/)
|
||||||
- [Koyeb](https://www.koyeb.com/deploy/ollama)
|
- [Koyeb](https://www.koyeb.com/deploy/ollama)
|
||||||
|
- [Harbor](https://github.com/av/harbor) - Containerized LLM toolkit with Ollama as default backend
|
||||||
|
|
||||||
### Tutorial
|
#### Package Managers
|
||||||
|
|
||||||
- [handy-ollama](https://github.com/datawhalechina/handy-ollama) (Chinese Tutorial for Ollama by [Datawhale ](https://github.com/datawhalechina) - China's Largest Open Source AI Learning Community)
|
|
||||||
|
|
||||||
### Terminal
|
|
||||||
|
|
||||||
- [oterm](https://github.com/ggozad/oterm)
|
|
||||||
- [Ellama Emacs client](https://github.com/s-kostyaev/ellama)
|
|
||||||
- [Emacs client](https://github.com/zweifisch/ollama)
|
|
||||||
- [neollama](https://github.com/paradoxical-dev/neollama) UI client for interacting with models from within Neovim
|
|
||||||
- [gen.nvim](https://github.com/David-Kunz/gen.nvim)
|
|
||||||
- [ollama.nvim](https://github.com/nomnivore/ollama.nvim)
|
|
||||||
- [ollero.nvim](https://github.com/marco-souza/ollero.nvim)
|
|
||||||
- [ollama-chat.nvim](https://github.com/gerazov/ollama-chat.nvim)
|
|
||||||
- [ogpt.nvim](https://github.com/huynle/ogpt.nvim)
|
|
||||||
- [gptel Emacs client](https://github.com/karthink/gptel)
|
|
||||||
- [Oatmeal](https://github.com/dustinblackman/oatmeal)
|
|
||||||
- [cmdh](https://github.com/pgibler/cmdh)
|
|
||||||
- [ooo](https://github.com/npahlfer/ooo)
|
|
||||||
- [shell-pilot](https://github.com/reid41/shell-pilot)(Interact with models via pure shell scripts on Linux or macOS)
|
|
||||||
- [tenere](https://github.com/pythops/tenere)
|
|
||||||
- [llm-ollama](https://github.com/taketwo/llm-ollama) for [Datasette's LLM CLI](https://llm.datasette.io/en/stable/).
|
|
||||||
- [typechat-cli](https://github.com/anaisbetts/typechat-cli)
|
|
||||||
- [ShellOracle](https://github.com/djcopley/ShellOracle)
|
|
||||||
- [tlm](https://github.com/yusufcanb/tlm)
|
|
||||||
- [podman-ollama](https://github.com/ericcurtin/podman-ollama)
|
|
||||||
- [gollama](https://github.com/sammcj/gollama)
|
|
||||||
- [ParLlama](https://github.com/paulrobello/parllama)
|
|
||||||
- [Ollama eBook Summary](https://github.com/cognitivetech/ollama-ebook-summary/)
|
|
||||||
- [Ollama Mixture of Experts (MOE) in 50 lines of code](https://github.com/rapidarchitect/ollama_moe)
|
|
||||||
- [vim-intelligence-bridge](https://github.com/pepo-ec/vim-intelligence-bridge) Simple interaction of "Ollama" with the Vim editor
|
|
||||||
- [x-cmd ollama](https://x-cmd.com/mod/ollama)
|
|
||||||
- [bb7](https://github.com/drunkwcodes/bb7)
|
|
||||||
- [SwollamaCLI](https://github.com/marcusziade/Swollama) bundled with the Swollama Swift package. [Demo](https://github.com/marcusziade/Swollama?tab=readme-ov-file#cli-usage)
|
|
||||||
- [aichat](https://github.com/sigoden/aichat) All-in-one LLM CLI tool featuring Shell Assistant, Chat-REPL, RAG, AI tools & agents, with access to OpenAI, Claude, Gemini, Ollama, Groq, and more.
|
|
||||||
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
|
|
||||||
- [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis.
|
|
||||||
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
|
|
||||||
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull, and download models from Ollama Registry in your terminal.
|
|
||||||
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
|
|
||||||
- [AWS-Strands-With-Ollama](https://github.com/rapidarchitect/ollama_strands) - AWS Strands Agents with Ollama Examples
|
|
||||||
- [ollama-multirun](https://github.com/attogram/ollama-multirun) - A bash shell script to run a single prompt against any or all of your locally installed ollama models, saving the output and performance statistics as easily navigable web pages. ([Demo](https://attogram.github.io/ai_test_zone/))
|
|
||||||
- [ollama-bash-toolshed](https://github.com/attogram/ollama-bash-toolshed) - Bash scripts to chat with tool using models. Add new tools to your shed with ease. Runs on Ollama.
|
|
||||||
- [hle-eval-ollama](https://github.com/mags0ft/hle-eval-ollama) - Runs benchmarks like "Humanity's Last Exam" (HLE) on your favorite local Ollama models and evaluates the quality of their responses
|
|
||||||
- [VT Code](https://github.com/vinhnx/vtcode) - VT Code is a Rust-based terminal coding agent with semantic code intelligence via Tree-sitter. Ollama integration for running local/cloud models with configurable endpoints.
|
|
||||||
|
|
||||||
### Apple Vision Pro
|
|
||||||
|
|
||||||
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Cross-platform AI chat app supporting Apple Vision Pro via "Designed for iPad")
|
|
||||||
- [Enchanted](https://github.com/AugustDev/enchanted)
|
|
||||||
|
|
||||||
### Database
|
|
||||||
|
|
||||||
- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector)
|
|
||||||
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
|
|
||||||
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
|
|
||||||
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
|
|
||||||
- [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases)
|
|
||||||
|
|
||||||
### Package managers
|
|
||||||
|
|
||||||
- [Pacman](https://archlinux.org/packages/extra/x86_64/ollama/)
|
- [Pacman](https://archlinux.org/packages/extra/x86_64/ollama/)
|
||||||
- [Gentoo](https://github.com/gentoo/guru/tree/master/app-misc/ollama)
|
|
||||||
- [Homebrew](https://formulae.brew.sh/formula/ollama)
|
- [Homebrew](https://formulae.brew.sh/formula/ollama)
|
||||||
- [Helm Chart](https://artifacthub.io/packages/helm/ollama-helm/ollama)
|
|
||||||
- [Guix channel](https://codeberg.org/tusharhero/ollama-guix)
|
|
||||||
- [Nix package](https://search.nixos.org/packages?show=ollama&from=0&size=50&sort=relevance&type=packages&query=ollama)
|
- [Nix package](https://search.nixos.org/packages?show=ollama&from=0&size=50&sort=relevance&type=packages&query=ollama)
|
||||||
|
- [Helm Chart](https://artifacthub.io/packages/helm/ollama-helm/ollama)
|
||||||
|
- [Gentoo](https://github.com/gentoo/guru/tree/master/app-misc/ollama)
|
||||||
- [Flox](https://flox.dev/blog/ollama-part-one)
|
- [Flox](https://flox.dev/blog/ollama-part-one)
|
||||||
|
- [Guix channel](https://codeberg.org/tusharhero/ollama-guix)
|
||||||
### Libraries
|
|
||||||
|
|
||||||
- [LangChain](https://python.langchain.com/docs/integrations/chat/ollama/) and [LangChain.js](https://js.langchain.com/docs/integrations/chat/ollama/) with [example](https://js.langchain.com/docs/tutorials/local_rag/)
|
|
||||||
- [Firebase Genkit](https://firebase.google.com/docs/genkit/plugins/ollama)
|
|
||||||
- [crewAI](https://github.com/crewAIInc/crewAI)
|
|
||||||
- [Yacana](https://remembersoftwares.github.io/yacana/) (User-friendly multi-agent framework for brainstorming and executing predetermined flows with built-in tool integration)
|
|
||||||
- [Strands Agents](https://github.com/strands-agents/sdk-python) (A model-driven approach to building AI agents in just a few lines of code)
|
|
||||||
- [Spring AI](https://github.com/spring-projects/spring-ai) with [reference](https://docs.spring.io/spring-ai/reference/api/chat/ollama-chat.html) and [example](https://github.com/tzolov/ollama-tools)
|
|
||||||
- [LangChainGo](https://github.com/tmc/langchaingo/) with [example](https://github.com/tmc/langchaingo/tree/main/examples/ollama-completion-example)
|
|
||||||
- [LangChain4j](https://github.com/langchain4j/langchain4j) with [example](https://github.com/langchain4j/langchain4j-examples/tree/main/ollama-examples/src/main/java)
|
|
||||||
- [LangChainRust](https://github.com/Abraxas-365/langchain-rust) with [example](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/llm_ollama.rs)
|
|
||||||
- [LangChain for .NET](https://github.com/tryAGI/LangChain) with [example](https://github.com/tryAGI/LangChain/blob/main/examples/LangChain.Samples.OpenAI/Program.cs)
|
|
||||||
- [LLPhant](https://github.com/theodo-group/LLPhant?tab=readme-ov-file#ollama)
|
|
||||||
- [LlamaIndex](https://docs.llamaindex.ai/en/stable/examples/llm/ollama/) and [LlamaIndexTS](https://ts.llamaindex.ai/modules/llms/available_llms/ollama)
|
|
||||||
- [LiteLLM](https://github.com/BerriAI/litellm)
|
|
||||||
- [OllamaFarm for Go](https://github.com/presbrey/ollamafarm)
|
|
||||||
- [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp)
|
|
||||||
- [Ollama for Ruby](https://github.com/crmne/ruby_llm)
|
|
||||||
- [Ollama-rs for Rust](https://github.com/pepperoni21/ollama-rs)
|
|
||||||
- [Ollama-hpp for C++](https://github.com/jmont-dev/ollama-hpp)
|
|
||||||
- [Ollama4j for Java](https://github.com/ollama4j/ollama4j)
|
|
||||||
- [ModelFusion Typescript Library](https://modelfusion.dev/integration/model-provider/ollama)
|
|
||||||
- [OllamaKit for Swift](https://github.com/kevinhermawan/OllamaKit)
|
|
||||||
- [Ollama for Dart](https://github.com/breitburg/dart-ollama)
|
|
||||||
- [Ollama for Laravel](https://github.com/cloudstudio/ollama-laravel)
|
|
||||||
- [LangChainDart](https://github.com/davidmigloz/langchain_dart)
|
|
||||||
- [Semantic Kernel - Python](https://github.com/microsoft/semantic-kernel/tree/main/python/semantic_kernel/connectors/ai/ollama)
|
|
||||||
- [Haystack](https://github.com/deepset-ai/haystack-integrations/blob/main/integrations/ollama.md)
|
|
||||||
- [Elixir LangChain](https://github.com/brainlid/langchain)
|
|
||||||
- [Ollama for R - rollama](https://github.com/JBGruber/rollama)
|
|
||||||
- [Ollama for R - ollama-r](https://github.com/hauselin/ollama-r)
|
|
||||||
- [Ollama-ex for Elixir](https://github.com/lebrunel/ollama-ex)
|
|
||||||
- [Ollama Connector for SAP ABAP](https://github.com/b-tocs/abap_btocs_ollama)
|
|
||||||
- [Testcontainers](https://testcontainers.com/modules/ollama/)
|
|
||||||
- [Portkey](https://portkey.ai/docs/welcome/integration-guides/ollama)
|
|
||||||
- [PromptingTools.jl](https://github.com/svilupp/PromptingTools.jl) with an [example](https://svilupp.github.io/PromptingTools.jl/dev/examples/working_with_ollama)
|
|
||||||
- [LlamaScript](https://github.com/Project-Llama/llamascript)
|
|
||||||
- [llm-axe](https://github.com/emirsahin1/llm-axe) (Python Toolkit for Building LLM Powered Apps)
|
|
||||||
- [Gollm](https://docs.gollm.co/examples/ollama-example)
|
|
||||||
- [Gollama for Golang](https://github.com/jonathanhecl/gollama)
|
|
||||||
- [Ollamaclient for Golang](https://github.com/xyproto/ollamaclient)
|
|
||||||
- [High-level function abstraction in Go](https://gitlab.com/tozd/go/fun)
|
|
||||||
- [Ollama PHP](https://github.com/ArdaGnsrn/ollama-php)
|
|
||||||
- [Agents-Flex for Java](https://github.com/agents-flex/agents-flex) with [example](https://github.com/agents-flex/agents-flex/tree/main/agents-flex-llm/agents-flex-llm-ollama/src/test/java/com/agentsflex/llm/ollama)
|
|
||||||
- [Parakeet](https://github.com/parakeet-nest/parakeet) is a GoLang library, made to simplify the development of small generative AI applications with Ollama.
|
|
||||||
- [Haverscript](https://github.com/andygill/haverscript) with [examples](https://github.com/andygill/haverscript/tree/main/examples)
|
|
||||||
- [Ollama for Swift](https://github.com/mattt/ollama-swift)
|
|
||||||
- [Swollama for Swift](https://github.com/guitaripod/Swollama) with [DocC](https://guitaripod.github.io/Swollama/documentation/swollama)
|
|
||||||
- [GoLamify](https://github.com/prasad89/golamify)
|
|
||||||
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
|
|
||||||
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API)
|
|
||||||
- [LlmTornado](https://github.com/lofcz/llmtornado) (C# library providing a unified interface for major FOSS & Commercial inference APIs)
|
|
||||||
- [Ollama for Zig](https://github.com/dravenk/ollama-zig)
|
|
||||||
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
|
|
||||||
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
|
|
||||||
- [Ollama for D](https://github.com/kassane/ollama-d)
|
|
||||||
- [OllamaPlusPlus](https://github.com/HardCodeDev777/OllamaPlusPlus) (Very simple C++ library for Ollama)
|
|
||||||
- [any-llm](https://github.com/mozilla-ai/any-llm) (A single interface to use different llm providers by [mozilla.ai](https://www.mozilla.ai/))
|
|
||||||
- [any-agent](https://github.com/mozilla-ai/any-agent) (A single interface to use and evaluate different agent frameworks by [mozilla.ai](https://www.mozilla.ai/))
|
|
||||||
- [Neuro SAN](https://github.com/cognizant-ai-lab/neuro-san-studio) (Data-driven multi-agent orchestration framework) with [example](https://github.com/cognizant-ai-lab/neuro-san-studio/blob/main/docs/user_guide.md#ollama)
|
|
||||||
- [achatbot-go](https://github.com/ai-bot-pro/achatbot-go) a multimodal(text/audio/image) chatbot.
|
|
||||||
- [Ollama Bash Lib](https://github.com/attogram/ollama-bash-lib) - A Bash Library for Ollama. Run LLM prompts straight from your shell, and more
|
|
||||||
|
|
||||||
### Mobile
|
|
||||||
|
|
||||||
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Lightning-fast Cross-platform AI chat app with native UI for Android, iOS, and iPad)
|
|
||||||
- [Enchanted](https://github.com/AugustDev/enchanted)
|
|
||||||
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid)
|
|
||||||
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
|
||||||
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
|
|
||||||
- [Ollama Android Chat](https://github.com/sunshine0523/OllamaServer) (No need for Termux, start the Ollama service with one click on an Android device)
|
|
||||||
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
|
||||||
|
|
||||||
### Extensions & Plugins
|
|
||||||
|
|
||||||
- [Raycast extension](https://github.com/MassimilianoPasquini97/raycast_ollama)
|
|
||||||
- [Discollama](https://github.com/mxyng/discollama) (Discord bot inside the Ollama discord channel)
|
|
||||||
- [Continue](https://github.com/continuedev/continue)
|
|
||||||
- [Vibe](https://github.com/thewh1teagle/vibe) (Transcribe and analyze meetings with Ollama)
|
|
||||||
- [Obsidian Ollama plugin](https://github.com/hinterdupfinger/obsidian-ollama)
|
|
||||||
- [Logseq Ollama plugin](https://github.com/omagdy7/ollama-logseq)
|
|
||||||
- [NotesOllama](https://github.com/andersrex/notesollama) (Apple Notes Ollama plugin)
|
|
||||||
- [Dagger Chatbot](https://github.com/samalba/dagger-chatbot)
|
|
||||||
- [Discord AI Bot](https://github.com/mekb-turtle/discord-ai-bot)
|
|
||||||
- [Ollama Telegram Bot](https://github.com/ruecat/ollama-telegram)
|
|
||||||
- [Hass Ollama Conversation](https://github.com/ej52/hass-ollama-conversation)
|
|
||||||
- [Rivet plugin](https://github.com/abrenneke/rivet-plugin-ollama)
|
|
||||||
- [Obsidian BMO Chatbot plugin](https://github.com/longy2k/obsidian-bmo-chatbot)
|
|
||||||
- [Cliobot](https://github.com/herval/cliobot) (Telegram bot with Ollama support)
|
|
||||||
- [Copilot for Obsidian plugin](https://github.com/logancyang/obsidian-copilot)
|
|
||||||
- [Obsidian Local GPT plugin](https://github.com/pfrankov/obsidian-local-gpt)
|
|
||||||
- [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama)
|
|
||||||
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
|
|
||||||
- [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use Ollama as a copilot like GitHub Copilot)
|
|
||||||
- [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama)
|
|
||||||
- [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and Hugging Face)
|
|
||||||
- [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension)
|
|
||||||
- [Plasmoid Ollama Control](https://github.com/imoize/plasmoid-ollamacontrol) (KDE Plasma extension that allows you to quickly manage/control Ollama model)
|
|
||||||
- [AI Telegram Bot](https://github.com/tusharhero/aitelegrambot) (Telegram bot using Ollama in backend)
|
|
||||||
- [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (Sublime Text 4 AI assistant plugin with Ollama support)
|
|
||||||
- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation)
|
|
||||||
- [ChatGPTBox: All in one browser extension](https://github.com/josStorer/chatGPTBox) with [Integrating Tutorial](https://github.com/josStorer/chatGPTBox/issues/616#issuecomment-1975186467)
|
|
||||||
- [Discord AI chat/moderation bot](https://github.com/rapmd73/Companion) Chat/moderation bot written in python. Uses Ollama to create personalities.
|
|
||||||
- [Headless Ollama](https://github.com/nischalj10/headless-ollama) (Scripts to automatically install ollama client & models on any OS for apps that depend on ollama server)
|
|
||||||
- [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front-end Open WebUI service.)
|
|
||||||
- [node-red-contrib-ollama](https://github.com/jakubburkiewicz/node-red-contrib-ollama)
|
|
||||||
- [Local AI Helper](https://github.com/ivostoykov/localAI) (Chrome and Firefox extensions that enable interactions with the active tab and customisable API endpoints. Includes secure storage for user prompts.)
|
|
||||||
- [LSP-AI](https://github.com/SilasMarvin/lsp-ai) (Open-source language server for AI-powered functionality)
|
|
||||||
- [QodeAssist](https://github.com/Palm1r/QodeAssist) (AI-powered coding assistant plugin for Qt Creator)
|
|
||||||
- [Obsidian Quiz Generator plugin](https://github.com/ECuiDev/obsidian-quiz-generator)
|
|
||||||
- [AI Summary Helper plugin](https://github.com/philffm/ai-summary-helper)
|
|
||||||
- [TextCraft](https://github.com/suncloudsmoon/TextCraft) (Copilot in Word alternative using Ollama)
|
|
||||||
- [Alfred Ollama](https://github.com/zeitlings/alfred-ollama) (Alfred Workflow)
|
|
||||||
- [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language
|
|
||||||
- [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai)
|
|
||||||
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
|
|
||||||
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
|
|
||||||
- [SimpleOllamaUnity](https://github.com/HardCodeDev777/SimpleOllamaUnity) (Unity Engine extension for communicating with Ollama in a few lines of code. Also works at runtime)
|
|
||||||
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Editor tool to analyze scripts via Ollama)
|
|
||||||
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies)
|
|
||||||
- [GMAI - Gradle Managed AI](https://gmai.premex.se/) (Gradle plugin for automated Ollama lifecycle management during build phases)
|
|
||||||
- [NOMYO Router](https://github.com/nomyo-ai/nomyo-router) (A transparent Ollama proxy with model deployment aware routing which auto-manages multiple Ollama instances in a given network)
|
|
||||||
|
|
||||||
### Supported backends
|
|
||||||
|
|
||||||
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
|
|
||||||
|
|
||||||
### Observability
|
|
||||||
|
|
||||||
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
|
|
||||||
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
|
|
||||||
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
|
||||||
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.
|
|
||||||
- [Langfuse](https://langfuse.com/docs/integrations/ollama) is an open source LLM observability platform that enables teams to collaboratively monitor, evaluate and debug AI applications.
|
|
||||||
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
|
|
||||||
|
|
||||||
### Security
|
|
||||||
|
|
||||||
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)
|
|
||||||
|
|||||||
475
anthropic/anthropic.go
Normal file → Executable file
475
anthropic/anthropic.go
Normal file → Executable file
@@ -1,17 +1,25 @@
|
|||||||
package anthropic
|
package anthropic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/auth"
|
||||||
|
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||||
|
"github.com/ollama/ollama/logutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Error types matching Anthropic API
|
// Error types matching Anthropic API
|
||||||
@@ -82,22 +90,25 @@ type MessageParam struct {
|
|||||||
// Text and Thinking use pointers so they serialize as the field being present (even if empty)
|
// Text and Thinking use pointers so they serialize as the field being present (even if empty)
|
||||||
// only when set, which is required for SDK streaming accumulation.
|
// only when set, which is required for SDK streaming accumulation.
|
||||||
type ContentBlock struct {
|
type ContentBlock struct {
|
||||||
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
|
Type string `json:"type"` // text, image, tool_use, tool_result, thinking, server_tool_use, web_search_tool_result
|
||||||
|
|
||||||
// For text blocks - pointer so field only appears when set (SDK requires it for accumulation)
|
// For text blocks - pointer so field only appears when set (SDK requires it for accumulation)
|
||||||
Text *string `json:"text,omitempty"`
|
Text *string `json:"text,omitempty"`
|
||||||
|
|
||||||
|
// For text blocks with citations
|
||||||
|
Citations []Citation `json:"citations,omitempty"`
|
||||||
|
|
||||||
// For image blocks
|
// For image blocks
|
||||||
Source *ImageSource `json:"source,omitempty"`
|
Source *ImageSource `json:"source,omitempty"`
|
||||||
|
|
||||||
// For tool_use blocks
|
// For tool_use and server_tool_use blocks
|
||||||
ID string `json:"id,omitempty"`
|
ID string `json:"id,omitempty"`
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
Input any `json:"input,omitempty"`
|
Input any `json:"input,omitempty"`
|
||||||
|
|
||||||
// For tool_result blocks
|
// For tool_result and web_search_tool_result blocks
|
||||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||||
Content any `json:"content,omitempty"` // string or []ContentBlock
|
Content any `json:"content,omitempty"` // string, []ContentBlock, []WebSearchResult, or WebSearchToolResultError
|
||||||
IsError bool `json:"is_error,omitempty"`
|
IsError bool `json:"is_error,omitempty"`
|
||||||
|
|
||||||
// For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation)
|
// For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation)
|
||||||
@@ -105,6 +116,30 @@ type ContentBlock struct {
|
|||||||
Signature string `json:"signature,omitempty"`
|
Signature string `json:"signature,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Citation represents a citation in a text block
|
||||||
|
type Citation struct {
|
||||||
|
Type string `json:"type"` // "web_search_result_location"
|
||||||
|
URL string `json:"url"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
EncryptedIndex string `json:"encrypted_index,omitempty"`
|
||||||
|
CitedText string `json:"cited_text,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebSearchResult represents a single web search result
|
||||||
|
type WebSearchResult struct {
|
||||||
|
Type string `json:"type"` // "web_search_result"
|
||||||
|
URL string `json:"url"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
EncryptedContent string `json:"encrypted_content,omitempty"`
|
||||||
|
PageAge string `json:"page_age,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebSearchToolResultError represents an error from web search
|
||||||
|
type WebSearchToolResultError struct {
|
||||||
|
Type string `json:"type"` // "web_search_tool_result_error"
|
||||||
|
ErrorCode string `json:"error_code"`
|
||||||
|
}
|
||||||
|
|
||||||
// ImageSource represents the source of an image
|
// ImageSource represents the source of an image
|
||||||
type ImageSource struct {
|
type ImageSource struct {
|
||||||
Type string `json:"type"` // "base64" or "url"
|
Type string `json:"type"` // "base64" or "url"
|
||||||
@@ -115,10 +150,13 @@ type ImageSource struct {
|
|||||||
|
|
||||||
// Tool represents a tool definition
|
// Tool represents a tool definition
|
||||||
type Tool struct {
|
type Tool struct {
|
||||||
Type string `json:"type,omitempty"` // "custom" for user-defined tools
|
Type string `json:"type,omitempty"` // "custom" for user-defined tools, or "web_search_20250305" for web search
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Description string `json:"description,omitempty"`
|
Description string `json:"description,omitempty"`
|
||||||
InputSchema json.RawMessage `json:"input_schema,omitempty"`
|
InputSchema json.RawMessage `json:"input_schema,omitempty"`
|
||||||
|
|
||||||
|
// Web search specific fields
|
||||||
|
MaxUses int `json:"max_uses,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToolChoice controls how the model uses tools
|
// ToolChoice controls how the model uses tools
|
||||||
@@ -211,6 +249,7 @@ type MessageDelta struct {
|
|||||||
|
|
||||||
// DeltaUsage contains cumulative token usage
|
// DeltaUsage contains cumulative token usage
|
||||||
type DeltaUsage struct {
|
type DeltaUsage struct {
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
OutputTokens int `json:"output_tokens"`
|
OutputTokens int `json:"output_tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -232,6 +271,8 @@ type StreamErrorEvent struct {
|
|||||||
|
|
||||||
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
|
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
|
||||||
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||||
|
logutil.Trace("anthropic: converting request", "req", TraceMessagesRequest(r))
|
||||||
|
|
||||||
var messages []api.Message
|
var messages []api.Message
|
||||||
|
|
||||||
if r.System != nil {
|
if r.System != nil {
|
||||||
@@ -258,9 +299,10 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, msg := range r.Messages {
|
for i, msg := range r.Messages {
|
||||||
converted, err := convertMessage(msg)
|
converted, err := convertMessage(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logutil.Trace("anthropic: message conversion failed", "index", i, "role", msg.Role, "err", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
messages = append(messages, converted...)
|
messages = append(messages, converted...)
|
||||||
@@ -287,8 +329,24 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var tools api.Tools
|
var tools api.Tools
|
||||||
|
hasBuiltinWebSearch := false
|
||||||
for _, t := range r.Tools {
|
for _, t := range r.Tools {
|
||||||
tool, err := convertTool(t)
|
if strings.HasPrefix(t.Type, "web_search") {
|
||||||
|
hasBuiltinWebSearch = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, t := range r.Tools {
|
||||||
|
// Anthropic built-in web_search maps to Ollama function name "web_search".
|
||||||
|
// If a user-defined tool also uses that name in the same request, drop the
|
||||||
|
// user-defined one to avoid ambiguous tool-call routing.
|
||||||
|
if hasBuiltinWebSearch && !strings.HasPrefix(t.Type, "web_search") && t.Name == "web_search" {
|
||||||
|
logutil.Trace("anthropic: dropping colliding custom web_search tool", "tool", TraceTool(t))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
tool, _, err := convertTool(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -301,15 +359,17 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
stream := r.Stream
|
stream := r.Stream
|
||||||
|
convertedRequest := &api.ChatRequest{
|
||||||
return &api.ChatRequest{
|
|
||||||
Model: r.Model,
|
Model: r.Model,
|
||||||
Messages: messages,
|
Messages: messages,
|
||||||
Options: options,
|
Options: options,
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Tools: tools,
|
Tools: tools,
|
||||||
Think: think,
|
Think: think,
|
||||||
}, nil
|
}
|
||||||
|
logutil.Trace("anthropic: converted request", "req", TraceChatRequest(convertedRequest))
|
||||||
|
|
||||||
|
return convertedRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
|
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
|
||||||
@@ -327,10 +387,19 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
|||||||
var toolCalls []api.ToolCall
|
var toolCalls []api.ToolCall
|
||||||
var thinking string
|
var thinking string
|
||||||
var toolResults []api.Message
|
var toolResults []api.Message
|
||||||
|
textBlocks := 0
|
||||||
|
imageBlocks := 0
|
||||||
|
toolUseBlocks := 0
|
||||||
|
toolResultBlocks := 0
|
||||||
|
serverToolUseBlocks := 0
|
||||||
|
webSearchToolResultBlocks := 0
|
||||||
|
thinkingBlocks := 0
|
||||||
|
unknownBlocks := 0
|
||||||
|
|
||||||
for _, block := range content {
|
for _, block := range content {
|
||||||
blockMap, ok := block.(map[string]any)
|
blockMap, ok := block.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
logutil.Trace("anthropic: invalid content block format", "role", role)
|
||||||
return nil, errors.New("invalid content block format")
|
return nil, errors.New("invalid content block format")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -338,13 +407,16 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
|||||||
|
|
||||||
switch blockType {
|
switch blockType {
|
||||||
case "text":
|
case "text":
|
||||||
|
textBlocks++
|
||||||
if text, ok := blockMap["text"].(string); ok {
|
if text, ok := blockMap["text"].(string); ok {
|
||||||
textContent.WriteString(text)
|
textContent.WriteString(text)
|
||||||
}
|
}
|
||||||
|
|
||||||
case "image":
|
case "image":
|
||||||
|
imageBlocks++
|
||||||
source, ok := blockMap["source"].(map[string]any)
|
source, ok := blockMap["source"].(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
logutil.Trace("anthropic: invalid image source", "role", role)
|
||||||
return nil, errors.New("invalid image source")
|
return nil, errors.New("invalid image source")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -353,21 +425,26 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
|||||||
data, _ := source["data"].(string)
|
data, _ := source["data"].(string)
|
||||||
decoded, err := base64.StdEncoding.DecodeString(data)
|
decoded, err := base64.StdEncoding.DecodeString(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err)
|
||||||
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
||||||
}
|
}
|
||||||
images = append(images, decoded)
|
images = append(images, decoded)
|
||||||
} else {
|
} else {
|
||||||
|
logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", sourceType)
|
||||||
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
|
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
|
||||||
}
|
}
|
||||||
// URL images would need to be fetched - skip for now
|
// URL images would need to be fetched - skip for now
|
||||||
|
|
||||||
case "tool_use":
|
case "tool_use":
|
||||||
|
toolUseBlocks++
|
||||||
id, ok := blockMap["id"].(string)
|
id, ok := blockMap["id"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
logutil.Trace("anthropic: tool_use block missing id", "role", role)
|
||||||
return nil, errors.New("tool_use block missing required 'id' field")
|
return nil, errors.New("tool_use block missing required 'id' field")
|
||||||
}
|
}
|
||||||
name, ok := blockMap["name"].(string)
|
name, ok := blockMap["name"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
logutil.Trace("anthropic: tool_use block missing name", "role", role)
|
||||||
return nil, errors.New("tool_use block missing required 'name' field")
|
return nil, errors.New("tool_use block missing required 'name' field")
|
||||||
}
|
}
|
||||||
tc := api.ToolCall{
|
tc := api.ToolCall{
|
||||||
@@ -382,6 +459,7 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
|||||||
toolCalls = append(toolCalls, tc)
|
toolCalls = append(toolCalls, tc)
|
||||||
|
|
||||||
case "tool_result":
|
case "tool_result":
|
||||||
|
toolResultBlocks++
|
||||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||||
var resultContent string
|
var resultContent string
|
||||||
|
|
||||||
@@ -407,9 +485,36 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
case "thinking":
|
case "thinking":
|
||||||
|
thinkingBlocks++
|
||||||
if t, ok := blockMap["thinking"].(string); ok {
|
if t, ok := blockMap["thinking"].(string); ok {
|
||||||
thinking = t
|
thinking = t
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case "server_tool_use":
|
||||||
|
serverToolUseBlocks++
|
||||||
|
id, _ := blockMap["id"].(string)
|
||||||
|
name, _ := blockMap["name"].(string)
|
||||||
|
tc := api.ToolCall{
|
||||||
|
ID: id,
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: name,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if input, ok := blockMap["input"].(map[string]any); ok {
|
||||||
|
tc.Function.Arguments = mapToArgs(input)
|
||||||
|
}
|
||||||
|
toolCalls = append(toolCalls, tc)
|
||||||
|
|
||||||
|
case "web_search_tool_result":
|
||||||
|
webSearchToolResultBlocks++
|
||||||
|
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||||
|
toolResults = append(toolResults, api.Message{
|
||||||
|
Role: "tool",
|
||||||
|
Content: formatWebSearchToolResultContent(blockMap["content"]),
|
||||||
|
ToolCallID: toolUseID,
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
unknownBlocks++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -426,6 +531,19 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
|||||||
|
|
||||||
// Add tool results as separate messages
|
// Add tool results as separate messages
|
||||||
messages = append(messages, toolResults...)
|
messages = append(messages, toolResults...)
|
||||||
|
logutil.Trace("anthropic: converted block message",
|
||||||
|
"role", role,
|
||||||
|
"blocks", len(content),
|
||||||
|
"text", textBlocks,
|
||||||
|
"image", imageBlocks,
|
||||||
|
"tool_use", toolUseBlocks,
|
||||||
|
"tool_result", toolResultBlocks,
|
||||||
|
"server_tool_use", serverToolUseBlocks,
|
||||||
|
"web_search_result", webSearchToolResultBlocks,
|
||||||
|
"thinking", thinkingBlocks,
|
||||||
|
"unknown", unknownBlocks,
|
||||||
|
"messages", TraceAPIMessages(messages),
|
||||||
|
)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid message content type: %T", content)
|
return nil, fmt.Errorf("invalid message content type: %T", content)
|
||||||
@@ -434,12 +552,94 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
|||||||
return messages, nil
|
return messages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// convertTool converts an Anthropic Tool to an Ollama api.Tool
|
func formatWebSearchToolResultContent(content any) string {
|
||||||
func convertTool(t Tool) (api.Tool, error) {
|
switch c := content.(type) {
|
||||||
|
case string:
|
||||||
|
return c
|
||||||
|
case []WebSearchResult:
|
||||||
|
var resultContent strings.Builder
|
||||||
|
for _, item := range c {
|
||||||
|
if item.Type != "web_search_result" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fmt.Fprintf(&resultContent, "- %s: %s\n", item.Title, item.URL)
|
||||||
|
}
|
||||||
|
return resultContent.String()
|
||||||
|
case []any:
|
||||||
|
var resultContent strings.Builder
|
||||||
|
for _, item := range c {
|
||||||
|
itemMap, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch itemMap["type"] {
|
||||||
|
case "web_search_result":
|
||||||
|
title, _ := itemMap["title"].(string)
|
||||||
|
url, _ := itemMap["url"].(string)
|
||||||
|
fmt.Fprintf(&resultContent, "- %s: %s\n", title, url)
|
||||||
|
case "web_search_tool_result_error":
|
||||||
|
errorCode, _ := itemMap["error_code"].(string)
|
||||||
|
if errorCode == "" {
|
||||||
|
return "web_search_tool_result_error"
|
||||||
|
}
|
||||||
|
return "web_search_tool_result_error: " + errorCode
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return resultContent.String()
|
||||||
|
case map[string]any:
|
||||||
|
if c["type"] == "web_search_tool_result_error" {
|
||||||
|
errorCode, _ := c["error_code"].(string)
|
||||||
|
if errorCode == "" {
|
||||||
|
return "web_search_tool_result_error"
|
||||||
|
}
|
||||||
|
return "web_search_tool_result_error: " + errorCode
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(c)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return string(data)
|
||||||
|
case WebSearchToolResultError:
|
||||||
|
if c.ErrorCode == "" {
|
||||||
|
return "web_search_tool_result_error"
|
||||||
|
}
|
||||||
|
return "web_search_tool_result_error: " + c.ErrorCode
|
||||||
|
default:
|
||||||
|
data, err := json.Marshal(c)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return string(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertTool converts an Anthropic Tool to an Ollama api.Tool, returning true if it's a server tool
|
||||||
|
func convertTool(t Tool) (api.Tool, bool, error) {
|
||||||
|
if strings.HasPrefix(t.Type, "web_search") {
|
||||||
|
props := api.NewToolPropertiesMap()
|
||||||
|
props.Set("query", api.ToolProperty{
|
||||||
|
Type: api.PropertyType{"string"},
|
||||||
|
Description: "The search query to look up on the web",
|
||||||
|
})
|
||||||
|
return api.Tool{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "web_search",
|
||||||
|
Description: "Search the web for current information. Use this to find up-to-date information about any topic.",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Required: []string{"query"},
|
||||||
|
Properties: props,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
var params api.ToolFunctionParameters
|
var params api.ToolFunctionParameters
|
||||||
if len(t.InputSchema) > 0 {
|
if len(t.InputSchema) > 0 {
|
||||||
if err := json.Unmarshal(t.InputSchema, ¶ms); err != nil {
|
if err := json.Unmarshal(t.InputSchema, ¶ms); err != nil {
|
||||||
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
|
logutil.Trace("anthropic: invalid tool schema", "tool", t.Name, "err", err)
|
||||||
|
return api.Tool{}, false, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -450,7 +650,7 @@ func convertTool(t Tool) (api.Tool, error) {
|
|||||||
Description: t.Description,
|
Description: t.Description,
|
||||||
Parameters: params,
|
Parameters: params,
|
||||||
},
|
},
|
||||||
}, nil
|
}, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
|
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
|
||||||
@@ -523,17 +723,19 @@ type StreamConverter struct {
|
|||||||
contentIndex int
|
contentIndex int
|
||||||
inputTokens int
|
inputTokens int
|
||||||
outputTokens int
|
outputTokens int
|
||||||
|
estimatedInputTokens int // Estimated tokens from request (used when actual metrics are 0)
|
||||||
thinkingStarted bool
|
thinkingStarted bool
|
||||||
thinkingDone bool
|
thinkingDone bool
|
||||||
textStarted bool
|
textStarted bool
|
||||||
toolCallsSent map[string]bool
|
toolCallsSent map[string]bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewStreamConverter(id, model string) *StreamConverter {
|
func NewStreamConverter(id, model string, estimatedInputTokens int) *StreamConverter {
|
||||||
return &StreamConverter{
|
return &StreamConverter{
|
||||||
ID: id,
|
ID: id,
|
||||||
Model: model,
|
Model: model,
|
||||||
firstWrite: true,
|
firstWrite: true,
|
||||||
|
estimatedInputTokens: estimatedInputTokens,
|
||||||
toolCallsSent: make(map[string]bool),
|
toolCallsSent: make(map[string]bool),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -550,7 +752,11 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
|||||||
|
|
||||||
if c.firstWrite {
|
if c.firstWrite {
|
||||||
c.firstWrite = false
|
c.firstWrite = false
|
||||||
|
// Use actual metrics if available, otherwise use estimate
|
||||||
c.inputTokens = r.Metrics.PromptEvalCount
|
c.inputTokens = r.Metrics.PromptEvalCount
|
||||||
|
if c.inputTokens == 0 && c.estimatedInputTokens > 0 {
|
||||||
|
c.inputTokens = c.estimatedInputTokens
|
||||||
|
}
|
||||||
|
|
||||||
events = append(events, StreamEvent{
|
events = append(events, StreamEvent{
|
||||||
Event: "message_start",
|
Event: "message_start",
|
||||||
@@ -646,6 +852,19 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close thinking block if still open (thinking → tool_use without text in between)
|
||||||
|
if c.thinkingStarted && !c.thinkingDone {
|
||||||
|
c.thinkingDone = true
|
||||||
|
events = append(events, StreamEvent{
|
||||||
|
Event: "content_block_stop",
|
||||||
|
Data: ContentBlockStopEvent{
|
||||||
|
Type: "content_block_stop",
|
||||||
|
Index: c.contentIndex,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
c.contentIndex++
|
||||||
|
}
|
||||||
|
|
||||||
if c.textStarted {
|
if c.textStarted {
|
||||||
events = append(events, StreamEvent{
|
events = append(events, StreamEvent{
|
||||||
Event: "content_block_stop",
|
Event: "content_block_stop",
|
||||||
@@ -721,6 +940,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.inputTokens = r.Metrics.PromptEvalCount
|
||||||
c.outputTokens = r.Metrics.EvalCount
|
c.outputTokens = r.Metrics.EvalCount
|
||||||
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
|
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
|
||||||
|
|
||||||
@@ -732,6 +952,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
|||||||
StopReason: stopReason,
|
StopReason: stopReason,
|
||||||
},
|
},
|
||||||
Usage: DeltaUsage{
|
Usage: DeltaUsage{
|
||||||
|
InputTokens: c.inputTokens,
|
||||||
OutputTokens: c.outputTokens,
|
OutputTokens: c.outputTokens,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -776,3 +997,227 @@ func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
|
|||||||
}
|
}
|
||||||
return args
|
return args
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CountTokensRequest represents an Anthropic count_tokens request
|
||||||
|
type CountTokensRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []MessageParam `json:"messages"`
|
||||||
|
System any `json:"system,omitempty"`
|
||||||
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EstimateInputTokens estimates input tokens from a MessagesRequest (reuses CountTokensRequest logic)
|
||||||
|
func EstimateInputTokens(req MessagesRequest) int {
|
||||||
|
return estimateTokens(CountTokensRequest{
|
||||||
|
Model: req.Model,
|
||||||
|
Messages: req.Messages,
|
||||||
|
System: req.System,
|
||||||
|
Tools: req.Tools,
|
||||||
|
Thinking: req.Thinking,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountTokensResponse represents an Anthropic count_tokens response
|
||||||
|
type CountTokensResponse struct {
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// estimateTokens returns a rough estimate of tokens (len/4).
|
||||||
|
// TODO: Replace with actual tokenization via Tokenize API for accuracy.
|
||||||
|
// Current len/4 heuristic is a rough approximation (~4 chars/token average).
|
||||||
|
func estimateTokens(req CountTokensRequest) int {
|
||||||
|
var totalLen int
|
||||||
|
|
||||||
|
// Count system prompt
|
||||||
|
if req.System != nil {
|
||||||
|
totalLen += countAnyContent(req.System)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count messages
|
||||||
|
for _, msg := range req.Messages {
|
||||||
|
// Count role (always present)
|
||||||
|
totalLen += len(msg.Role)
|
||||||
|
// Count content
|
||||||
|
contentLen := countAnyContent(msg.Content)
|
||||||
|
totalLen += contentLen
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tool := range req.Tools {
|
||||||
|
totalLen += len(tool.Name) + len(tool.Description) + len(tool.InputSchema)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return len/4 as rough token estimate, minimum 1 if there's any content
|
||||||
|
tokens := totalLen / 4
|
||||||
|
if tokens == 0 && (len(req.Messages) > 0 || req.System != nil) {
|
||||||
|
tokens = 1
|
||||||
|
}
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
func countAnyContent(content any) int {
|
||||||
|
if content == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
switch c := content.(type) {
|
||||||
|
case string:
|
||||||
|
return len(c)
|
||||||
|
case []any:
|
||||||
|
total := 0
|
||||||
|
for _, block := range c {
|
||||||
|
total += countContentBlock(block)
|
||||||
|
}
|
||||||
|
return total
|
||||||
|
default:
|
||||||
|
if data, err := json.Marshal(content); err == nil {
|
||||||
|
return len(data)
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func countContentBlock(block any) int {
|
||||||
|
blockMap, ok := block.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
if s, ok := block.(string); ok {
|
||||||
|
return len(s)
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
total := 0
|
||||||
|
blockType, _ := blockMap["type"].(string)
|
||||||
|
|
||||||
|
if text, ok := blockMap["text"].(string); ok {
|
||||||
|
total += len(text)
|
||||||
|
}
|
||||||
|
|
||||||
|
if thinking, ok := blockMap["thinking"].(string); ok {
|
||||||
|
total += len(thinking)
|
||||||
|
}
|
||||||
|
|
||||||
|
if blockType == "tool_use" {
|
||||||
|
if data, err := json.Marshal(blockMap); err == nil {
|
||||||
|
total += len(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if blockType == "tool_result" {
|
||||||
|
if data, err := json.Marshal(blockMap); err == nil {
|
||||||
|
total += len(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return total
|
||||||
|
}
|
||||||
|
|
||||||
|
// OllamaWebSearchRequest represents a request to the Ollama web search API
|
||||||
|
type OllamaWebSearchRequest struct {
|
||||||
|
Query string `json:"query"`
|
||||||
|
MaxResults int `json:"max_results,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// OllamaWebSearchResult represents a single search result from Ollama API
|
||||||
|
type OllamaWebSearchResult struct {
|
||||||
|
Title string `json:"title"`
|
||||||
|
URL string `json:"url"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// OllamaWebSearchResponse represents the response from the Ollama web search API
|
||||||
|
type OllamaWebSearchResponse struct {
|
||||||
|
Results []OllamaWebSearchResult `json:"results"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var WebSearchEndpoint = "https://ollama.com/api/web_search"
|
||||||
|
|
||||||
|
func WebSearch(ctx context.Context, query string, maxResults int) (*OllamaWebSearchResponse, error) {
|
||||||
|
if internalcloud.Disabled() {
|
||||||
|
logutil.TraceContext(ctx, "anthropic: web search blocked", "reason", "cloud_disabled")
|
||||||
|
return nil, errors.New(internalcloud.DisabledError("web search is unavailable"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if maxResults <= 0 {
|
||||||
|
maxResults = 5
|
||||||
|
}
|
||||||
|
if maxResults > 10 {
|
||||||
|
maxResults = 10
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBody := OllamaWebSearchRequest{
|
||||||
|
Query: query,
|
||||||
|
MaxResults: maxResults,
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal web search request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
searchURL, err := url.Parse(WebSearchEndpoint)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse web search URL: %w", err)
|
||||||
|
}
|
||||||
|
logutil.TraceContext(ctx, "anthropic: web search request",
|
||||||
|
"query", TraceTruncateString(query),
|
||||||
|
"max_results", maxResults,
|
||||||
|
"url", searchURL.String(),
|
||||||
|
)
|
||||||
|
|
||||||
|
q := searchURL.Query()
|
||||||
|
q.Set("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
||||||
|
searchURL.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
signature := ""
|
||||||
|
if strings.EqualFold(searchURL.Hostname(), "ollama.com") {
|
||||||
|
challenge := fmt.Sprintf("%s,%s", http.MethodPost, searchURL.RequestURI())
|
||||||
|
signature, err = auth.Sign(ctx, []byte(challenge))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to sign web search request: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logutil.TraceContext(ctx, "anthropic: web search auth", "signed", signature != "")
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", searchURL.String(), bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create web search request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
if signature != "" {
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signature))
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("web search request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
logutil.TraceContext(ctx, "anthropic: web search response", "status", resp.StatusCode)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
return nil, fmt.Errorf("web search returned status %d: %s", resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
var searchResp OllamaWebSearchResponse
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&searchResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode web search response: %w", err)
|
||||||
|
}
|
||||||
|
logutil.TraceContext(ctx, "anthropic: web search results", "count", len(searchResp.Results))
|
||||||
|
|
||||||
|
return &searchResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConvertOllamaToAnthropicResults(ollamaResults *OllamaWebSearchResponse) []WebSearchResult {
|
||||||
|
var results []WebSearchResult
|
||||||
|
for _, r := range ollamaResults.Results {
|
||||||
|
results = append(results, WebSearchResult{
|
||||||
|
Type: "web_search_result",
|
||||||
|
URL: r.URL,
|
||||||
|
Title: r.Title,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
|||||||
633
anthropic/anthropic_test.go
Normal file → Executable file
633
anthropic/anthropic_test.go
Normal file → Executable file
@@ -3,6 +3,7 @@ package anthropic
|
|||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
@@ -300,6 +301,78 @@ func TestFromMessagesRequest_WithTools(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFromMessagesRequest_DropsCustomWebSearchWhenBuiltinPresent(t *testing.T) {
|
||||||
|
req := MessagesRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
MaxTokens: 1024,
|
||||||
|
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||||
|
Tools: []Tool{
|
||||||
|
{
|
||||||
|
Type: "web_search_20250305",
|
||||||
|
Name: "web_search",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: "custom",
|
||||||
|
Name: "web_search",
|
||||||
|
Description: "User-defined web search that should be dropped",
|
||||||
|
InputSchema: json.RawMessage(`{"type":"invalid"}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: "custom",
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get current weather",
|
||||||
|
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FromMessagesRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Tools) != 2 {
|
||||||
|
t.Fatalf("expected 2 tools after dropping custom web_search, got %d", len(result.Tools))
|
||||||
|
}
|
||||||
|
if result.Tools[0].Function.Name != "web_search" {
|
||||||
|
t.Fatalf("expected first tool to be built-in web_search, got %q", result.Tools[0].Function.Name)
|
||||||
|
}
|
||||||
|
if result.Tools[1].Function.Name != "get_weather" {
|
||||||
|
t.Fatalf("expected second tool to be get_weather, got %q", result.Tools[1].Function.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromMessagesRequest_KeepsCustomWebSearchWhenBuiltinAbsent(t *testing.T) {
|
||||||
|
req := MessagesRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
MaxTokens: 1024,
|
||||||
|
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||||
|
Tools: []Tool{
|
||||||
|
{
|
||||||
|
Type: "custom",
|
||||||
|
Name: "web_search",
|
||||||
|
Description: "User-defined web search",
|
||||||
|
InputSchema: json.RawMessage(`{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FromMessagesRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Tools) != 1 {
|
||||||
|
t.Fatalf("expected 1 custom tool, got %d", len(result.Tools))
|
||||||
|
}
|
||||||
|
if result.Tools[0].Function.Name != "web_search" {
|
||||||
|
t.Fatalf("expected custom tool name web_search, got %q", result.Tools[0].Function.Name)
|
||||||
|
}
|
||||||
|
if result.Tools[0].Function.Description != "User-defined web search" {
|
||||||
|
t.Fatalf("expected custom description preserved, got %q", result.Tools[0].Function.Description)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestFromMessagesRequest_WithThinking(t *testing.T) {
|
func TestFromMessagesRequest_WithThinking(t *testing.T) {
|
||||||
req := MessagesRequest{
|
req := MessagesRequest{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
@@ -321,8 +394,6 @@ func TestFromMessagesRequest_WithThinking(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
|
|
||||||
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
|
|
||||||
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
|
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
|
||||||
req := MessagesRequest{
|
req := MessagesRequest{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
@@ -605,7 +676,7 @@ func TestGenerateMessageID(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStreamConverter_Basic(t *testing.T) {
|
func TestStreamConverter_Basic(t *testing.T) {
|
||||||
conv := NewStreamConverter("msg_123", "test-model")
|
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||||
|
|
||||||
// First chunk
|
// First chunk
|
||||||
resp1 := api.ChatResponse{
|
resp1 := api.ChatResponse{
|
||||||
@@ -642,7 +713,7 @@ func TestStreamConverter_Basic(t *testing.T) {
|
|||||||
},
|
},
|
||||||
Done: true,
|
Done: true,
|
||||||
DoneReason: "stop",
|
DoneReason: "stop",
|
||||||
Metrics: api.Metrics{EvalCount: 5},
|
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
|
||||||
}
|
}
|
||||||
|
|
||||||
events2 := conv.Process(resp2)
|
events2 := conv.Process(resp2)
|
||||||
@@ -650,6 +721,24 @@ func TestStreamConverter_Basic(t *testing.T) {
|
|||||||
// Should have content_block_delta, content_block_stop, message_delta, message_stop
|
// Should have content_block_delta, content_block_stop, message_delta, message_stop
|
||||||
hasStop := false
|
hasStop := false
|
||||||
for _, e := range events2 {
|
for _, e := range events2 {
|
||||||
|
if e.Event == "message_delta" {
|
||||||
|
if data, ok := e.Data.(MessageDeltaEvent); ok {
|
||||||
|
if data.Type != "message_delta" {
|
||||||
|
t.Errorf("unexpected data type: %+v", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.Delta.StopReason != "end_turn" {
|
||||||
|
t.Errorf("unexpected stop reason: %+v", data.Delta.StopReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.Usage.InputTokens != 10 || data.Usage.OutputTokens != 5 {
|
||||||
|
t.Errorf("unexpected usage: %+v", data.Usage)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Errorf("unexpected data: %+v", e.Data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if e.Event == "message_stop" {
|
if e.Event == "message_stop" {
|
||||||
hasStop = true
|
hasStop = true
|
||||||
}
|
}
|
||||||
@@ -660,7 +749,7 @@ func TestStreamConverter_Basic(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStreamConverter_WithToolCalls(t *testing.T) {
|
func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||||
conv := NewStreamConverter("msg_123", "test-model")
|
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||||
|
|
||||||
resp := api.ChatResponse{
|
resp := api.ChatResponse{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
@@ -710,10 +799,111 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestStreamConverter_ThinkingDirectlyFollowedByToolCall verifies that when a
|
||||||
|
// model emits a thinking block followed directly by a tool_use block (with no
|
||||||
|
// text block in between), the streaming converter correctly closes the thinking
|
||||||
|
// block and increments the content index before opening the tool_use block.
|
||||||
|
// Previously, the converter reused contentIndex=0 for the tool_use block,
|
||||||
|
// which caused "Content block not found" errors in clients. See #14816.
|
||||||
|
func TestStreamConverter_ThinkingDirectlyFollowedByToolCall(t *testing.T) {
|
||||||
|
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||||
|
|
||||||
|
// First chunk: thinking content (no text)
|
||||||
|
resp1 := api.ChatResponse{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Thinking: "I should call the tool.",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
events1 := conv.Process(resp1)
|
||||||
|
|
||||||
|
// Should have: message_start, content_block_start(thinking), content_block_delta(thinking)
|
||||||
|
if len(events1) < 3 {
|
||||||
|
t.Fatalf("expected at least 3 events for thinking chunk, got %d", len(events1))
|
||||||
|
}
|
||||||
|
if events1[0].Event != "message_start" {
|
||||||
|
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
|
||||||
|
}
|
||||||
|
thinkingStart, ok := events1[1].Data.(ContentBlockStartEvent)
|
||||||
|
if !ok || thinkingStart.ContentBlock.Type != "thinking" {
|
||||||
|
t.Errorf("expected content_block_start(thinking) as second event, got %+v", events1[1])
|
||||||
|
}
|
||||||
|
if thinkingStart.Index != 0 {
|
||||||
|
t.Errorf("expected thinking block at index 0, got %d", thinkingStart.Index)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second chunk: tool call (no text between thinking and tool)
|
||||||
|
resp2 := api.ChatResponse{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_abc",
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "ask_user",
|
||||||
|
Arguments: testArgs(map[string]any{"question": "cats or dogs?"}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "stop",
|
||||||
|
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
|
||||||
|
}
|
||||||
|
events2 := conv.Process(resp2)
|
||||||
|
|
||||||
|
// Expect: content_block_stop(index=0), content_block_start(tool_use, index=1),
|
||||||
|
// content_block_delta(input_json_delta, index=1), content_block_stop(index=1),
|
||||||
|
// message_delta, message_stop
|
||||||
|
var thinkingStop, toolStart, toolDelta, toolStop *StreamEvent
|
||||||
|
for i := range events2 {
|
||||||
|
e := &events2[i]
|
||||||
|
switch e.Event {
|
||||||
|
case "content_block_stop":
|
||||||
|
if stop, ok := e.Data.(ContentBlockStopEvent); ok {
|
||||||
|
if stop.Index == 0 && thinkingStop == nil {
|
||||||
|
thinkingStop = e
|
||||||
|
} else if stop.Index == 1 {
|
||||||
|
toolStop = e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "content_block_start":
|
||||||
|
if start, ok := e.Data.(ContentBlockStartEvent); ok && start.ContentBlock.Type == "tool_use" {
|
||||||
|
toolStart = e
|
||||||
|
}
|
||||||
|
case "content_block_delta":
|
||||||
|
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok && delta.Delta.Type == "input_json_delta" {
|
||||||
|
toolDelta = e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if thinkingStop == nil {
|
||||||
|
t.Error("expected content_block_stop for thinking block (index 0)")
|
||||||
|
}
|
||||||
|
if toolStart == nil {
|
||||||
|
t.Fatal("expected content_block_start for tool_use block")
|
||||||
|
}
|
||||||
|
if start, ok := toolStart.Data.(ContentBlockStartEvent); !ok || start.Index != 1 {
|
||||||
|
t.Errorf("expected tool_use block at index 1, got %+v", toolStart.Data)
|
||||||
|
}
|
||||||
|
if toolDelta == nil {
|
||||||
|
t.Fatal("expected input_json_delta event for tool call")
|
||||||
|
}
|
||||||
|
if delta, ok := toolDelta.Data.(ContentBlockDeltaEvent); !ok || delta.Index != 1 {
|
||||||
|
t.Errorf("expected tool delta at index 1, got %+v", toolDelta.Data)
|
||||||
|
}
|
||||||
|
if toolStop == nil {
|
||||||
|
t.Error("expected content_block_stop for tool_use block (index 1)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||||
// Test that unmarshalable arguments (like channels) are handled gracefully
|
// Test that unmarshalable arguments (like channels) are handled gracefully
|
||||||
// and don't cause a panic or corrupt stream
|
// and don't cause a panic or corrupt stream
|
||||||
conv := NewStreamConverter("msg_123", "test-model")
|
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||||
|
|
||||||
// Create a channel which cannot be JSON marshaled
|
// Create a channel which cannot be JSON marshaled
|
||||||
unmarshalable := make(chan int)
|
unmarshalable := make(chan int)
|
||||||
@@ -760,7 +950,7 @@ func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
|||||||
|
|
||||||
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
||||||
// Test that valid tool calls still work when mixed with invalid ones
|
// Test that valid tool calls still work when mixed with invalid ones
|
||||||
conv := NewStreamConverter("msg_123", "test-model")
|
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||||
|
|
||||||
unmarshalable := make(chan int)
|
unmarshalable := make(chan int)
|
||||||
badArgs := api.NewToolCallFunctionArguments()
|
badArgs := api.NewToolCallFunctionArguments()
|
||||||
@@ -824,10 +1014,6 @@ func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
|
|
||||||
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
|
|
||||||
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
|
|
||||||
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
|
|
||||||
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -881,11 +1067,9 @@ func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
|
|
||||||
// events include the required empty fields for SDK compatibility.
|
|
||||||
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||||
t.Run("text block start includes empty text", func(t *testing.T) {
|
t.Run("text block start includes empty text", func(t *testing.T) {
|
||||||
conv := NewStreamConverter("msg_123", "test-model")
|
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||||
|
|
||||||
resp := api.ChatResponse{
|
resp := api.ChatResponse{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
@@ -919,7 +1103,7 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
|
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
|
||||||
conv := NewStreamConverter("msg_123", "test-model")
|
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||||
|
|
||||||
resp := api.ChatResponse{
|
resp := api.ChatResponse{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
@@ -951,3 +1135,422 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEstimateTokens_SimpleMessage(t *testing.T) {
|
||||||
|
req := CountTokensRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []MessageParam{
|
||||||
|
{Role: "user", Content: "Hello, world!"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens := estimateTokens(req)
|
||||||
|
|
||||||
|
// "user" (4) + "Hello, world!" (13) = 17 chars / 4 = 4 tokens
|
||||||
|
if tokens < 1 {
|
||||||
|
t.Errorf("expected at least 1 token, got %d", tokens)
|
||||||
|
}
|
||||||
|
// Sanity check: shouldn't be wildly off
|
||||||
|
if tokens > 10 {
|
||||||
|
t.Errorf("expected fewer than 10 tokens for short message, got %d", tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEstimateTokens_WithSystemPrompt(t *testing.T) {
|
||||||
|
req := CountTokensRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
System: "You are a helpful assistant.",
|
||||||
|
Messages: []MessageParam{
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens := estimateTokens(req)
|
||||||
|
|
||||||
|
// System prompt adds to count
|
||||||
|
if tokens < 5 {
|
||||||
|
t.Errorf("expected at least 5 tokens with system prompt, got %d", tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEstimateTokens_WithTools(t *testing.T) {
|
||||||
|
req := CountTokensRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []MessageParam{
|
||||||
|
{Role: "user", Content: "What's the weather?"},
|
||||||
|
},
|
||||||
|
Tools: []Tool{
|
||||||
|
{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get the current weather for a location",
|
||||||
|
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}}}`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens := estimateTokens(req)
|
||||||
|
|
||||||
|
// Tools add significant content
|
||||||
|
if tokens < 10 {
|
||||||
|
t.Errorf("expected at least 10 tokens with tools, got %d", tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEstimateTokens_WithThinking(t *testing.T) {
|
||||||
|
req := CountTokensRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []MessageParam{
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "thinking",
|
||||||
|
"thinking": "Let me think about this carefully...",
|
||||||
|
},
|
||||||
|
map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Here is my response.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens := estimateTokens(req)
|
||||||
|
|
||||||
|
// Thinking content should be counted
|
||||||
|
if tokens < 10 {
|
||||||
|
t.Errorf("expected at least 10 tokens with thinking content, got %d", tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEstimateTokens_EmptyContent(t *testing.T) {
|
||||||
|
req := CountTokensRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []MessageParam{},
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens := estimateTokens(req)
|
||||||
|
|
||||||
|
if tokens != 0 {
|
||||||
|
t.Errorf("expected 0 tokens for empty content, got %d", tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Web Search Tests
|
||||||
|
|
||||||
|
func TestConvertTool_WebSearch(t *testing.T) {
|
||||||
|
tool := Tool{
|
||||||
|
Type: "web_search_20250305",
|
||||||
|
Name: "web_search",
|
||||||
|
MaxUses: 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, isServerTool, err := convertTool(tool)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isServerTool {
|
||||||
|
t.Error("expected isServerTool to be true for web_search tool")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Type != "function" {
|
||||||
|
t.Errorf("expected type 'function', got %q", result.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Function.Name != "web_search" {
|
||||||
|
t.Errorf("expected name 'web_search', got %q", result.Function.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Function.Description == "" {
|
||||||
|
t.Error("expected non-empty description for web_search tool")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that query parameter is defined
|
||||||
|
if result.Function.Parameters.Properties == nil {
|
||||||
|
t.Fatal("expected properties to be defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
queryProp, ok := result.Function.Parameters.Properties.Get("query")
|
||||||
|
if !ok {
|
||||||
|
t.Error("expected 'query' property to be defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(queryProp.Type) == 0 || queryProp.Type[0] != "string" {
|
||||||
|
t.Errorf("expected query type to be 'string', got %v", queryProp.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertTool_RegularTool(t *testing.T) {
|
||||||
|
tool := Tool{
|
||||||
|
Type: "custom",
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get the weather",
|
||||||
|
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}}}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
result, isServerTool, err := convertTool(tool)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if isServerTool {
|
||||||
|
t.Error("expected isServerTool to be false for regular tool")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Function.Name != "get_weather" {
|
||||||
|
t.Errorf("expected name 'get_weather', got %q", result.Function.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertMessage_ServerToolUse(t *testing.T) {
|
||||||
|
msg := MessageParam{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "server_tool_use",
|
||||||
|
"id": "srvtoolu_123",
|
||||||
|
"name": "web_search",
|
||||||
|
"input": map[string]any{"query": "test query"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
messages, err := convertMessage(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(messages) != 1 {
|
||||||
|
t.Fatalf("expected 1 message, got %d", len(messages))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(messages[0].ToolCalls) != 1 {
|
||||||
|
t.Fatalf("expected 1 tool call, got %d", len(messages[0].ToolCalls))
|
||||||
|
}
|
||||||
|
|
||||||
|
tc := messages[0].ToolCalls[0]
|
||||||
|
if tc.ID != "srvtoolu_123" {
|
||||||
|
t.Errorf("expected tool call ID 'srvtoolu_123', got %q", tc.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.Function.Name != "web_search" {
|
||||||
|
t.Errorf("expected tool name 'web_search', got %q", tc.Function.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertMessage_WebSearchToolResult(t *testing.T) {
|
||||||
|
msg := MessageParam{
|
||||||
|
Role: "user",
|
||||||
|
Content: []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "web_search_tool_result",
|
||||||
|
"tool_use_id": "srvtoolu_123",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "web_search_result",
|
||||||
|
"title": "Test Result",
|
||||||
|
"url": "https://example.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
messages, err := convertMessage(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have a tool result message
|
||||||
|
if len(messages) != 1 {
|
||||||
|
t.Fatalf("expected 1 message, got %d", len(messages))
|
||||||
|
}
|
||||||
|
|
||||||
|
if messages[0].Role != "tool" {
|
||||||
|
t.Errorf("expected role 'tool', got %q", messages[0].Role)
|
||||||
|
}
|
||||||
|
|
||||||
|
if messages[0].ToolCallID != "srvtoolu_123" {
|
||||||
|
t.Errorf("expected tool_call_id 'srvtoolu_123', got %q", messages[0].ToolCallID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if messages[0].Content == "" {
|
||||||
|
t.Error("expected non-empty content from web search results")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertMessage_WebSearchToolResultEmptyStillCreatesToolMessage(t *testing.T) {
|
||||||
|
msg := MessageParam{
|
||||||
|
Role: "user",
|
||||||
|
Content: []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "web_search_tool_result",
|
||||||
|
"tool_use_id": "srvtoolu_empty",
|
||||||
|
"content": []any{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
messages, err := convertMessage(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(messages) != 1 {
|
||||||
|
t.Fatalf("expected 1 message, got %d", len(messages))
|
||||||
|
}
|
||||||
|
if messages[0].Role != "tool" {
|
||||||
|
t.Fatalf("expected role tool, got %q", messages[0].Role)
|
||||||
|
}
|
||||||
|
if messages[0].ToolCallID != "srvtoolu_empty" {
|
||||||
|
t.Fatalf("expected tool_call_id srvtoolu_empty, got %q", messages[0].ToolCallID)
|
||||||
|
}
|
||||||
|
if messages[0].Content != "" {
|
||||||
|
t.Fatalf("expected empty content for empty web search results, got %q", messages[0].Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertMessage_WebSearchToolResultErrorStillCreatesToolMessage(t *testing.T) {
|
||||||
|
msg := MessageParam{
|
||||||
|
Role: "user",
|
||||||
|
Content: []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "web_search_tool_result",
|
||||||
|
"tool_use_id": "srvtoolu_error",
|
||||||
|
"content": map[string]any{
|
||||||
|
"type": "web_search_tool_result_error",
|
||||||
|
"error_code": "max_uses_exceeded",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
messages, err := convertMessage(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(messages) != 1 {
|
||||||
|
t.Fatalf("expected 1 message, got %d", len(messages))
|
||||||
|
}
|
||||||
|
if messages[0].Role != "tool" {
|
||||||
|
t.Fatalf("expected role tool, got %q", messages[0].Role)
|
||||||
|
}
|
||||||
|
if messages[0].ToolCallID != "srvtoolu_error" {
|
||||||
|
t.Fatalf("expected tool_call_id srvtoolu_error, got %q", messages[0].ToolCallID)
|
||||||
|
}
|
||||||
|
if !strings.Contains(messages[0].Content, "max_uses_exceeded") {
|
||||||
|
t.Fatalf("expected error code in converted tool content, got %q", messages[0].Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertOllamaToAnthropicResults(t *testing.T) {
|
||||||
|
ollamaResp := &OllamaWebSearchResponse{
|
||||||
|
Results: []OllamaWebSearchResult{
|
||||||
|
{
|
||||||
|
Title: "Test Title",
|
||||||
|
URL: "https://example.com",
|
||||||
|
Content: "Test content",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Title: "Another Result",
|
||||||
|
URL: "https://example.org",
|
||||||
|
Content: "More content",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
results := ConvertOllamaToAnthropicResults(ollamaResp)
|
||||||
|
|
||||||
|
if len(results) != 2 {
|
||||||
|
t.Fatalf("expected 2 results, got %d", len(results))
|
||||||
|
}
|
||||||
|
|
||||||
|
if results[0].Type != "web_search_result" {
|
||||||
|
t.Errorf("expected type 'web_search_result', got %q", results[0].Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
if results[0].Title != "Test Title" {
|
||||||
|
t.Errorf("expected title 'Test Title', got %q", results[0].Title)
|
||||||
|
}
|
||||||
|
|
||||||
|
if results[0].URL != "https://example.com" {
|
||||||
|
t.Errorf("expected URL 'https://example.com', got %q", results[0].URL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebSearchTypes(t *testing.T) {
|
||||||
|
// Test that WebSearchResult serializes correctly
|
||||||
|
result := WebSearchResult{
|
||||||
|
Type: "web_search_result",
|
||||||
|
URL: "https://example.com",
|
||||||
|
Title: "Test",
|
||||||
|
EncryptedContent: "abc123",
|
||||||
|
PageAge: "2025-01-01",
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to marshal WebSearchResult: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var unmarshaled WebSearchResult
|
||||||
|
if err := json.Unmarshal(data, &unmarshaled); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal WebSearchResult: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if unmarshaled.Type != result.Type {
|
||||||
|
t.Errorf("type mismatch: expected %q, got %q", result.Type, unmarshaled.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test WebSearchToolResultError
|
||||||
|
errResult := WebSearchToolResultError{
|
||||||
|
Type: "web_search_tool_result_error",
|
||||||
|
ErrorCode: "max_uses_exceeded",
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err = json.Marshal(errResult)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to marshal WebSearchToolResultError: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var unmarshaledErr WebSearchToolResultError
|
||||||
|
if err := json.Unmarshal(data, &unmarshaledErr); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal WebSearchToolResultError: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if unmarshaledErr.ErrorCode != "max_uses_exceeded" {
|
||||||
|
t.Errorf("error_code mismatch: expected 'max_uses_exceeded', got %q", unmarshaledErr.ErrorCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCitation(t *testing.T) {
|
||||||
|
citation := Citation{
|
||||||
|
Type: "web_search_result_location",
|
||||||
|
URL: "https://example.com",
|
||||||
|
Title: "Example",
|
||||||
|
EncryptedIndex: "enc123",
|
||||||
|
CitedText: "Some cited text...",
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(citation)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to marshal Citation: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var unmarshaled Citation
|
||||||
|
if err := json.Unmarshal(data, &unmarshaled); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal Citation: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if unmarshaled.Type != "web_search_result_location" {
|
||||||
|
t.Errorf("type mismatch: expected 'web_search_result_location', got %q", unmarshaled.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
if unmarshaled.CitedText != "Some cited text..." {
|
||||||
|
t.Errorf("cited_text mismatch: expected 'Some cited text...', got %q", unmarshaled.CitedText)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
352
anthropic/trace.go
Normal file
352
anthropic/trace.go
Normal file
@@ -0,0 +1,352 @@
|
|||||||
|
package anthropic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Trace truncation limits.
|
||||||
|
const (
|
||||||
|
TraceMaxStringRunes = 240
|
||||||
|
TraceMaxSliceItems = 8
|
||||||
|
TraceMaxMapEntries = 16
|
||||||
|
TraceMaxDepth = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
// TraceTruncateString shortens s to TraceMaxStringRunes, appending a count of
|
||||||
|
// omitted characters when truncated.
|
||||||
|
func TraceTruncateString(s string) string {
|
||||||
|
if len(s) == 0 {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
runes := []rune(s)
|
||||||
|
if len(runes) <= TraceMaxStringRunes {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s...(+%d chars)", string(runes[:TraceMaxStringRunes]), len(runes)-TraceMaxStringRunes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TraceJSON round-trips v through JSON and returns a compacted representation.
|
||||||
|
func TraceJSON(v any) any {
|
||||||
|
if v == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return map[string]any{"marshal_error": err.Error(), "type": fmt.Sprintf("%T", v)}
|
||||||
|
}
|
||||||
|
var out any
|
||||||
|
if err := json.Unmarshal(data, &out); err != nil {
|
||||||
|
return TraceTruncateString(string(data))
|
||||||
|
}
|
||||||
|
return TraceCompactValue(out, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TraceCompactValue recursively truncates strings, slices, and maps for trace
|
||||||
|
// output. depth tracks recursion to enforce TraceMaxDepth.
|
||||||
|
func TraceCompactValue(v any, depth int) any {
|
||||||
|
if v == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if depth >= TraceMaxDepth {
|
||||||
|
switch t := v.(type) {
|
||||||
|
case string:
|
||||||
|
return TraceTruncateString(t)
|
||||||
|
case []any:
|
||||||
|
return fmt.Sprintf("<array len=%d>", len(t))
|
||||||
|
case map[string]any:
|
||||||
|
return fmt.Sprintf("<object keys=%d>", len(t))
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("<%T>", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
switch t := v.(type) {
|
||||||
|
case string:
|
||||||
|
return TraceTruncateString(t)
|
||||||
|
case []any:
|
||||||
|
limit := min(len(t), TraceMaxSliceItems)
|
||||||
|
out := make([]any, 0, limit+1)
|
||||||
|
for i := range limit {
|
||||||
|
out = append(out, TraceCompactValue(t[i], depth+1))
|
||||||
|
}
|
||||||
|
if len(t) > limit {
|
||||||
|
out = append(out, fmt.Sprintf("... +%d more items", len(t)-limit))
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
case map[string]any:
|
||||||
|
keys := make([]string, 0, len(t))
|
||||||
|
for k := range t {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
limit := min(len(keys), TraceMaxMapEntries)
|
||||||
|
out := make(map[string]any, limit+1)
|
||||||
|
for i := range limit {
|
||||||
|
out[keys[i]] = TraceCompactValue(t[keys[i]], depth+1)
|
||||||
|
}
|
||||||
|
if len(keys) > limit {
|
||||||
|
out["__truncated_keys"] = len(keys) - limit
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
default:
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Anthropic request/response tracing
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// TraceMessagesRequest returns a compact trace representation of a MessagesRequest.
|
||||||
|
func TraceMessagesRequest(r MessagesRequest) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"model": r.Model,
|
||||||
|
"max_tokens": r.MaxTokens,
|
||||||
|
"messages": traceMessageParams(r.Messages),
|
||||||
|
"system": traceAnthropicContent(r.System),
|
||||||
|
"stream": r.Stream,
|
||||||
|
"tools": traceTools(r.Tools),
|
||||||
|
"tool_choice": TraceJSON(r.ToolChoice),
|
||||||
|
"thinking": TraceJSON(r.Thinking),
|
||||||
|
"stop_sequences": r.StopSequences,
|
||||||
|
"temperature": ptrVal(r.Temperature),
|
||||||
|
"top_p": ptrVal(r.TopP),
|
||||||
|
"top_k": ptrVal(r.TopK),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TraceMessagesResponse returns a compact trace representation of a MessagesResponse.
|
||||||
|
func TraceMessagesResponse(r MessagesResponse) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"id": r.ID,
|
||||||
|
"model": r.Model,
|
||||||
|
"content": TraceJSON(r.Content),
|
||||||
|
"stop_reason": r.StopReason,
|
||||||
|
"usage": r.Usage,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func traceMessageParams(msgs []MessageParam) []map[string]any {
|
||||||
|
out := make([]map[string]any, 0, len(msgs))
|
||||||
|
for _, m := range msgs {
|
||||||
|
out = append(out, map[string]any{
|
||||||
|
"role": m.Role,
|
||||||
|
"content": traceAnthropicContent(m.Content),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func traceAnthropicContent(content any) any {
|
||||||
|
switch c := content.(type) {
|
||||||
|
case nil:
|
||||||
|
return nil
|
||||||
|
case string:
|
||||||
|
return TraceTruncateString(c)
|
||||||
|
case []any:
|
||||||
|
blocks := make([]any, 0, len(c))
|
||||||
|
for _, block := range c {
|
||||||
|
blockMap, ok := block.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
blocks = append(blocks, TraceCompactValue(block, 0))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
blocks = append(blocks, traceAnthropicBlock(blockMap))
|
||||||
|
}
|
||||||
|
return blocks
|
||||||
|
default:
|
||||||
|
return TraceJSON(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func traceAnthropicBlock(block map[string]any) map[string]any {
|
||||||
|
blockType, _ := block["type"].(string)
|
||||||
|
out := map[string]any{"type": blockType}
|
||||||
|
switch blockType {
|
||||||
|
case "text":
|
||||||
|
if text, ok := block["text"].(string); ok {
|
||||||
|
out["text"] = TraceTruncateString(text)
|
||||||
|
} else {
|
||||||
|
out["text"] = TraceCompactValue(block["text"], 0)
|
||||||
|
}
|
||||||
|
case "thinking":
|
||||||
|
if thinking, ok := block["thinking"].(string); ok {
|
||||||
|
out["thinking"] = TraceTruncateString(thinking)
|
||||||
|
} else {
|
||||||
|
out["thinking"] = TraceCompactValue(block["thinking"], 0)
|
||||||
|
}
|
||||||
|
case "tool_use", "server_tool_use":
|
||||||
|
out["id"] = block["id"]
|
||||||
|
out["name"] = block["name"]
|
||||||
|
out["input"] = TraceCompactValue(block["input"], 0)
|
||||||
|
case "tool_result", "web_search_tool_result":
|
||||||
|
out["tool_use_id"] = block["tool_use_id"]
|
||||||
|
out["content"] = TraceCompactValue(block["content"], 0)
|
||||||
|
case "image":
|
||||||
|
if source, ok := block["source"].(map[string]any); ok {
|
||||||
|
out["source"] = map[string]any{
|
||||||
|
"type": source["type"],
|
||||||
|
"media_type": source["media_type"],
|
||||||
|
"url": source["url"],
|
||||||
|
"data_len": len(fmt.Sprint(source["data"])),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
out["block"] = TraceCompactValue(block, 0)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func traceTools(tools []Tool) []map[string]any {
|
||||||
|
out := make([]map[string]any, 0, len(tools))
|
||||||
|
for _, t := range tools {
|
||||||
|
out = append(out, TraceTool(t))
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// TraceTool returns a compact trace representation of an Anthropic Tool.
|
||||||
|
func TraceTool(t Tool) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"type": t.Type,
|
||||||
|
"name": t.Name,
|
||||||
|
"description": TraceTruncateString(t.Description),
|
||||||
|
"input_schema": TraceJSON(t.InputSchema),
|
||||||
|
"max_uses": t.MaxUses,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContentBlockTypes returns the type strings from content (when it's []any blocks).
|
||||||
|
func ContentBlockTypes(content any) []string {
|
||||||
|
blocks, ok := content.([]any)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
types := make([]string, 0, len(blocks))
|
||||||
|
for _, block := range blocks {
|
||||||
|
blockMap, ok := block.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
types = append(types, fmt.Sprintf("%T", block))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
t, _ := blockMap["type"].(string)
|
||||||
|
types = append(types, t)
|
||||||
|
}
|
||||||
|
return types
|
||||||
|
}
|
||||||
|
|
||||||
|
func ptrVal[T any](v *T) any {
|
||||||
|
if v == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return *v
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Ollama api.* tracing (shared between anthropic and middleware packages)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// TraceChatRequest returns a compact trace representation of an Ollama ChatRequest.
|
||||||
|
func TraceChatRequest(req *api.ChatRequest) map[string]any {
|
||||||
|
if req == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
stream := false
|
||||||
|
if req.Stream != nil {
|
||||||
|
stream = *req.Stream
|
||||||
|
}
|
||||||
|
return map[string]any{
|
||||||
|
"model": req.Model,
|
||||||
|
"messages": TraceAPIMessages(req.Messages),
|
||||||
|
"tools": TraceAPITools(req.Tools),
|
||||||
|
"stream": stream,
|
||||||
|
"options": req.Options,
|
||||||
|
"think": TraceJSON(req.Think),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TraceChatResponse returns a compact trace representation of an Ollama ChatResponse.
|
||||||
|
func TraceChatResponse(resp api.ChatResponse) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"model": resp.Model,
|
||||||
|
"done": resp.Done,
|
||||||
|
"done_reason": resp.DoneReason,
|
||||||
|
"message": TraceAPIMessage(resp.Message),
|
||||||
|
"metrics": TraceJSON(resp.Metrics),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TraceAPIMessages returns compact trace representations for a slice of api.Message.
|
||||||
|
func TraceAPIMessages(msgs []api.Message) []map[string]any {
|
||||||
|
out := make([]map[string]any, 0, len(msgs))
|
||||||
|
for _, m := range msgs {
|
||||||
|
out = append(out, TraceAPIMessage(m))
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// TraceAPIMessage returns a compact trace representation of a single api.Message.
|
||||||
|
func TraceAPIMessage(m api.Message) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"role": m.Role,
|
||||||
|
"content": TraceTruncateString(m.Content),
|
||||||
|
"thinking": TraceTruncateString(m.Thinking),
|
||||||
|
"images": traceImageSizes(m.Images),
|
||||||
|
"tool_calls": traceToolCalls(m.ToolCalls),
|
||||||
|
"tool_name": m.ToolName,
|
||||||
|
"tool_call_id": m.ToolCallID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func traceImageSizes(images []api.ImageData) []int {
|
||||||
|
if len(images) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
sizes := make([]int, 0, len(images))
|
||||||
|
for _, img := range images {
|
||||||
|
sizes = append(sizes, len(img))
|
||||||
|
}
|
||||||
|
return sizes
|
||||||
|
}
|
||||||
|
|
||||||
|
// TraceAPITools returns compact trace representations for a slice of api.Tool.
|
||||||
|
func TraceAPITools(tools api.Tools) []map[string]any {
|
||||||
|
out := make([]map[string]any, 0, len(tools))
|
||||||
|
for _, t := range tools {
|
||||||
|
out = append(out, TraceAPITool(t))
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// TraceAPITool returns a compact trace representation of a single api.Tool.
|
||||||
|
func TraceAPITool(t api.Tool) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"type": t.Type,
|
||||||
|
"name": t.Function.Name,
|
||||||
|
"description": TraceTruncateString(t.Function.Description),
|
||||||
|
"parameters": TraceJSON(t.Function.Parameters),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TraceToolCall returns a compact trace representation of an api.ToolCall.
|
||||||
|
func TraceToolCall(tc api.ToolCall) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"id": tc.ID,
|
||||||
|
"name": tc.Function.Name,
|
||||||
|
"args": TraceJSON(tc.Function.Arguments),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func traceToolCalls(tcs []api.ToolCall) []map[string]any {
|
||||||
|
if len(tcs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]map[string]any, 0, len(tcs))
|
||||||
|
for _, tc := range tcs {
|
||||||
|
out = append(out, TraceToolCall(tc))
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
@@ -449,6 +449,16 @@ func (c *Client) Version(ctx context.Context) (string, error) {
|
|||||||
return version.Version, nil
|
return version.Version, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CloudStatusExperimental returns whether cloud features are disabled on the server.
|
||||||
|
func (c *Client) CloudStatusExperimental(ctx context.Context) (*StatusResponse, error) {
|
||||||
|
var status StatusResponse
|
||||||
|
if err := c.do(ctx, http.MethodGet, "/api/status", nil, &status); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &status, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Signout will signout a client for a local ollama server.
|
// Signout will signout a client for a local ollama server.
|
||||||
func (c *Client) Signout(ctx context.Context) error {
|
func (c *Client) Signout(ctx context.Context) error {
|
||||||
return c.do(ctx, http.MethodPost, "/api/signout", nil, nil)
|
return c.do(ctx, http.MethodPost, "/api/signout", nil, nil)
|
||||||
|
|||||||
10
api/types.go
10
api/types.go
@@ -834,6 +834,16 @@ type TokenResponse struct {
|
|||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CloudStatus struct {
|
||||||
|
Disabled bool `json:"disabled"`
|
||||||
|
Source string `json:"source"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatusResponse is the response from [Client.CloudStatusExperimental].
|
||||||
|
type StatusResponse struct {
|
||||||
|
Cloud CloudStatus `json:"cloud"`
|
||||||
|
}
|
||||||
|
|
||||||
// GenerateResponse is the response passed into [GenerateResponseFunc].
|
// GenerateResponse is the response passed into [GenerateResponseFunc].
|
||||||
type GenerateResponse struct {
|
type GenerateResponse struct {
|
||||||
// Model is the model name that generated the response.
|
// Model is the model name that generated the response.
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ import (
|
|||||||
var (
|
var (
|
||||||
wv = &Webview{}
|
wv = &Webview{}
|
||||||
uiServerPort int
|
uiServerPort int
|
||||||
|
appStore *store.Store
|
||||||
)
|
)
|
||||||
|
|
||||||
var debug = strings.EqualFold(os.Getenv("OLLAMA_DEBUG"), "true") || os.Getenv("OLLAMA_DEBUG") == "1"
|
var debug = strings.EqualFold(os.Getenv("OLLAMA_DEBUG"), "true") || os.Getenv("OLLAMA_DEBUG") == "1"
|
||||||
@@ -208,6 +209,7 @@ func main() {
|
|||||||
uiServerPort = port
|
uiServerPort = port
|
||||||
|
|
||||||
st := &store.Store{}
|
st := &store.Store{}
|
||||||
|
appStore = st
|
||||||
|
|
||||||
// Enable CORS in development mode
|
// Enable CORS in development mode
|
||||||
if devMode {
|
if devMode {
|
||||||
@@ -253,6 +255,8 @@ func main() {
|
|||||||
done <- osrv.Run(octx)
|
done <- osrv.Run(octx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
upd := &updater.Updater{Store: st}
|
||||||
|
|
||||||
uiServer := ui.Server{
|
uiServer := ui.Server{
|
||||||
Token: token,
|
Token: token,
|
||||||
Restart: func() {
|
Restart: func() {
|
||||||
@@ -267,6 +271,10 @@ func main() {
|
|||||||
ToolRegistry: toolRegistry,
|
ToolRegistry: toolRegistry,
|
||||||
Dev: devMode,
|
Dev: devMode,
|
||||||
Logger: slog.Default(),
|
Logger: slog.Default(),
|
||||||
|
Updater: upd,
|
||||||
|
UpdateAvailableFunc: func() {
|
||||||
|
UpdateAvailable("")
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
@@ -284,8 +292,20 @@ func main() {
|
|||||||
slog.Debug("background desktop server done")
|
slog.Debug("background desktop server done")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
updater := &updater.Updater{Store: st}
|
upd.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
|
||||||
updater.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
|
|
||||||
|
// Check for pending updates on startup (show tray notification if update is ready)
|
||||||
|
if updater.IsUpdatePending() {
|
||||||
|
// On Windows, the tray is initialized in osRun(). Calling UpdateAvailable
|
||||||
|
// before that would dereference a nil tray callback.
|
||||||
|
// TODO: refactor so the update check runs after platform init on all platforms.
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
slog.Debug("update pending on startup, deferring tray notification until tray initialization")
|
||||||
|
} else {
|
||||||
|
slog.Debug("update pending on startup, showing tray notification")
|
||||||
|
UpdateAvailable("")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
hasCompletedFirstRun, err := st.HasCompletedFirstRun()
|
hasCompletedFirstRun, err := st.HasCompletedFirstRun()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -348,6 +368,17 @@ func startHiddenTasks() {
|
|||||||
// CLI triggered app startup use-case
|
// CLI triggered app startup use-case
|
||||||
slog.Info("deferring pending update for fast startup")
|
slog.Info("deferring pending update for fast startup")
|
||||||
} else {
|
} else {
|
||||||
|
// Check if auto-update is enabled before automatically upgrading
|
||||||
|
settings, err := appStore.Settings()
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to load settings for upgrade check", "error", err)
|
||||||
|
} else if !settings.AutoUpdateEnabled {
|
||||||
|
slog.Info("auto-update disabled, skipping automatic upgrade at startup")
|
||||||
|
// Still show tray notification so user knows update is ready
|
||||||
|
UpdateAvailable("")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if err := updater.DoUpgradeAtStartup(); err != nil {
|
if err := updater.DoUpgradeAtStartup(); err != nil {
|
||||||
slog.Info("unable to perform upgrade at startup", "error", err)
|
slog.Info("unable to perform upgrade at startup", "error", err)
|
||||||
// Make sure the restart to upgrade menu shows so we can attempt an interactive upgrade to get authorization
|
// Make sure the restart to upgrade menu shows so we can attempt an interactive upgrade to get authorization
|
||||||
|
|||||||
@@ -154,6 +154,10 @@ func handleURLSchemeRequest(urlScheme string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func UpdateAvailable(ver string) error {
|
func UpdateAvailable(ver string) error {
|
||||||
|
if app.t == nil {
|
||||||
|
slog.Debug("tray not yet initialized, skipping update notification")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return app.t.UpdateAvailable(ver)
|
return app.t.UpdateAvailable(ver)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -165,6 +169,14 @@ func osRun(shutdown func(), hasCompletedFirstRun, startHidden bool) {
|
|||||||
log.Fatalf("Failed to start: %s", err)
|
log.Fatalf("Failed to start: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for pending updates now that the tray is initialized.
|
||||||
|
// The platform-independent check in app.go fires before osRun,
|
||||||
|
// when app.t is still nil, so we must re-check here.
|
||||||
|
if updater.IsUpdatePending() {
|
||||||
|
slog.Debug("update pending on startup, showing tray notification")
|
||||||
|
UpdateAvailable("")
|
||||||
|
}
|
||||||
|
|
||||||
signals := make(chan os.Signal, 1)
|
signals := make(chan os.Signal, 1)
|
||||||
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
|
|||||||
@@ -41,6 +41,11 @@ type InferenceCompute struct {
|
|||||||
VRAM string
|
VRAM string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type InferenceInfo struct {
|
||||||
|
Computes []InferenceCompute
|
||||||
|
DefaultContextLength int
|
||||||
|
}
|
||||||
|
|
||||||
func New(s *store.Store, devMode bool) *Server {
|
func New(s *store.Store, devMode bool) *Server {
|
||||||
p := resolvePath("ollama")
|
p := resolvePath("ollama")
|
||||||
return &Server{store: s, bin: p, dev: devMode}
|
return &Server{store: s, bin: p, dev: devMode}
|
||||||
@@ -205,6 +210,11 @@ func (s *Server) cmd(ctx context.Context) (*exec.Cmd, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cloudDisabled, err := s.store.CloudDisabled()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
cmd := commandContext(ctx, s.bin, "serve")
|
cmd := commandContext(ctx, s.bin, "serve")
|
||||||
cmd.Stdout, cmd.Stderr = s.log, s.log
|
cmd.Stdout, cmd.Stderr = s.log, s.log
|
||||||
|
|
||||||
@@ -230,6 +240,11 @@ func (s *Server) cmd(ctx context.Context) (*exec.Cmd, error) {
|
|||||||
if settings.ContextLength > 0 {
|
if settings.ContextLength > 0 {
|
||||||
env["OLLAMA_CONTEXT_LENGTH"] = strconv.Itoa(settings.ContextLength)
|
env["OLLAMA_CONTEXT_LENGTH"] = strconv.Itoa(settings.ContextLength)
|
||||||
}
|
}
|
||||||
|
if cloudDisabled {
|
||||||
|
env["OLLAMA_NO_CLOUD"] = "1"
|
||||||
|
} else {
|
||||||
|
env["OLLAMA_NO_CLOUD"] = "0"
|
||||||
|
}
|
||||||
cmd.Env = []string{}
|
cmd.Env = []string{}
|
||||||
for k, v := range env {
|
for k, v := range env {
|
||||||
cmd.Env = append(cmd.Env, k+"="+v)
|
cmd.Env = append(cmd.Env, k+"="+v)
|
||||||
@@ -262,9 +277,12 @@ func openRotatingLog() (io.WriteCloser, error) {
|
|||||||
|
|
||||||
// Attempt to retrieve inference compute information from the server
|
// Attempt to retrieve inference compute information from the server
|
||||||
// log. Set ctx to timeout to control how long to wait for the logs to appear
|
// log. Set ctx to timeout to control how long to wait for the logs to appear
|
||||||
func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
|
func GetInferenceInfo(ctx context.Context) (*InferenceInfo, error) {
|
||||||
inference := []InferenceCompute{}
|
info := &InferenceInfo{}
|
||||||
marker := regexp.MustCompile(`inference compute.*library=`)
|
computeMarker := regexp.MustCompile(`inference compute.*library=`)
|
||||||
|
defaultCtxMarker := regexp.MustCompile(`vram-based default context`)
|
||||||
|
defaultCtxRegex := regexp.MustCompile(`default_num_ctx=(\d+)`)
|
||||||
|
|
||||||
q := `inference compute.*%s=["]([^"]*)["]`
|
q := `inference compute.*%s=["]([^"]*)["]`
|
||||||
nq := `inference compute.*%s=(\S+)\s`
|
nq := `inference compute.*%s=(\S+)\s`
|
||||||
type regex struct {
|
type regex struct {
|
||||||
@@ -330,8 +348,8 @@ func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
|
|||||||
scanner := bufio.NewScanner(file)
|
scanner := bufio.NewScanner(file)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
match := marker.FindStringSubmatch(line)
|
// Check for inference compute lines
|
||||||
if len(match) > 0 {
|
if computeMarker.MatchString(line) {
|
||||||
ic := InferenceCompute{
|
ic := InferenceCompute{
|
||||||
Library: get("library", line),
|
Library: get("library", line),
|
||||||
Variant: get("variant", line),
|
Variant: get("variant", line),
|
||||||
@@ -342,12 +360,25 @@ func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("Matched", "inference compute", ic)
|
slog.Info("Matched", "inference compute", ic)
|
||||||
inference = append(inference, ic)
|
info.Computes = append(info.Computes, ic)
|
||||||
} else {
|
continue
|
||||||
// Break out on first non matching line after we start matching
|
|
||||||
if len(inference) > 0 {
|
|
||||||
return inference, nil
|
|
||||||
}
|
}
|
||||||
|
// Check for default context length line
|
||||||
|
if defaultCtxMarker.MatchString(line) {
|
||||||
|
match := defaultCtxRegex.FindStringSubmatch(line)
|
||||||
|
if len(match) > 1 {
|
||||||
|
numCtx, err := strconv.Atoi(match[1])
|
||||||
|
if err == nil {
|
||||||
|
info.DefaultContextLength = numCtx
|
||||||
|
slog.Info("Matched default context length", "default_num_ctx", numCtx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return info, nil
|
||||||
|
}
|
||||||
|
// If we've found compute info but hit a non-matching line, return what we have
|
||||||
|
// This handles older server versions that don't log the default context line
|
||||||
|
if len(info.Computes) > 0 {
|
||||||
|
return info, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|||||||
@@ -111,7 +111,7 @@ func TestServerCmd(t *testing.T) {
|
|||||||
for _, want := range tt.want {
|
for _, want := range tt.want {
|
||||||
found := false
|
found := false
|
||||||
for _, env := range cmd.Env {
|
for _, env := range cmd.Env {
|
||||||
if strings.Contains(env, want) {
|
if strings.HasPrefix(env, want) {
|
||||||
found = true
|
found = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -123,7 +123,7 @@ func TestServerCmd(t *testing.T) {
|
|||||||
|
|
||||||
for _, dont := range tt.dont {
|
for _, dont := range tt.dont {
|
||||||
for _, env := range cmd.Env {
|
for _, env := range cmd.Env {
|
||||||
if strings.Contains(env, dont) {
|
if strings.HasPrefix(env, dont) {
|
||||||
t.Errorf("unexpected environment variable: %s", env)
|
t.Errorf("unexpected environment variable: %s", env)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -136,44 +136,119 @@ func TestServerCmd(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetInferenceComputer(t *testing.T) {
|
func TestServerCmdCloudSettingEnv(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
envValue string
|
||||||
|
configContent string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default cloud enabled",
|
||||||
|
want: "OLLAMA_NO_CLOUD=0",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "env disables cloud",
|
||||||
|
envValue: "1",
|
||||||
|
want: "OLLAMA_NO_CLOUD=1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "config disables cloud",
|
||||||
|
configContent: `{"disable_ollama_cloud": true}`,
|
||||||
|
want: "OLLAMA_NO_CLOUD=1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid env disables cloud",
|
||||||
|
envValue: "invalid",
|
||||||
|
want: "OLLAMA_NO_CLOUD=1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tmpHome := t.TempDir()
|
||||||
|
t.Setenv("HOME", tmpHome)
|
||||||
|
t.Setenv("USERPROFILE", tmpHome)
|
||||||
|
t.Setenv("OLLAMA_NO_CLOUD", tt.envValue)
|
||||||
|
|
||||||
|
if tt.configContent != "" {
|
||||||
|
configDir := filepath.Join(tmpHome, ".ollama")
|
||||||
|
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||||
|
t.Fatalf("mkdir config dir: %v", err)
|
||||||
|
}
|
||||||
|
configPath := filepath.Join(configDir, "server.json")
|
||||||
|
if err := os.WriteFile(configPath, []byte(tt.configContent), 0o644); err != nil {
|
||||||
|
t.Fatalf("write config: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
st := &store.Store{DBPath: filepath.Join(t.TempDir(), "db.sqlite")}
|
||||||
|
defer st.Close()
|
||||||
|
|
||||||
|
s := &Server{store: st}
|
||||||
|
cmd, err := s.cmd(t.Context())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("s.cmd() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for _, env := range cmd.Env {
|
||||||
|
if env == tt.want {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("expected environment variable %q in command env", tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetInferenceInfo(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
log string
|
log string
|
||||||
exp []InferenceCompute
|
expComputes []InferenceCompute
|
||||||
|
expDefaultCtxLen int
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "metal",
|
name: "metal",
|
||||||
log: `time=2025-06-30T09:23:07.374-07:00 level=DEBUG source=sched.go:108 msg="starting llm scheduler"
|
log: `time=2025-06-30T09:23:07.374-07:00 level=DEBUG source=sched.go:108 msg="starting llm scheduler"
|
||||||
time=2025-06-30T09:23:07.416-07:00 level=INFO source=types.go:130 msg="inference compute" id=0 library=metal variant="" compute="" driver=0.0 name="" total="96.0 GiB" available="96.0 GiB"
|
time=2025-06-30T09:23:07.416-07:00 level=INFO source=types.go:130 msg="inference compute" id=0 library=metal variant="" compute="" driver=0.0 name="" total="96.0 GiB" available="96.0 GiB"
|
||||||
|
time=2025-06-30T09:23:07.417-07:00 level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="96.0 GiB" default_num_ctx=262144
|
||||||
time=2025-06-30T09:25:56.197-07:00 level=DEBUG source=ggml.go:155 msg="key not found" key=general.alignment default=32
|
time=2025-06-30T09:25:56.197-07:00 level=DEBUG source=ggml.go:155 msg="key not found" key=general.alignment default=32
|
||||||
`,
|
`,
|
||||||
exp: []InferenceCompute{{
|
expComputes: []InferenceCompute{{
|
||||||
Library: "metal",
|
Library: "metal",
|
||||||
Driver: "0.0",
|
Driver: "0.0",
|
||||||
VRAM: "96.0 GiB",
|
VRAM: "96.0 GiB",
|
||||||
}},
|
}},
|
||||||
|
expDefaultCtxLen: 262144,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "cpu",
|
name: "cpu",
|
||||||
log: `time=2025-07-01T17:59:51.470Z level=INFO source=gpu.go:377 msg="no compatible GPUs were discovered"
|
log: `time=2025-07-01T17:59:51.470Z level=INFO source=gpu.go:377 msg="no compatible GPUs were discovered"
|
||||||
time=2025-07-01T17:59:51.470Z level=INFO source=types.go:130 msg="inference compute" id=0 library=cpu variant="" compute="" driver=0.0 name="" total="31.3 GiB" available="30.4 GiB"
|
time=2025-07-01T17:59:51.470Z level=INFO source=types.go:130 msg="inference compute" id=0 library=cpu variant="" compute="" driver=0.0 name="" total="31.3 GiB" available="30.4 GiB"
|
||||||
|
time=2025-07-01T17:59:51.471Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="31.3 GiB" default_num_ctx=32768
|
||||||
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
|
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
|
||||||
`,
|
`,
|
||||||
exp: []InferenceCompute{{
|
expComputes: []InferenceCompute{{
|
||||||
Library: "cpu",
|
Library: "cpu",
|
||||||
Driver: "0.0",
|
Driver: "0.0",
|
||||||
VRAM: "31.3 GiB",
|
VRAM: "31.3 GiB",
|
||||||
}},
|
}},
|
||||||
|
expDefaultCtxLen: 32768,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "cuda1",
|
name: "cuda1",
|
||||||
log: `time=2025-07-01T19:33:43.162Z level=DEBUG source=amd_linux.go:419 msg="amdgpu driver not detected /sys/module/amdgpu"
|
log: `time=2025-07-01T19:33:43.162Z level=DEBUG source=amd_linux.go:419 msg="amdgpu driver not detected /sys/module/amdgpu"
|
||||||
releasing cuda driver library
|
releasing cuda driver library
|
||||||
time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference compute" id=GPU-452cac9f-6960-839c-4fb3-0cec83699196 library=cuda variant=v12 compute=6.1 driver=12.7 name="NVIDIA GeForce GT 1030" total="3.9 GiB" available="3.9 GiB"
|
time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference compute" id=GPU-452cac9f-6960-839c-4fb3-0cec83699196 library=cuda variant=v12 compute=6.1 driver=12.7 name="NVIDIA GeForce GT 1030" total="3.9 GiB" available="3.9 GiB"
|
||||||
|
time=2025-07-01T19:33:43.163Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="3.9 GiB" default_num_ctx=4096
|
||||||
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
|
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
|
||||||
`,
|
`,
|
||||||
exp: []InferenceCompute{{
|
expComputes: []InferenceCompute{{
|
||||||
Library: "cuda",
|
Library: "cuda",
|
||||||
Variant: "v12",
|
Variant: "v12",
|
||||||
Compute: "6.1",
|
Compute: "6.1",
|
||||||
@@ -181,6 +256,7 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
|
|||||||
Name: "NVIDIA GeForce GT 1030",
|
Name: "NVIDIA GeForce GT 1030",
|
||||||
VRAM: "3.9 GiB",
|
VRAM: "3.9 GiB",
|
||||||
}},
|
}},
|
||||||
|
expDefaultCtxLen: 4096,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "frank",
|
name: "frank",
|
||||||
@@ -188,9 +264,10 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
|
|||||||
releasing cuda driver library
|
releasing cuda driver library
|
||||||
time=2025-07-01T19:36:13.315Z level=INFO source=types.go:130 msg="inference compute" id=GPU-d6de3398-9932-6902-11ec-fee8e424c8a2 library=cuda variant=v12 compute=7.5 driver=12.8 name="NVIDIA GeForce RTX 2080 Ti" total="10.6 GiB" available="10.4 GiB"
|
time=2025-07-01T19:36:13.315Z level=INFO source=types.go:130 msg="inference compute" id=GPU-d6de3398-9932-6902-11ec-fee8e424c8a2 library=cuda variant=v12 compute=7.5 driver=12.8 name="NVIDIA GeForce RTX 2080 Ti" total="10.6 GiB" available="10.4 GiB"
|
||||||
time=2025-07-01T19:36:13.315Z level=INFO source=types.go:130 msg="inference compute" id=GPU-9abb57639fa80c50 library=rocm variant="" compute=gfx1030 driver=6.3 name=1002:73bf total="16.0 GiB" available="1.3 GiB"
|
time=2025-07-01T19:36:13.315Z level=INFO source=types.go:130 msg="inference compute" id=GPU-9abb57639fa80c50 library=rocm variant="" compute=gfx1030 driver=6.3 name=1002:73bf total="16.0 GiB" available="1.3 GiB"
|
||||||
|
time=2025-07-01T19:36:13.316Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="26.6 GiB" default_num_ctx=32768
|
||||||
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
|
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
|
||||||
`,
|
`,
|
||||||
exp: []InferenceCompute{
|
expComputes: []InferenceCompute{
|
||||||
{
|
{
|
||||||
Library: "cuda",
|
Library: "cuda",
|
||||||
Variant: "v12",
|
Variant: "v12",
|
||||||
@@ -207,6 +284,20 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
|
|||||||
VRAM: "16.0 GiB",
|
VRAM: "16.0 GiB",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
expDefaultCtxLen: 32768,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing_default_context",
|
||||||
|
log: `time=2025-06-30T09:23:07.374-07:00 level=DEBUG source=sched.go:108 msg="starting llm scheduler"
|
||||||
|
time=2025-06-30T09:23:07.416-07:00 level=INFO source=types.go:130 msg="inference compute" id=0 library=metal variant="" compute="" driver=0.0 name="" total="96.0 GiB" available="96.0 GiB"
|
||||||
|
time=2025-06-30T09:25:56.197-07:00 level=DEBUG source=ggml.go:155 msg="key not found" key=general.alignment default=32
|
||||||
|
`,
|
||||||
|
expComputes: []InferenceCompute{{
|
||||||
|
Library: "metal",
|
||||||
|
Driver: "0.0",
|
||||||
|
VRAM: "96.0 GiB",
|
||||||
|
}},
|
||||||
|
expDefaultCtxLen: 0, // No default context line, should return 0
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -219,18 +310,21 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
|
|||||||
}
|
}
|
||||||
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond)
|
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
ics, err := GetInferenceComputer(ctx)
|
info, err := GetInferenceInfo(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf(" failed to get inference compute: %v", err)
|
t.Fatalf("failed to get inference info: %v", err)
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(ics, tt.exp) {
|
if !reflect.DeepEqual(info.Computes, tt.expComputes) {
|
||||||
t.Fatalf("got:\n%#v\nwant:\n%#v", ics, tt.exp)
|
t.Fatalf("computes mismatch\ngot:\n%#v\nwant:\n%#v", info.Computes, tt.expComputes)
|
||||||
|
}
|
||||||
|
if info.DefaultContextLength != tt.expDefaultCtxLen {
|
||||||
|
t.Fatalf("default context length mismatch: got %d, want %d", info.DefaultContextLength, tt.expDefaultCtxLen)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetInferenceComputerTimeout(t *testing.T) {
|
func TestGetInferenceInfoTimeout(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond)
|
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
@@ -239,7 +333,7 @@ func TestGetInferenceComputerTimeout(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to write log file %s: %s", serverLogPath, err)
|
t.Fatalf("failed to write log file %s: %s", serverLogPath, err)
|
||||||
}
|
}
|
||||||
_, err = GetInferenceComputer(ctx)
|
_, err = GetInferenceInfo(ctx)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected timeout")
|
t.Fatal("expected timeout")
|
||||||
}
|
}
|
||||||
|
|||||||
128
app/store/cloud_config.go
Normal file
128
app/store/cloud_config.go
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
//go:build windows || darwin
|
||||||
|
|
||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
)
|
||||||
|
|
||||||
|
const serverConfigFilename = "server.json"
|
||||||
|
|
||||||
|
type serverConfig struct {
|
||||||
|
DisableOllamaCloud bool `json:"disable_ollama_cloud,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloudDisabled returns whether cloud features should be disabled.
|
||||||
|
// The source of truth is: OLLAMA_NO_CLOUD OR ~/.ollama/server.json:disable_ollama_cloud.
|
||||||
|
func (s *Store) CloudDisabled() (bool, error) {
|
||||||
|
disabled, _, err := s.CloudStatus()
|
||||||
|
return disabled, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloudStatus returns whether cloud is disabled and the source of that decision.
|
||||||
|
// Source is one of: "none", "env", "config", "both".
|
||||||
|
func (s *Store) CloudStatus() (bool, string, error) {
|
||||||
|
if err := s.ensureDB(); err != nil {
|
||||||
|
return false, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
configDisabled, err := readServerConfigCloudDisabled()
|
||||||
|
if err != nil {
|
||||||
|
return false, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
envDisabled := envconfig.NoCloudEnv()
|
||||||
|
return envDisabled || configDisabled, cloudStatusSource(envDisabled, configDisabled), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCloudEnabled writes the cloud setting to ~/.ollama/server.json.
|
||||||
|
func (s *Store) SetCloudEnabled(enabled bool) error {
|
||||||
|
if err := s.ensureDB(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return setCloudEnabled(enabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setCloudEnabled(enabled bool) error {
|
||||||
|
configPath, err := serverConfigPath()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||||
|
return fmt.Errorf("create server config directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
configMap := map[string]any{}
|
||||||
|
if data, err := os.ReadFile(configPath); err == nil {
|
||||||
|
if err := json.Unmarshal(data, &configMap); err != nil {
|
||||||
|
// If the existing file is invalid JSON, overwrite with a fresh object.
|
||||||
|
configMap = map[string]any{}
|
||||||
|
}
|
||||||
|
} else if !errors.Is(err, os.ErrNotExist) {
|
||||||
|
return fmt.Errorf("read server config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
configMap["disable_ollama_cloud"] = !enabled
|
||||||
|
|
||||||
|
data, err := json.MarshalIndent(configMap, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal server config: %w", err)
|
||||||
|
}
|
||||||
|
data = append(data, '\n')
|
||||||
|
|
||||||
|
if err := os.WriteFile(configPath, data, 0o644); err != nil {
|
||||||
|
return fmt.Errorf("write server config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readServerConfigCloudDisabled() (bool, error) {
|
||||||
|
configPath, err := serverConfigPath()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, fmt.Errorf("read server config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg serverConfig
|
||||||
|
// Invalid or unexpected JSON should not block startup; treat as default.
|
||||||
|
if json.Unmarshal(data, &cfg) == nil {
|
||||||
|
return cfg.DisableOllamaCloud, nil
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func serverConfigPath() (string, error) {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("resolve home directory: %w", err)
|
||||||
|
}
|
||||||
|
return filepath.Join(home, ".ollama", serverConfigFilename), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloudStatusSource(envDisabled bool, configDisabled bool) string {
|
||||||
|
switch {
|
||||||
|
case envDisabled && configDisabled:
|
||||||
|
return "both"
|
||||||
|
case envDisabled:
|
||||||
|
return "env"
|
||||||
|
case configDisabled:
|
||||||
|
return "config"
|
||||||
|
default:
|
||||||
|
return "none"
|
||||||
|
}
|
||||||
|
}
|
||||||
130
app/store/cloud_config_test.go
Normal file
130
app/store/cloud_config_test.go
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
//go:build windows || darwin
|
||||||
|
|
||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCloudDisabled(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
envValue string
|
||||||
|
configContent string
|
||||||
|
wantDisabled bool
|
||||||
|
wantSource string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default enabled",
|
||||||
|
wantDisabled: false,
|
||||||
|
wantSource: "none",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "env disables cloud",
|
||||||
|
envValue: "1",
|
||||||
|
wantDisabled: true,
|
||||||
|
wantSource: "env",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "config disables cloud",
|
||||||
|
configContent: `{"disable_ollama_cloud": true}`,
|
||||||
|
wantDisabled: true,
|
||||||
|
wantSource: "config",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "env and config",
|
||||||
|
envValue: "1",
|
||||||
|
configContent: `{"disable_ollama_cloud": false}`,
|
||||||
|
wantDisabled: true,
|
||||||
|
wantSource: "env",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid config is ignored",
|
||||||
|
configContent: `{bad`,
|
||||||
|
wantDisabled: false,
|
||||||
|
wantSource: "none",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tmpHome := t.TempDir()
|
||||||
|
setTestHome(t, tmpHome)
|
||||||
|
t.Setenv("OLLAMA_NO_CLOUD", tt.envValue)
|
||||||
|
|
||||||
|
if tt.configContent != "" {
|
||||||
|
configDir := filepath.Join(tmpHome, ".ollama")
|
||||||
|
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||||
|
t.Fatalf("mkdir config dir: %v", err)
|
||||||
|
}
|
||||||
|
configPath := filepath.Join(configDir, serverConfigFilename)
|
||||||
|
if err := os.WriteFile(configPath, []byte(tt.configContent), 0o644); err != nil {
|
||||||
|
t.Fatalf("write config: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &Store{DBPath: filepath.Join(tmpHome, "db.sqlite")}
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
disabled, err := s.CloudDisabled()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CloudDisabled() error = %v", err)
|
||||||
|
}
|
||||||
|
if disabled != tt.wantDisabled {
|
||||||
|
t.Fatalf("CloudDisabled() = %v, want %v", disabled, tt.wantDisabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
statusDisabled, source, err := s.CloudStatus()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CloudStatus() error = %v", err)
|
||||||
|
}
|
||||||
|
if statusDisabled != tt.wantDisabled {
|
||||||
|
t.Fatalf("CloudStatus() disabled = %v, want %v", statusDisabled, tt.wantDisabled)
|
||||||
|
}
|
||||||
|
if source != tt.wantSource {
|
||||||
|
t.Fatalf("CloudStatus() source = %v, want %v", source, tt.wantSource)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetCloudEnabled(t *testing.T) {
|
||||||
|
tmpHome := t.TempDir()
|
||||||
|
setTestHome(t, tmpHome)
|
||||||
|
|
||||||
|
configDir := filepath.Join(tmpHome, ".ollama")
|
||||||
|
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||||
|
t.Fatalf("mkdir config dir: %v", err)
|
||||||
|
}
|
||||||
|
configPath := filepath.Join(configDir, serverConfigFilename)
|
||||||
|
if err := os.WriteFile(configPath, []byte(`{"another_key":"value","disable_ollama_cloud":true}`), 0o644); err != nil {
|
||||||
|
t.Fatalf("seed config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &Store{DBPath: filepath.Join(tmpHome, "db.sqlite")}
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
if err := s.SetCloudEnabled(true); err != nil {
|
||||||
|
t.Fatalf("SetCloudEnabled(true) error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var got map[string]any
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Fatalf("unmarshal config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got["disable_ollama_cloud"] != false {
|
||||||
|
t.Fatalf("disable_ollama_cloud = %v, want false", got["disable_ollama_cloud"])
|
||||||
|
}
|
||||||
|
if got["another_key"] != "value" {
|
||||||
|
t.Fatalf("another_key = %v, want value", got["another_key"])
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -9,12 +9,12 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
sqlite3 "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
// currentSchemaVersion defines the current database schema version.
|
// currentSchemaVersion defines the current database schema version.
|
||||||
// Increment this when making schema changes that require migrations.
|
// Increment this when making schema changes that require migrations.
|
||||||
const currentSchemaVersion = 12
|
const currentSchemaVersion = 15
|
||||||
|
|
||||||
// database wraps the SQLite connection.
|
// database wraps the SQLite connection.
|
||||||
// SQLite handles its own locking for concurrent access:
|
// SQLite handles its own locking for concurrent access:
|
||||||
@@ -73,7 +73,7 @@ func (db *database) init() error {
|
|||||||
agent BOOLEAN NOT NULL DEFAULT 0,
|
agent BOOLEAN NOT NULL DEFAULT 0,
|
||||||
tools BOOLEAN NOT NULL DEFAULT 0,
|
tools BOOLEAN NOT NULL DEFAULT 0,
|
||||||
working_dir TEXT NOT NULL DEFAULT '',
|
working_dir TEXT NOT NULL DEFAULT '',
|
||||||
context_length INTEGER NOT NULL DEFAULT 4096,
|
context_length INTEGER NOT NULL DEFAULT 0,
|
||||||
window_width INTEGER NOT NULL DEFAULT 0,
|
window_width INTEGER NOT NULL DEFAULT 0,
|
||||||
window_height INTEGER NOT NULL DEFAULT 0,
|
window_height INTEGER NOT NULL DEFAULT 0,
|
||||||
config_migrated BOOLEAN NOT NULL DEFAULT 0,
|
config_migrated BOOLEAN NOT NULL DEFAULT 0,
|
||||||
@@ -84,7 +84,9 @@ func (db *database) init() error {
|
|||||||
sidebar_open BOOLEAN NOT NULL DEFAULT 0,
|
sidebar_open BOOLEAN NOT NULL DEFAULT 0,
|
||||||
think_enabled BOOLEAN NOT NULL DEFAULT 0,
|
think_enabled BOOLEAN NOT NULL DEFAULT 0,
|
||||||
think_level TEXT NOT NULL DEFAULT '',
|
think_level TEXT NOT NULL DEFAULT '',
|
||||||
|
cloud_setting_migrated BOOLEAN NOT NULL DEFAULT 0,
|
||||||
remote TEXT NOT NULL DEFAULT '', -- deprecated
|
remote TEXT NOT NULL DEFAULT '', -- deprecated
|
||||||
|
auto_update_enabled BOOLEAN NOT NULL DEFAULT 1,
|
||||||
schema_version INTEGER NOT NULL DEFAULT %d
|
schema_version INTEGER NOT NULL DEFAULT %d
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -244,6 +246,24 @@ func (db *database) migrate() error {
|
|||||||
return fmt.Errorf("migrate v11 to v12: %w", err)
|
return fmt.Errorf("migrate v11 to v12: %w", err)
|
||||||
}
|
}
|
||||||
version = 12
|
version = 12
|
||||||
|
case 12:
|
||||||
|
// add cloud_setting_migrated column to settings table
|
||||||
|
if err := db.migrateV12ToV13(); err != nil {
|
||||||
|
return fmt.Errorf("migrate v12 to v13: %w", err)
|
||||||
|
}
|
||||||
|
version = 13
|
||||||
|
case 13:
|
||||||
|
// change default context_length from 4096 to 0 (VRAM-based tiered defaults)
|
||||||
|
if err := db.migrateV13ToV14(); err != nil {
|
||||||
|
return fmt.Errorf("migrate v13 to v14: %w", err)
|
||||||
|
}
|
||||||
|
version = 14
|
||||||
|
case 14:
|
||||||
|
// add auto_update_enabled column to settings table
|
||||||
|
if err := db.migrateV14ToV15(); err != nil {
|
||||||
|
return fmt.Errorf("migrate v14 to v15: %w", err)
|
||||||
|
}
|
||||||
|
version = 15
|
||||||
default:
|
default:
|
||||||
// If we have a version we don't recognize, just set it to current
|
// If we have a version we don't recognize, just set it to current
|
||||||
// This might happen during development
|
// This might happen during development
|
||||||
@@ -452,6 +472,52 @@ func (db *database) migrateV11ToV12() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// migrateV12ToV13 adds cloud_setting_migrated to settings.
|
||||||
|
func (db *database) migrateV12ToV13() error {
|
||||||
|
_, err := db.conn.Exec(`ALTER TABLE settings ADD COLUMN cloud_setting_migrated BOOLEAN NOT NULL DEFAULT 0`)
|
||||||
|
if err != nil && !duplicateColumnError(err) {
|
||||||
|
return fmt.Errorf("add cloud_setting_migrated column: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 13`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update schema version: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// migrateV13ToV14 changes the default context_length from 4096 to 0.
|
||||||
|
// When context_length is 0, the ollama server uses VRAM-based tiered defaults.
|
||||||
|
func (db *database) migrateV13ToV14() error {
|
||||||
|
_, err := db.conn.Exec(`UPDATE settings SET context_length = 0 WHERE context_length = 4096`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update context_length default: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 14`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update schema version: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// migrateV14ToV15 adds the auto_update_enabled column to the settings table
|
||||||
|
func (db *database) migrateV14ToV15() error {
|
||||||
|
_, err := db.conn.Exec(`ALTER TABLE settings ADD COLUMN auto_update_enabled BOOLEAN NOT NULL DEFAULT 1`)
|
||||||
|
if err != nil && !duplicateColumnError(err) {
|
||||||
|
return fmt.Errorf("add auto_update_enabled column: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 15`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update schema version: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
|
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
|
||||||
func (db *database) cleanupOrphanedData() error {
|
func (db *database) cleanupOrphanedData() error {
|
||||||
_, err := db.conn.Exec(`
|
_, err := db.conn.Exec(`
|
||||||
@@ -482,19 +548,11 @@ func (db *database) cleanupOrphanedData() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func duplicateColumnError(err error) bool {
|
func duplicateColumnError(err error) bool {
|
||||||
if sqlite3Err, ok := err.(sqlite3.Error); ok {
|
return err != nil && strings.Contains(err.Error(), "duplicate column name")
|
||||||
return sqlite3Err.Code == sqlite3.ErrError &&
|
|
||||||
strings.Contains(sqlite3Err.Error(), "duplicate column name")
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func columnNotExists(err error) bool {
|
func columnNotExists(err error) bool {
|
||||||
if sqlite3Err, ok := err.(sqlite3.Error); ok {
|
return err != nil && strings.Contains(err.Error(), "no such column")
|
||||||
return sqlite3Err.Code == sqlite3.ErrError &&
|
|
||||||
strings.Contains(sqlite3Err.Error(), "no such column")
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *database) getAllChats() ([]Chat, error) {
|
func (db *database) getAllChats() ([]Chat, error) {
|
||||||
@@ -1108,9 +1166,9 @@ func (db *database) getSettings() (Settings, error) {
|
|||||||
var s Settings
|
var s Settings
|
||||||
|
|
||||||
err := db.conn.QueryRow(`
|
err := db.conn.QueryRow(`
|
||||||
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, airplane_mode, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level
|
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level, auto_update_enabled
|
||||||
FROM settings
|
FROM settings
|
||||||
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.AirplaneMode, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel)
|
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel, &s.AutoUpdateEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Settings{}, fmt.Errorf("get settings: %w", err)
|
return Settings{}, fmt.Errorf("get settings: %w", err)
|
||||||
}
|
}
|
||||||
@@ -1121,14 +1179,40 @@ func (db *database) getSettings() (Settings, error) {
|
|||||||
func (db *database) setSettings(s Settings) error {
|
func (db *database) setSettings(s Settings) error {
|
||||||
_, err := db.conn.Exec(`
|
_, err := db.conn.Exec(`
|
||||||
UPDATE settings
|
UPDATE settings
|
||||||
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, airplane_mode = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?
|
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?, auto_update_enabled = ?
|
||||||
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.AirplaneMode, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel)
|
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel, s.AutoUpdateEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("set settings: %w", err)
|
return fmt.Errorf("set settings: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db *database) isCloudSettingMigrated() (bool, error) {
|
||||||
|
var migrated bool
|
||||||
|
err := db.conn.QueryRow("SELECT cloud_setting_migrated FROM settings").Scan(&migrated)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("get cloud setting migration status: %w", err)
|
||||||
|
}
|
||||||
|
return migrated, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *database) setCloudSettingMigrated(migrated bool) error {
|
||||||
|
_, err := db.conn.Exec("UPDATE settings SET cloud_setting_migrated = ?", migrated)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("set cloud setting migration status: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *database) getAirplaneMode() (bool, error) {
|
||||||
|
var airplaneMode bool
|
||||||
|
err := db.conn.QueryRow("SELECT airplane_mode FROM settings").Scan(&airplaneMode)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("get airplane_mode: %w", err)
|
||||||
|
}
|
||||||
|
return airplaneMode, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (db *database) getWindowSize() (int, int, error) {
|
func (db *database) getWindowSize() (int, int, error) {
|
||||||
var width, height int
|
var width, height int
|
||||||
err := db.conn.QueryRow("SELECT window_width, window_height FROM settings").Scan(&width, &height)
|
err := db.conn.QueryRow("SELECT window_width, window_height FROM settings").Scan(&width, &height)
|
||||||
|
|||||||
@@ -98,6 +98,43 @@ func TestSchemaMigrations(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMigrationV13ToV14ContextLength(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
dbPath := filepath.Join(tmpDir, "test.db")
|
||||||
|
|
||||||
|
db, err := newDatabase(dbPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create database: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
_, err = db.conn.Exec("UPDATE settings SET context_length = 4096, schema_version = 13")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to seed v13 settings row: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.migrate(); err != nil {
|
||||||
|
t.Fatalf("migration from v13 to v14 failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var contextLength int
|
||||||
|
if err := db.conn.QueryRow("SELECT context_length FROM settings").Scan(&contextLength); err != nil {
|
||||||
|
t.Fatalf("failed to read context_length: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if contextLength != 0 {
|
||||||
|
t.Fatalf("expected context_length to migrate to 0, got %d", contextLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
version, err := db.getSchemaVersion()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get schema version: %v", err)
|
||||||
|
}
|
||||||
|
if version != currentSchemaVersion {
|
||||||
|
t.Fatalf("expected schema version %d, got %d", currentSchemaVersion, version)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestChatDeletionWithCascade(t *testing.T) {
|
func TestChatDeletionWithCascade(t *testing.T) {
|
||||||
t.Run("chat deletion cascades to related messages", func(t *testing.T) {
|
t.Run("chat deletion cascades to related messages", func(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
|
|||||||
@@ -127,6 +127,65 @@ func TestNoConfigToMigrate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCloudMigrationFromAirplaneMode(t *testing.T) {
|
||||||
|
tmpHome := t.TempDir()
|
||||||
|
setTestHome(t, tmpHome)
|
||||||
|
t.Setenv("OLLAMA_NO_CLOUD", "")
|
||||||
|
|
||||||
|
dbPath := filepath.Join(tmpHome, "db.sqlite")
|
||||||
|
db, err := newDatabase(dbPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create database: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := db.conn.Exec("UPDATE settings SET airplane_mode = 1, cloud_setting_migrated = 0"); err != nil {
|
||||||
|
db.Close()
|
||||||
|
t.Fatalf("failed to seed airplane migration state: %v", err)
|
||||||
|
}
|
||||||
|
db.Close()
|
||||||
|
|
||||||
|
s := Store{DBPath: dbPath}
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
// Trigger DB initialization + one-time cloud migration.
|
||||||
|
if _, err := s.ID(); err != nil {
|
||||||
|
t.Fatalf("failed to initialize store: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
disabled, err := s.CloudDisabled()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CloudDisabled() error: %v", err)
|
||||||
|
}
|
||||||
|
if !disabled {
|
||||||
|
t.Fatal("expected cloud to be disabled after migrating airplane_mode=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
configPath := filepath.Join(tmpHome, ".ollama", serverConfigFilename)
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read migrated server config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg map[string]any
|
||||||
|
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||||
|
t.Fatalf("failed to parse migrated server config: %v", err)
|
||||||
|
}
|
||||||
|
if cfg["disable_ollama_cloud"] != true {
|
||||||
|
t.Fatalf("disable_ollama_cloud = %v, want true", cfg["disable_ollama_cloud"])
|
||||||
|
}
|
||||||
|
|
||||||
|
var airplaneMode, migrated bool
|
||||||
|
if err := s.db.conn.QueryRow("SELECT airplane_mode, cloud_setting_migrated FROM settings").Scan(&airplaneMode, &migrated); err != nil {
|
||||||
|
t.Fatalf("failed to read migration flags from DB: %v", err)
|
||||||
|
}
|
||||||
|
if !airplaneMode {
|
||||||
|
t.Fatal("expected legacy airplane_mode value to remain unchanged")
|
||||||
|
}
|
||||||
|
if !migrated {
|
||||||
|
t.Fatal("expected cloud_setting_migrated to be true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
v1Schema = `
|
v1Schema = `
|
||||||
CREATE TABLE IF NOT EXISTS settings (
|
CREATE TABLE IF NOT EXISTS settings (
|
||||||
|
|||||||
@@ -149,9 +149,6 @@ type Settings struct {
|
|||||||
// ContextLength specifies the context length for the ollama server (using OLLAMA_CONTEXT_LENGTH)
|
// ContextLength specifies the context length for the ollama server (using OLLAMA_CONTEXT_LENGTH)
|
||||||
ContextLength int
|
ContextLength int
|
||||||
|
|
||||||
// AirplaneMode when true, turns off Ollama Turbo features and only uses local models
|
|
||||||
AirplaneMode bool
|
|
||||||
|
|
||||||
// TurboEnabled indicates if Ollama Turbo features are enabled
|
// TurboEnabled indicates if Ollama Turbo features are enabled
|
||||||
TurboEnabled bool
|
TurboEnabled bool
|
||||||
|
|
||||||
@@ -169,6 +166,9 @@ type Settings struct {
|
|||||||
|
|
||||||
// SidebarOpen indicates if the chat sidebar is open
|
// SidebarOpen indicates if the chat sidebar is open
|
||||||
SidebarOpen bool
|
SidebarOpen bool
|
||||||
|
|
||||||
|
// AutoUpdateEnabled indicates if automatic updates should be downloaded
|
||||||
|
AutoUpdateEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type Store struct {
|
type Store struct {
|
||||||
@@ -259,6 +259,40 @@ func (s *Store) ensureDB() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Run one-time migration from legacy airplane_mode behavior.
|
||||||
|
if err := s.migrateCloudSetting(database); err != nil {
|
||||||
|
return fmt.Errorf("migrate cloud setting: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// migrateCloudSetting migrates legacy airplane_mode into server.json exactly once.
|
||||||
|
// After this, cloud state is sourced from server.json OR OLLAMA_NO_CLOUD.
|
||||||
|
func (s *Store) migrateCloudSetting(database *database) error {
|
||||||
|
migrated, err := database.isCloudSettingMigrated()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if migrated {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
airplaneMode, err := database.getAirplaneMode()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if airplaneMode {
|
||||||
|
if err := setCloudEnabled(false); err != nil {
|
||||||
|
return fmt.Errorf("migrate airplane_mode to cloud disabled: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := database.setCloudSettingMigrated(true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
11
app/store/test_home_test.go
Normal file
11
app/store/test_home_test.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
//go:build windows || darwin
|
||||||
|
|
||||||
|
package store
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func setTestHome(t *testing.T, home string) {
|
||||||
|
t.Helper()
|
||||||
|
t.Setenv("HOME", home)
|
||||||
|
t.Setenv("USERPROFILE", home)
|
||||||
|
}
|
||||||
2
app/store/testdata/schema.sql
vendored
2
app/store/testdata/schema.sql
vendored
@@ -13,7 +13,7 @@ CREATE TABLE IF NOT EXISTS settings (
|
|||||||
agent BOOLEAN NOT NULL DEFAULT 0,
|
agent BOOLEAN NOT NULL DEFAULT 0,
|
||||||
tools BOOLEAN NOT NULL DEFAULT 0,
|
tools BOOLEAN NOT NULL DEFAULT 0,
|
||||||
working_dir TEXT NOT NULL DEFAULT '',
|
working_dir TEXT NOT NULL DEFAULT '',
|
||||||
context_length INTEGER NOT NULL DEFAULT 4096,
|
context_length INTEGER NOT NULL DEFAULT 0,
|
||||||
window_width INTEGER NOT NULL DEFAULT 0,
|
window_width INTEGER NOT NULL DEFAULT 0,
|
||||||
window_height INTEGER NOT NULL DEFAULT 0,
|
window_height INTEGER NOT NULL DEFAULT 0,
|
||||||
config_migrated BOOLEAN NOT NULL DEFAULT 0,
|
config_migrated BOOLEAN NOT NULL DEFAULT 0,
|
||||||
|
|||||||
35
app/tools/cloud_policy.go
Normal file
35
app/tools/cloud_policy.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
//go:build windows || darwin
|
||||||
|
|
||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ensureCloudEnabledForTool checks cloud policy from the connected Ollama server.
|
||||||
|
// If policy cannot be determined, this fails closed and blocks the operation.
|
||||||
|
func ensureCloudEnabledForTool(ctx context.Context, operation string) error {
|
||||||
|
// Reuse shared message formatting; policy evaluation is still done via
|
||||||
|
// the connected server's /api/status endpoint below.
|
||||||
|
disabledMessage := internalcloud.DisabledError(operation)
|
||||||
|
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
return errors.New(disabledMessage + " (unable to verify server cloud policy)")
|
||||||
|
}
|
||||||
|
|
||||||
|
status, err := client.CloudStatusExperimental(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New(disabledMessage + " (unable to verify server cloud policy)")
|
||||||
|
}
|
||||||
|
|
||||||
|
if status.Cloud.Disabled {
|
||||||
|
return errors.New(disabledMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
73
app/tools/cloud_policy_test.go
Normal file
73
app/tools/cloud_policy_test.go
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
//go:build windows || darwin
|
||||||
|
|
||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEnsureCloudEnabledForTool(t *testing.T) {
|
||||||
|
const op = "web search is unavailable"
|
||||||
|
const disabledPrefix = "ollama cloud is disabled: web search is unavailable"
|
||||||
|
|
||||||
|
t.Run("enabled allows tool execution", func(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/api/status" {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"cloud":{"disabled":false,"source":"none"}}`))
|
||||||
|
}))
|
||||||
|
t.Cleanup(ts.Close)
|
||||||
|
t.Setenv("OLLAMA_HOST", ts.URL)
|
||||||
|
|
||||||
|
if err := ensureCloudEnabledForTool(context.Background(), op); err != nil {
|
||||||
|
t.Fatalf("expected nil error, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("disabled blocks tool execution", func(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/api/status" {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"cloud":{"disabled":true,"source":"config"}}`))
|
||||||
|
}))
|
||||||
|
t.Cleanup(ts.Close)
|
||||||
|
t.Setenv("OLLAMA_HOST", ts.URL)
|
||||||
|
|
||||||
|
err := ensureCloudEnabledForTool(context.Background(), op)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if got := err.Error(); got != disabledPrefix {
|
||||||
|
t.Fatalf("unexpected error: %q", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("status unavailable fails closed", func(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}))
|
||||||
|
t.Cleanup(ts.Close)
|
||||||
|
t.Setenv("OLLAMA_HOST", ts.URL)
|
||||||
|
|
||||||
|
err := ensureCloudEnabledForTool(context.Background(), op)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if got := err.Error(); !strings.Contains(got, disabledPrefix) {
|
||||||
|
t.Fatalf("expected disabled prefix, got %q", got)
|
||||||
|
}
|
||||||
|
if got := err.Error(); !strings.Contains(got, "unable to verify server cloud policy") {
|
||||||
|
t.Fatalf("expected verification failure detail, got %q", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -77,6 +77,10 @@ func (w *WebFetch) Execute(ctx context.Context, args map[string]any) (any, strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
func performWebFetch(ctx context.Context, targetURL string) (*FetchResponse, error) {
|
func performWebFetch(ctx context.Context, targetURL string) (*FetchResponse, error) {
|
||||||
|
if err := ensureCloudEnabledForTool(ctx, "web fetch is unavailable"); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
reqBody := FetchRequest{URL: targetURL}
|
reqBody := FetchRequest{URL: targetURL}
|
||||||
jsonBody, err := json.Marshal(reqBody)
|
jsonBody, err := json.Marshal(reqBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -93,6 +93,10 @@ func (w *WebSearch) Execute(ctx context.Context, args map[string]any) (any, stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
func performWebSearch(ctx context.Context, query string, maxResults int) (*SearchResponse, error) {
|
func performWebSearch(ctx context.Context, query string, maxResults int) (*SearchResponse, error) {
|
||||||
|
if err := ensureCloudEnabledForTool(ctx, "web search is unavailable"); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
reqBody := SearchRequest{Query: query, MaxResults: maxResults}
|
reqBody := SearchRequest{Query: query, MaxResults: maxResults}
|
||||||
|
|
||||||
jsonBody, err := json.Marshal(reqBody)
|
jsonBody, err := json.Marshal(reqBody)
|
||||||
|
|||||||
@@ -289,10 +289,12 @@ export class InferenceCompute {
|
|||||||
}
|
}
|
||||||
export class InferenceComputeResponse {
|
export class InferenceComputeResponse {
|
||||||
inferenceComputes: InferenceCompute[];
|
inferenceComputes: InferenceCompute[];
|
||||||
|
defaultContextLength: number;
|
||||||
|
|
||||||
constructor(source: any = {}) {
|
constructor(source: any = {}) {
|
||||||
if ('string' === typeof source) source = JSON.parse(source);
|
if ('string' === typeof source) source = JSON.parse(source);
|
||||||
this.inferenceComputes = this.convertValues(source["inferenceComputes"], InferenceCompute);
|
this.inferenceComputes = this.convertValues(source["inferenceComputes"], InferenceCompute);
|
||||||
|
this.defaultContextLength = source["defaultContextLength"];
|
||||||
}
|
}
|
||||||
|
|
||||||
convertValues(a: any, classs: any, asMap: boolean = false): any {
|
convertValues(a: any, classs: any, asMap: boolean = false): any {
|
||||||
@@ -406,13 +408,13 @@ export class Settings {
|
|||||||
Tools: boolean;
|
Tools: boolean;
|
||||||
WorkingDir: string;
|
WorkingDir: string;
|
||||||
ContextLength: number;
|
ContextLength: number;
|
||||||
AirplaneMode: boolean;
|
|
||||||
TurboEnabled: boolean;
|
TurboEnabled: boolean;
|
||||||
WebSearchEnabled: boolean;
|
WebSearchEnabled: boolean;
|
||||||
ThinkEnabled: boolean;
|
ThinkEnabled: boolean;
|
||||||
ThinkLevel: string;
|
ThinkLevel: string;
|
||||||
SelectedModel: string;
|
SelectedModel: string;
|
||||||
SidebarOpen: boolean;
|
SidebarOpen: boolean;
|
||||||
|
AutoUpdateEnabled: boolean;
|
||||||
|
|
||||||
constructor(source: any = {}) {
|
constructor(source: any = {}) {
|
||||||
if ('string' === typeof source) source = JSON.parse(source);
|
if ('string' === typeof source) source = JSON.parse(source);
|
||||||
@@ -424,13 +426,13 @@ export class Settings {
|
|||||||
this.Tools = source["Tools"];
|
this.Tools = source["Tools"];
|
||||||
this.WorkingDir = source["WorkingDir"];
|
this.WorkingDir = source["WorkingDir"];
|
||||||
this.ContextLength = source["ContextLength"];
|
this.ContextLength = source["ContextLength"];
|
||||||
this.AirplaneMode = source["AirplaneMode"];
|
|
||||||
this.TurboEnabled = source["TurboEnabled"];
|
this.TurboEnabled = source["TurboEnabled"];
|
||||||
this.WebSearchEnabled = source["WebSearchEnabled"];
|
this.WebSearchEnabled = source["WebSearchEnabled"];
|
||||||
this.ThinkEnabled = source["ThinkEnabled"];
|
this.ThinkEnabled = source["ThinkEnabled"];
|
||||||
this.ThinkLevel = source["ThinkLevel"];
|
this.ThinkLevel = source["ThinkLevel"];
|
||||||
this.SelectedModel = source["SelectedModel"];
|
this.SelectedModel = source["SelectedModel"];
|
||||||
this.SidebarOpen = source["SidebarOpen"];
|
this.SidebarOpen = source["SidebarOpen"];
|
||||||
|
this.AutoUpdateEnabled = source["AutoUpdateEnabled"];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
export class SettingsResponse {
|
export class SettingsResponse {
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import {
|
|||||||
ChatEvent,
|
ChatEvent,
|
||||||
DownloadEvent,
|
DownloadEvent,
|
||||||
ErrorEvent,
|
ErrorEvent,
|
||||||
InferenceCompute,
|
|
||||||
InferenceComputeResponse,
|
InferenceComputeResponse,
|
||||||
ModelCapabilitiesResponse,
|
ModelCapabilitiesResponse,
|
||||||
Model,
|
Model,
|
||||||
@@ -27,6 +26,12 @@ declare module "@/gotypes" {
|
|||||||
Model.prototype.isCloud = function (): boolean {
|
Model.prototype.isCloud = function (): boolean {
|
||||||
return this.model.endsWith("cloud");
|
return this.model.endsWith("cloud");
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type CloudStatusSource = "env" | "config" | "both" | "none";
|
||||||
|
export interface CloudStatusResponse {
|
||||||
|
disabled: boolean;
|
||||||
|
source: CloudStatusSource;
|
||||||
|
}
|
||||||
// Helper function to convert Uint8Array to base64
|
// Helper function to convert Uint8Array to base64
|
||||||
function uint8ArrayToBase64(uint8Array: Uint8Array): string {
|
function uint8ArrayToBase64(uint8Array: Uint8Array): string {
|
||||||
const chunkSize = 0x8000; // 32KB chunks to avoid stack overflow
|
const chunkSize = 0x8000; // 32KB chunks to avoid stack overflow
|
||||||
@@ -285,6 +290,28 @@ export async function updateSettings(settings: Settings): Promise<{
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function updateCloudSetting(
|
||||||
|
enabled: boolean,
|
||||||
|
): Promise<CloudStatusResponse> {
|
||||||
|
const response = await fetch(`${API_BASE}/api/v1/cloud`, {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({ enabled }),
|
||||||
|
});
|
||||||
|
if (!response.ok) {
|
||||||
|
const error = await response.text();
|
||||||
|
throw new Error(error || "Failed to update cloud setting");
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
return {
|
||||||
|
disabled: Boolean(data.disabled),
|
||||||
|
source: (data.source as CloudStatusSource) || "none",
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
export async function renameChat(chatId: string, title: string): Promise<void> {
|
export async function renameChat(chatId: string, title: string): Promise<void> {
|
||||||
const response = await fetch(`${API_BASE}/api/v1/chat/${chatId}/rename`, {
|
const response = await fetch(`${API_BASE}/api/v1/chat/${chatId}/rename`, {
|
||||||
method: "PUT",
|
method: "PUT",
|
||||||
@@ -379,7 +406,7 @@ export async function* pullModel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function getInferenceCompute(): Promise<InferenceCompute[]> {
|
export async function getInferenceCompute(): Promise<InferenceComputeResponse> {
|
||||||
const response = await fetch(`${API_BASE}/api/v1/inference-compute`);
|
const response = await fetch(`${API_BASE}/api/v1/inference-compute`);
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
@@ -388,8 +415,7 @@ export async function getInferenceCompute(): Promise<InferenceCompute[]> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
const inferenceComputeResponse = new InferenceComputeResponse(data);
|
return new InferenceComputeResponse(data);
|
||||||
return inferenceComputeResponse.inferenceComputes || [];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function fetchHealth(): Promise<boolean> {
|
export async function fetchHealth(): Promise<boolean> {
|
||||||
@@ -414,3 +440,16 @@ export async function fetchHealth(): Promise<boolean> {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function getCloudStatus(): Promise<CloudStatusResponse | null> {
|
||||||
|
const response = await fetch(`${API_BASE}/api/v1/cloud`);
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Failed to fetch cloud status: ${response.status}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
return {
|
||||||
|
disabled: Boolean(data.disabled),
|
||||||
|
source: (data.source as CloudStatusSource) || "none",
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|||||||
@@ -17,11 +17,15 @@ import {
|
|||||||
} from "@/hooks/useChats";
|
} from "@/hooks/useChats";
|
||||||
import { useNavigate } from "@tanstack/react-router";
|
import { useNavigate } from "@tanstack/react-router";
|
||||||
import { useSelectedModel } from "@/hooks/useSelectedModel";
|
import { useSelectedModel } from "@/hooks/useSelectedModel";
|
||||||
import { useHasVisionCapability } from "@/hooks/useModelCapabilities";
|
import {
|
||||||
|
useHasVisionCapability,
|
||||||
|
useHasToolsCapability,
|
||||||
|
} from "@/hooks/useModelCapabilities";
|
||||||
import { useUser } from "@/hooks/useUser";
|
import { useUser } from "@/hooks/useUser";
|
||||||
import { DisplayLogin } from "@/components/DisplayLogin";
|
import { DisplayLogin } from "@/components/DisplayLogin";
|
||||||
import { ErrorEvent, Message } from "@/gotypes";
|
import { ErrorEvent, Message } from "@/gotypes";
|
||||||
import { useSettings } from "@/hooks/useSettings";
|
import { useSettings } from "@/hooks/useSettings";
|
||||||
|
import { useCloudStatus } from "@/hooks/useCloudStatus";
|
||||||
import { ThinkButton } from "./ThinkButton";
|
import { ThinkButton } from "./ThinkButton";
|
||||||
import { ErrorMessage } from "./ErrorMessage";
|
import { ErrorMessage } from "./ErrorMessage";
|
||||||
import { processFiles } from "@/utils/fileValidation";
|
import { processFiles } from "@/utils/fileValidation";
|
||||||
@@ -141,19 +145,14 @@ function ChatForm({
|
|||||||
const {
|
const {
|
||||||
settings: {
|
settings: {
|
||||||
webSearchEnabled,
|
webSearchEnabled,
|
||||||
airplaneMode,
|
|
||||||
thinkEnabled,
|
thinkEnabled,
|
||||||
thinkLevel: settingsThinkLevel,
|
thinkLevel: settingsThinkLevel,
|
||||||
},
|
},
|
||||||
setSettings,
|
setSettings,
|
||||||
} = useSettings();
|
} = useSettings();
|
||||||
|
const { cloudDisabled } = useCloudStatus();
|
||||||
|
|
||||||
// current supported models for web search
|
const supportsWebSearch = useHasToolsCapability(selectedModel?.model);
|
||||||
const modelLower = selectedModel?.model.toLowerCase() || "";
|
|
||||||
const supportsWebSearch =
|
|
||||||
modelLower.startsWith("gpt-oss") ||
|
|
||||||
modelLower.startsWith("qwen3") ||
|
|
||||||
modelLower.startsWith("deepseek-v3");
|
|
||||||
// Use per-chat thinking level instead of global
|
// Use per-chat thinking level instead of global
|
||||||
const thinkLevel: ThinkingLevel =
|
const thinkLevel: ThinkingLevel =
|
||||||
settingsThinkLevel === "none" || !settingsThinkLevel
|
settingsThinkLevel === "none" || !settingsThinkLevel
|
||||||
@@ -180,6 +179,12 @@ function ChatForm({
|
|||||||
setSettings,
|
setSettings,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (cloudDisabled && webSearchEnabled) {
|
||||||
|
setSettings({ WebSearchEnabled: false });
|
||||||
|
}
|
||||||
|
}, [cloudDisabled, webSearchEnabled, setSettings]);
|
||||||
|
|
||||||
const removeFile = (index: number) => {
|
const removeFile = (index: number) => {
|
||||||
setMessage((prev) => ({
|
setMessage((prev) => ({
|
||||||
...prev,
|
...prev,
|
||||||
@@ -234,19 +239,19 @@ function ChatForm({
|
|||||||
|
|
||||||
// Determine if login banner should be shown
|
// Determine if login banner should be shown
|
||||||
const shouldShowLoginBanner =
|
const shouldShowLoginBanner =
|
||||||
|
!cloudDisabled &&
|
||||||
!isLoadingUser &&
|
!isLoadingUser &&
|
||||||
!isAuthenticated &&
|
!isAuthenticated &&
|
||||||
((webSearchEnabled && supportsWebSearch) ||
|
((webSearchEnabled && supportsWebSearch) || selectedModel?.isCloud());
|
||||||
(selectedModel?.isCloud() && !airplaneMode));
|
|
||||||
|
|
||||||
// Determine which feature to highlight in the banner
|
// Determine which feature to highlight in the banner
|
||||||
const getActiveFeatureForBanner = () => {
|
const getActiveFeatureForBanner = () => {
|
||||||
|
if (cloudDisabled) return null;
|
||||||
if (!isAuthenticated) {
|
if (!isAuthenticated) {
|
||||||
if (loginPromptFeature) return loginPromptFeature;
|
if (loginPromptFeature) return loginPromptFeature;
|
||||||
if (webSearchEnabled && selectedModel?.isCloud() && !airplaneMode)
|
if (webSearchEnabled && selectedModel?.isCloud()) return "webSearch";
|
||||||
return "webSearch";
|
|
||||||
if (webSearchEnabled) return "webSearch";
|
if (webSearchEnabled) return "webSearch";
|
||||||
if (selectedModel?.isCloud() && !airplaneMode) return "turbo";
|
if (selectedModel?.isCloud()) return "turbo";
|
||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
};
|
};
|
||||||
@@ -269,11 +274,12 @@ function ChatForm({
|
|||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (
|
if (
|
||||||
isAuthenticated ||
|
isAuthenticated ||
|
||||||
(!webSearchEnabled && !!selectedModel?.isCloud() && !airplaneMode)
|
cloudDisabled ||
|
||||||
|
(!webSearchEnabled && !!selectedModel?.isCloud())
|
||||||
) {
|
) {
|
||||||
setLoginPromptFeature(null);
|
setLoginPromptFeature(null);
|
||||||
}
|
}
|
||||||
}, [isAuthenticated, webSearchEnabled, selectedModel, airplaneMode]);
|
}, [isAuthenticated, webSearchEnabled, selectedModel, cloudDisabled]);
|
||||||
|
|
||||||
// When entering edit mode, populate the composition with existing data
|
// When entering edit mode, populate the composition with existing data
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@@ -465,6 +471,10 @@ function ChatForm({
|
|||||||
const handleSubmit = async () => {
|
const handleSubmit = async () => {
|
||||||
if (!message.content.trim() || isStreaming || isDownloading) return;
|
if (!message.content.trim() || isStreaming || isDownloading) return;
|
||||||
|
|
||||||
|
if (cloudDisabled && selectedModel?.isCloud()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Check if cloud mode is enabled but user is not authenticated
|
// Check if cloud mode is enabled but user is not authenticated
|
||||||
if (shouldShowLoginBanner) {
|
if (shouldShowLoginBanner) {
|
||||||
return;
|
return;
|
||||||
@@ -478,7 +488,8 @@ function ChatForm({
|
|||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
|
|
||||||
const useWebSearch = supportsWebSearch && webSearchEnabled && !airplaneMode;
|
const useWebSearch =
|
||||||
|
supportsWebSearch && webSearchEnabled && !cloudDisabled;
|
||||||
const useThink = modelSupportsThinkingLevels
|
const useThink = modelSupportsThinkingLevels
|
||||||
? thinkLevel
|
? thinkLevel
|
||||||
: supportsThinkToggling
|
: supportsThinkToggling
|
||||||
@@ -899,7 +910,7 @@ function ChatForm({
|
|||||||
)}
|
)}
|
||||||
<WebSearchButton
|
<WebSearchButton
|
||||||
ref={webSearchButtonRef}
|
ref={webSearchButtonRef}
|
||||||
isVisible={supportsWebSearch && airplaneMode === false}
|
isVisible={supportsWebSearch && cloudDisabled === false}
|
||||||
isActive={webSearchEnabled}
|
isActive={webSearchEnabled}
|
||||||
onToggle={() => {
|
onToggle={() => {
|
||||||
if (!webSearchEnabled && !isAuthenticated) {
|
if (!webSearchEnabled && !isAuthenticated) {
|
||||||
@@ -940,6 +951,7 @@ function ChatForm({
|
|||||||
!isDownloading &&
|
!isDownloading &&
|
||||||
(!message.content.trim() ||
|
(!message.content.trim() ||
|
||||||
shouldShowLoginBanner ||
|
shouldShowLoginBanner ||
|
||||||
|
(cloudDisabled && selectedModel?.isCloud()) ||
|
||||||
message.fileErrors.length > 0)
|
message.fileErrors.length > 0)
|
||||||
}
|
}
|
||||||
className={`flex items-center justify-center h-9 w-9 rounded-full disabled:cursor-default cursor-pointer bg-black text-white dark:bg-white dark:text-black disabled:opacity-10 focus:outline-none focus:ring-2 focus:ring-blue-500`}
|
className={`flex items-center justify-center h-9 w-9 rounded-full disabled:cursor-default cursor-pointer bg-black text-white dark:bg-white dark:text-black disabled:opacity-10 focus:outline-none focus:ring-2 focus:ring-blue-500`}
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import {
|
|||||||
} from "react";
|
} from "react";
|
||||||
import { Model } from "@/gotypes";
|
import { Model } from "@/gotypes";
|
||||||
import { useSelectedModel } from "@/hooks/useSelectedModel";
|
import { useSelectedModel } from "@/hooks/useSelectedModel";
|
||||||
import { useSettings } from "@/hooks/useSettings";
|
import { useCloudStatus } from "@/hooks/useCloudStatus";
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
import { getModelUpstreamInfo } from "@/api";
|
import { getModelUpstreamInfo } from "@/api";
|
||||||
import { ArrowDownTrayIcon } from "@heroicons/react/24/outline";
|
import { ArrowDownTrayIcon } from "@heroicons/react/24/outline";
|
||||||
@@ -34,7 +34,7 @@ export const ModelPicker = forwardRef<
|
|||||||
chatId,
|
chatId,
|
||||||
searchQuery,
|
searchQuery,
|
||||||
);
|
);
|
||||||
const { settings } = useSettings();
|
const { cloudDisabled } = useCloudStatus();
|
||||||
const dropdownRef = useRef<HTMLDivElement>(null);
|
const dropdownRef = useRef<HTMLDivElement>(null);
|
||||||
const searchInputRef = useRef<HTMLInputElement>(null);
|
const searchInputRef = useRef<HTMLInputElement>(null);
|
||||||
const queryClient = useQueryClient();
|
const queryClient = useQueryClient();
|
||||||
@@ -219,7 +219,7 @@ export const ModelPicker = forwardRef<
|
|||||||
models={models}
|
models={models}
|
||||||
selectedModel={selectedModel}
|
selectedModel={selectedModel}
|
||||||
onModelSelect={handleModelSelect}
|
onModelSelect={handleModelSelect}
|
||||||
airplaneMode={settings.airplaneMode}
|
cloudDisabled={cloudDisabled}
|
||||||
isOpen={isOpen}
|
isOpen={isOpen}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
@@ -233,13 +233,13 @@ export const ModelList = forwardRef(function ModelList(
|
|||||||
models,
|
models,
|
||||||
selectedModel,
|
selectedModel,
|
||||||
onModelSelect,
|
onModelSelect,
|
||||||
airplaneMode,
|
cloudDisabled,
|
||||||
isOpen,
|
isOpen,
|
||||||
}: {
|
}: {
|
||||||
models: Model[];
|
models: Model[];
|
||||||
selectedModel: Model | null;
|
selectedModel: Model | null;
|
||||||
onModelSelect: (model: Model) => void;
|
onModelSelect: (model: Model) => void;
|
||||||
airplaneMode: boolean;
|
cloudDisabled: boolean;
|
||||||
isOpen: boolean;
|
isOpen: boolean;
|
||||||
},
|
},
|
||||||
ref,
|
ref,
|
||||||
@@ -348,7 +348,7 @@ export const ModelList = forwardRef(function ModelList(
|
|||||||
</svg>
|
</svg>
|
||||||
)}
|
)}
|
||||||
{model.digest === undefined &&
|
{model.digest === undefined &&
|
||||||
(airplaneMode || !model.isCloud()) && (
|
(cloudDisabled || !model.isCloud()) && (
|
||||||
<ArrowDownTrayIcon
|
<ArrowDownTrayIcon
|
||||||
className="h-4 w-4 text-neutral-500 dark:text-neutral-400"
|
className="h-4 w-4 text-neutral-500 dark:text-neutral-400"
|
||||||
strokeWidth={1.75}
|
strokeWidth={1.75}
|
||||||
|
|||||||
@@ -11,15 +11,24 @@ import {
|
|||||||
FolderIcon,
|
FolderIcon,
|
||||||
BoltIcon,
|
BoltIcon,
|
||||||
WrenchIcon,
|
WrenchIcon,
|
||||||
|
CloudIcon,
|
||||||
XMarkIcon,
|
XMarkIcon,
|
||||||
CogIcon,
|
CogIcon,
|
||||||
ArrowLeftIcon,
|
ArrowLeftIcon,
|
||||||
|
ArrowDownTrayIcon,
|
||||||
} from "@heroicons/react/20/solid";
|
} from "@heroicons/react/20/solid";
|
||||||
import { Settings as SettingsType } from "@/gotypes";
|
import { Settings as SettingsType } from "@/gotypes";
|
||||||
import { useNavigate } from "@tanstack/react-router";
|
import { useNavigate } from "@tanstack/react-router";
|
||||||
import { useUser } from "@/hooks/useUser";
|
import { useUser } from "@/hooks/useUser";
|
||||||
|
import { useCloudStatus } from "@/hooks/useCloudStatus";
|
||||||
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
|
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
|
||||||
import { getSettings, updateSettings } from "@/api";
|
import {
|
||||||
|
getSettings,
|
||||||
|
type CloudStatusResponse,
|
||||||
|
updateCloudSetting,
|
||||||
|
updateSettings,
|
||||||
|
getInferenceCompute,
|
||||||
|
} from "@/api";
|
||||||
|
|
||||||
function AnimatedDots() {
|
function AnimatedDots() {
|
||||||
return (
|
return (
|
||||||
@@ -53,6 +62,11 @@ export default function Settings() {
|
|||||||
const [connectionError, setConnectionError] = useState<string | null>(null);
|
const [connectionError, setConnectionError] = useState<string | null>(null);
|
||||||
const [pollingInterval, setPollingInterval] = useState<number | null>(null);
|
const [pollingInterval, setPollingInterval] = useState<number | null>(null);
|
||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
|
const {
|
||||||
|
cloudDisabled,
|
||||||
|
cloudStatus,
|
||||||
|
isLoading: cloudStatusLoading,
|
||||||
|
} = useCloudStatus();
|
||||||
|
|
||||||
const {
|
const {
|
||||||
data: settingsData,
|
data: settingsData,
|
||||||
@@ -65,6 +79,13 @@ export default function Settings() {
|
|||||||
|
|
||||||
const settings = settingsData?.settings || null;
|
const settings = settingsData?.settings || null;
|
||||||
|
|
||||||
|
const { data: inferenceComputeResponse } = useQuery({
|
||||||
|
queryKey: ["inferenceCompute"],
|
||||||
|
queryFn: getInferenceCompute,
|
||||||
|
});
|
||||||
|
|
||||||
|
const defaultContextLength = inferenceComputeResponse?.defaultContextLength;
|
||||||
|
|
||||||
const updateSettingsMutation = useMutation({
|
const updateSettingsMutation = useMutation({
|
||||||
mutationFn: updateSettings,
|
mutationFn: updateSettings,
|
||||||
onSuccess: () => {
|
onSuccess: () => {
|
||||||
@@ -74,6 +95,50 @@ export default function Settings() {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const updateCloudMutation = useMutation({
|
||||||
|
mutationFn: (enabled: boolean) => updateCloudSetting(enabled),
|
||||||
|
onMutate: async (enabled: boolean) => {
|
||||||
|
await queryClient.cancelQueries({ queryKey: ["cloudStatus"] });
|
||||||
|
|
||||||
|
const previous = queryClient.getQueryData<CloudStatusResponse | null>([
|
||||||
|
"cloudStatus",
|
||||||
|
]);
|
||||||
|
const envForcesDisabled =
|
||||||
|
previous?.source === "env" || previous?.source === "both";
|
||||||
|
|
||||||
|
queryClient.setQueryData<CloudStatusResponse | null>(
|
||||||
|
["cloudStatus"],
|
||||||
|
previous
|
||||||
|
? {
|
||||||
|
...previous,
|
||||||
|
disabled: !enabled || envForcesDisabled,
|
||||||
|
}
|
||||||
|
: {
|
||||||
|
disabled: !enabled,
|
||||||
|
source: "config",
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
return { previous };
|
||||||
|
},
|
||||||
|
onError: (_error, _enabled, context) => {
|
||||||
|
if (context?.previous !== undefined) {
|
||||||
|
queryClient.setQueryData(["cloudStatus"], context.previous);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
onSuccess: (status) => {
|
||||||
|
queryClient.setQueryData<CloudStatusResponse | null>(
|
||||||
|
["cloudStatus"],
|
||||||
|
status,
|
||||||
|
);
|
||||||
|
queryClient.invalidateQueries({ queryKey: ["models"] });
|
||||||
|
queryClient.invalidateQueries({ queryKey: ["cloudStatus"] });
|
||||||
|
|
||||||
|
setShowSaved(true);
|
||||||
|
setTimeout(() => setShowSaved(false), 1500);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
refetchUser();
|
refetchUser();
|
||||||
}, []); // eslint-disable-line react-hooks/exhaustive-deps
|
}, []); // eslint-disable-line react-hooks/exhaustive-deps
|
||||||
@@ -148,13 +213,18 @@ export default function Settings() {
|
|||||||
Models: "",
|
Models: "",
|
||||||
Agent: false,
|
Agent: false,
|
||||||
Tools: false,
|
Tools: false,
|
||||||
ContextLength: 4096,
|
ContextLength: 0,
|
||||||
AirplaneMode: false,
|
AutoUpdateEnabled: true,
|
||||||
});
|
});
|
||||||
updateSettingsMutation.mutate(defaultSettings);
|
updateSettingsMutation.mutate(defaultSettings);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const cloudOverriddenByEnv =
|
||||||
|
cloudStatus?.source === "env" || cloudStatus?.source === "both";
|
||||||
|
const cloudToggleDisabled =
|
||||||
|
cloudStatusLoading || updateCloudMutation.isPending || cloudOverriddenByEnv;
|
||||||
|
|
||||||
const handleConnectOllamaAccount = async () => {
|
const handleConnectOllamaAccount = async () => {
|
||||||
setConnectionError(null);
|
setConnectionError(null);
|
||||||
|
|
||||||
@@ -237,7 +307,7 @@ export default function Settings() {
|
|||||||
<div className="space-y-4 max-w-2xl mx-auto">
|
<div className="space-y-4 max-w-2xl mx-auto">
|
||||||
{/* Connect Ollama Account */}
|
{/* Connect Ollama Account */}
|
||||||
<div className="overflow-hidden rounded-xl bg-white dark:bg-neutral-800">
|
<div className="overflow-hidden rounded-xl bg-white dark:bg-neutral-800">
|
||||||
<div className="p-4 border-b border-neutral-200 dark:border-neutral-800">
|
<div className="p-4">
|
||||||
<Field>
|
<Field>
|
||||||
{isLoading ? (
|
{isLoading ? (
|
||||||
// Loading skeleton, this will only happen if the app started recently
|
// Loading skeleton, this will only happen if the app started recently
|
||||||
@@ -344,6 +414,57 @@ export default function Settings() {
|
|||||||
{/* Local Configuration */}
|
{/* Local Configuration */}
|
||||||
<div className="relative overflow-hidden rounded-xl bg-white dark:bg-neutral-800">
|
<div className="relative overflow-hidden rounded-xl bg-white dark:bg-neutral-800">
|
||||||
<div className="space-y-4 p-4">
|
<div className="space-y-4 p-4">
|
||||||
|
<Field>
|
||||||
|
<div className="flex items-start justify-between gap-4">
|
||||||
|
<div className="flex items-start space-x-3 flex-1">
|
||||||
|
<CloudIcon className="mt-1 h-5 w-5 flex-shrink-0 text-black dark:text-neutral-100" />
|
||||||
|
<div>
|
||||||
|
<Label>Cloud</Label>
|
||||||
|
<Description>
|
||||||
|
{cloudOverriddenByEnv
|
||||||
|
? "The OLLAMA_NO_CLOUD environment variable is currently forcing cloud off."
|
||||||
|
: "Enable cloud models and web search."}
|
||||||
|
</Description>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="flex-shrink-0">
|
||||||
|
<Switch
|
||||||
|
checked={!cloudDisabled}
|
||||||
|
disabled={cloudToggleDisabled}
|
||||||
|
onChange={(checked) => {
|
||||||
|
if (cloudOverriddenByEnv) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
updateCloudMutation.mutate(checked);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</Field>
|
||||||
|
|
||||||
|
{/* Auto Update */}
|
||||||
|
<Field>
|
||||||
|
<div className="flex items-start justify-between gap-4">
|
||||||
|
<div className="flex items-start space-x-3 flex-1">
|
||||||
|
<ArrowDownTrayIcon className="mt-1 h-5 w-5 flex-shrink-0 text-black dark:text-neutral-100" />
|
||||||
|
<div>
|
||||||
|
<Label>Auto-download updates</Label>
|
||||||
|
<Description>
|
||||||
|
{settings.AutoUpdateEnabled
|
||||||
|
? "Automatically download updates when available."
|
||||||
|
: "Updates will not be downloaded automatically."}
|
||||||
|
</Description>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="flex-shrink-0">
|
||||||
|
<Switch
|
||||||
|
checked={settings.AutoUpdateEnabled}
|
||||||
|
onChange={(checked) => handleChange("AutoUpdateEnabled", checked)}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</Field>
|
||||||
|
|
||||||
{/* Expose Ollama */}
|
{/* Expose Ollama */}
|
||||||
<Field>
|
<Field>
|
||||||
<div className="flex items-start justify-between gap-4">
|
<div className="flex items-start justify-between gap-4">
|
||||||
@@ -419,13 +540,11 @@ export default function Settings() {
|
|||||||
</Description>
|
</Description>
|
||||||
<div className="mt-3">
|
<div className="mt-3">
|
||||||
<Slider
|
<Slider
|
||||||
value={(() => {
|
value={settings.ContextLength || defaultContextLength || 0}
|
||||||
// Otherwise use the settings value
|
|
||||||
return settings.ContextLength || 4096;
|
|
||||||
})()}
|
|
||||||
onChange={(value) => {
|
onChange={(value) => {
|
||||||
handleChange("ContextLength", value);
|
handleChange("ContextLength", value);
|
||||||
}}
|
}}
|
||||||
|
disabled={!defaultContextLength}
|
||||||
options={[
|
options={[
|
||||||
{ value: 4096, label: "4k" },
|
{ value: 4096, label: "4k" },
|
||||||
{ value: 8192, label: "8k" },
|
{ value: 8192, label: "8k" },
|
||||||
@@ -440,35 +559,6 @@ export default function Settings() {
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</Field>
|
</Field>
|
||||||
{/* Airplane Mode */}
|
|
||||||
<Field>
|
|
||||||
<div className="flex items-start justify-between gap-4">
|
|
||||||
<div className="flex items-start space-x-3 flex-1">
|
|
||||||
<svg
|
|
||||||
className="mt-1 h-5 w-5 flex-shrink-0 text-black dark:text-neutral-100"
|
|
||||||
viewBox="0 0 21.5508 17.9033"
|
|
||||||
fill="currentColor"
|
|
||||||
>
|
|
||||||
<path d="M21.5508 8.94727C21.542 7.91895 20.1445 7.17188 18.4658 7.17188L14.9238 7.17188C14.4316 7.17188 14.2471 7.09277 13.957 6.75879L8.05078 0.316406C7.86621 0.105469 7.6377 0 7.37402 0L6.35449 0C6.12598 0 5.99414 0.202148 6.1084 0.448242L9.14941 7.17188L4.68457 7.68164L3.09375 4.76367C2.97949 4.54395 2.78613 4.44727 2.49609 4.44727L2.11816 4.44727C1.88965 4.44727 1.74023 4.59668 1.74023 4.8252L1.74023 13.0693C1.74023 13.2979 1.88965 13.4385 2.11816 13.4385L2.49609 13.4385C2.78613 13.4385 2.97949 13.3418 3.09375 13.1309L4.68457 10.2129L9.14941 10.7227L6.1084 17.4463C5.99414 17.6836 6.12598 17.8945 6.35449 17.8945L7.37402 17.8945C7.6377 17.8945 7.86621 17.7803 8.05078 17.5781L13.957 11.127C14.2471 10.8018 14.4316 10.7227 14.9238 10.7227L18.4658 10.7227C20.1445 10.7227 21.542 9.9668 21.5508 8.94727Z" />
|
|
||||||
</svg>
|
|
||||||
<div>
|
|
||||||
<Label>Airplane mode</Label>
|
|
||||||
<Description>
|
|
||||||
Airplane mode keeps data local, disabling cloud models
|
|
||||||
and web search.
|
|
||||||
</Description>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div className="flex-shrink-0">
|
|
||||||
<Switch
|
|
||||||
checked={settings.AirplaneMode}
|
|
||||||
onChange={(checked) =>
|
|
||||||
handleChange("AirplaneMode", checked)
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</Field>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|||||||
@@ -6,10 +6,11 @@ export interface SliderProps {
|
|||||||
value?: number;
|
value?: number;
|
||||||
onChange?: (value: number) => void;
|
onChange?: (value: number) => void;
|
||||||
className?: string;
|
className?: string;
|
||||||
|
disabled?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
||||||
({ label, options, value = 0, onChange }, ref) => {
|
({ label, options, value = 0, onChange, disabled = false }, ref) => {
|
||||||
const [selectedValue, setSelectedValue] = React.useState(value);
|
const [selectedValue, setSelectedValue] = React.useState(value);
|
||||||
const [isDragging, setIsDragging] = React.useState(false);
|
const [isDragging, setIsDragging] = React.useState(false);
|
||||||
const containerRef = React.useRef<HTMLDivElement>(null);
|
const containerRef = React.useRef<HTMLDivElement>(null);
|
||||||
@@ -20,6 +21,7 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
|||||||
}, [value]);
|
}, [value]);
|
||||||
|
|
||||||
const handleClick = (optionValue: number) => {
|
const handleClick = (optionValue: number) => {
|
||||||
|
if (disabled) return;
|
||||||
setSelectedValue(optionValue);
|
setSelectedValue(optionValue);
|
||||||
onChange?.(optionValue);
|
onChange?.(optionValue);
|
||||||
};
|
};
|
||||||
@@ -39,6 +41,7 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
|||||||
};
|
};
|
||||||
|
|
||||||
const handleMouseDown = (e: React.MouseEvent) => {
|
const handleMouseDown = (e: React.MouseEvent) => {
|
||||||
|
if (disabled) return;
|
||||||
setIsDragging(true);
|
setIsDragging(true);
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
};
|
};
|
||||||
@@ -77,7 +80,7 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="space-y-2" ref={ref}>
|
<div className={`space-y-2 ${disabled ? "opacity-50" : ""}`} ref={ref}>
|
||||||
{label && <label className="text-sm font-medium">{label}</label>}
|
{label && <label className="text-sm font-medium">{label}</label>}
|
||||||
<div className="relative">
|
<div className="relative">
|
||||||
<div className="absolute top-[9px] left-2 right-2 h-1 bg-neutral-200 dark:bg-neutral-700 pointer-events-none rounded-full" />
|
<div className="absolute top-[9px] left-2 right-2 h-1 bg-neutral-200 dark:bg-neutral-700 pointer-events-none rounded-full" />
|
||||||
@@ -88,10 +91,11 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
|||||||
<button
|
<button
|
||||||
onClick={() => handleClick(option.value)}
|
onClick={() => handleClick(option.value)}
|
||||||
onMouseDown={handleMouseDown}
|
onMouseDown={handleMouseDown}
|
||||||
className="relative px-3 py-6 -mx-3 -my-6 z-10 cursor-pointer"
|
disabled={disabled}
|
||||||
|
className={`relative px-3 py-6 -mx-3 -my-6 z-10 ${disabled ? "cursor-not-allowed" : "cursor-pointer"}`}
|
||||||
>
|
>
|
||||||
<div className="relative w-5 h-5 flex items-center justify-center">
|
<div className="relative w-5 h-5 flex items-center justify-center">
|
||||||
{selectedValue === option.value && (
|
{selectedValue === option.value && !disabled && (
|
||||||
<div className="w-4 h-4 bg-white dark:bg-white border border-neutral-400 dark:border-neutral-500 rounded-full cursor-grab active:cursor-grabbing" />
|
<div className="w-4 h-4 bg-white dark:bg-white border border-neutral-400 dark:border-neutral-500 rounded-full cursor-grab active:cursor-grabbing" />
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import { useSelectedModel } from "./useSelectedModel";
|
|||||||
import { createQueryBatcher } from "./useQueryBatcher";
|
import { createQueryBatcher } from "./useQueryBatcher";
|
||||||
import { useRefetchModels } from "./useModels";
|
import { useRefetchModels } from "./useModels";
|
||||||
import { useStreamingContext } from "@/contexts/StreamingContext";
|
import { useStreamingContext } from "@/contexts/StreamingContext";
|
||||||
import { useSettings } from "./useSettings";
|
|
||||||
import { getModelCapabilities } from "@/api";
|
import { getModelCapabilities } from "@/api";
|
||||||
|
import { useCloudStatus } from "./useCloudStatus";
|
||||||
|
|
||||||
export const useChats = () => {
|
export const useChats = () => {
|
||||||
return useQuery({
|
return useQuery({
|
||||||
@@ -116,11 +116,9 @@ export const useIsModelStale = (modelName: string) => {
|
|||||||
export const useShouldShowStaleDisplay = (model: Model | null) => {
|
export const useShouldShowStaleDisplay = (model: Model | null) => {
|
||||||
const isStale = useIsModelStale(model?.model || "");
|
const isStale = useIsModelStale(model?.model || "");
|
||||||
const { data: dismissedModels } = useDismissedStaleModels();
|
const { data: dismissedModels } = useDismissedStaleModels();
|
||||||
const {
|
const { cloudDisabled } = useCloudStatus();
|
||||||
settings: { airplaneMode },
|
|
||||||
} = useSettings();
|
|
||||||
|
|
||||||
if (model?.isCloud() && !airplaneMode) {
|
if (model?.isCloud() && !cloudDisabled) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
20
app/ui/app/src/hooks/useCloudStatus.ts
Normal file
20
app/ui/app/src/hooks/useCloudStatus.ts
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
import { useQuery } from "@tanstack/react-query";
|
||||||
|
import { getCloudStatus, type CloudStatusResponse } from "@/api";
|
||||||
|
|
||||||
|
export function useCloudStatus() {
|
||||||
|
const cloudQuery = useQuery<CloudStatusResponse | null>({
|
||||||
|
queryKey: ["cloudStatus"],
|
||||||
|
queryFn: getCloudStatus,
|
||||||
|
retry: false,
|
||||||
|
staleTime: 60 * 1000,
|
||||||
|
});
|
||||||
|
|
||||||
|
return {
|
||||||
|
cloudStatus: cloudQuery.data,
|
||||||
|
cloudDisabled: cloudQuery.data?.disabled ?? false,
|
||||||
|
isKnown: cloudQuery.data !== null && cloudQuery.data !== undefined,
|
||||||
|
isLoading: cloudQuery.isLoading,
|
||||||
|
isError: cloudQuery.isError,
|
||||||
|
error: cloudQuery.error,
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -20,3 +20,8 @@ export function useHasVisionCapability(modelName: string | undefined) {
|
|||||||
const { data: capabilitiesResponse } = useModelCapabilities(modelName);
|
const { data: capabilitiesResponse } = useModelCapabilities(modelName);
|
||||||
return capabilitiesResponse?.capabilities?.includes("vision") ?? false;
|
return capabilitiesResponse?.capabilities?.includes("vision") ?? false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function useHasToolsCapability(modelName: string | undefined) {
|
||||||
|
const { data: capabilitiesResponse } = useModelCapabilities(modelName);
|
||||||
|
return capabilitiesResponse?.capabilities?.includes("tools") ?? false;
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,11 +2,11 @@ import { useQuery } from "@tanstack/react-query";
|
|||||||
import { Model } from "@/gotypes";
|
import { Model } from "@/gotypes";
|
||||||
import { getModels } from "@/api";
|
import { getModels } from "@/api";
|
||||||
import { mergeModels } from "@/utils/mergeModels";
|
import { mergeModels } from "@/utils/mergeModels";
|
||||||
import { useSettings } from "./useSettings";
|
|
||||||
import { useMemo } from "react";
|
import { useMemo } from "react";
|
||||||
|
import { useCloudStatus } from "./useCloudStatus";
|
||||||
|
|
||||||
export function useModels(searchQuery = "") {
|
export function useModels(searchQuery = "") {
|
||||||
const { settings } = useSettings();
|
const { cloudDisabled } = useCloudStatus();
|
||||||
const localQuery = useQuery<Model[], Error>({
|
const localQuery = useQuery<Model[], Error>({
|
||||||
queryKey: ["models", searchQuery],
|
queryKey: ["models", searchQuery],
|
||||||
queryFn: () => getModels(searchQuery),
|
queryFn: () => getModels(searchQuery),
|
||||||
@@ -20,7 +20,7 @@ export function useModels(searchQuery = "") {
|
|||||||
});
|
});
|
||||||
|
|
||||||
const allModels = useMemo(() => {
|
const allModels = useMemo(() => {
|
||||||
const models = mergeModels(localQuery.data || [], settings.airplaneMode);
|
const models = mergeModels(localQuery.data || [], cloudDisabled);
|
||||||
|
|
||||||
if (searchQuery && searchQuery.trim()) {
|
if (searchQuery && searchQuery.trim()) {
|
||||||
const query = searchQuery.toLowerCase().trim();
|
const query = searchQuery.toLowerCase().trim();
|
||||||
@@ -40,7 +40,7 @@ export function useModels(searchQuery = "") {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return models;
|
return models;
|
||||||
}, [localQuery.data, searchQuery, settings.airplaneMode]);
|
}, [localQuery.data, searchQuery, cloudDisabled]);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
...localQuery,
|
...localQuery,
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import { Model } from "@/gotypes";
|
|||||||
import { FEATURED_MODELS } from "@/utils/mergeModels";
|
import { FEATURED_MODELS } from "@/utils/mergeModels";
|
||||||
import { getTotalVRAM } from "@/utils/vram.ts";
|
import { getTotalVRAM } from "@/utils/vram.ts";
|
||||||
import { getInferenceCompute } from "@/api";
|
import { getInferenceCompute } from "@/api";
|
||||||
|
import { useCloudStatus } from "./useCloudStatus";
|
||||||
|
|
||||||
export function recommendDefaultModel(totalVRAM: number): string {
|
export function recommendDefaultModel(totalVRAM: number): string {
|
||||||
const vram = Math.max(0, Number(totalVRAM) || 0);
|
const vram = Math.max(0, Number(totalVRAM) || 0);
|
||||||
@@ -22,16 +23,19 @@ export function recommendDefaultModel(totalVRAM: number): string {
|
|||||||
export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
||||||
const { settings, setSettings } = useSettings();
|
const { settings, setSettings } = useSettings();
|
||||||
const { data: models = [], isLoading } = useModels(searchQuery || "");
|
const { data: models = [], isLoading } = useModels(searchQuery || "");
|
||||||
|
const { cloudDisabled } = useCloudStatus();
|
||||||
const { data: chatData, isLoading: isChatLoading } = useChat(
|
const { data: chatData, isLoading: isChatLoading } = useChat(
|
||||||
currentChatId && currentChatId !== "new" ? currentChatId : "",
|
currentChatId && currentChatId !== "new" ? currentChatId : "",
|
||||||
);
|
);
|
||||||
|
|
||||||
const { data: inferenceComputes = [] } = useQuery({
|
const { data: inferenceComputeResponse } = useQuery({
|
||||||
queryKey: ["inference-compute"],
|
queryKey: ["inferenceCompute"],
|
||||||
queryFn: getInferenceCompute,
|
queryFn: getInferenceCompute,
|
||||||
enabled: !settings.selectedModel, // Only fetch if no model is selected
|
enabled: !settings.selectedModel, // Only fetch if no model is selected
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const inferenceComputes = inferenceComputeResponse?.inferenceComputes || [];
|
||||||
|
|
||||||
const totalVRAM = useMemo(
|
const totalVRAM = useMemo(
|
||||||
() => getTotalVRAM(inferenceComputes),
|
() => getTotalVRAM(inferenceComputes),
|
||||||
[inferenceComputes],
|
[inferenceComputes],
|
||||||
@@ -46,12 +50,11 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
|||||||
const restoredChatRef = useRef<string | null>(null);
|
const restoredChatRef = useRef<string | null>(null);
|
||||||
|
|
||||||
const selectedModel: Model | null = useMemo(() => {
|
const selectedModel: Model | null = useMemo(() => {
|
||||||
// if airplane mode is on and selected model ends with cloud,
|
// If cloud is disabled and selected model ends with cloud, switch to a local default.
|
||||||
// switch to recommended default model
|
if (cloudDisabled && settings.selectedModel?.endsWith("cloud")) {
|
||||||
if (settings.airplaneMode && settings.selectedModel?.endsWith("cloud")) {
|
|
||||||
return (
|
return (
|
||||||
models.find((m) => m.model === recommendedModel) ||
|
models.find((m) => m.model === recommendedModel) ||
|
||||||
models.find((m) => m.isCloud) ||
|
models.find((m) => !m.isCloud()) ||
|
||||||
models.find((m) => m.digest === undefined || m.digest === "") ||
|
models.find((m) => m.digest === undefined || m.digest === "") ||
|
||||||
models[0] ||
|
models[0] ||
|
||||||
null
|
null
|
||||||
@@ -68,7 +71,7 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
|||||||
"qwen3-coder:480b",
|
"qwen3-coder:480b",
|
||||||
];
|
];
|
||||||
const shouldMigrate =
|
const shouldMigrate =
|
||||||
!settings.airplaneMode &&
|
!cloudDisabled &&
|
||||||
settings.turboEnabled &&
|
settings.turboEnabled &&
|
||||||
baseModelsToMigrate.includes(settings.selectedModel);
|
baseModelsToMigrate.includes(settings.selectedModel);
|
||||||
|
|
||||||
@@ -96,13 +99,18 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
|||||||
})) ||
|
})) ||
|
||||||
null
|
null
|
||||||
);
|
);
|
||||||
}, [models, settings.selectedModel, settings.airplaneMode, recommendedModel]);
|
}, [
|
||||||
|
models,
|
||||||
|
settings.selectedModel,
|
||||||
|
cloudDisabled,
|
||||||
|
recommendedModel,
|
||||||
|
]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!selectedModel) return;
|
if (!selectedModel) return;
|
||||||
|
|
||||||
if (
|
if (
|
||||||
settings.airplaneMode &&
|
cloudDisabled &&
|
||||||
settings.selectedModel?.endsWith("cloud") &&
|
settings.selectedModel?.endsWith("cloud") &&
|
||||||
selectedModel.model !== settings.selectedModel
|
selectedModel.model !== settings.selectedModel
|
||||||
) {
|
) {
|
||||||
@@ -110,13 +118,17 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
!settings.airplaneMode &&
|
!cloudDisabled &&
|
||||||
settings.turboEnabled &&
|
settings.turboEnabled &&
|
||||||
selectedModel.model !== settings.selectedModel
|
selectedModel.model !== settings.selectedModel
|
||||||
) {
|
) {
|
||||||
setSettings({ SelectedModel: selectedModel.model, TurboEnabled: false });
|
setSettings({ SelectedModel: selectedModel.model, TurboEnabled: false });
|
||||||
}
|
}
|
||||||
}, [selectedModel, settings.airplaneMode, settings.selectedModel]);
|
}, [
|
||||||
|
selectedModel,
|
||||||
|
cloudDisabled,
|
||||||
|
settings.selectedModel,
|
||||||
|
]);
|
||||||
|
|
||||||
// Set model from chat history when chat data loads
|
// Set model from chat history when chat data loads
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@@ -169,7 +181,9 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
|||||||
|
|
||||||
const defaultModel =
|
const defaultModel =
|
||||||
models.find((m) => m.model === recommendedModel) ||
|
models.find((m) => m.model === recommendedModel) ||
|
||||||
models.find((m) => m.isCloud()) ||
|
(cloudDisabled
|
||||||
|
? models.find((m) => !m.isCloud())
|
||||||
|
: models.find((m) => m.isCloud())) ||
|
||||||
models.find((m) => m.digest === undefined || m.digest === "") ||
|
models.find((m) => m.digest === undefined || m.digest === "") ||
|
||||||
models[0];
|
models[0];
|
||||||
|
|
||||||
@@ -181,6 +195,7 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
|||||||
inferenceComputes.length,
|
inferenceComputes.length,
|
||||||
models.length,
|
models.length,
|
||||||
settings.selectedModel,
|
settings.selectedModel,
|
||||||
|
cloudDisabled,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
// Add the selected model to the models list if it's not already there
|
// Add the selected model to the models list if it's not already there
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ interface SettingsState {
|
|||||||
webSearchEnabled: boolean;
|
webSearchEnabled: boolean;
|
||||||
selectedModel: string;
|
selectedModel: string;
|
||||||
sidebarOpen: boolean;
|
sidebarOpen: boolean;
|
||||||
airplaneMode: boolean;
|
|
||||||
thinkEnabled: boolean;
|
thinkEnabled: boolean;
|
||||||
thinkLevel: string;
|
thinkLevel: string;
|
||||||
}
|
}
|
||||||
@@ -51,7 +50,6 @@ export function useSettings() {
|
|||||||
thinkLevel: settingsData?.settings?.ThinkLevel ?? "none",
|
thinkLevel: settingsData?.settings?.ThinkLevel ?? "none",
|
||||||
selectedModel: settingsData?.settings?.SelectedModel ?? "",
|
selectedModel: settingsData?.settings?.SelectedModel ?? "",
|
||||||
sidebarOpen: settingsData?.settings?.SidebarOpen ?? false,
|
sidebarOpen: settingsData?.settings?.SidebarOpen ?? false,
|
||||||
airplaneMode: settingsData?.settings?.AirplaneMode ?? false,
|
|
||||||
}),
|
}),
|
||||||
[settingsData?.settings],
|
[settingsData?.settings],
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import type { QueryClient } from "@tanstack/react-query";
|
|||||||
import { createRootRouteWithContext, Outlet } from "@tanstack/react-router";
|
import { createRootRouteWithContext, Outlet } from "@tanstack/react-router";
|
||||||
import { getSettings } from "@/api";
|
import { getSettings } from "@/api";
|
||||||
import { useQuery } from "@tanstack/react-query";
|
import { useQuery } from "@tanstack/react-query";
|
||||||
|
import { useCloudStatus } from "@/hooks/useCloudStatus";
|
||||||
|
|
||||||
function RootComponent() {
|
function RootComponent() {
|
||||||
// This hook ensures settings are fetched on app startup
|
// This hook ensures settings are fetched on app startup
|
||||||
@@ -9,6 +10,8 @@ function RootComponent() {
|
|||||||
queryKey: ["settings"],
|
queryKey: ["settings"],
|
||||||
queryFn: getSettings,
|
queryFn: getSettings,
|
||||||
});
|
});
|
||||||
|
// Fetch cloud status on startup (best-effort)
|
||||||
|
useCloudStatus();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div>
|
<div>
|
||||||
|
|||||||
@@ -41,14 +41,14 @@ describe("Model merging logic", () => {
|
|||||||
expect(merged.length).toBe(FEATURED_MODELS.length + 2);
|
expect(merged.length).toBe(FEATURED_MODELS.length + 2);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should hide cloud models in airplane mode", () => {
|
it("should hide cloud models when cloud is disabled", () => {
|
||||||
const localModels: Model[] = [
|
const localModels: Model[] = [
|
||||||
new Model({ model: "gpt-oss:120b-cloud" }),
|
new Model({ model: "gpt-oss:120b-cloud" }),
|
||||||
new Model({ model: "llama3:latest" }),
|
new Model({ model: "llama3:latest" }),
|
||||||
new Model({ model: "mistral:latest" }),
|
new Model({ model: "mistral:latest" }),
|
||||||
];
|
];
|
||||||
|
|
||||||
const merged = mergeModels(localModels, true); // airplane mode = true
|
const merged = mergeModels(localModels, true); // cloud disabled = true
|
||||||
|
|
||||||
// No cloud models should be present
|
// No cloud models should be present
|
||||||
const cloudModels = merged.filter((m) => m.isCloud());
|
const cloudModels = merged.filter((m) => m.isCloud());
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ function alphabeticalSort(a: Model, b: Model): number {
|
|||||||
//Merges models, sorting cloud models first, then other models
|
//Merges models, sorting cloud models first, then other models
|
||||||
export function mergeModels(
|
export function mergeModels(
|
||||||
localModels: Model[],
|
localModels: Model[],
|
||||||
airplaneMode: boolean = false,
|
hideCloudModels: boolean = false,
|
||||||
): Model[] {
|
): Model[] {
|
||||||
const allModels = (localModels || []).map((model) => model);
|
const allModels = (localModels || []).map((model) => model);
|
||||||
|
|
||||||
@@ -95,7 +95,7 @@ export function mergeModels(
|
|||||||
|
|
||||||
remainingModels.sort(alphabeticalSort);
|
remainingModels.sort(alphabeticalSort);
|
||||||
|
|
||||||
return airplaneMode
|
return hideCloudModels
|
||||||
? [...featuredModels, ...remainingModels]
|
? [...featuredModels, ...remainingModels]
|
||||||
: [...cloudModels, ...featuredModels, ...remainingModels];
|
: [...cloudModels, ...featuredModels, ...remainingModels];
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ type InferenceCompute struct {
|
|||||||
|
|
||||||
type InferenceComputeResponse struct {
|
type InferenceComputeResponse struct {
|
||||||
InferenceComputes []InferenceCompute `json:"inferenceComputes"`
|
InferenceComputes []InferenceCompute `json:"inferenceComputes"`
|
||||||
|
DefaultContextLength int `json:"defaultContextLength"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ModelCapabilitiesResponse struct {
|
type ModelCapabilitiesResponse struct {
|
||||||
|
|||||||
94
app/ui/ui.go
94
app/ui/ui.go
@@ -28,6 +28,7 @@ import (
|
|||||||
"github.com/ollama/ollama/app/tools"
|
"github.com/ollama/ollama/app/tools"
|
||||||
"github.com/ollama/ollama/app/types/not"
|
"github.com/ollama/ollama/app/types/not"
|
||||||
"github.com/ollama/ollama/app/ui/responses"
|
"github.com/ollama/ollama/app/ui/responses"
|
||||||
|
"github.com/ollama/ollama/app/updater"
|
||||||
"github.com/ollama/ollama/app/version"
|
"github.com/ollama/ollama/app/version"
|
||||||
ollamaAuth "github.com/ollama/ollama/auth"
|
ollamaAuth "github.com/ollama/ollama/auth"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
@@ -106,6 +107,10 @@ type Server struct {
|
|||||||
|
|
||||||
// Dev is true if the server is running in development mode
|
// Dev is true if the server is running in development mode
|
||||||
Dev bool
|
Dev bool
|
||||||
|
|
||||||
|
// Updater for checking and downloading updates
|
||||||
|
Updater *updater.Updater
|
||||||
|
UpdateAvailableFunc func()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) log() *slog.Logger {
|
func (s *Server) log() *slog.Logger {
|
||||||
@@ -150,7 +155,7 @@ func (s *Server) ollamaProxy() http.Handler {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
target := envconfig.Host()
|
target := envconfig.ConnectableHost()
|
||||||
s.log().Info("configuring ollama proxy", "target", target.String())
|
s.log().Info("configuring ollama proxy", "target", target.String())
|
||||||
|
|
||||||
newProxy := httputil.NewSingleHostReverseProxy(target)
|
newProxy := httputil.NewSingleHostReverseProxy(target)
|
||||||
@@ -284,12 +289,15 @@ func (s *Server) Handler() http.Handler {
|
|||||||
mux.Handle("POST /api/v1/model/upstream", handle(s.modelUpstream))
|
mux.Handle("POST /api/v1/model/upstream", handle(s.modelUpstream))
|
||||||
mux.Handle("GET /api/v1/settings", handle(s.getSettings))
|
mux.Handle("GET /api/v1/settings", handle(s.getSettings))
|
||||||
mux.Handle("POST /api/v1/settings", handle(s.settings))
|
mux.Handle("POST /api/v1/settings", handle(s.settings))
|
||||||
|
mux.Handle("GET /api/v1/cloud", handle(s.getCloudSetting))
|
||||||
|
mux.Handle("POST /api/v1/cloud", handle(s.cloudSetting))
|
||||||
|
|
||||||
// Ollama proxy endpoints
|
// Ollama proxy endpoints
|
||||||
ollamaProxy := s.ollamaProxy()
|
ollamaProxy := s.ollamaProxy()
|
||||||
mux.Handle("GET /api/tags", ollamaProxy)
|
mux.Handle("GET /api/tags", ollamaProxy)
|
||||||
mux.Handle("POST /api/show", ollamaProxy)
|
mux.Handle("POST /api/show", ollamaProxy)
|
||||||
mux.Handle("GET /api/version", ollamaProxy)
|
mux.Handle("GET /api/version", ollamaProxy)
|
||||||
|
mux.Handle("GET /api/status", ollamaProxy)
|
||||||
mux.Handle("HEAD /api/version", ollamaProxy)
|
mux.Handle("HEAD /api/version", ollamaProxy)
|
||||||
mux.Handle("POST /api/me", ollamaProxy)
|
mux.Handle("POST /api/me", ollamaProxy)
|
||||||
mux.Handle("POST /api/signout", ollamaProxy)
|
mux.Handle("POST /api/signout", ollamaProxy)
|
||||||
@@ -826,8 +834,9 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
|
|||||||
|
|
||||||
if !hasAttachments {
|
if !hasAttachments {
|
||||||
WebSearchEnabled := req.WebSearch != nil && *req.WebSearch
|
WebSearchEnabled := req.WebSearch != nil && *req.WebSearch
|
||||||
|
hasToolsCapability := slices.Contains(details.Capabilities, model.CapabilityTools)
|
||||||
|
|
||||||
if WebSearchEnabled {
|
if WebSearchEnabled && hasToolsCapability {
|
||||||
if supportsBrowserTools(req.Model) {
|
if supportsBrowserTools(req.Model) {
|
||||||
browserState, ok := s.browserState(chat)
|
browserState, ok := s.browserState(chat)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -837,7 +846,7 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
|
|||||||
registry.Register(tools.NewBrowserSearch(browser))
|
registry.Register(tools.NewBrowserSearch(browser))
|
||||||
registry.Register(tools.NewBrowserOpen(browser))
|
registry.Register(tools.NewBrowserOpen(browser))
|
||||||
registry.Register(tools.NewBrowserFind(browser))
|
registry.Register(tools.NewBrowserFind(browser))
|
||||||
} else if supportsWebSearchTools(req.Model) {
|
} else {
|
||||||
registry.Register(&tools.WebSearch{})
|
registry.Register(&tools.WebSearch{})
|
||||||
registry.Register(&tools.WebFetch{})
|
registry.Register(&tools.WebFetch{})
|
||||||
}
|
}
|
||||||
@@ -1417,11 +1426,6 @@ func (s *Server) getSettings(w http.ResponseWriter, r *http.Request) error {
|
|||||||
settings.Models = envconfig.Models()
|
settings.Models = envconfig.Models()
|
||||||
}
|
}
|
||||||
|
|
||||||
// set default context length if not set
|
|
||||||
if settings.ContextLength == 0 {
|
|
||||||
settings.ContextLength = 4096
|
|
||||||
}
|
|
||||||
|
|
||||||
// Include current runtime settings
|
// Include current runtime settings
|
||||||
settings.Agent = s.Agent
|
settings.Agent = s.Agent
|
||||||
settings.Tools = s.Tools
|
settings.Tools = s.Tools
|
||||||
@@ -1448,6 +1452,24 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
|
|||||||
return fmt.Errorf("failed to save settings: %w", err)
|
return fmt.Errorf("failed to save settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle auto-update toggle changes
|
||||||
|
if old.AutoUpdateEnabled != settings.AutoUpdateEnabled {
|
||||||
|
if !settings.AutoUpdateEnabled {
|
||||||
|
// Auto-update disabled: cancel any ongoing download
|
||||||
|
if s.Updater != nil {
|
||||||
|
s.Updater.CancelOngoingDownload()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Auto-update re-enabled: show notification if update is already staged, or trigger immediate check
|
||||||
|
if (updater.IsUpdatePending() || updater.UpdateDownloaded) && s.UpdateAvailableFunc != nil {
|
||||||
|
s.UpdateAvailableFunc()
|
||||||
|
} else if s.Updater != nil {
|
||||||
|
// Trigger the background checker to run immediately
|
||||||
|
s.Updater.TriggerImmediateCheck()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if old.ContextLength != settings.ContextLength ||
|
if old.ContextLength != settings.ContextLength ||
|
||||||
old.Models != settings.Models ||
|
old.Models != settings.Models ||
|
||||||
old.Expose != settings.Expose {
|
old.Expose != settings.Expose {
|
||||||
@@ -1460,17 +1482,51 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) cloudSetting(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
var req struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
return fmt.Errorf("invalid request body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.Store.SetCloudEnabled(req.Enabled); err != nil {
|
||||||
|
return fmt.Errorf("failed to persist cloud setting: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Restart()
|
||||||
|
|
||||||
|
return s.writeCloudStatus(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) getCloudSetting(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
return s.writeCloudStatus(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) writeCloudStatus(w http.ResponseWriter) error {
|
||||||
|
disabled, source, err := s.Store.CloudStatus()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load cloud status: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
return json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"disabled": disabled,
|
||||||
|
"source": source,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error {
|
func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error {
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond)
|
ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
serverInferenceComputes, err := server.GetInferenceComputer(ctx)
|
info, err := server.GetInferenceInfo(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.log().Error("failed to get inference compute", "error", err)
|
s.log().Error("failed to get inference info", "error", err)
|
||||||
return fmt.Errorf("failed to get inference compute: %w", err)
|
return fmt.Errorf("failed to get inference info: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
inferenceComputes := make([]responses.InferenceCompute, len(serverInferenceComputes))
|
inferenceComputes := make([]responses.InferenceCompute, len(info.Computes))
|
||||||
for i, ic := range serverInferenceComputes {
|
for i, ic := range info.Computes {
|
||||||
inferenceComputes[i] = responses.InferenceCompute{
|
inferenceComputes[i] = responses.InferenceCompute{
|
||||||
Library: ic.Library,
|
Library: ic.Library,
|
||||||
Variant: ic.Variant,
|
Variant: ic.Variant,
|
||||||
@@ -1483,6 +1539,7 @@ func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) err
|
|||||||
|
|
||||||
response := responses.InferenceComputeResponse{
|
response := responses.InferenceComputeResponse{
|
||||||
InferenceComputes: inferenceComputes,
|
InferenceComputes: inferenceComputes,
|
||||||
|
DefaultContextLength: info.DefaultContextLength,
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
@@ -1615,17 +1672,6 @@ func supportsBrowserTools(model string) bool {
|
|||||||
return strings.HasPrefix(strings.ToLower(model), "gpt-oss")
|
return strings.HasPrefix(strings.ToLower(model), "gpt-oss")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Web search tools are simpler, providing only basic web search and fetch capabilities (e.g., "web_search", "web_fetch") without simulating a browser. Currently only qwen3 and deepseek-v3 support web search tools.
|
|
||||||
func supportsWebSearchTools(model string) bool {
|
|
||||||
model = strings.ToLower(model)
|
|
||||||
prefixes := []string{"qwen3", "deepseek-v3"}
|
|
||||||
for _, p := range prefixes {
|
|
||||||
if strings.HasPrefix(model, p) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildChatRequest converts store.Chat to api.ChatRequest
|
// buildChatRequest converts store.Chat to api.ChatRequest
|
||||||
func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, availableTools []map[string]any) (*api.ChatRequest, error) {
|
func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, availableTools []map[string]any) (*api.ChatRequest, error) {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ package ui
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -11,9 +12,11 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/ollama/ollama/app/store"
|
"github.com/ollama/ollama/app/store"
|
||||||
|
"github.com/ollama/ollama/app/updater"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHandlePostApiSettings(t *testing.T) {
|
func TestHandlePostApiSettings(t *testing.T) {
|
||||||
@@ -115,6 +118,107 @@ func TestHandlePostApiSettings(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandlePostApiCloudSetting(t *testing.T) {
|
||||||
|
tmpHome := t.TempDir()
|
||||||
|
t.Setenv("HOME", tmpHome)
|
||||||
|
t.Setenv("OLLAMA_NO_CLOUD", "")
|
||||||
|
|
||||||
|
testStore := &store.Store{
|
||||||
|
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
|
||||||
|
}
|
||||||
|
defer testStore.Close()
|
||||||
|
|
||||||
|
restartCount := 0
|
||||||
|
server := &Server{
|
||||||
|
Store: testStore,
|
||||||
|
Restart: func() {
|
||||||
|
restartCount++
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
wantEnabled bool
|
||||||
|
}{
|
||||||
|
{name: "disable cloud", body: `{"enabled": false}`, wantEnabled: false},
|
||||||
|
{name: "enable cloud", body: `{"enabled": true}`, wantEnabled: true},
|
||||||
|
} {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("POST", "/api/v1/cloud", bytes.NewBufferString(tc.body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
if err := server.cloudSetting(rr, req); err != nil {
|
||||||
|
t.Fatalf("cloudSetting() error = %v", err)
|
||||||
|
}
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("cloudSetting() status = %d, want %d", rr.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
var got map[string]any
|
||||||
|
if err := json.Unmarshal(rr.Body.Bytes(), &got); err != nil {
|
||||||
|
t.Fatalf("cloudSetting() invalid response JSON: %v", err)
|
||||||
|
}
|
||||||
|
if got["disabled"] != !tc.wantEnabled {
|
||||||
|
t.Fatalf("response disabled = %v, want %v", got["disabled"], !tc.wantEnabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
disabled, err := testStore.CloudDisabled()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CloudDisabled() error = %v", err)
|
||||||
|
}
|
||||||
|
if gotEnabled := !disabled; gotEnabled != tc.wantEnabled {
|
||||||
|
t.Fatalf("cloud enabled = %v, want %v", gotEnabled, tc.wantEnabled)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if restartCount != 2 {
|
||||||
|
t.Fatalf("Restart called %d times, want 2", restartCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleGetApiCloudSetting(t *testing.T) {
|
||||||
|
tmpHome := t.TempDir()
|
||||||
|
t.Setenv("HOME", tmpHome)
|
||||||
|
t.Setenv("OLLAMA_NO_CLOUD", "")
|
||||||
|
|
||||||
|
testStore := &store.Store{
|
||||||
|
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
|
||||||
|
}
|
||||||
|
defer testStore.Close()
|
||||||
|
|
||||||
|
if err := testStore.SetCloudEnabled(false); err != nil {
|
||||||
|
t.Fatalf("SetCloudEnabled(false) error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
server := &Server{
|
||||||
|
Store: testStore,
|
||||||
|
Restart: func() {},
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/api/v1/cloud", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
if err := server.getCloudSetting(rr, req); err != nil {
|
||||||
|
t.Fatalf("getCloudSetting() error = %v", err)
|
||||||
|
}
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("getCloudSetting() status = %d, want %d", rr.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
var got map[string]any
|
||||||
|
if err := json.Unmarshal(rr.Body.Bytes(), &got); err != nil {
|
||||||
|
t.Fatalf("getCloudSetting() invalid response JSON: %v", err)
|
||||||
|
}
|
||||||
|
if got["disabled"] != true {
|
||||||
|
t.Fatalf("response disabled = %v, want true", got["disabled"])
|
||||||
|
}
|
||||||
|
if got["source"] != "config" {
|
||||||
|
t.Fatalf("response source = %v, want config", got["source"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAuthenticationMiddleware(t *testing.T) {
|
func TestAuthenticationMiddleware(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -421,3 +525,290 @@ func TestUserAgentTransport(t *testing.T) {
|
|||||||
|
|
||||||
t.Logf("User-Agent transport successfully set: %s", receivedUA)
|
t.Logf("User-Agent transport successfully set: %s", receivedUA)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSupportsBrowserTools(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
model string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"gpt-oss", true},
|
||||||
|
{"gpt-oss-latest", true},
|
||||||
|
{"GPT-OSS", true},
|
||||||
|
{"Gpt-Oss-v2", true},
|
||||||
|
{"qwen3", false},
|
||||||
|
{"deepseek-v3", false},
|
||||||
|
{"llama3.3", false},
|
||||||
|
{"", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.model, func(t *testing.T) {
|
||||||
|
if got := supportsBrowserTools(tt.model); got != tt.want {
|
||||||
|
t.Errorf("supportsBrowserTools(%q) = %v, want %v", tt.model, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebSearchToolRegistration(t *testing.T) {
|
||||||
|
// Validates that the capability-gating logic in chat() correctly
|
||||||
|
// decides which tools to register based on model capabilities and
|
||||||
|
// the web search flag.
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
webSearchEnabled bool
|
||||||
|
hasToolsCap bool
|
||||||
|
model string
|
||||||
|
wantBrowser bool // expects browser tools (gpt-oss)
|
||||||
|
wantWebSearch bool // expects basic web search/fetch tools
|
||||||
|
wantNone bool // expects no tools registered
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "web search enabled with tools capability - browser model",
|
||||||
|
webSearchEnabled: true,
|
||||||
|
hasToolsCap: true,
|
||||||
|
model: "gpt-oss-latest",
|
||||||
|
wantBrowser: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "web search enabled with tools capability - non-browser model",
|
||||||
|
webSearchEnabled: true,
|
||||||
|
hasToolsCap: true,
|
||||||
|
model: "qwen3",
|
||||||
|
wantWebSearch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "web search enabled without tools capability",
|
||||||
|
webSearchEnabled: true,
|
||||||
|
hasToolsCap: false,
|
||||||
|
model: "llama3.3",
|
||||||
|
wantNone: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "web search disabled with tools capability",
|
||||||
|
webSearchEnabled: false,
|
||||||
|
hasToolsCap: true,
|
||||||
|
model: "qwen3",
|
||||||
|
wantNone: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "web search disabled without tools capability",
|
||||||
|
webSearchEnabled: false,
|
||||||
|
hasToolsCap: false,
|
||||||
|
model: "llama3.3",
|
||||||
|
wantNone: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Replicate the decision logic from chat() handler
|
||||||
|
gotBrowser := false
|
||||||
|
gotWebSearch := false
|
||||||
|
|
||||||
|
if tt.webSearchEnabled && tt.hasToolsCap {
|
||||||
|
if supportsBrowserTools(tt.model) {
|
||||||
|
gotBrowser = true
|
||||||
|
} else {
|
||||||
|
gotWebSearch = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.wantBrowser && !gotBrowser {
|
||||||
|
t.Error("expected browser tools to be registered")
|
||||||
|
}
|
||||||
|
if tt.wantWebSearch && !gotWebSearch {
|
||||||
|
t.Error("expected web search tools to be registered")
|
||||||
|
}
|
||||||
|
if tt.wantNone && (gotBrowser || gotWebSearch) {
|
||||||
|
t.Error("expected no tools to be registered")
|
||||||
|
}
|
||||||
|
if !tt.wantBrowser && gotBrowser {
|
||||||
|
t.Error("unexpected browser tools registered")
|
||||||
|
}
|
||||||
|
if !tt.wantWebSearch && gotWebSearch {
|
||||||
|
t.Error("unexpected web search tools registered")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettingsToggleAutoUpdateOff_CancelsDownload(t *testing.T) {
|
||||||
|
testStore := &store.Store{
|
||||||
|
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
|
||||||
|
}
|
||||||
|
defer testStore.Close()
|
||||||
|
|
||||||
|
// Start with auto-update enabled
|
||||||
|
settings, err := testStore.Settings()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
settings.AutoUpdateEnabled = true
|
||||||
|
if err := testStore.SetSettings(settings); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
upd := &updater.Updater{Store: &store.Store{
|
||||||
|
DBPath: filepath.Join(t.TempDir(), "db2.sqlite"),
|
||||||
|
}}
|
||||||
|
defer upd.Store.Close()
|
||||||
|
|
||||||
|
// We can't easily mock CancelOngoingDownload, but we can verify
|
||||||
|
// the full settings handler flow works without error
|
||||||
|
server := &Server{
|
||||||
|
Store: testStore,
|
||||||
|
Restart: func() {},
|
||||||
|
Updater: upd,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disable auto-update via settings API
|
||||||
|
settings.AutoUpdateEnabled = false
|
||||||
|
body, err := json.Marshal(settings)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
if err := server.settings(rr, req); err != nil {
|
||||||
|
t.Fatalf("settings() error = %v", err)
|
||||||
|
}
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("settings() status = %d, want %d", rr.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify settings were saved with auto-update disabled
|
||||||
|
saved, err := testStore.Settings()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if saved.AutoUpdateEnabled {
|
||||||
|
t.Fatal("expected AutoUpdateEnabled to be false after toggle off")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettingsToggleAutoUpdateOn_WithPendingUpdate_ShowsNotification(t *testing.T) {
|
||||||
|
testStore := &store.Store{
|
||||||
|
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
|
||||||
|
}
|
||||||
|
defer testStore.Close()
|
||||||
|
|
||||||
|
// Start with auto-update disabled
|
||||||
|
settings, err := testStore.Settings()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
settings.AutoUpdateEnabled = false
|
||||||
|
if err := testStore.SetSettings(settings); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate that an update was previously downloaded
|
||||||
|
oldVal := updater.UpdateDownloaded
|
||||||
|
updater.UpdateDownloaded = true
|
||||||
|
defer func() { updater.UpdateDownloaded = oldVal }()
|
||||||
|
|
||||||
|
var notificationCalled atomic.Bool
|
||||||
|
server := &Server{
|
||||||
|
Store: testStore,
|
||||||
|
Restart: func() {},
|
||||||
|
UpdateAvailableFunc: func() {
|
||||||
|
notificationCalled.Store(true)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-enable auto-update via settings API
|
||||||
|
settings.AutoUpdateEnabled = true
|
||||||
|
body, err := json.Marshal(settings)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
if err := server.settings(rr, req); err != nil {
|
||||||
|
t.Fatalf("settings() error = %v", err)
|
||||||
|
}
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("settings() status = %d, want %d", rr.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !notificationCalled.Load() {
|
||||||
|
t.Fatal("expected UpdateAvailableFunc to be called when re-enabling with a downloaded update")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettingsToggleAutoUpdateOn_NoPendingUpdate_TriggersCheck(t *testing.T) {
|
||||||
|
testStore := &store.Store{
|
||||||
|
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
|
||||||
|
}
|
||||||
|
defer testStore.Close()
|
||||||
|
|
||||||
|
// Start with auto-update disabled
|
||||||
|
settings, err := testStore.Settings()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
settings.AutoUpdateEnabled = false
|
||||||
|
if err := testStore.SetSettings(settings); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure no pending update - clear both the downloaded flag and the stage dir
|
||||||
|
oldVal := updater.UpdateDownloaded
|
||||||
|
updater.UpdateDownloaded = false
|
||||||
|
defer func() { updater.UpdateDownloaded = oldVal }()
|
||||||
|
|
||||||
|
oldStageDir := updater.UpdateStageDir
|
||||||
|
updater.UpdateStageDir = t.TempDir() // empty dir means IsUpdatePending() returns false
|
||||||
|
defer func() { updater.UpdateStageDir = oldStageDir }()
|
||||||
|
|
||||||
|
upd := &updater.Updater{Store: &store.Store{
|
||||||
|
DBPath: filepath.Join(t.TempDir(), "db2.sqlite"),
|
||||||
|
}}
|
||||||
|
defer upd.Store.Close()
|
||||||
|
|
||||||
|
// Initialize the checkNow channel by starting (and immediately stopping) the checker
|
||||||
|
// so TriggerImmediateCheck doesn't panic on nil channel
|
||||||
|
ctx, cancel := context.WithCancel(t.Context())
|
||||||
|
upd.StartBackgroundUpdaterChecker(ctx, func(string) error { return nil })
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var notificationCalled atomic.Bool
|
||||||
|
server := &Server{
|
||||||
|
Store: testStore,
|
||||||
|
Restart: func() {},
|
||||||
|
Updater: upd,
|
||||||
|
UpdateAvailableFunc: func() {
|
||||||
|
notificationCalled.Store(true)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-enable auto-update via settings API
|
||||||
|
settings.AutoUpdateEnabled = true
|
||||||
|
body, err := json.Marshal(settings)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
if err := server.settings(rr, req); err != nil {
|
||||||
|
t.Fatalf("settings() error = %v", err)
|
||||||
|
}
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("settings() status = %d, want %d", rr.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAvailableFunc should NOT be called since there's no pending update
|
||||||
|
if notificationCalled.Load() {
|
||||||
|
t.Fatal("UpdateAvailableFunc should not be called when there is no pending update")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/app/store"
|
"github.com/ollama/ollama/app/store"
|
||||||
@@ -58,7 +59,8 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
|
|||||||
query := requestURL.Query()
|
query := requestURL.Query()
|
||||||
query.Add("os", runtime.GOOS)
|
query.Add("os", runtime.GOOS)
|
||||||
query.Add("arch", runtime.GOARCH)
|
query.Add("arch", runtime.GOARCH)
|
||||||
query.Add("version", version.Version)
|
currentVersion := version.Version
|
||||||
|
query.Add("version", currentVersion)
|
||||||
query.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
query.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
||||||
|
|
||||||
// The original macOS app used to use the device ID
|
// The original macOS app used to use the device ID
|
||||||
@@ -131,15 +133,27 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
|
func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
|
||||||
|
// Create a cancellable context for this download
|
||||||
|
downloadCtx, cancel := context.WithCancel(ctx)
|
||||||
|
u.cancelDownloadLock.Lock()
|
||||||
|
u.cancelDownload = cancel
|
||||||
|
u.cancelDownloadLock.Unlock()
|
||||||
|
defer func() {
|
||||||
|
u.cancelDownloadLock.Lock()
|
||||||
|
u.cancelDownload = nil
|
||||||
|
u.cancelDownloadLock.Unlock()
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
// Do a head first to check etag info
|
// Do a head first to check etag info
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodHead, updateResp.UpdateURL, nil)
|
req, err := http.NewRequestWithContext(downloadCtx, http.MethodHead, updateResp.UpdateURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// In case of slow downloads, continue the update check in the background
|
// In case of slow downloads, continue the update check in the background
|
||||||
bgctx, cancel := context.WithCancel(ctx)
|
bgctx, bgcancel := context.WithCancel(downloadCtx)
|
||||||
defer cancel()
|
defer bgcancel()
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -176,6 +190,7 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
|
|||||||
_, err = os.Stat(stageFilename)
|
_, err = os.Stat(stageFilename)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
slog.Info("update already downloaded", "bundle", stageFilename)
|
slog.Info("update already downloaded", "bundle", stageFilename)
|
||||||
|
UpdateDownloaded = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -245,32 +260,84 @@ func cleanupOldDownloads(stageDir string) {
|
|||||||
|
|
||||||
type Updater struct {
|
type Updater struct {
|
||||||
Store *store.Store
|
Store *store.Store
|
||||||
|
cancelDownload context.CancelFunc
|
||||||
|
cancelDownloadLock sync.Mutex
|
||||||
|
checkNow chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CancelOngoingDownload cancels any currently running download
|
||||||
|
func (u *Updater) CancelOngoingDownload() {
|
||||||
|
u.cancelDownloadLock.Lock()
|
||||||
|
defer u.cancelDownloadLock.Unlock()
|
||||||
|
if u.cancelDownload != nil {
|
||||||
|
slog.Info("cancelling ongoing update download")
|
||||||
|
u.cancelDownload()
|
||||||
|
u.cancelDownload = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TriggerImmediateCheck signals the background checker to check for updates immediately
|
||||||
|
func (u *Updater) TriggerImmediateCheck() {
|
||||||
|
if u.checkNow != nil {
|
||||||
|
select {
|
||||||
|
case u.checkNow <- struct{}{}:
|
||||||
|
default:
|
||||||
|
// Check already pending, no need to queue another
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
|
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
|
||||||
|
u.checkNow = make(chan struct{}, 1)
|
||||||
|
u.checkNow <- struct{}{} // Trigger first check after initial delay
|
||||||
go func() {
|
go func() {
|
||||||
// Don't blast an update message immediately after startup
|
// Don't blast an update message immediately after startup
|
||||||
time.Sleep(UpdateCheckInitialDelay)
|
time.Sleep(UpdateCheckInitialDelay)
|
||||||
slog.Info("beginning update checker", "interval", UpdateCheckInterval)
|
slog.Info("beginning update checker", "interval", UpdateCheckInterval)
|
||||||
|
ticker := time.NewTicker(UpdateCheckInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
available, resp := u.checkForUpdate(ctx)
|
|
||||||
if available {
|
|
||||||
err := u.DownloadNewRelease(ctx, resp)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error(fmt.Sprintf("failed to download new release: %s", err))
|
|
||||||
} else {
|
|
||||||
err = cb(resp.UpdateVersion)
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn(fmt.Sprintf("failed to register update available with tray: %s", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
slog.Debug("stopping background update checker")
|
slog.Debug("stopping background update checker")
|
||||||
return
|
return
|
||||||
default:
|
case <-u.checkNow:
|
||||||
time.Sleep(UpdateCheckInterval)
|
// Immediate check triggered
|
||||||
|
case <-ticker.C:
|
||||||
|
// Regular interval check
|
||||||
|
}
|
||||||
|
|
||||||
|
// Always check for updates
|
||||||
|
available, resp := u.checkForUpdate(ctx)
|
||||||
|
if !available {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update is available - check if auto-update is enabled for downloading
|
||||||
|
settings, err := u.Store.Settings()
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to load settings", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !settings.AutoUpdateEnabled {
|
||||||
|
// Auto-update disabled - don't download, just log
|
||||||
|
slog.Debug("update available but auto-update disabled", "version", resp.UpdateVersion)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auto-update is enabled - download
|
||||||
|
err = u.DownloadNewRelease(ctx, resp)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to download new release", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Download successful - show tray notification
|
||||||
|
err = cb(resp.UpdateVersion)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to register update available with tray", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"path/filepath"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -33,7 +35,7 @@ func TestIsNewReleaseAvailable(t *testing.T) {
|
|||||||
defer server.Close()
|
defer server.Close()
|
||||||
slog.Debug("server", "url", server.URL)
|
slog.Debug("server", "url", server.URL)
|
||||||
|
|
||||||
updater := &Updater{Store: &store.Store{}}
|
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
|
||||||
defer updater.Store.Close() // Ensure database is closed
|
defer updater.Store.Close() // Ensure database is closed
|
||||||
UpdateCheckURLBase = server.URL + "/update.json"
|
UpdateCheckURLBase = server.URL + "/update.json"
|
||||||
updatePresent, resp := updater.checkForUpdate(t.Context())
|
updatePresent, resp := updater.checkForUpdate(t.Context())
|
||||||
@@ -84,8 +86,18 @@ func TestBackgoundChecker(t *testing.T) {
|
|||||||
defer server.Close()
|
defer server.Close()
|
||||||
UpdateCheckURLBase = server.URL + "/update.json"
|
UpdateCheckURLBase = server.URL + "/update.json"
|
||||||
|
|
||||||
updater := &Updater{Store: &store.Store{}}
|
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
|
||||||
defer updater.Store.Close() // Ensure database is closed
|
defer updater.Store.Close()
|
||||||
|
|
||||||
|
settings, err := updater.Store.Settings()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
settings.AutoUpdateEnabled = true
|
||||||
|
if err := updater.Store.SetSettings(settings); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
||||||
select {
|
select {
|
||||||
case <-stallTimer.C:
|
case <-stallTimer.C:
|
||||||
@@ -99,3 +111,267 @@ func TestBackgoundChecker(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAutoUpdateDisabledSkipsDownload(t *testing.T) {
|
||||||
|
UpdateStageDir = t.TempDir()
|
||||||
|
var downloadAttempted atomic.Bool
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(t.Context())
|
||||||
|
defer cancel()
|
||||||
|
UpdateCheckInitialDelay = 5 * time.Millisecond
|
||||||
|
UpdateCheckInterval = 5 * time.Millisecond
|
||||||
|
VerifyDownload = func() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var server *httptest.Server
|
||||||
|
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/update.json" {
|
||||||
|
w.Write([]byte(
|
||||||
|
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
|
||||||
|
server.URL+"/9.9.9/"+Installer)))
|
||||||
|
} else if r.URL.Path == "/9.9.9/"+Installer {
|
||||||
|
downloadAttempted.Store(true)
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
zw := zip.NewWriter(buf)
|
||||||
|
zw.Close()
|
||||||
|
io.Copy(w, buf)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
UpdateCheckURLBase = server.URL + "/update.json"
|
||||||
|
|
||||||
|
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
|
||||||
|
defer updater.Store.Close()
|
||||||
|
|
||||||
|
// Ensure auto-update is disabled
|
||||||
|
settings, err := updater.Store.Settings()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
settings.AutoUpdateEnabled = false
|
||||||
|
if err := updater.Store.SetSettings(settings); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cb := func(ver string) error {
|
||||||
|
t.Fatal("callback should not be called when auto-update is disabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
||||||
|
|
||||||
|
// Wait enough time for multiple check cycles
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
close(done)
|
||||||
|
|
||||||
|
if downloadAttempted.Load() {
|
||||||
|
t.Fatal("download should not be attempted when auto-update is disabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAutoUpdateReenabledDownloadsUpdate(t *testing.T) {
|
||||||
|
UpdateStageDir = t.TempDir()
|
||||||
|
var downloadAttempted atomic.Bool
|
||||||
|
callbackCalled := make(chan struct{}, 1)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(t.Context())
|
||||||
|
defer cancel()
|
||||||
|
UpdateCheckInitialDelay = 5 * time.Millisecond
|
||||||
|
UpdateCheckInterval = 5 * time.Millisecond
|
||||||
|
VerifyDownload = func() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var server *httptest.Server
|
||||||
|
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/update.json" {
|
||||||
|
w.Write([]byte(
|
||||||
|
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
|
||||||
|
server.URL+"/9.9.9/"+Installer)))
|
||||||
|
} else if r.URL.Path == "/9.9.9/"+Installer {
|
||||||
|
downloadAttempted.Store(true)
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
zw := zip.NewWriter(buf)
|
||||||
|
zw.Close()
|
||||||
|
io.Copy(w, buf)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
UpdateCheckURLBase = server.URL + "/update.json"
|
||||||
|
|
||||||
|
upd := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
|
||||||
|
defer upd.Store.Close()
|
||||||
|
|
||||||
|
// Start with auto-update disabled
|
||||||
|
settings, err := upd.Store.Settings()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
settings.AutoUpdateEnabled = false
|
||||||
|
if err := upd.Store.SetSettings(settings); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cb := func(ver string) error {
|
||||||
|
select {
|
||||||
|
case callbackCalled <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
upd.StartBackgroundUpdaterChecker(ctx, cb)
|
||||||
|
|
||||||
|
// Wait for a few cycles with auto-update disabled - no download should happen
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
if downloadAttempted.Load() {
|
||||||
|
t.Fatal("download should not happen while auto-update is disabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-enable auto-update
|
||||||
|
settings.AutoUpdateEnabled = true
|
||||||
|
if err := upd.Store.SetSettings(settings); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the checker to pick it up and download
|
||||||
|
select {
|
||||||
|
case <-callbackCalled:
|
||||||
|
// Success: download happened and callback was called after re-enabling
|
||||||
|
if !downloadAttempted.Load() {
|
||||||
|
t.Fatal("expected download to be attempted after re-enabling")
|
||||||
|
}
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("expected download and callback after re-enabling auto-update")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCancelOngoingDownload(t *testing.T) {
|
||||||
|
UpdateStageDir = t.TempDir()
|
||||||
|
downloadStarted := make(chan struct{})
|
||||||
|
downloadCancelled := make(chan struct{})
|
||||||
|
|
||||||
|
ctx := t.Context()
|
||||||
|
VerifyDownload = func() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var server *httptest.Server
|
||||||
|
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/update.json" {
|
||||||
|
w.Write([]byte(
|
||||||
|
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
|
||||||
|
server.URL+"/9.9.9/"+Installer)))
|
||||||
|
} else if r.URL.Path == "/9.9.9/"+Installer {
|
||||||
|
if r.Method == http.MethodHead {
|
||||||
|
w.Header().Set("Content-Length", "1000000")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Signal that download has started
|
||||||
|
close(downloadStarted)
|
||||||
|
// Wait for cancellation or timeout
|
||||||
|
select {
|
||||||
|
case <-r.Context().Done():
|
||||||
|
close(downloadCancelled)
|
||||||
|
return
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Error("download was not cancelled in time")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
UpdateCheckURLBase = server.URL + "/update.json"
|
||||||
|
|
||||||
|
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
|
||||||
|
defer updater.Store.Close()
|
||||||
|
|
||||||
|
_, resp := updater.checkForUpdate(ctx)
|
||||||
|
|
||||||
|
// Start download in goroutine
|
||||||
|
go func() {
|
||||||
|
_ = updater.DownloadNewRelease(ctx, resp)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for download to start
|
||||||
|
select {
|
||||||
|
case <-downloadStarted:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("download did not start in time")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cancel the download
|
||||||
|
updater.CancelOngoingDownload()
|
||||||
|
|
||||||
|
// Verify cancellation was received
|
||||||
|
select {
|
||||||
|
case <-downloadCancelled:
|
||||||
|
// Success
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("download cancellation was not received by server")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTriggerImmediateCheck(t *testing.T) {
|
||||||
|
UpdateStageDir = t.TempDir()
|
||||||
|
checkCount := atomic.Int32{}
|
||||||
|
checkDone := make(chan struct{}, 10)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(t.Context())
|
||||||
|
defer cancel()
|
||||||
|
// Set a very long interval so only TriggerImmediateCheck causes checks
|
||||||
|
UpdateCheckInitialDelay = 1 * time.Millisecond
|
||||||
|
UpdateCheckInterval = 1 * time.Hour
|
||||||
|
VerifyDownload = func() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/update.json" {
|
||||||
|
checkCount.Add(1)
|
||||||
|
select {
|
||||||
|
case checkDone <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
// Return no update available
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
UpdateCheckURLBase = server.URL + "/update.json"
|
||||||
|
|
||||||
|
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
|
||||||
|
defer updater.Store.Close()
|
||||||
|
|
||||||
|
cb := func(ver string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
||||||
|
|
||||||
|
// Wait for the initial check that fires after the initial delay
|
||||||
|
select {
|
||||||
|
case <-checkDone:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("initial check did not happen")
|
||||||
|
}
|
||||||
|
|
||||||
|
initialCount := checkCount.Load()
|
||||||
|
|
||||||
|
// Trigger immediate check
|
||||||
|
updater.TriggerImmediateCheck()
|
||||||
|
|
||||||
|
// Wait for the triggered check
|
||||||
|
select {
|
||||||
|
case <-checkDone:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("triggered check did not happen")
|
||||||
|
}
|
||||||
|
|
||||||
|
finalCount := checkCount.Load()
|
||||||
|
if finalCount <= initialCount {
|
||||||
|
t.Fatalf("TriggerImmediateCheck did not cause additional check: initial=%d, final=%d", initialCount, finalCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -369,25 +369,6 @@ func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// func (t *winTray) hideMenuItem(menuItemId, parentId uint32) error {
|
|
||||||
// const ERROR_SUCCESS syscall.Errno = 0
|
|
||||||
|
|
||||||
// t.muMenus.RLock()
|
|
||||||
// menu := uintptr(t.menus[parentId])
|
|
||||||
// t.muMenus.RUnlock()
|
|
||||||
// res, _, err := pRemoveMenu.Call(
|
|
||||||
// menu,
|
|
||||||
// uintptr(menuItemId),
|
|
||||||
// MF_BYCOMMAND,
|
|
||||||
// )
|
|
||||||
// if res == 0 && err.(syscall.Errno) != ERROR_SUCCESS {
|
|
||||||
// return err
|
|
||||||
// }
|
|
||||||
// t.delFromVisibleItems(parentId, menuItemId)
|
|
||||||
|
|
||||||
// return nil
|
|
||||||
// }
|
|
||||||
|
|
||||||
func (t *winTray) showMenu() error {
|
func (t *winTray) showMenu() error {
|
||||||
p := point{}
|
p := point{}
|
||||||
boolRet, _, err := pGetCursorPos.Call(uintptr(unsafe.Pointer(&p)))
|
boolRet, _, err := pGetCursorPos.Call(uintptr(unsafe.Pointer(&p)))
|
||||||
|
|||||||
@@ -51,7 +51,6 @@ const (
|
|||||||
IMAGE_ICON = 1 // Loads an icon
|
IMAGE_ICON = 1 // Loads an icon
|
||||||
LR_DEFAULTSIZE = 0x00000040 // Loads default-size icon for windows(SM_CXICON x SM_CYICON) if cx, cy are set to zero
|
LR_DEFAULTSIZE = 0x00000040 // Loads default-size icon for windows(SM_CXICON x SM_CYICON) if cx, cy are set to zero
|
||||||
LR_LOADFROMFILE = 0x00000010 // Loads the stand-alone image from the file
|
LR_LOADFROMFILE = 0x00000010 // Loads the stand-alone image from the file
|
||||||
MF_BYCOMMAND = 0x00000000
|
|
||||||
MFS_DISABLED = 0x00000003
|
MFS_DISABLED = 0x00000003
|
||||||
MFT_SEPARATOR = 0x00000800
|
MFT_SEPARATOR = 0x00000800
|
||||||
MFT_STRING = 0x00000000
|
MFT_STRING = 0x00000000
|
||||||
|
|||||||
13
cmd/background_unix.go
Normal file
13
cmd/background_unix.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import "syscall"
|
||||||
|
|
||||||
|
// backgroundServerSysProcAttr returns SysProcAttr for running the server in the background on Unix.
|
||||||
|
// Setpgid prevents the server from being killed when the parent process exits.
|
||||||
|
func backgroundServerSysProcAttr() *syscall.SysProcAttr {
|
||||||
|
return &syscall.SysProcAttr{
|
||||||
|
Setpgid: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
12
cmd/background_windows.go
Normal file
12
cmd/background_windows.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import "syscall"
|
||||||
|
|
||||||
|
// backgroundServerSysProcAttr returns SysProcAttr for running the server in the background on Windows.
|
||||||
|
// CREATE_NO_WINDOW (0x08000000) prevents a console window from appearing.
|
||||||
|
func backgroundServerSysProcAttr() *syscall.SysProcAttr {
|
||||||
|
return &syscall.SysProcAttr{
|
||||||
|
CreationFlags: 0x08000000,
|
||||||
|
HideWindow: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,27 +1,31 @@
|
|||||||
Ollama Benchmark Tool
|
Ollama Benchmark Tool
|
||||||
---------------------
|
---------------------
|
||||||
|
|
||||||
A Go-based command-line tool for benchmarking Ollama models with configurable parameters and multiple output formats.
|
A Go-based command-line tool for benchmarking Ollama models with configurable parameters, warmup phases, TTFT tracking, VRAM monitoring, and benchstat/CSV output.
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
* Benchmark multiple models in a single run
|
* Benchmark multiple models in a single run
|
||||||
* Support for both text and image prompts
|
* Support for both text and image prompts
|
||||||
* Configurable generation parameters (temperature, max tokens, seed, etc.)
|
* Configurable generation parameters (temperature, max tokens, seed, etc.)
|
||||||
* Supports benchstat and CSV output formats
|
* Warmup phase before timed epochs to stabilize measurements
|
||||||
* Detailed performance metrics (prefill, generate, load, total durations)
|
* Time-to-first-token (TTFT) tracking per epoch
|
||||||
|
* Model metadata display (parameter size, quantization level, family)
|
||||||
|
* VRAM and CPU memory usage tracking via running process info
|
||||||
|
* Controlled prompt token length for reproducible benchmarks
|
||||||
|
* Benchstat and CSV output formats
|
||||||
|
|
||||||
## Building from Source
|
## Building from Source
|
||||||
|
|
||||||
```
|
```
|
||||||
go build -o ollama-bench bench.go
|
go build -o ollama-bench ./cmd/bench
|
||||||
./ollama-bench -model gpt-oss:20b -epochs 6 -format csv
|
./ollama-bench -model gemma3 -epochs 6 -format csv
|
||||||
```
|
```
|
||||||
|
|
||||||
Using Go Run (without building)
|
Using Go Run (without building)
|
||||||
|
|
||||||
```
|
```
|
||||||
go run bench.go -model gpt-oss:20b -epochs 3
|
go run ./cmd/bench -model gemma3 -epochs 3
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
@@ -45,10 +49,16 @@ benchstat -col /name gemma.bench
|
|||||||
./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
|
./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Controlled Prompt Length
|
||||||
|
|
||||||
|
```
|
||||||
|
./ollama-bench -model gemma3 -epochs 6 -prompt-tokens 512
|
||||||
|
```
|
||||||
|
|
||||||
### Advanced Example
|
### Advanced Example
|
||||||
|
|
||||||
```
|
```
|
||||||
./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv
|
./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -warmup 2 -format csv -output results.csv
|
||||||
```
|
```
|
||||||
|
|
||||||
## Command Line Options
|
## Command Line Options
|
||||||
@@ -56,41 +66,48 @@ benchstat -col /name gemma.bench
|
|||||||
| Option | Description | Default |
|
| Option | Description | Default |
|
||||||
|----------|-------------|---------|
|
|----------|-------------|---------|
|
||||||
| -model | Comma-separated list of models to benchmark | (required) |
|
| -model | Comma-separated list of models to benchmark | (required) |
|
||||||
| -epochs | Number of iterations per model | 1 |
|
| -epochs | Number of iterations per model | 6 |
|
||||||
| -max-tokens | Maximum tokens for model response | 0 (unlimited) |
|
| -max-tokens | Maximum tokens for model response | 200 |
|
||||||
| -temperature | Temperature parameter | 0.0 |
|
| -temperature | Temperature parameter | 0.0 |
|
||||||
| -seed | Random seed | 0 (random) |
|
| -seed | Random seed | 0 (random) |
|
||||||
| -timeout | Timeout in seconds | 300 |
|
| -timeout | Timeout in seconds | 300 |
|
||||||
| -p | Prompt text | "Write a long story." |
|
| -p | Prompt text | (default story prompt) |
|
||||||
| -image | Image file to include in prompt | |
|
| -image | Image file to include in prompt | |
|
||||||
| -k | Keep-alive duration in seconds | 0 |
|
| -k | Keep-alive duration in seconds | 0 |
|
||||||
| -format | Output format (benchstat, csv) | benchstat |
|
| -format | Output format (benchstat, csv) | benchstat |
|
||||||
| -output | Output file for results | "" (stdout) |
|
| -output | Output file for results | "" (stdout) |
|
||||||
|
| -warmup | Number of warmup requests before timing | 1 |
|
||||||
|
| -prompt-tokens | Generate prompt targeting ~N tokens (0 = use -p) | 0 |
|
||||||
| -v | Verbose mode | false |
|
| -v | Verbose mode | false |
|
||||||
| -debug | Show debug information | false |
|
| -debug | Show debug information | false |
|
||||||
|
|
||||||
## Output Formats
|
## Output Formats
|
||||||
|
|
||||||
### Markdown Format
|
### Benchstat Format (default)
|
||||||
|
|
||||||
The default markdown format is suitable for copying and pasting into a GitHub issue and will look like:
|
Compatible with Go's benchstat tool for statistical analysis. Uses one value/unit pair per line, standard `ns/op` for timing metrics, and `ns/token` for throughput. Each epoch produces one set of lines -- benchstat aggregates across repeated runs to compute statistics.
|
||||||
```
|
|
||||||
Model | Step | Count | Duration | nsPerToken | tokensPerSec |
|
|
||||||
|-------|------|-------|----------|------------|--------------|
|
|
||||||
| gpt-oss:20b | prefill | 124 | 30.006458ms | 241987.56 | 4132.44 |
|
|
||||||
| gpt-oss:20b | generate | 200 | 2.646843954s | 13234219.77 | 75.56 |
|
|
||||||
| gpt-oss:20b | load | 1 | 121.674208ms | - | - |
|
|
||||||
| gpt-oss:20b | total | 1 | 2.861047625s | - | - |
|
|
||||||
```
|
|
||||||
|
|
||||||
### Benchstat Format
|
|
||||||
|
|
||||||
Compatible with Go's benchstat tool for statistical analysis:
|
|
||||||
|
|
||||||
```
|
```
|
||||||
BenchmarkModel/name=gpt-oss:20b/step=prefill 128 78125.00 ns/token 12800.00 token/sec
|
# Model: gemma3 | Params: 4.3B | Quant: Q4_K_M | Family: gemma3 | Size: 4080218931 | VRAM: 4080218931
|
||||||
BenchmarkModel/name=gpt-oss:20b/step=generate 512 19531.25 ns/token 51200.00 token/sec
|
BenchmarkModel/name=gemma3/step=prefill 1 78125.00 ns/token 12800.00 token/sec
|
||||||
BenchmarkModel/name=gpt-oss:20b/step=load 1 1500000000 ns/request
|
BenchmarkModel/name=gemma3/step=generate 1 19531.25 ns/token 51200.00 token/sec
|
||||||
|
BenchmarkModel/name=gemma3/step=ttft 1 45123000 ns/op
|
||||||
|
BenchmarkModel/name=gemma3/step=load 1 1500000000 ns/op
|
||||||
|
BenchmarkModel/name=gemma3/step=total 1 2861047625 ns/op
|
||||||
|
```
|
||||||
|
|
||||||
|
Use with benchstat:
|
||||||
|
```
|
||||||
|
./ollama-bench -model gemma3 -epochs 6 > gemma3.bench
|
||||||
|
benchstat -col /step gemma3.bench
|
||||||
|
```
|
||||||
|
|
||||||
|
Compare two runs:
|
||||||
|
```
|
||||||
|
./ollama-bench -model gemma3 -epochs 6 > before.bench
|
||||||
|
# ... make changes ...
|
||||||
|
./ollama-bench -model gemma3 -epochs 6 > after.bench
|
||||||
|
benchstat before.bench after.bench
|
||||||
```
|
```
|
||||||
|
|
||||||
### CSV Format
|
### CSV Format
|
||||||
@@ -99,17 +116,28 @@ Machine-readable comma-separated values:
|
|||||||
|
|
||||||
```
|
```
|
||||||
NAME,STEP,COUNT,NS_PER_COUNT,TOKEN_PER_SEC
|
NAME,STEP,COUNT,NS_PER_COUNT,TOKEN_PER_SEC
|
||||||
gpt-oss:20b,prefill,128,78125.00,12800.00
|
# Model: gemma3 | Params: 4.3B | Quant: Q4_K_M | Family: gemma3 | Size: 4080218931 | VRAM: 4080218931
|
||||||
gpt-oss:20b,generate,512,19531.25,51200.00
|
gemma3,prefill,128,78125.00,12800.00
|
||||||
gpt-oss:20b,load,1,1500000000,0
|
gemma3,generate,512,19531.25,51200.00
|
||||||
|
gemma3,ttft,1,45123000,0
|
||||||
|
gemma3,load,1,1500000000,0
|
||||||
|
gemma3,total,1,2861047625,0
|
||||||
```
|
```
|
||||||
|
|
||||||
## Metrics Explained
|
## Metrics Explained
|
||||||
|
|
||||||
The tool reports four types of metrics for each model:
|
The tool reports the following metrics for each epoch:
|
||||||
|
|
||||||
* prefill: Time spent processing the prompt
|
* **prefill**: Time spent processing the prompt (ns/token)
|
||||||
* generate: Time spent generating the response
|
* **generate**: Time spent generating the response (ns/token)
|
||||||
* load: Model loading time (one-time cost)
|
* **ttft**: Time to first token -- latency from request start to first response content
|
||||||
* total: Total request duration
|
* **load**: Model loading time (one-time cost)
|
||||||
|
* **total**: Total request duration
|
||||||
|
|
||||||
|
Additionally, the model info comment line (displayed once per model before epochs) includes:
|
||||||
|
|
||||||
|
* **Params**: Model parameter count (e.g., 4.3B)
|
||||||
|
* **Quant**: Quantization level (e.g., Q4_K_M)
|
||||||
|
* **Family**: Model family (e.g., gemma3)
|
||||||
|
* **Size**: Total model memory in bytes
|
||||||
|
* **VRAM**: GPU memory used by the loaded model (when Size > VRAM, the difference is CPU spill)
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ type flagOptions struct {
|
|||||||
outputFile *string
|
outputFile *string
|
||||||
debug *bool
|
debug *bool
|
||||||
verbose *bool
|
verbose *bool
|
||||||
|
warmup *int
|
||||||
|
promptTokens *int
|
||||||
}
|
}
|
||||||
|
|
||||||
type Metrics struct {
|
type Metrics struct {
|
||||||
@@ -39,48 +41,169 @@ type Metrics struct {
|
|||||||
Duration time.Duration
|
Duration time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
var once sync.Once
|
type ModelInfo struct {
|
||||||
|
Name string
|
||||||
|
ParameterSize string
|
||||||
|
QuantizationLevel string
|
||||||
|
Family string
|
||||||
|
SizeBytes int64
|
||||||
|
VRAMBytes int64
|
||||||
|
}
|
||||||
|
|
||||||
const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.`
|
const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.`
|
||||||
|
|
||||||
|
// Word list for generating prompts targeting a specific token count.
|
||||||
|
var promptWordList = []string{
|
||||||
|
"the", "quick", "brown", "fox", "jumps", "over", "lazy", "dog",
|
||||||
|
"a", "bright", "sunny", "day", "in", "the", "meadow", "where",
|
||||||
|
"flowers", "bloom", "and", "birds", "sing", "their", "morning",
|
||||||
|
"songs", "while", "gentle", "breeze", "carries", "sweet", "scent",
|
||||||
|
"of", "pine", "trees", "across", "rolling", "hills", "toward",
|
||||||
|
"distant", "mountains", "covered", "with", "fresh", "snow",
|
||||||
|
"beneath", "clear", "blue", "sky", "children", "play", "near",
|
||||||
|
"old", "stone", "bridge", "that", "crosses", "winding", "river",
|
||||||
|
}
|
||||||
|
|
||||||
|
func generatePromptForTokenCount(targetTokens int, epoch int) string {
|
||||||
|
// ~1.3 tokens per word heuristic
|
||||||
|
targetWords := int(float64(targetTokens) / 1.3)
|
||||||
|
if targetWords < 1 {
|
||||||
|
targetWords = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vary the starting offset by epoch to defeat KV cache prefix matching
|
||||||
|
offset := epoch * 7 // stride by a prime to get good distribution
|
||||||
|
n := len(promptWordList)
|
||||||
|
words := make([]string, targetWords)
|
||||||
|
for i := range words {
|
||||||
|
words[i] = promptWordList[((i+offset)%n+n)%n]
|
||||||
|
}
|
||||||
|
return strings.Join(words, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildGenerateRequest(model string, fOpt flagOptions, imgData api.ImageData, epoch int) *api.GenerateRequest {
|
||||||
|
options := make(map[string]interface{})
|
||||||
|
if *fOpt.maxTokens > 0 {
|
||||||
|
options["num_predict"] = *fOpt.maxTokens
|
||||||
|
}
|
||||||
|
options["temperature"] = *fOpt.temperature
|
||||||
|
if fOpt.seed != nil && *fOpt.seed > 0 {
|
||||||
|
options["seed"] = *fOpt.seed
|
||||||
|
}
|
||||||
|
|
||||||
|
var keepAliveDuration *api.Duration
|
||||||
|
if *fOpt.keepAlive > 0 {
|
||||||
|
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
|
||||||
|
keepAliveDuration = &duration
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt := *fOpt.prompt
|
||||||
|
if *fOpt.promptTokens > 0 {
|
||||||
|
prompt = generatePromptForTokenCount(*fOpt.promptTokens, epoch)
|
||||||
|
} else {
|
||||||
|
// Vary the prompt per epoch to defeat KV cache prefix matching
|
||||||
|
prompt = fmt.Sprintf("[%d] %s", epoch, prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: model,
|
||||||
|
Prompt: prompt,
|
||||||
|
Raw: true,
|
||||||
|
Options: options,
|
||||||
|
KeepAlive: keepAliveDuration,
|
||||||
|
}
|
||||||
|
|
||||||
|
if imgData != nil {
|
||||||
|
req.Images = []api.ImageData{imgData}
|
||||||
|
}
|
||||||
|
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchModelInfo(ctx context.Context, client *api.Client, model string) ModelInfo {
|
||||||
|
info := ModelInfo{Name: model}
|
||||||
|
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "WARNING: Could not fetch model info for '%s': %v\n", model, err)
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
info.ParameterSize = resp.Details.ParameterSize
|
||||||
|
info.QuantizationLevel = resp.Details.QuantizationLevel
|
||||||
|
info.Family = resp.Details.Family
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchMemoryUsage(ctx context.Context, client *api.Client, model string) (size, vram int64) {
|
||||||
|
resp, err := client.ListRunning(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
||||||
|
fmt.Fprintf(os.Stderr, "WARNING: Could not fetch memory usage: %v\n", err)
|
||||||
|
}
|
||||||
|
return 0, 0
|
||||||
|
}
|
||||||
|
for _, m := range resp.Models {
|
||||||
|
if m.Name == model || m.Model == model {
|
||||||
|
return m.Size, m.SizeVRAM
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Try prefix match (model names may include :latest or tags)
|
||||||
|
for _, m := range resp.Models {
|
||||||
|
if strings.HasPrefix(m.Name, model) || strings.HasPrefix(m.Model, model) {
|
||||||
|
return m.Size, m.SizeVRAM
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func outputFormatHeader(w io.Writer, format string, verbose bool) {
|
||||||
|
switch format {
|
||||||
|
case "benchstat":
|
||||||
|
if verbose {
|
||||||
|
fmt.Fprintf(w, "goos: %s\n", runtime.GOOS)
|
||||||
|
fmt.Fprintf(w, "goarch: %s\n", runtime.GOARCH)
|
||||||
|
}
|
||||||
|
case "csv":
|
||||||
|
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
|
||||||
|
fmt.Fprintln(w, strings.Join(headings, ","))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func outputModelInfo(w io.Writer, format string, info ModelInfo) {
|
||||||
|
params := cmp.Or(info.ParameterSize, "unknown")
|
||||||
|
quant := cmp.Or(info.QuantizationLevel, "unknown")
|
||||||
|
family := cmp.Or(info.Family, "unknown")
|
||||||
|
|
||||||
|
memStr := ""
|
||||||
|
if info.SizeBytes > 0 {
|
||||||
|
memStr = fmt.Sprintf(" | Size: %d | VRAM: %d", info.SizeBytes, info.VRAMBytes)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(w, "# Model: %s | Params: %s | Quant: %s | Family: %s%s\n",
|
||||||
|
info.Name, params, quant, family, memStr)
|
||||||
|
}
|
||||||
|
|
||||||
func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) {
|
func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) {
|
||||||
switch format {
|
switch format {
|
||||||
case "benchstat":
|
case "benchstat":
|
||||||
if verbose {
|
|
||||||
printHeader := func() {
|
|
||||||
fmt.Fprintf(w, "sysname: %s\n", runtime.GOOS)
|
|
||||||
fmt.Fprintf(w, "machine: %s\n", runtime.GOARCH)
|
|
||||||
}
|
|
||||||
once.Do(printHeader)
|
|
||||||
}
|
|
||||||
for _, m := range metrics {
|
for _, m := range metrics {
|
||||||
if m.Step == "generate" || m.Step == "prefill" {
|
if m.Step == "generate" || m.Step == "prefill" {
|
||||||
if m.Count > 0 {
|
if m.Count > 0 {
|
||||||
nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
||||||
tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
||||||
|
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 %.2f ns/token %.2f token/sec\n",
|
||||||
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d %.2f ns/token %.2f token/sec\n",
|
m.Model, m.Step, nsPerToken, tokensPerSec)
|
||||||
m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
|
|
||||||
} else {
|
} else {
|
||||||
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d 0 ns/token 0 token/sec\n",
|
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 0 ns/token 0 token/sec\n",
|
||||||
m.Model, m.Step, m.Count)
|
m.Model, m.Step)
|
||||||
}
|
}
|
||||||
|
} else if m.Step == "ttft" {
|
||||||
|
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=ttft 1 %d ns/op\n",
|
||||||
|
m.Model, m.Duration.Nanoseconds())
|
||||||
} else {
|
} else {
|
||||||
var suffix string
|
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 %d ns/op\n",
|
||||||
if m.Step == "load" {
|
m.Model, m.Step, m.Duration.Nanoseconds())
|
||||||
suffix = "/step=load"
|
|
||||||
}
|
|
||||||
fmt.Fprintf(w, "BenchmarkModel/name=%s%s 1 %d ns/request\n",
|
|
||||||
m.Model, suffix, m.Duration.Nanoseconds())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "csv":
|
case "csv":
|
||||||
printHeader := func() {
|
|
||||||
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
|
|
||||||
fmt.Fprintln(w, strings.Join(headings, ","))
|
|
||||||
}
|
|
||||||
once.Do(printHeader)
|
|
||||||
|
|
||||||
for _, m := range metrics {
|
for _, m := range metrics {
|
||||||
if m.Step == "generate" || m.Step == "prefill" {
|
if m.Step == "generate" || m.Step == "prefill" {
|
||||||
var nsPerToken float64
|
var nsPerToken float64
|
||||||
@@ -94,39 +217,14 @@ func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool)
|
|||||||
fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds())
|
fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "markdown":
|
|
||||||
printHeader := func() {
|
|
||||||
fmt.Fprintln(w, "| Model | Step | Count | Duration | nsPerToken | tokensPerSec |")
|
|
||||||
fmt.Fprintln(w, "|-------|------|-------|----------|------------|--------------|")
|
|
||||||
}
|
|
||||||
once.Do(printHeader)
|
|
||||||
|
|
||||||
for _, m := range metrics {
|
|
||||||
var nsPerToken, tokensPerSec float64
|
|
||||||
var nsPerTokenStr, tokensPerSecStr string
|
|
||||||
|
|
||||||
if m.Step == "generate" || m.Step == "prefill" {
|
|
||||||
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
|
||||||
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
|
||||||
nsPerTokenStr = fmt.Sprintf("%.2f", nsPerToken)
|
|
||||||
tokensPerSecStr = fmt.Sprintf("%.2f", tokensPerSec)
|
|
||||||
} else {
|
|
||||||
nsPerTokenStr = "-"
|
|
||||||
tokensPerSecStr = "-"
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Fprintf(w, "| %s | %s | %d | %v | %s | %s |\n",
|
|
||||||
m.Model, m.Step, m.Count, m.Duration, nsPerTokenStr, tokensPerSecStr)
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format)
|
fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkChat(fOpt flagOptions) error {
|
func BenchmarkModel(fOpt flagOptions) error {
|
||||||
models := strings.Split(*fOpt.models, ",")
|
models := strings.Split(*fOpt.models, ",")
|
||||||
|
|
||||||
// todo - add multi-image support
|
|
||||||
var imgData api.ImageData
|
var imgData api.ImageData
|
||||||
var err error
|
var err error
|
||||||
if *fOpt.imageFile != "" {
|
if *fOpt.imageFile != "" {
|
||||||
@@ -158,54 +256,83 @@ func BenchmarkChat(fOpt flagOptions) error {
|
|||||||
out = f
|
out = f
|
||||||
}
|
}
|
||||||
|
|
||||||
|
outputFormatHeader(out, *fOpt.format, *fOpt.verbose)
|
||||||
|
|
||||||
|
// Log prompt-tokens info in debug mode
|
||||||
|
if *fOpt.debug && *fOpt.promptTokens > 0 {
|
||||||
|
prompt := generatePromptForTokenCount(*fOpt.promptTokens, 0)
|
||||||
|
wordCount := len(strings.Fields(prompt))
|
||||||
|
fmt.Fprintf(os.Stderr, "Generated prompt targeting ~%d tokens (%d words, varied per epoch)\n", *fOpt.promptTokens, wordCount)
|
||||||
|
}
|
||||||
|
|
||||||
for _, model := range models {
|
for _, model := range models {
|
||||||
for range *fOpt.epochs {
|
// Fetch model info
|
||||||
options := make(map[string]interface{})
|
infoCtx, infoCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
if *fOpt.maxTokens > 0 {
|
info := fetchModelInfo(infoCtx, client, model)
|
||||||
options["num_predict"] = *fOpt.maxTokens
|
infoCancel()
|
||||||
|
|
||||||
|
// Warmup phase (uses negative epoch numbers to avoid colliding with timed epochs)
|
||||||
|
for i := range *fOpt.warmup {
|
||||||
|
req := buildGenerateRequest(model, fOpt, imgData, -(i + 1))
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
|
||||||
|
|
||||||
|
err = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "WARNING: Warmup %d/%d for %s failed: %v\n", i+1, *fOpt.warmup, model, err)
|
||||||
|
} else if *fOpt.debug {
|
||||||
|
fmt.Fprintf(os.Stderr, "Warmup %d/%d for %s complete\n", i+1, *fOpt.warmup, model)
|
||||||
}
|
}
|
||||||
options["temperature"] = *fOpt.temperature
|
|
||||||
if fOpt.seed != nil && *fOpt.seed > 0 {
|
|
||||||
options["seed"] = *fOpt.seed
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var keepAliveDuration *api.Duration
|
// Fetch memory usage once after warmup (model is loaded and stable)
|
||||||
if *fOpt.keepAlive > 0 {
|
memCtx, memCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
|
info.SizeBytes, info.VRAMBytes = fetchMemoryUsage(memCtx, client, model)
|
||||||
keepAliveDuration = &duration
|
memCancel()
|
||||||
}
|
|
||||||
|
|
||||||
req := &api.ChatRequest{
|
outputModelInfo(out, *fOpt.format, info)
|
||||||
Model: model,
|
|
||||||
Messages: []api.Message{
|
|
||||||
{
|
|
||||||
Role: "user",
|
|
||||||
Content: *fOpt.prompt,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Options: options,
|
|
||||||
KeepAlive: keepAliveDuration,
|
|
||||||
}
|
|
||||||
|
|
||||||
if imgData != nil {
|
|
||||||
req.Messages[0].Images = []api.ImageData{imgData}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// Timed epoch loop
|
||||||
|
shortCount := 0
|
||||||
|
for epoch := range *fOpt.epochs {
|
||||||
var responseMetrics *api.Metrics
|
var responseMetrics *api.Metrics
|
||||||
|
var ttft time.Duration
|
||||||
|
short := false
|
||||||
|
|
||||||
|
// Retry loop: if the model hits a stop token before max-tokens,
|
||||||
|
// retry with a different prompt (up to maxRetries times).
|
||||||
|
const maxRetries = 3
|
||||||
|
for attempt := range maxRetries + 1 {
|
||||||
|
responseMetrics = nil
|
||||||
|
ttft = 0
|
||||||
|
var ttftOnce sync.Once
|
||||||
|
|
||||||
|
req := buildGenerateRequest(model, fOpt, imgData, epoch+attempt*1000)
|
||||||
|
requestStart := time.Now()
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
err = client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
err = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||||
if *fOpt.debug {
|
if *fOpt.debug {
|
||||||
fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Message.Thinking, resp.Message.Content))
|
fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Thinking, resp.Response))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Capture TTFT on first content
|
||||||
|
ttftOnce.Do(func() {
|
||||||
|
if resp.Response != "" || resp.Thinking != "" {
|
||||||
|
ttft = time.Since(requestStart)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
if resp.Done {
|
if resp.Done {
|
||||||
responseMetrics = &resp.Metrics
|
responseMetrics = &resp.Metrics
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
cancel()
|
||||||
|
|
||||||
if *fOpt.debug {
|
if *fOpt.debug {
|
||||||
fmt.Fprintln(os.Stderr)
|
fmt.Fprintln(os.Stderr)
|
||||||
@@ -213,18 +340,42 @@ func BenchmarkChat(fOpt flagOptions) error {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if ctx.Err() == context.DeadlineExceeded {
|
if ctx.Err() == context.DeadlineExceeded {
|
||||||
fmt.Fprintf(os.Stderr, "ERROR: Chat request timed out with model '%s' after %vs\n", model, 1)
|
fmt.Fprintf(os.Stderr, "ERROR: Request timed out with model '%s' after %vs\n", model, *fOpt.timeout)
|
||||||
continue
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "ERROR: Couldn't generate with model '%s': %v\n", model, err)
|
||||||
}
|
}
|
||||||
fmt.Fprintf(os.Stderr, "ERROR: Couldn't chat with model '%s': %v\n", model, err)
|
break
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if responseMetrics == nil {
|
if responseMetrics == nil {
|
||||||
fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model)
|
fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the response was shorter than requested
|
||||||
|
short = *fOpt.maxTokens > 0 && responseMetrics.EvalCount < *fOpt.maxTokens
|
||||||
|
if !short || attempt == maxRetries {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if *fOpt.debug {
|
||||||
|
fmt.Fprintf(os.Stderr, "Short response (%d/%d tokens), retrying with different prompt (attempt %d/%d)\n",
|
||||||
|
responseMetrics.EvalCount, *fOpt.maxTokens, attempt+1, maxRetries)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil || responseMetrics == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if short {
|
||||||
|
shortCount++
|
||||||
|
if *fOpt.debug {
|
||||||
|
fmt.Fprintf(os.Stderr, "WARNING: Short response (%d/%d tokens) after %d retries for epoch %d\n",
|
||||||
|
responseMetrics.EvalCount, *fOpt.maxTokens, maxRetries, epoch+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
metrics := []Metrics{
|
metrics := []Metrics{
|
||||||
{
|
{
|
||||||
Model: model,
|
Model: model,
|
||||||
@@ -238,6 +389,12 @@ func BenchmarkChat(fOpt flagOptions) error {
|
|||||||
Count: responseMetrics.EvalCount,
|
Count: responseMetrics.EvalCount,
|
||||||
Duration: responseMetrics.EvalDuration,
|
Duration: responseMetrics.EvalDuration,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Model: model,
|
||||||
|
Step: "ttft",
|
||||||
|
Count: 1,
|
||||||
|
Duration: ttft,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Model: model,
|
Model: model,
|
||||||
Step: "load",
|
Step: "load",
|
||||||
@@ -254,15 +411,42 @@ func BenchmarkChat(fOpt flagOptions) error {
|
|||||||
|
|
||||||
OutputMetrics(out, *fOpt.format, metrics, *fOpt.verbose)
|
OutputMetrics(out, *fOpt.format, metrics, *fOpt.verbose)
|
||||||
|
|
||||||
|
if *fOpt.debug && *fOpt.promptTokens > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "Generated prompt targeting ~%d tokens (actual: %d)\n",
|
||||||
|
*fOpt.promptTokens, responseMetrics.PromptEvalCount)
|
||||||
|
}
|
||||||
|
|
||||||
if *fOpt.keepAlive > 0 {
|
if *fOpt.keepAlive > 0 {
|
||||||
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
|
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if shortCount > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "WARNING: %d/%d epochs for '%s' had short responses (<%d tokens). Generation metrics may be unreliable.\n",
|
||||||
|
shortCount, *fOpt.epochs, model, *fOpt.maxTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unload model before moving to the next one
|
||||||
|
unloadModel(client, model, *fOpt.timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func unloadModel(client *api.Client, model string, timeout int) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
zero := api.Duration{Duration: 0}
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: model,
|
||||||
|
KeepAlive: &zero,
|
||||||
|
}
|
||||||
|
_ = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func readImage(filePath string) (api.ImageData, error) {
|
func readImage(filePath string) (api.ImageData, error) {
|
||||||
file, err := os.Open(filePath)
|
file, err := os.Open(filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -289,10 +473,12 @@ func main() {
|
|||||||
prompt: flag.String("p", DefaultPrompt, "Prompt to use"),
|
prompt: flag.String("p", DefaultPrompt, "Prompt to use"),
|
||||||
imageFile: flag.String("image", "", "Filename for an image to include"),
|
imageFile: flag.String("image", "", "Filename for an image to include"),
|
||||||
keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"),
|
keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"),
|
||||||
format: flag.String("format", "markdown", "Output format [benchstat|csv] (default benchstat)"),
|
format: flag.String("format", "benchstat", "Output format [benchstat|csv]"),
|
||||||
outputFile: flag.String("output", "", "Output file for results (stdout if empty)"),
|
outputFile: flag.String("output", "", "Output file for results (stdout if empty)"),
|
||||||
verbose: flag.Bool("v", false, "Show system information"),
|
verbose: flag.Bool("v", false, "Show system information"),
|
||||||
debug: flag.Bool("debug", false, "Show debug information"),
|
debug: flag.Bool("debug", false, "Show debug information"),
|
||||||
|
warmup: flag.Int("warmup", 1, "Number of warmup requests before timing"),
|
||||||
|
promptTokens: flag.Int("prompt-tokens", 0, "Generate prompt targeting ~N tokens (0 = use -p prompt)"),
|
||||||
}
|
}
|
||||||
|
|
||||||
flag.Usage = func() {
|
flag.Usage = func() {
|
||||||
@@ -302,11 +488,12 @@ func main() {
|
|||||||
fmt.Fprintf(os.Stderr, "Options:\n")
|
fmt.Fprintf(os.Stderr, "Options:\n")
|
||||||
flag.PrintDefaults()
|
flag.PrintDefaults()
|
||||||
fmt.Fprintf(os.Stderr, "\nExamples:\n")
|
fmt.Fprintf(os.Stderr, "\nExamples:\n")
|
||||||
fmt.Fprintf(os.Stderr, " bench -model gpt-oss:20b -epochs 3 -temperature 0.7\n")
|
fmt.Fprintf(os.Stderr, " bench -model gemma3,llama3 -epochs 6\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " bench -model gemma3 -epochs 6 -prompt-tokens 512 -format csv\n")
|
||||||
}
|
}
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
if !slices.Contains([]string{"markdown", "benchstat", "csv"}, *fOpt.format) {
|
if !slices.Contains([]string{"benchstat", "csv"}, *fOpt.format) {
|
||||||
fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format)
|
fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
@@ -317,5 +504,5 @@ func main() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
BenchmarkChat(fOpt)
|
BenchmarkModel(fOpt)
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
366
cmd/cmd.go
366
cmd/cmd.go
@@ -11,10 +11,12 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"os/exec"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
@@ -29,6 +31,7 @@ import (
|
|||||||
"github.com/containerd/console"
|
"github.com/containerd/console"
|
||||||
"github.com/mattn/go-runewidth"
|
"github.com/mattn/go-runewidth"
|
||||||
"github.com/olekukonko/tablewriter"
|
"github.com/olekukonko/tablewriter"
|
||||||
|
"github.com/pkg/browser"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
@@ -36,8 +39,12 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/cmd/config"
|
"github.com/ollama/ollama/cmd/config"
|
||||||
|
"github.com/ollama/ollama/cmd/launch"
|
||||||
|
"github.com/ollama/ollama/cmd/tui"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
|
"github.com/ollama/ollama/internal/modelref"
|
||||||
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/progress"
|
"github.com/ollama/ollama/progress"
|
||||||
"github.com/ollama/ollama/readline"
|
"github.com/ollama/ollama/readline"
|
||||||
@@ -52,7 +59,50 @@ import (
|
|||||||
"github.com/ollama/ollama/x/imagegen"
|
"github.com/ollama/ollama/x/imagegen"
|
||||||
)
|
)
|
||||||
|
|
||||||
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
func init() {
|
||||||
|
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
|
||||||
|
launch.DefaultSingleSelector = func(title string, items []launch.ModelItem, current string) (string, error) {
|
||||||
|
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
|
||||||
|
return "", fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode")
|
||||||
|
}
|
||||||
|
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||||
|
result, err := tui.SelectSingle(title, tuiItems, current)
|
||||||
|
if errors.Is(err, tui.ErrCancelled) {
|
||||||
|
return "", launch.ErrCancelled
|
||||||
|
}
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
|
||||||
|
launch.DefaultMultiSelector = func(title string, items []launch.ModelItem, preChecked []string) ([]string, error) {
|
||||||
|
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
|
||||||
|
return nil, fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode")
|
||||||
|
}
|
||||||
|
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||||
|
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
|
||||||
|
if errors.Is(err, tui.ErrCancelled) {
|
||||||
|
return nil, launch.ErrCancelled
|
||||||
|
}
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
|
||||||
|
launch.DefaultSignIn = func(modelName, signInURL string) (string, error) {
|
||||||
|
userName, err := tui.RunSignIn(modelName, signInURL)
|
||||||
|
if errors.Is(err, tui.ErrCancelled) {
|
||||||
|
return "", launch.ErrCancelled
|
||||||
|
}
|
||||||
|
return userName, err
|
||||||
|
}
|
||||||
|
|
||||||
|
launch.DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||||
|
ok, err := tui.RunConfirm(prompt)
|
||||||
|
if errors.Is(err, tui.ErrCancelled) {
|
||||||
|
return false, launch.ErrCancelled
|
||||||
|
}
|
||||||
|
return ok, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const ConnectInstructions = "If your browser did not open, navigate to:\n %s\n\n"
|
||||||
|
|
||||||
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
|
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
|
||||||
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
|
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
|
||||||
@@ -91,6 +141,17 @@ func getModelfileName(cmd *cobra.Command) (string, error) {
|
|||||||
return absName, nil
|
return absName, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isLocalhost returns true if the configured Ollama host is a loopback or unspecified address.
|
||||||
|
func isLocalhost() bool {
|
||||||
|
host := envconfig.Host()
|
||||||
|
h, _, _ := net.SplitHostPort(host.Host)
|
||||||
|
if h == "localhost" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
ip := net.ParseIP(h)
|
||||||
|
return ip != nil && (ip.IsLoopback() || ip.IsUnspecified())
|
||||||
|
}
|
||||||
|
|
||||||
func CreateHandler(cmd *cobra.Command, args []string) error {
|
func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||||
p := progress.NewProgress(os.Stderr)
|
p := progress.NewProgress(os.Stderr)
|
||||||
defer p.Stop()
|
defer p.Stop()
|
||||||
@@ -105,6 +166,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
// Check for --experimental flag for safetensors model creation
|
// Check for --experimental flag for safetensors model creation
|
||||||
experimental, _ := cmd.Flags().GetBool("experimental")
|
experimental, _ := cmd.Flags().GetBool("experimental")
|
||||||
if experimental {
|
if experimental {
|
||||||
|
if !isLocalhost() {
|
||||||
|
return errors.New("remote safetensor model creation not yet supported")
|
||||||
|
}
|
||||||
// Get Modelfile content - either from -f flag or default to "FROM ."
|
// Get Modelfile content - either from -f flag or default to "FROM ."
|
||||||
var reader io.Reader
|
var reader io.Reader
|
||||||
filename, err := getModelfileName(cmd)
|
filename, err := getModelfileName(cmd)
|
||||||
@@ -128,25 +192,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return fmt.Errorf("failed to parse Modelfile: %w", err)
|
return fmt.Errorf("failed to parse Modelfile: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract FROM path and configuration
|
modelDir, mfConfig, err := xcreateclient.ConfigFromModelfile(modelfile)
|
||||||
var modelDir string
|
if err != nil {
|
||||||
mfConfig := &xcreateclient.ModelfileConfig{}
|
return err
|
||||||
|
|
||||||
for _, cmd := range modelfile.Commands {
|
|
||||||
switch cmd.Name {
|
|
||||||
case "model":
|
|
||||||
modelDir = cmd.Args
|
|
||||||
case "template":
|
|
||||||
mfConfig.Template = cmd.Args
|
|
||||||
case "system":
|
|
||||||
mfConfig.System = cmd.Args
|
|
||||||
case "license":
|
|
||||||
mfConfig.License = cmd.Args
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if modelDir == "" {
|
|
||||||
modelDir = "."
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve relative paths based on Modelfile location
|
// Resolve relative paths based on Modelfile location
|
||||||
@@ -170,6 +218,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
if filename == "" {
|
if filename == "" {
|
||||||
// No Modelfile found - check if current directory is an image gen model
|
// No Modelfile found - check if current directory is an image gen model
|
||||||
if create.IsTensorModelDir(".") {
|
if create.IsTensorModelDir(".") {
|
||||||
|
if !isLocalhost() {
|
||||||
|
return errors.New("remote safetensor model creation not yet supported")
|
||||||
|
}
|
||||||
quantize, _ := cmd.Flags().GetString("quantize")
|
quantize, _ := cmd.Flags().GetString("quantize")
|
||||||
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
||||||
ModelName: modelName,
|
ModelName: modelName,
|
||||||
@@ -362,18 +413,35 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
requestedCloud := modelref.HasExplicitCloudSource(opts.Model)
|
||||||
|
|
||||||
if info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model}); err != nil {
|
if info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model}); err != nil {
|
||||||
return err
|
return err
|
||||||
} else if info.RemoteHost != "" {
|
} else if info.RemoteHost != "" || requestedCloud {
|
||||||
// Cloud model, no need to load/unload
|
// Cloud model, no need to load/unload
|
||||||
|
|
||||||
|
isCloud := requestedCloud || strings.HasPrefix(info.RemoteHost, "https://ollama.com")
|
||||||
|
|
||||||
|
// Check if user is signed in for ollama.com cloud models
|
||||||
|
if isCloud {
|
||||||
|
if _, err := client.Whoami(cmd.Context()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if opts.ShowConnect {
|
if opts.ShowConnect {
|
||||||
p.StopAndClear()
|
p.StopAndClear()
|
||||||
if strings.HasPrefix(info.RemoteHost, "https://ollama.com") {
|
remoteModel := info.RemoteModel
|
||||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
|
if remoteModel == "" {
|
||||||
|
remoteModel = opts.Model
|
||||||
|
}
|
||||||
|
if isCloud {
|
||||||
|
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", remoteModel)
|
||||||
} else {
|
} else {
|
||||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
|
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", remoteModel, info.RemoteHost)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -442,6 +510,64 @@ func generateEmbedding(cmd *cobra.Command, modelName, input string, keepAlive *a
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(parthsareen): consolidate with TUI signin flow
|
||||||
|
func handleCloudAuthorizationError(err error) bool {
|
||||||
|
var authErr api.AuthorizationError
|
||||||
|
if errors.As(err, &authErr) && authErr.StatusCode == http.StatusUnauthorized {
|
||||||
|
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
|
||||||
|
if authErr.SigninURL != "" {
|
||||||
|
fmt.Printf(ConnectInstructions, authErr.SigninURL)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TEMP(drifkin): To match legacy `ollama run some-model:cloud` behavior, we
|
||||||
|
// best-effort pull cloud stub files for any explicit cloud source models.
|
||||||
|
// Remove this once `/api/tags` is cloud-aware.
|
||||||
|
func ensureCloudStub(ctx context.Context, client *api.Client, modelName string) {
|
||||||
|
if !modelref.HasExplicitCloudSource(modelName) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizedName, _, err := modelref.NormalizePullName(modelName)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to normalize pull name", "model", modelName, "error", err, "normalizedName", normalizedName)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
listResp, err := client.List(ctx)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to list models", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasListedModelName(listResp.Models, modelName) || hasListedModelName(listResp.Models, normalizedName) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logutil.Trace("pulling cloud stub", "model", modelName, "normalizedName", normalizedName)
|
||||||
|
err = client.Pull(ctx, &api.PullRequest{
|
||||||
|
Model: normalizedName,
|
||||||
|
}, func(api.ProgressResponse) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to pull cloud stub", "model", modelName, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasListedModelName(models []api.ListModelResponse, name string) bool {
|
||||||
|
for _, m := range models {
|
||||||
|
if strings.EqualFold(m.Name, name) || strings.EqualFold(m.Model, name) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||||
interactive := true
|
interactive := true
|
||||||
|
|
||||||
@@ -538,12 +664,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
name := args[0]
|
name := args[0]
|
||||||
|
requestedCloud := modelref.HasExplicitCloudSource(name)
|
||||||
|
|
||||||
info, err := func() (*api.ShowResponse, error) {
|
info, err := func() (*api.ShowResponse, error) {
|
||||||
showReq := &api.ShowRequest{Name: name}
|
showReq := &api.ShowRequest{Name: name}
|
||||||
info, err := client.Show(cmd.Context(), showReq)
|
info, err := client.Show(cmd.Context(), showReq)
|
||||||
var se api.StatusError
|
var se api.StatusError
|
||||||
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
|
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
|
||||||
|
if requestedCloud {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if err := PullHandler(cmd, []string{name}); err != nil {
|
if err := PullHandler(cmd, []string{name}); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -552,9 +682,14 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return info, err
|
return info, err
|
||||||
}()
|
}()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if handleCloudAuthorizationError(err) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ensureCloudStub(cmd.Context(), client, name)
|
||||||
|
|
||||||
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkFlag.Changed)
|
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkFlag.Changed)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -646,7 +781,13 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
return generateInteractive(cmd, opts)
|
return generateInteractive(cmd, opts)
|
||||||
}
|
}
|
||||||
return generate(cmd, opts)
|
if err := generate(cmd, opts); err != nil {
|
||||||
|
if handleCloudAuthorizationError(err) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SigninHandler(cmd *cobra.Command, args []string) error {
|
func SigninHandler(cmd *cobra.Command, args []string) error {
|
||||||
@@ -663,6 +804,7 @@ func SigninHandler(cmd *cobra.Command, args []string) error {
|
|||||||
fmt.Println()
|
fmt.Println()
|
||||||
|
|
||||||
if aErr.SigninURL != "" {
|
if aErr.SigninURL != "" {
|
||||||
|
_ = browser.OpenURL(aErr.SigninURL)
|
||||||
fmt.Printf(ConnectInstructions, aErr.SigninURL)
|
fmt.Printf(ConnectInstructions, aErr.SigninURL)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -1750,7 +1892,7 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := startApp(cmd.Context(), client); err != nil {
|
if err := startApp(cmd.Context(), client); err != nil {
|
||||||
return fmt.Errorf("ollama server not responding - %w", err)
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -1791,6 +1933,144 @@ Environment Variables:
|
|||||||
cmd.SetUsageTemplate(cmd.UsageTemplate() + envUsage)
|
cmd.SetUsageTemplate(cmd.UsageTemplate() + envUsage)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ensureServerRunning checks if the ollama server is running and starts it in the background if not.
|
||||||
|
func ensureServerRunning(ctx context.Context) error {
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if server is already running
|
||||||
|
if err := client.Heartbeat(ctx); err == nil {
|
||||||
|
return nil // server is already running
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server not running, start it in the background
|
||||||
|
exe, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("could not find executable: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
serverCmd := exec.CommandContext(ctx, exe, "serve")
|
||||||
|
serverCmd.Env = os.Environ()
|
||||||
|
serverCmd.SysProcAttr = backgroundServerSysProcAttr()
|
||||||
|
if err := serverCmd.Start(); err != nil {
|
||||||
|
return fmt.Errorf("failed to start server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the server to be ready
|
||||||
|
for {
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
if err := client.Heartbeat(ctx); err == nil {
|
||||||
|
return nil // server has started
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func launchInteractiveModel(cmd *cobra.Command, modelName string) error {
|
||||||
|
opts := runOptions{
|
||||||
|
Model: modelName,
|
||||||
|
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||||
|
Options: map[string]any{},
|
||||||
|
ShowConnect: true,
|
||||||
|
}
|
||||||
|
// loadOrUnloadModel is cloud-safe here: remote/cloud models skip local preload
|
||||||
|
// and only validate auth/connectivity before interactive chat starts.
|
||||||
|
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||||
|
return fmt.Errorf("error loading model: %w", err)
|
||||||
|
}
|
||||||
|
if err := generateInteractive(cmd, opts); err != nil {
|
||||||
|
return fmt.Errorf("error running model: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// runInteractiveTUI runs the main interactive TUI menu.
|
||||||
|
func runInteractiveTUI(cmd *cobra.Command) {
|
||||||
|
// Ensure the server is running before showing the TUI
|
||||||
|
if err := ensureServerRunning(cmd.Context()); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error starting server: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
deps := launcherDeps{
|
||||||
|
buildState: launch.BuildLauncherState,
|
||||||
|
runMenu: tui.RunMenu,
|
||||||
|
resolveRunModel: launch.ResolveRunModel,
|
||||||
|
launchIntegration: launch.LaunchIntegration,
|
||||||
|
runModel: launchInteractiveModel,
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
continueLoop, err := runInteractiveTUIStep(cmd, deps)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||||
|
}
|
||||||
|
if !continueLoop {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type launcherDeps struct {
|
||||||
|
buildState func(context.Context) (*launch.LauncherState, error)
|
||||||
|
runMenu func(*launch.LauncherState) (tui.TUIAction, error)
|
||||||
|
resolveRunModel func(context.Context, launch.RunModelRequest) (string, error)
|
||||||
|
launchIntegration func(context.Context, launch.IntegrationLaunchRequest) error
|
||||||
|
runModel func(*cobra.Command, string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func runInteractiveTUIStep(cmd *cobra.Command, deps launcherDeps) (bool, error) {
|
||||||
|
state, err := deps.buildState(cmd.Context())
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("build launcher state: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
action, err := deps.runMenu(state)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("run launcher menu: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return runLauncherAction(cmd, action, deps)
|
||||||
|
}
|
||||||
|
|
||||||
|
func saveLauncherSelection(action tui.TUIAction) {
|
||||||
|
// Best effort only: this affects menu recall, not launch correctness.
|
||||||
|
_ = config.SetLastSelection(action.LastSelection())
|
||||||
|
}
|
||||||
|
|
||||||
|
func runLauncherAction(cmd *cobra.Command, action tui.TUIAction, deps launcherDeps) (bool, error) {
|
||||||
|
switch action.Kind {
|
||||||
|
case tui.TUIActionNone:
|
||||||
|
return false, nil
|
||||||
|
case tui.TUIActionRunModel:
|
||||||
|
saveLauncherSelection(action)
|
||||||
|
modelName, err := deps.resolveRunModel(cmd.Context(), action.RunModelRequest())
|
||||||
|
if errors.Is(err, launch.ErrCancelled) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return true, fmt.Errorf("selecting model: %w", err)
|
||||||
|
}
|
||||||
|
if err := deps.runModel(cmd, modelName); err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
case tui.TUIActionLaunchIntegration:
|
||||||
|
saveLauncherSelection(action)
|
||||||
|
err := deps.launchIntegration(cmd.Context(), action.IntegrationLaunchRequest())
|
||||||
|
if errors.Is(err, launch.ErrCancelled) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return true, fmt.Errorf("launching %s: %w", action.Integration, err)
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
default:
|
||||||
|
return false, fmt.Errorf("unknown launcher action: %d", action.Kind)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func NewCLI() *cobra.Command {
|
func NewCLI() *cobra.Command {
|
||||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||||
cobra.EnableCommandSorting = false
|
cobra.EnableCommandSorting = false
|
||||||
@@ -1813,11 +2093,13 @@ func NewCLI() *cobra.Command {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.Print(cmd.UsageString())
|
runInteractiveTUI(cmd)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
|
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
|
||||||
|
rootCmd.Flags().Bool("verbose", false, "Show timings for response")
|
||||||
|
rootCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
|
||||||
|
|
||||||
createCmd := &cobra.Command{
|
createCmd := &cobra.Command{
|
||||||
Use: "create MODEL",
|
Use: "create MODEL",
|
||||||
@@ -1877,6 +2159,9 @@ func NewCLI() *cobra.Command {
|
|||||||
// Image generation flags (width, height, steps, seed, etc.)
|
// Image generation flags (width, height, steps, seed, etc.)
|
||||||
imagegen.RegisterFlags(runCmd)
|
imagegen.RegisterFlags(runCmd)
|
||||||
|
|
||||||
|
runCmd.Flags().Bool("imagegen", false, "Use the imagegen runner for LLM inference")
|
||||||
|
runCmd.Flags().MarkHidden("imagegen")
|
||||||
|
|
||||||
stopCmd := &cobra.Command{
|
stopCmd := &cobra.Command{
|
||||||
Use: "stop MODEL",
|
Use: "stop MODEL",
|
||||||
Short: "Stop a running model",
|
Short: "Stop a running model",
|
||||||
@@ -1888,7 +2173,7 @@ func NewCLI() *cobra.Command {
|
|||||||
serveCmd := &cobra.Command{
|
serveCmd := &cobra.Command{
|
||||||
Use: "serve",
|
Use: "serve",
|
||||||
Aliases: []string{"start"},
|
Aliases: []string{"start"},
|
||||||
Short: "Start ollama",
|
Short: "Start Ollama",
|
||||||
Args: cobra.ExactArgs(0),
|
Args: cobra.ExactArgs(0),
|
||||||
RunE: RunServer,
|
RunE: RunServer,
|
||||||
}
|
}
|
||||||
@@ -1921,6 +2206,15 @@ func NewCLI() *cobra.Command {
|
|||||||
RunE: SigninHandler,
|
RunE: SigninHandler,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
loginCmd := &cobra.Command{
|
||||||
|
Use: "login",
|
||||||
|
Short: "Sign in to ollama.com",
|
||||||
|
Hidden: true,
|
||||||
|
Args: cobra.ExactArgs(0),
|
||||||
|
PreRunE: checkServerHeartbeat,
|
||||||
|
RunE: SigninHandler,
|
||||||
|
}
|
||||||
|
|
||||||
signoutCmd := &cobra.Command{
|
signoutCmd := &cobra.Command{
|
||||||
Use: "signout",
|
Use: "signout",
|
||||||
Short: "Sign out from ollama.com",
|
Short: "Sign out from ollama.com",
|
||||||
@@ -1929,6 +2223,15 @@ func NewCLI() *cobra.Command {
|
|||||||
RunE: SignoutHandler,
|
RunE: SignoutHandler,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logoutCmd := &cobra.Command{
|
||||||
|
Use: "logout",
|
||||||
|
Short: "Sign out from ollama.com",
|
||||||
|
Hidden: true,
|
||||||
|
Args: cobra.ExactArgs(0),
|
||||||
|
PreRunE: checkServerHeartbeat,
|
||||||
|
RunE: SignoutHandler,
|
||||||
|
}
|
||||||
|
|
||||||
listCmd := &cobra.Command{
|
listCmd := &cobra.Command{
|
||||||
Use: "list",
|
Use: "list",
|
||||||
Aliases: []string{"ls"},
|
Aliases: []string{"ls"},
|
||||||
@@ -1991,7 +2294,7 @@ func NewCLI() *cobra.Command {
|
|||||||
switch cmd {
|
switch cmd {
|
||||||
case runCmd:
|
case runCmd:
|
||||||
imagegen.AppendFlagsDocs(cmd)
|
imagegen.AppendFlagsDocs(cmd)
|
||||||
appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_HOST"], envVars["OLLAMA_NOHISTORY"]})
|
appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_EDITOR"], envVars["OLLAMA_HOST"], envVars["OLLAMA_NOHISTORY"]})
|
||||||
case serveCmd:
|
case serveCmd:
|
||||||
appendEnvDocs(cmd, []envconfig.EnvVar{
|
appendEnvDocs(cmd, []envconfig.EnvVar{
|
||||||
envVars["OLLAMA_DEBUG"],
|
envVars["OLLAMA_DEBUG"],
|
||||||
@@ -2002,6 +2305,7 @@ func NewCLI() *cobra.Command {
|
|||||||
envVars["OLLAMA_MAX_QUEUE"],
|
envVars["OLLAMA_MAX_QUEUE"],
|
||||||
envVars["OLLAMA_MODELS"],
|
envVars["OLLAMA_MODELS"],
|
||||||
envVars["OLLAMA_NUM_PARALLEL"],
|
envVars["OLLAMA_NUM_PARALLEL"],
|
||||||
|
envVars["OLLAMA_NO_CLOUD"],
|
||||||
envVars["OLLAMA_NOPRUNE"],
|
envVars["OLLAMA_NOPRUNE"],
|
||||||
envVars["OLLAMA_ORIGINS"],
|
envVars["OLLAMA_ORIGINS"],
|
||||||
envVars["OLLAMA_SCHED_SPREAD"],
|
envVars["OLLAMA_SCHED_SPREAD"],
|
||||||
@@ -2025,13 +2329,15 @@ func NewCLI() *cobra.Command {
|
|||||||
pullCmd,
|
pullCmd,
|
||||||
pushCmd,
|
pushCmd,
|
||||||
signinCmd,
|
signinCmd,
|
||||||
|
loginCmd,
|
||||||
signoutCmd,
|
signoutCmd,
|
||||||
|
logoutCmd,
|
||||||
listCmd,
|
listCmd,
|
||||||
psCmd,
|
psCmd,
|
||||||
copyCmd,
|
copyCmd,
|
||||||
deleteCmd,
|
deleteCmd,
|
||||||
runnerCmd,
|
runnerCmd,
|
||||||
config.LaunchCmd(checkServerHeartbeat),
|
launch.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
|
||||||
)
|
)
|
||||||
|
|
||||||
return rootCmd
|
return rootCmd
|
||||||
|
|||||||
233
cmd/cmd_launcher_test.go
Normal file
233
cmd/cmd_launcher_test.go
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/cmd/config"
|
||||||
|
"github.com/ollama/ollama/cmd/launch"
|
||||||
|
"github.com/ollama/ollama/cmd/tui"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setCmdTestHome(t *testing.T, dir string) {
|
||||||
|
t.Helper()
|
||||||
|
t.Setenv("HOME", dir)
|
||||||
|
t.Setenv("USERPROFILE", dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
func unexpectedRunModelResolution(t *testing.T) func(context.Context, launch.RunModelRequest) (string, error) {
|
||||||
|
t.Helper()
|
||||||
|
return func(ctx context.Context, req launch.RunModelRequest) (string, error) {
|
||||||
|
t.Fatalf("did not expect run-model resolution: %+v", req)
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func unexpectedIntegrationLaunch(t *testing.T) func(context.Context, launch.IntegrationLaunchRequest) error {
|
||||||
|
t.Helper()
|
||||||
|
return func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||||
|
t.Fatalf("did not expect integration launch: %+v", req)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func unexpectedModelLaunch(t *testing.T) func(*cobra.Command, string) error {
|
||||||
|
t.Helper()
|
||||||
|
return func(cmd *cobra.Command, model string) error {
|
||||||
|
t.Fatalf("did not expect chat launch: %s", model)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunInteractiveTUI_RunModelActionsUseResolveRunModel(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
action tui.TUIAction
|
||||||
|
wantForce bool
|
||||||
|
wantModel string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "enter uses saved model flow",
|
||||||
|
action: tui.TUIAction{Kind: tui.TUIActionRunModel},
|
||||||
|
wantModel: "qwen3:8b",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "right forces picker",
|
||||||
|
action: tui.TUIAction{Kind: tui.TUIActionRunModel, ForceConfigure: true},
|
||||||
|
wantForce: true,
|
||||||
|
wantModel: "glm-5:cloud",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
setCmdTestHome(t, t.TempDir())
|
||||||
|
|
||||||
|
var menuCalls int
|
||||||
|
runMenu := func(state *launch.LauncherState) (tui.TUIAction, error) {
|
||||||
|
menuCalls++
|
||||||
|
if menuCalls == 1 {
|
||||||
|
return tt.action, nil
|
||||||
|
}
|
||||||
|
return tui.TUIAction{Kind: tui.TUIActionNone}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var gotReq launch.RunModelRequest
|
||||||
|
var launched string
|
||||||
|
deps := launcherDeps{
|
||||||
|
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
|
||||||
|
return &launch.LauncherState{}, nil
|
||||||
|
},
|
||||||
|
runMenu: runMenu,
|
||||||
|
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
|
||||||
|
gotReq = req
|
||||||
|
return tt.wantModel, nil
|
||||||
|
},
|
||||||
|
launchIntegration: unexpectedIntegrationLaunch(t),
|
||||||
|
runModel: func(cmd *cobra.Command, model string) error {
|
||||||
|
launched = model
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(context.Background())
|
||||||
|
for {
|
||||||
|
continueLoop, err := runInteractiveTUIStep(cmd, deps)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected step error: %v", err)
|
||||||
|
}
|
||||||
|
if !continueLoop {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotReq.ForcePicker != tt.wantForce {
|
||||||
|
t.Fatalf("expected ForcePicker=%v, got %v", tt.wantForce, gotReq.ForcePicker)
|
||||||
|
}
|
||||||
|
if launched != tt.wantModel {
|
||||||
|
t.Fatalf("expected interactive launcher to run %q, got %q", tt.wantModel, launched)
|
||||||
|
}
|
||||||
|
if got := config.LastSelection(); got != "run" {
|
||||||
|
t.Fatalf("expected last selection to be run, got %q", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunInteractiveTUI_IntegrationActionsUseLaunchIntegration(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
action tui.TUIAction
|
||||||
|
wantForce bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "enter launches integration",
|
||||||
|
action: tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "right forces configure",
|
||||||
|
action: tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude", ForceConfigure: true},
|
||||||
|
wantForce: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
setCmdTestHome(t, t.TempDir())
|
||||||
|
|
||||||
|
var menuCalls int
|
||||||
|
runMenu := func(state *launch.LauncherState) (tui.TUIAction, error) {
|
||||||
|
menuCalls++
|
||||||
|
if menuCalls == 1 {
|
||||||
|
return tt.action, nil
|
||||||
|
}
|
||||||
|
return tui.TUIAction{Kind: tui.TUIActionNone}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var gotReq launch.IntegrationLaunchRequest
|
||||||
|
deps := launcherDeps{
|
||||||
|
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
|
||||||
|
return &launch.LauncherState{}, nil
|
||||||
|
},
|
||||||
|
runMenu: runMenu,
|
||||||
|
resolveRunModel: unexpectedRunModelResolution(t),
|
||||||
|
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||||
|
gotReq = req
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
runModel: unexpectedModelLaunch(t),
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(context.Background())
|
||||||
|
for {
|
||||||
|
continueLoop, err := runInteractiveTUIStep(cmd, deps)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected step error: %v", err)
|
||||||
|
}
|
||||||
|
if !continueLoop {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotReq.Name != "claude" {
|
||||||
|
t.Fatalf("expected integration name to be passed through, got %q", gotReq.Name)
|
||||||
|
}
|
||||||
|
if gotReq.ForceConfigure != tt.wantForce {
|
||||||
|
t.Fatalf("expected ForceConfigure=%v, got %v", tt.wantForce, gotReq.ForceConfigure)
|
||||||
|
}
|
||||||
|
if got := config.LastSelection(); got != "claude" {
|
||||||
|
t.Fatalf("expected last selection to be claude, got %q", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunLauncherAction_RunModelContinuesAfterCancellation(t *testing.T) {
|
||||||
|
setCmdTestHome(t, t.TempDir())
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(context.Background())
|
||||||
|
|
||||||
|
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionRunModel}, launcherDeps{
|
||||||
|
buildState: nil,
|
||||||
|
runMenu: nil,
|
||||||
|
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
|
||||||
|
return "", launch.ErrCancelled
|
||||||
|
},
|
||||||
|
launchIntegration: unexpectedIntegrationLaunch(t),
|
||||||
|
runModel: unexpectedModelLaunch(t),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected nil error on cancellation, got %v", err)
|
||||||
|
}
|
||||||
|
if !continueLoop {
|
||||||
|
t.Fatal("expected cancellation to continue the menu loop")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunLauncherAction_IntegrationContinuesAfterCancellation(t *testing.T) {
|
||||||
|
setCmdTestHome(t, t.TempDir())
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(context.Background())
|
||||||
|
|
||||||
|
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"}, launcherDeps{
|
||||||
|
buildState: nil,
|
||||||
|
runMenu: nil,
|
||||||
|
resolveRunModel: unexpectedRunModelResolution(t),
|
||||||
|
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||||
|
return launch.ErrCancelled
|
||||||
|
},
|
||||||
|
runModel: unexpectedModelLaunch(t),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected nil error on cancellation, got %v", err)
|
||||||
|
}
|
||||||
|
if !continueLoop {
|
||||||
|
t.Fatal("expected cancellation to continue the menu loop")
|
||||||
|
}
|
||||||
|
}
|
||||||
551
cmd/cmd_test.go
551
cmd/cmd_test.go
@@ -3,6 +3,7 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -704,6 +705,347 @@ func TestRunEmbeddingModelNoInput(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRunHandler_CloudAuthErrorOnShow_PrintsSigninMessage(t *testing.T) {
|
||||||
|
var generateCalled bool
|
||||||
|
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
if err := json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"error": "unauthorized",
|
||||||
|
"signin_url": "https://ollama.com/signin",
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||||
|
generateCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
t.Cleanup(mockServer.Close)
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(t.Context())
|
||||||
|
cmd.Flags().String("keepalive", "", "")
|
||||||
|
cmd.Flags().Bool("truncate", false, "")
|
||||||
|
cmd.Flags().Int("dimensions", 0, "")
|
||||||
|
cmd.Flags().Bool("verbose", false, "")
|
||||||
|
cmd.Flags().Bool("insecure", false, "")
|
||||||
|
cmd.Flags().Bool("nowordwrap", false, "")
|
||||||
|
cmd.Flags().String("format", "", "")
|
||||||
|
cmd.Flags().String("think", "", "")
|
||||||
|
cmd.Flags().Bool("hidethinking", false, "")
|
||||||
|
|
||||||
|
oldStdout := os.Stdout
|
||||||
|
readOut, writeOut, _ := os.Pipe()
|
||||||
|
os.Stdout = writeOut
|
||||||
|
t.Cleanup(func() { os.Stdout = oldStdout })
|
||||||
|
|
||||||
|
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||||
|
|
||||||
|
_ = writeOut.Close()
|
||||||
|
var out bytes.Buffer
|
||||||
|
_, _ = io.Copy(&out, readOut)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RunHandler returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if generateCalled {
|
||||||
|
t.Fatal("expected run to stop before /api/generate after unauthorized /api/show")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
|
||||||
|
t.Fatalf("expected sign-in guidance message, got %q", out.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(out.String(), "https://ollama.com/signin") {
|
||||||
|
t.Fatalf("expected signin_url in output, got %q", out.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunHandler_CloudAuthErrorOnGenerate_PrintsSigninMessage(t *testing.T) {
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||||
|
Capabilities: []model.Capability{model.CapabilityCompletion},
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
if err := json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"error": "unauthorized",
|
||||||
|
"signin_url": "https://ollama.com/signin",
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
t.Cleanup(mockServer.Close)
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(t.Context())
|
||||||
|
cmd.Flags().String("keepalive", "", "")
|
||||||
|
cmd.Flags().Bool("truncate", false, "")
|
||||||
|
cmd.Flags().Int("dimensions", 0, "")
|
||||||
|
cmd.Flags().Bool("verbose", false, "")
|
||||||
|
cmd.Flags().Bool("insecure", false, "")
|
||||||
|
cmd.Flags().Bool("nowordwrap", false, "")
|
||||||
|
cmd.Flags().String("format", "", "")
|
||||||
|
cmd.Flags().String("think", "", "")
|
||||||
|
cmd.Flags().Bool("hidethinking", false, "")
|
||||||
|
|
||||||
|
oldStdout := os.Stdout
|
||||||
|
readOut, writeOut, _ := os.Pipe()
|
||||||
|
os.Stdout = writeOut
|
||||||
|
t.Cleanup(func() { os.Stdout = oldStdout })
|
||||||
|
|
||||||
|
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||||
|
|
||||||
|
_ = writeOut.Close()
|
||||||
|
var out bytes.Buffer
|
||||||
|
_, _ = io.Copy(&out, readOut)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RunHandler returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
|
||||||
|
t.Fatalf("expected sign-in guidance message, got %q", out.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(out.String(), "https://ollama.com/signin") {
|
||||||
|
t.Fatalf("expected signin_url in output, got %q", out.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunHandler_ExplicitCloudStubMissing_PullsNormalizedNameTEMP(t *testing.T) {
|
||||||
|
var pulledModel string
|
||||||
|
var generateCalled bool
|
||||||
|
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||||
|
Capabilities: []model.Capability{model.CapabilityCompletion},
|
||||||
|
RemoteModel: "gpt-oss:20b",
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/tags" && r.Method == http.MethodGet:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ListResponse{Models: nil}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/pull" && r.Method == http.MethodPost:
|
||||||
|
var req api.PullRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pulledModel = req.Model
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ProgressResponse{Status: "success"}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||||
|
generateCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
t.Cleanup(mockServer.Close)
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(t.Context())
|
||||||
|
cmd.Flags().String("keepalive", "", "")
|
||||||
|
cmd.Flags().Bool("truncate", false, "")
|
||||||
|
cmd.Flags().Int("dimensions", 0, "")
|
||||||
|
cmd.Flags().Bool("verbose", false, "")
|
||||||
|
cmd.Flags().Bool("insecure", false, "")
|
||||||
|
cmd.Flags().Bool("nowordwrap", false, "")
|
||||||
|
cmd.Flags().String("format", "", "")
|
||||||
|
cmd.Flags().String("think", "", "")
|
||||||
|
cmd.Flags().Bool("hidethinking", false, "")
|
||||||
|
|
||||||
|
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RunHandler returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pulledModel != "gpt-oss:20b-cloud" {
|
||||||
|
t.Fatalf("expected normalized pull model %q, got %q", "gpt-oss:20b-cloud", pulledModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !generateCalled {
|
||||||
|
t.Fatal("expected /api/generate to be called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunHandler_ExplicitCloudStubPresent_SkipsPullTEMP(t *testing.T) {
|
||||||
|
var pullCalled bool
|
||||||
|
var generateCalled bool
|
||||||
|
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||||
|
Capabilities: []model.Capability{model.CapabilityCompletion},
|
||||||
|
RemoteModel: "gpt-oss:20b",
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/tags" && r.Method == http.MethodGet:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ListResponse{
|
||||||
|
Models: []api.ListModelResponse{{Name: "gpt-oss:20b-cloud"}},
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/pull" && r.Method == http.MethodPost:
|
||||||
|
pullCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ProgressResponse{Status: "success"}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||||
|
generateCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
t.Cleanup(mockServer.Close)
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(t.Context())
|
||||||
|
cmd.Flags().String("keepalive", "", "")
|
||||||
|
cmd.Flags().Bool("truncate", false, "")
|
||||||
|
cmd.Flags().Int("dimensions", 0, "")
|
||||||
|
cmd.Flags().Bool("verbose", false, "")
|
||||||
|
cmd.Flags().Bool("insecure", false, "")
|
||||||
|
cmd.Flags().Bool("nowordwrap", false, "")
|
||||||
|
cmd.Flags().String("format", "", "")
|
||||||
|
cmd.Flags().String("think", "", "")
|
||||||
|
cmd.Flags().Bool("hidethinking", false, "")
|
||||||
|
|
||||||
|
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RunHandler returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pullCalled {
|
||||||
|
t.Fatal("expected /api/pull not to be called when cloud stub already exists")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !generateCalled {
|
||||||
|
t.Fatal("expected /api/generate to be called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunHandler_ExplicitCloudStubPullFailure_IsBestEffortTEMP(t *testing.T) {
|
||||||
|
var generateCalled bool
|
||||||
|
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||||
|
Capabilities: []model.Capability{model.CapabilityCompletion},
|
||||||
|
RemoteModel: "gpt-oss:20b",
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/tags" && r.Method == http.MethodGet:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ListResponse{Models: nil}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/pull" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
if err := json.NewEncoder(w).Encode(map[string]string{"error": "pull failed"}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||||
|
generateCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
t.Cleanup(mockServer.Close)
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(t.Context())
|
||||||
|
cmd.Flags().String("keepalive", "", "")
|
||||||
|
cmd.Flags().Bool("truncate", false, "")
|
||||||
|
cmd.Flags().Int("dimensions", 0, "")
|
||||||
|
cmd.Flags().Bool("verbose", false, "")
|
||||||
|
cmd.Flags().Bool("insecure", false, "")
|
||||||
|
cmd.Flags().Bool("nowordwrap", false, "")
|
||||||
|
cmd.Flags().String("format", "", "")
|
||||||
|
cmd.Flags().String("think", "", "")
|
||||||
|
cmd.Flags().Bool("hidethinking", false, "")
|
||||||
|
|
||||||
|
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RunHandler returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !generateCalled {
|
||||||
|
t.Fatal("expected /api/generate to be called despite pull failure")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetModelfileName(t *testing.T) {
|
func TestGetModelfileName(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -1211,6 +1553,20 @@ func TestNewCreateRequest(t *testing.T) {
|
|||||||
Model: "newmodel",
|
Model: "newmodel",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"explicit cloud model preserves source when parent lacks it",
|
||||||
|
"newmodel",
|
||||||
|
runOptions{
|
||||||
|
Model: "qwen3.5:cloud",
|
||||||
|
ParentModel: "qwen3.5",
|
||||||
|
Messages: []api.Message{},
|
||||||
|
WordWrap: true,
|
||||||
|
},
|
||||||
|
&api.CreateRequest{
|
||||||
|
From: "qwen3.5:cloud",
|
||||||
|
Model: "newmodel",
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"parent model as filepath test",
|
"parent model as filepath test",
|
||||||
"newmodel",
|
"newmodel",
|
||||||
@@ -1553,7 +1909,7 @@ func TestShowInfoImageGen(t *testing.T) {
|
|||||||
Details: api.ModelDetails{
|
Details: api.ModelDetails{
|
||||||
Family: "ZImagePipeline",
|
Family: "ZImagePipeline",
|
||||||
ParameterSize: "10.3B",
|
ParameterSize: "10.3B",
|
||||||
QuantizationLevel: "FP8",
|
QuantizationLevel: "Q8",
|
||||||
},
|
},
|
||||||
Capabilities: []model.Capability{model.CapabilityImage},
|
Capabilities: []model.Capability{model.CapabilityImage},
|
||||||
Requires: "0.14.0",
|
Requires: "0.14.0",
|
||||||
@@ -1565,7 +1921,7 @@ func TestShowInfoImageGen(t *testing.T) {
|
|||||||
expect := " Model\n" +
|
expect := " Model\n" +
|
||||||
" architecture ZImagePipeline \n" +
|
" architecture ZImagePipeline \n" +
|
||||||
" parameters 10.3B \n" +
|
" parameters 10.3B \n" +
|
||||||
" quantization FP8 \n" +
|
" quantization Q8 \n" +
|
||||||
" requires 0.14.0 \n" +
|
" requires 0.14.0 \n" +
|
||||||
"\n" +
|
"\n" +
|
||||||
" Capabilities\n" +
|
" Capabilities\n" +
|
||||||
@@ -1659,3 +2015,194 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
|
|||||||
t.Error("Copy Think should not be affected by original modification")
|
t.Error("Copy Think should not be affected by original modification")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model string
|
||||||
|
showStatus int
|
||||||
|
remoteHost string
|
||||||
|
remoteModel string
|
||||||
|
whoamiStatus int
|
||||||
|
whoamiResp any
|
||||||
|
expectWhoami bool
|
||||||
|
expectedError string
|
||||||
|
expectAuthError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ollama.com cloud model - user signed in",
|
||||||
|
model: "test-cloud-model",
|
||||||
|
remoteHost: "https://ollama.com",
|
||||||
|
remoteModel: "test-model",
|
||||||
|
whoamiStatus: http.StatusOK,
|
||||||
|
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||||
|
expectWhoami: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ollama.com cloud model - user not signed in",
|
||||||
|
model: "test-cloud-model",
|
||||||
|
remoteHost: "https://ollama.com",
|
||||||
|
remoteModel: "test-model",
|
||||||
|
whoamiStatus: http.StatusUnauthorized,
|
||||||
|
whoamiResp: map[string]string{
|
||||||
|
"error": "unauthorized",
|
||||||
|
"signin_url": "https://ollama.com/signin",
|
||||||
|
},
|
||||||
|
expectWhoami: true,
|
||||||
|
expectedError: "unauthorized",
|
||||||
|
expectAuthError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-ollama.com remote - no auth check",
|
||||||
|
model: "test-cloud-model",
|
||||||
|
remoteHost: "https://other-remote.com",
|
||||||
|
remoteModel: "test-model",
|
||||||
|
whoamiStatus: http.StatusUnauthorized, // should not be called
|
||||||
|
whoamiResp: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit :cloud model - auth check without remote metadata",
|
||||||
|
model: "kimi-k2.5:cloud",
|
||||||
|
remoteHost: "",
|
||||||
|
remoteModel: "",
|
||||||
|
whoamiStatus: http.StatusOK,
|
||||||
|
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||||
|
expectWhoami: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit :cloud model without local stub returns not found by default",
|
||||||
|
model: "minimax-m2.7:cloud",
|
||||||
|
showStatus: http.StatusNotFound,
|
||||||
|
whoamiStatus: http.StatusOK,
|
||||||
|
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||||
|
expectedError: "not found",
|
||||||
|
expectWhoami: false,
|
||||||
|
expectAuthError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit -cloud model - auth check without remote metadata",
|
||||||
|
model: "kimi-k2.5:latest-cloud",
|
||||||
|
remoteHost: "",
|
||||||
|
remoteModel: "",
|
||||||
|
whoamiStatus: http.StatusOK,
|
||||||
|
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||||
|
expectWhoami: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dash cloud-like name without explicit source does not require auth",
|
||||||
|
model: "test-cloud-model",
|
||||||
|
remoteHost: "",
|
||||||
|
remoteModel: "",
|
||||||
|
whoamiStatus: http.StatusUnauthorized, // should not be called
|
||||||
|
whoamiResp: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
whoamiCalled := false
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
if tt.showStatus != 0 && tt.showStatus != http.StatusOK {
|
||||||
|
w.WriteHeader(tt.showStatus)
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]string{"error": "not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||||
|
RemoteHost: tt.remoteHost,
|
||||||
|
RemoteModel: tt.remoteModel,
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
case "/api/me":
|
||||||
|
whoamiCalled = true
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(tt.whoamiStatus)
|
||||||
|
if tt.whoamiResp != nil {
|
||||||
|
if err := json.NewEncoder(w).Encode(tt.whoamiResp); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "/api/generate":
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer mockServer.Close()
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(t.Context())
|
||||||
|
|
||||||
|
opts := &runOptions{
|
||||||
|
Model: tt.model,
|
||||||
|
ShowConnect: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := loadOrUnloadModel(cmd, opts)
|
||||||
|
|
||||||
|
if whoamiCalled != tt.expectWhoami {
|
||||||
|
t.Errorf("whoami called = %v, want %v", whoamiCalled, tt.expectWhoami)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.expectedError != "" {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error containing %q, got nil", tt.expectedError)
|
||||||
|
} else {
|
||||||
|
if !tt.expectAuthError && !strings.Contains(strings.ToLower(err.Error()), strings.ToLower(tt.expectedError)) {
|
||||||
|
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
|
||||||
|
}
|
||||||
|
if tt.expectAuthError {
|
||||||
|
var authErr api.AuthorizationError
|
||||||
|
if !errors.As(err, &authErr) {
|
||||||
|
t.Errorf("expected AuthorizationError, got %T: %v", err, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsLocalhost(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
host string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"default empty", "", true},
|
||||||
|
{"localhost no port", "localhost", true},
|
||||||
|
{"localhost with port", "localhost:11435", true},
|
||||||
|
{"127.0.0.1 no port", "127.0.0.1", true},
|
||||||
|
{"127.0.0.1 with port", "127.0.0.1:11434", true},
|
||||||
|
{"0.0.0.0 no port", "0.0.0.0", true},
|
||||||
|
{"0.0.0.0 with port", "0.0.0.0:11434", true},
|
||||||
|
{"::1 no port", "::1", true},
|
||||||
|
{"[::1] with port", "[::1]:11434", true},
|
||||||
|
{"loopback with scheme", "http://localhost:11434", true},
|
||||||
|
{"remote hostname", "example.com", false},
|
||||||
|
{"remote hostname with port", "example.com:11434", false},
|
||||||
|
{"remote IP", "192.168.1.1", false},
|
||||||
|
{"remote IP with port", "192.168.1.1:11434", false},
|
||||||
|
{"remote with scheme", "http://example.com:11434", false},
|
||||||
|
{"https remote", "https://example.com:443", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_HOST", tt.host)
|
||||||
|
got := isLocalhost()
|
||||||
|
if got != tt.expected {
|
||||||
|
t.Errorf("isLocalhost() with OLLAMA_HOST=%q = %v, want %v", tt.host, got, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,60 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Claude implements Runner for Claude Code integration
|
|
||||||
type Claude struct{}
|
|
||||||
|
|
||||||
func (c *Claude) String() string { return "Claude Code" }
|
|
||||||
|
|
||||||
func (c *Claude) args(model string) []string {
|
|
||||||
if model != "" {
|
|
||||||
return []string{"--model", model}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Claude) findPath() (string, error) {
|
|
||||||
if p, err := exec.LookPath("claude"); err == nil {
|
|
||||||
return p, nil
|
|
||||||
}
|
|
||||||
home, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
name := "claude"
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
name = "claude.exe"
|
|
||||||
}
|
|
||||||
fallback := filepath.Join(home, ".claude", "local", name)
|
|
||||||
if _, err := os.Stat(fallback); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return fallback, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Claude) Run(model string) error {
|
|
||||||
claudePath, err := c.findPath()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command(claudePath, c.args(model)...)
|
|
||||||
cmd.Stdin = os.Stdin
|
|
||||||
cmd.Stdout = os.Stdout
|
|
||||||
cmd.Stderr = os.Stderr
|
|
||||||
cmd.Env = append(os.Environ(),
|
|
||||||
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
|
|
||||||
"ANTHROPIC_API_KEY=",
|
|
||||||
"ANTHROPIC_AUTH_TOKEN=ollama",
|
|
||||||
)
|
|
||||||
return cmd.Run()
|
|
||||||
}
|
|
||||||
@@ -1,101 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
"slices"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestClaudeIntegration(t *testing.T) {
|
|
||||||
c := &Claude{}
|
|
||||||
|
|
||||||
t.Run("String", func(t *testing.T) {
|
|
||||||
if got := c.String(); got != "Claude Code" {
|
|
||||||
t.Errorf("String() = %q, want %q", got, "Claude Code")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("implements Runner", func(t *testing.T) {
|
|
||||||
var _ Runner = c
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClaudeFindPath(t *testing.T) {
|
|
||||||
c := &Claude{}
|
|
||||||
|
|
||||||
t.Run("finds claude in PATH", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
name := "claude"
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
name = "claude.exe"
|
|
||||||
}
|
|
||||||
fakeBin := filepath.Join(tmpDir, name)
|
|
||||||
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
|
|
||||||
t.Setenv("PATH", tmpDir)
|
|
||||||
|
|
||||||
got, err := c.findPath()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if got != fakeBin {
|
|
||||||
t.Errorf("findPath() = %q, want %q", got, fakeBin)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("falls back to ~/.claude/local/claude", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
t.Setenv("PATH", t.TempDir()) // empty dir, no claude binary
|
|
||||||
|
|
||||||
name := "claude"
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
name = "claude.exe"
|
|
||||||
}
|
|
||||||
fallback := filepath.Join(tmpDir, ".claude", "local", name)
|
|
||||||
os.MkdirAll(filepath.Dir(fallback), 0o755)
|
|
||||||
os.WriteFile(fallback, []byte("#!/bin/sh\n"), 0o755)
|
|
||||||
|
|
||||||
got, err := c.findPath()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if got != fallback {
|
|
||||||
t.Errorf("findPath() = %q, want %q", got, fallback)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns error when neither PATH nor fallback exists", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
t.Setenv("PATH", t.TempDir()) // empty dir, no claude binary
|
|
||||||
|
|
||||||
_, err := c.findPath()
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error, got nil")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClaudeArgs(t *testing.T) {
|
|
||||||
c := &Claude{}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
model string
|
|
||||||
want []string
|
|
||||||
}{
|
|
||||||
{"with model", "llama3.2", []string{"--model", "llama3.2"}},
|
|
||||||
{"empty model", "", nil},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got := c.args(tt.model)
|
|
||||||
if !slices.Equal(got, tt.want) {
|
|
||||||
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,195 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Clawdbot struct{}
|
|
||||||
|
|
||||||
func (c *Clawdbot) String() string { return "Clawdbot" }
|
|
||||||
|
|
||||||
const ansiGreen = "\033[32m"
|
|
||||||
|
|
||||||
func (c *Clawdbot) Run(model string) error {
|
|
||||||
if _, err := exec.LookPath("clawdbot"); err != nil {
|
|
||||||
return fmt.Errorf("clawdbot is not installed, install from https://docs.clawd.bot")
|
|
||||||
}
|
|
||||||
|
|
||||||
models := []string{model}
|
|
||||||
if config, err := loadIntegration("clawdbot"); err == nil && len(config.Models) > 0 {
|
|
||||||
models = config.Models
|
|
||||||
}
|
|
||||||
if err := c.Edit(models); err != nil {
|
|
||||||
return fmt.Errorf("setup failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command("clawdbot", "gateway")
|
|
||||||
cmd.Stdin = os.Stdin
|
|
||||||
|
|
||||||
// Capture output to detect "already running" message
|
|
||||||
var outputBuf bytes.Buffer
|
|
||||||
cmd.Stdout = io.MultiWriter(os.Stdout, &outputBuf)
|
|
||||||
cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf)
|
|
||||||
|
|
||||||
err := cmd.Run()
|
|
||||||
if err != nil && strings.Contains(outputBuf.String(), "Gateway already running") {
|
|
||||||
fmt.Fprintf(os.Stderr, "%sClawdbot has been configured with Ollama. Gateway is already running.%s\n", ansiGreen, ansiReset)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Clawdbot) Paths() []string {
|
|
||||||
home, _ := os.UserHomeDir()
|
|
||||||
p := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
|
||||||
if _, err := os.Stat(p); err == nil {
|
|
||||||
return []string{p}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Clawdbot) Edit(models []string) error {
|
|
||||||
if len(models) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
home, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
configPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
|
||||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read into map[string]any to preserve unknown fields
|
|
||||||
config := make(map[string]any)
|
|
||||||
if data, err := os.ReadFile(configPath); err == nil {
|
|
||||||
_ = json.Unmarshal(data, &config)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Navigate/create: models.providers.ollama (preserving other providers)
|
|
||||||
modelsSection, _ := config["models"].(map[string]any)
|
|
||||||
if modelsSection == nil {
|
|
||||||
modelsSection = make(map[string]any)
|
|
||||||
}
|
|
||||||
providers, _ := modelsSection["providers"].(map[string]any)
|
|
||||||
if providers == nil {
|
|
||||||
providers = make(map[string]any)
|
|
||||||
}
|
|
||||||
ollama, _ := providers["ollama"].(map[string]any)
|
|
||||||
if ollama == nil {
|
|
||||||
ollama = make(map[string]any)
|
|
||||||
}
|
|
||||||
|
|
||||||
ollama["baseUrl"] = envconfig.Host().String() + "/v1"
|
|
||||||
// needed to register provider
|
|
||||||
ollama["apiKey"] = "ollama-local"
|
|
||||||
// TODO(parthsareen): potentially move to responses
|
|
||||||
ollama["api"] = "openai-completions"
|
|
||||||
|
|
||||||
// Build map of existing models to preserve user customizations
|
|
||||||
existingModels, _ := ollama["models"].([]any)
|
|
||||||
existingByID := make(map[string]map[string]any)
|
|
||||||
for _, m := range existingModels {
|
|
||||||
if entry, ok := m.(map[string]any); ok {
|
|
||||||
if id, ok := entry["id"].(string); ok {
|
|
||||||
existingByID[id] = entry
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var newModels []any
|
|
||||||
for _, model := range models {
|
|
||||||
entry := map[string]any{
|
|
||||||
"id": model,
|
|
||||||
"name": model,
|
|
||||||
"reasoning": false,
|
|
||||||
"input": []any{"text"},
|
|
||||||
"cost": map[string]any{
|
|
||||||
"input": 0,
|
|
||||||
"output": 0,
|
|
||||||
"cacheRead": 0,
|
|
||||||
"cacheWrite": 0,
|
|
||||||
},
|
|
||||||
// TODO(parthsareen): get these values from API
|
|
||||||
"contextWindow": 131072,
|
|
||||||
"maxTokens": 16384,
|
|
||||||
}
|
|
||||||
// Merge existing fields (user customizations)
|
|
||||||
if existing, ok := existingByID[model]; ok {
|
|
||||||
for k, v := range existing {
|
|
||||||
if _, isNew := entry[k]; !isNew {
|
|
||||||
entry[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
newModels = append(newModels, entry)
|
|
||||||
}
|
|
||||||
ollama["models"] = newModels
|
|
||||||
|
|
||||||
providers["ollama"] = ollama
|
|
||||||
modelsSection["providers"] = providers
|
|
||||||
config["models"] = modelsSection
|
|
||||||
|
|
||||||
// Update agents.defaults.model.primary (preserving other agent settings)
|
|
||||||
agents, _ := config["agents"].(map[string]any)
|
|
||||||
if agents == nil {
|
|
||||||
agents = make(map[string]any)
|
|
||||||
}
|
|
||||||
defaults, _ := agents["defaults"].(map[string]any)
|
|
||||||
if defaults == nil {
|
|
||||||
defaults = make(map[string]any)
|
|
||||||
}
|
|
||||||
modelConfig, _ := defaults["model"].(map[string]any)
|
|
||||||
if modelConfig == nil {
|
|
||||||
modelConfig = make(map[string]any)
|
|
||||||
}
|
|
||||||
modelConfig["primary"] = "ollama/" + models[0]
|
|
||||||
defaults["model"] = modelConfig
|
|
||||||
agents["defaults"] = defaults
|
|
||||||
config["agents"] = agents
|
|
||||||
|
|
||||||
data, err := json.MarshalIndent(config, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return writeWithBackup(configPath, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Clawdbot) Models() []string {
|
|
||||||
home, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
config, err := readJSONFile(filepath.Join(home, ".clawdbot", "clawdbot.json"))
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
modelsSection, _ := config["models"].(map[string]any)
|
|
||||||
providers, _ := modelsSection["providers"].(map[string]any)
|
|
||||||
ollama, _ := providers["ollama"].(map[string]any)
|
|
||||||
modelList, _ := ollama["models"].([]any)
|
|
||||||
|
|
||||||
var result []string
|
|
||||||
for _, m := range modelList {
|
|
||||||
if entry, ok := m.(map[string]any); ok {
|
|
||||||
if id, ok := entry["id"].(string); ok {
|
|
||||||
result = append(result, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
@@ -1,625 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestClawdbotIntegration(t *testing.T) {
|
|
||||||
c := &Clawdbot{}
|
|
||||||
|
|
||||||
t.Run("String", func(t *testing.T) {
|
|
||||||
if got := c.String(); got != "Clawdbot" {
|
|
||||||
t.Errorf("String() = %q, want %q", got, "Clawdbot")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("implements Runner", func(t *testing.T) {
|
|
||||||
var _ Runner = c
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("implements Editor", func(t *testing.T) {
|
|
||||||
var _ Editor = c
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClawdbotEdit(t *testing.T) {
|
|
||||||
c := &Clawdbot{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
|
||||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
|
||||||
|
|
||||||
cleanup := func() { os.RemoveAll(configDir) }
|
|
||||||
|
|
||||||
t.Run("fresh install", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
assertClawdbotModelExists(t, configPath, "llama3.2")
|
|
||||||
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("multiple models - first is primary", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
assertClawdbotModelExists(t, configPath, "llama3.2")
|
|
||||||
assertClawdbotModelExists(t, configPath, "mistral")
|
|
||||||
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("preserve other providers", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"anthropic":{"apiKey":"xxx"}}}}`), 0o644)
|
|
||||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
models := cfg["models"].(map[string]any)
|
|
||||||
providers := models["providers"].(map[string]any)
|
|
||||||
if providers["anthropic"] == nil {
|
|
||||||
t.Error("anthropic provider was removed")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("preserve top-level keys", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{"theme":"dark","mcp":{"servers":{}}}`), 0o644)
|
|
||||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
if cfg["theme"] != "dark" {
|
|
||||||
t.Error("theme was removed")
|
|
||||||
}
|
|
||||||
if cfg["mcp"] == nil {
|
|
||||||
t.Error("mcp was removed")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("preserve user customizations on models", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
c.Edit([]string{"llama3.2"})
|
|
||||||
|
|
||||||
// User adds custom field
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
models := cfg["models"].(map[string]any)
|
|
||||||
providers := models["providers"].(map[string]any)
|
|
||||||
ollama := providers["ollama"].(map[string]any)
|
|
||||||
modelList := ollama["models"].([]any)
|
|
||||||
entry := modelList[0].(map[string]any)
|
|
||||||
entry["customField"] = "user-value"
|
|
||||||
configData, _ := json.MarshalIndent(cfg, "", " ")
|
|
||||||
os.WriteFile(configPath, configData, 0o644)
|
|
||||||
|
|
||||||
// Re-run Edit
|
|
||||||
c.Edit([]string{"llama3.2"})
|
|
||||||
|
|
||||||
data, _ = os.ReadFile(configPath)
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
models = cfg["models"].(map[string]any)
|
|
||||||
providers = models["providers"].(map[string]any)
|
|
||||||
ollama = providers["ollama"].(map[string]any)
|
|
||||||
modelList = ollama["models"].([]any)
|
|
||||||
entry = modelList[0].(map[string]any)
|
|
||||||
if entry["customField"] != "user-value" {
|
|
||||||
t.Error("custom field was lost")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("edit replaces models list", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
c.Edit([]string{"llama3.2", "mistral"})
|
|
||||||
c.Edit([]string{"llama3.2"})
|
|
||||||
|
|
||||||
assertClawdbotModelExists(t, configPath, "llama3.2")
|
|
||||||
assertClawdbotModelNotExists(t, configPath, "mistral")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("empty models is no-op", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
original := `{"existing":"data"}`
|
|
||||||
os.WriteFile(configPath, []byte(original), 0o644)
|
|
||||||
|
|
||||||
c.Edit([]string{})
|
|
||||||
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
if string(data) != original {
|
|
||||||
t.Error("empty models should not modify file")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("corrupted JSON treated as empty", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
|
|
||||||
|
|
||||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
var cfg map[string]any
|
|
||||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
|
||||||
t.Error("result should be valid JSON")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("wrong type models section", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{"models":"not a map"}`), 0o644)
|
|
||||||
|
|
||||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
assertClawdbotModelExists(t, configPath, "llama3.2")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClawdbotModels(t *testing.T) {
|
|
||||||
c := &Clawdbot{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
t.Run("no config returns nil", func(t *testing.T) {
|
|
||||||
if models := c.Models(); len(models) > 0 {
|
|
||||||
t.Errorf("expected nil/empty, got %v", models)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns all ollama models", func(t *testing.T) {
|
|
||||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(filepath.Join(configDir, "clawdbot.json"), []byte(`{
|
|
||||||
"models":{"providers":{"ollama":{"models":[
|
|
||||||
{"id":"llama3.2"},
|
|
||||||
{"id":"mistral"}
|
|
||||||
]}}}
|
|
||||||
}`), 0o644)
|
|
||||||
|
|
||||||
models := c.Models()
|
|
||||||
if len(models) != 2 {
|
|
||||||
t.Errorf("expected 2 models, got %v", models)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper functions
|
|
||||||
func assertClawdbotModelExists(t *testing.T, path, model string) {
|
|
||||||
t.Helper()
|
|
||||||
data, _ := os.ReadFile(path)
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
models := cfg["models"].(map[string]any)
|
|
||||||
providers := models["providers"].(map[string]any)
|
|
||||||
ollama := providers["ollama"].(map[string]any)
|
|
||||||
modelList := ollama["models"].([]any)
|
|
||||||
for _, m := range modelList {
|
|
||||||
if entry, ok := m.(map[string]any); ok {
|
|
||||||
if entry["id"] == model {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
t.Errorf("model %s not found", model)
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertClawdbotModelNotExists(t *testing.T, path, model string) {
|
|
||||||
t.Helper()
|
|
||||||
data, _ := os.ReadFile(path)
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
models, _ := cfg["models"].(map[string]any)
|
|
||||||
providers, _ := models["providers"].(map[string]any)
|
|
||||||
ollama, _ := providers["ollama"].(map[string]any)
|
|
||||||
modelList, _ := ollama["models"].([]any)
|
|
||||||
for _, m := range modelList {
|
|
||||||
if entry, ok := m.(map[string]any); ok {
|
|
||||||
if entry["id"] == model {
|
|
||||||
t.Errorf("model %s should not exist", model)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertClawdbotPrimaryModel(t *testing.T, path, expected string) {
|
|
||||||
t.Helper()
|
|
||||||
data, _ := os.ReadFile(path)
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
agents := cfg["agents"].(map[string]any)
|
|
||||||
defaults := agents["defaults"].(map[string]any)
|
|
||||||
model := defaults["model"].(map[string]any)
|
|
||||||
if model["primary"] != expected {
|
|
||||||
t.Errorf("primary model = %v, want %v", model["primary"], expected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClawdbotPaths(t *testing.T) {
|
|
||||||
c := &Clawdbot{}
|
|
||||||
|
|
||||||
t.Run("returns path when config exists", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(filepath.Join(configDir, "clawdbot.json"), []byte(`{}`), 0o644)
|
|
||||||
|
|
||||||
paths := c.Paths()
|
|
||||||
if len(paths) != 1 {
|
|
||||||
t.Errorf("expected 1 path, got %d", len(paths))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns nil when config missing", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
if paths := c.Paths(); paths != nil {
|
|
||||||
t.Errorf("expected nil, got %v", paths)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClawdbotModelsEdgeCases(t *testing.T) {
|
|
||||||
c := &Clawdbot{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
|
||||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
|
||||||
cleanup := func() { os.RemoveAll(configDir) }
|
|
||||||
|
|
||||||
t.Run("corrupted JSON returns nil", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
|
|
||||||
if models := c.Models(); models != nil {
|
|
||||||
t.Errorf("expected nil, got %v", models)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("wrong type at models level", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{"models":"string"}`), 0o644)
|
|
||||||
if models := c.Models(); models != nil {
|
|
||||||
t.Errorf("expected nil, got %v", models)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("wrong type at providers level", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{"models":{"providers":"string"}}`), 0o644)
|
|
||||||
if models := c.Models(); models != nil {
|
|
||||||
t.Errorf("expected nil, got %v", models)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("wrong type at ollama level", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":"string"}}}`), 0o644)
|
|
||||||
if models := c.Models(); models != nil {
|
|
||||||
t.Errorf("expected nil, got %v", models)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("model entry missing id", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"name":"test"}]}}}}`), 0o644)
|
|
||||||
if len(c.Models()) != 0 {
|
|
||||||
t.Error("expected empty for missing id")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("model id is not string", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"id":123}]}}}}`), 0o644)
|
|
||||||
if len(c.Models()) != 0 {
|
|
||||||
t.Error("expected empty for non-string id")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClawdbotEditSchemaFields(t *testing.T) {
|
|
||||||
c := &Clawdbot{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
configPath := filepath.Join(tmpDir, ".clawdbot", "clawdbot.json")
|
|
||||||
|
|
||||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
models := cfg["models"].(map[string]any)
|
|
||||||
providers := models["providers"].(map[string]any)
|
|
||||||
ollama := providers["ollama"].(map[string]any)
|
|
||||||
modelList := ollama["models"].([]any)
|
|
||||||
entry := modelList[0].(map[string]any)
|
|
||||||
|
|
||||||
// Verify required schema fields
|
|
||||||
if entry["reasoning"] != false {
|
|
||||||
t.Error("reasoning should be false")
|
|
||||||
}
|
|
||||||
if entry["input"] == nil {
|
|
||||||
t.Error("input should be set")
|
|
||||||
}
|
|
||||||
if entry["contextWindow"] == nil {
|
|
||||||
t.Error("contextWindow should be set")
|
|
||||||
}
|
|
||||||
if entry["maxTokens"] == nil {
|
|
||||||
t.Error("maxTokens should be set")
|
|
||||||
}
|
|
||||||
cost := entry["cost"].(map[string]any)
|
|
||||||
if cost["cacheRead"] == nil {
|
|
||||||
t.Error("cost.cacheRead should be set")
|
|
||||||
}
|
|
||||||
if cost["cacheWrite"] == nil {
|
|
||||||
t.Error("cost.cacheWrite should be set")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClawdbotEditModelNames(t *testing.T) {
|
|
||||||
c := &Clawdbot{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
configPath := filepath.Join(tmpDir, ".clawdbot", "clawdbot.json")
|
|
||||||
cleanup := func() { os.RemoveAll(filepath.Join(tmpDir, ".clawdbot")) }
|
|
||||||
|
|
||||||
t.Run("model with colon tag", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
if err := c.Edit([]string{"llama3.2:70b"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
assertClawdbotModelExists(t, configPath, "llama3.2:70b")
|
|
||||||
assertClawdbotPrimaryModel(t, configPath, "ollama/llama3.2:70b")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("model with slash", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
if err := c.Edit([]string{"library/model:tag"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
assertClawdbotModelExists(t, configPath, "library/model:tag")
|
|
||||||
assertClawdbotPrimaryModel(t, configPath, "ollama/library/model:tag")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("model with hyphen", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
if err := c.Edit([]string{"test-model"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
assertClawdbotModelExists(t, configPath, "test-model")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClawdbotEditAgentsPreservation(t *testing.T) {
|
|
||||||
c := &Clawdbot{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
|
||||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
|
||||||
cleanup := func() { os.RemoveAll(configDir) }
|
|
||||||
|
|
||||||
t.Run("preserve other agent defaults", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{"agents":{"defaults":{"model":{"primary":"old"},"temperature":0.7}}}`), 0o644)
|
|
||||||
|
|
||||||
c.Edit([]string{"llama3.2"})
|
|
||||||
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
agents := cfg["agents"].(map[string]any)
|
|
||||||
defaults := agents["defaults"].(map[string]any)
|
|
||||||
if defaults["temperature"] != 0.7 {
|
|
||||||
t.Error("temperature setting was lost")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("preserve other agents besides defaults", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{"agents":{"defaults":{},"custom-agent":{"foo":"bar"}}}`), 0o644)
|
|
||||||
|
|
||||||
c.Edit([]string{"llama3.2"})
|
|
||||||
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
agents := cfg["agents"].(map[string]any)
|
|
||||||
if agents["custom-agent"] == nil {
|
|
||||||
t.Error("custom-agent was lost")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
const testClawdbotFixture = `{
|
|
||||||
"theme": "dark",
|
|
||||||
"mcp": {"servers": {"custom": {"enabled": true}}},
|
|
||||||
"models": {
|
|
||||||
"providers": {
|
|
||||||
"anthropic": {"apiKey": "xxx"},
|
|
||||||
"ollama": {
|
|
||||||
"baseUrl": "http://127.0.0.1:11434/v1",
|
|
||||||
"models": [{"id": "old-model", "customField": "preserved"}]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"agents": {
|
|
||||||
"defaults": {"model": {"primary": "old"}, "temperature": 0.7},
|
|
||||||
"custom-agent": {"foo": "bar"}
|
|
||||||
}
|
|
||||||
}`
|
|
||||||
|
|
||||||
func TestClawdbotEdit_RoundTrip(t *testing.T) {
|
|
||||||
c := &Clawdbot{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
|
||||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
|
||||||
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
|
|
||||||
|
|
||||||
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
|
|
||||||
// Verify top-level preserved
|
|
||||||
if cfg["theme"] != "dark" {
|
|
||||||
t.Error("theme not preserved")
|
|
||||||
}
|
|
||||||
mcp := cfg["mcp"].(map[string]any)
|
|
||||||
servers := mcp["servers"].(map[string]any)
|
|
||||||
if servers["custom"] == nil {
|
|
||||||
t.Error("mcp.servers.custom not preserved")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify other providers preserved
|
|
||||||
models := cfg["models"].(map[string]any)
|
|
||||||
providers := models["providers"].(map[string]any)
|
|
||||||
if providers["anthropic"] == nil {
|
|
||||||
t.Error("anthropic provider not preserved")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify agents preserved
|
|
||||||
agents := cfg["agents"].(map[string]any)
|
|
||||||
if agents["custom-agent"] == nil {
|
|
||||||
t.Error("custom-agent not preserved")
|
|
||||||
}
|
|
||||||
defaults := agents["defaults"].(map[string]any)
|
|
||||||
if defaults["temperature"] != 0.7 {
|
|
||||||
t.Error("temperature not preserved")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClawdbotEdit_Idempotent(t *testing.T) {
|
|
||||||
c := &Clawdbot{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
|
||||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
|
||||||
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
|
|
||||||
|
|
||||||
c.Edit([]string{"llama3.2", "mistral"})
|
|
||||||
firstData, _ := os.ReadFile(configPath)
|
|
||||||
|
|
||||||
c.Edit([]string{"llama3.2", "mistral"})
|
|
||||||
secondData, _ := os.ReadFile(configPath)
|
|
||||||
|
|
||||||
if string(firstData) != string(secondData) {
|
|
||||||
t.Error("repeated edits with same models produced different results")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClawdbotEdit_MultipleConsecutiveEdits(t *testing.T) {
|
|
||||||
c := &Clawdbot{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
|
||||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
|
||||||
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(testClawdbotFixture), 0o644)
|
|
||||||
|
|
||||||
for i := range 10 {
|
|
||||||
models := []string{"model-a", "model-b"}
|
|
||||||
if i%2 == 0 {
|
|
||||||
models = []string{"model-x", "model-y", "model-z"}
|
|
||||||
}
|
|
||||||
if err := c.Edit(models); err != nil {
|
|
||||||
t.Fatalf("edit %d failed: %v", i, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
var cfg map[string]any
|
|
||||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
|
||||||
t.Fatalf("file is not valid JSON after multiple edits: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg["theme"] != "dark" {
|
|
||||||
t.Error("theme lost after multiple edits")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClawdbotEdit_BackupCreated(t *testing.T) {
|
|
||||||
c := &Clawdbot{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
|
||||||
configPath := filepath.Join(configDir, "clawdbot.json")
|
|
||||||
backupDir := filepath.Join(os.TempDir(), "ollama-backups")
|
|
||||||
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
uniqueMarker := fmt.Sprintf("test-marker-%d", os.Getpid())
|
|
||||||
original := fmt.Sprintf(`{"theme": "%s"}`, uniqueMarker)
|
|
||||||
os.WriteFile(configPath, []byte(original), 0o644)
|
|
||||||
|
|
||||||
if err := c.Edit([]string{"model-a"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
backups, _ := filepath.Glob(filepath.Join(backupDir, "clawdbot.json.*"))
|
|
||||||
foundBackup := false
|
|
||||||
for _, backup := range backups {
|
|
||||||
data, _ := os.ReadFile(backup)
|
|
||||||
if string(data) == original {
|
|
||||||
foundBackup = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !foundBackup {
|
|
||||||
t.Error("backup with original content not found")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClawdbotEdit_CreatesDirectoryIfMissing(t *testing.T) {
|
|
||||||
c := &Clawdbot{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
configDir := filepath.Join(tmpDir, ".clawdbot")
|
|
||||||
|
|
||||||
if _, err := os.Stat(configDir); !os.IsNotExist(err) {
|
|
||||||
t.Fatal("directory should not exist before test")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.Edit([]string{"model-a"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := os.Stat(configDir); os.IsNotExist(err) {
|
|
||||||
t.Fatal("directory was not created")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"slices"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCodexArgs(t *testing.T) {
|
|
||||||
c := &Codex{}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
model string
|
|
||||||
want []string
|
|
||||||
}{
|
|
||||||
{"with model", "llama3.2", []string{"--oss", "-m", "llama3.2"}},
|
|
||||||
{"empty model", "", []string{"--oss"}},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got := c.args(tt.model)
|
|
||||||
if !slices.Equal(got, tt.want) {
|
|
||||||
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -9,17 +9,34 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
type integration struct {
|
type integration struct {
|
||||||
Models []string `json:"models"`
|
Models []string `json:"models"`
|
||||||
|
Aliases map[string]string `json:"aliases,omitempty"`
|
||||||
|
Onboarded bool `json:"onboarded,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IntegrationConfig is the persisted config for one integration.
|
||||||
|
type IntegrationConfig = integration
|
||||||
|
|
||||||
type config struct {
|
type config struct {
|
||||||
Integrations map[string]*integration `json:"integrations"`
|
Integrations map[string]*integration `json:"integrations"`
|
||||||
|
LastModel string `json:"last_model,omitempty"`
|
||||||
|
LastSelection string `json:"last_selection,omitempty"` // "run" or integration name
|
||||||
}
|
}
|
||||||
|
|
||||||
func configPath() (string, error) {
|
func configPath() (string, error) {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return filepath.Join(home, ".ollama", "config.json"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func legacyConfigPath() (string, error) {
|
||||||
home, err := os.UserHomeDir()
|
home, err := os.UserHomeDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -27,6 +44,44 @@ func configPath() (string, error) {
|
|||||||
return filepath.Join(home, ".ollama", "config", "config.json"), nil
|
return filepath.Join(home, ".ollama", "config", "config.json"), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// migrateConfig moves the config from the legacy path to ~/.ollama/config.json
|
||||||
|
func migrateConfig() (bool, error) {
|
||||||
|
oldPath, err := legacyConfigPath()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
oldData, err := os.ReadFile(oldPath)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ignore legacy files with invalid JSON and continue startup.
|
||||||
|
if !json.Valid(oldData) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
newPath, err := configPath()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.MkdirAll(filepath.Dir(newPath), 0o755); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(newPath, oldData, 0o644); err != nil {
|
||||||
|
return false, fmt.Errorf("write new config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = os.Remove(oldPath)
|
||||||
|
_ = os.Remove(filepath.Dir(oldPath)) // clean up empty directory
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
func load() (*config, error) {
|
func load() (*config, error) {
|
||||||
path, err := configPath()
|
path, err := configPath()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -34,6 +89,11 @@ func load() (*config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
data, err := os.ReadFile(path)
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil && os.IsNotExist(err) {
|
||||||
|
if migrated, merr := migrateConfig(); merr == nil && migrated {
|
||||||
|
data, err = os.ReadFile(path)
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if os.IsNotExist(err) {
|
if os.IsNotExist(err) {
|
||||||
return &config{Integrations: make(map[string]*integration)}, nil
|
return &config{Integrations: make(map[string]*integration)}, nil
|
||||||
@@ -66,10 +126,10 @@ func save(cfg *config) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return writeWithBackup(path, data)
|
return fileutil.WriteWithBackup(path, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func saveIntegration(appName string, models []string) error {
|
func SaveIntegration(appName string, models []string) error {
|
||||||
if appName == "" {
|
if appName == "" {
|
||||||
return errors.New("app name cannot be empty")
|
return errors.New("app name cannot be empty")
|
||||||
}
|
}
|
||||||
@@ -79,25 +139,134 @@ func saveIntegration(appName string, models []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg.Integrations[strings.ToLower(appName)] = &integration{
|
key := strings.ToLower(appName)
|
||||||
|
existing := cfg.Integrations[key]
|
||||||
|
var aliases map[string]string
|
||||||
|
var onboarded bool
|
||||||
|
if existing != nil {
|
||||||
|
aliases = existing.Aliases
|
||||||
|
onboarded = existing.Onboarded
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.Integrations[key] = &integration{
|
||||||
Models: models,
|
Models: models,
|
||||||
|
Aliases: aliases,
|
||||||
|
Onboarded: onboarded,
|
||||||
}
|
}
|
||||||
|
|
||||||
return save(cfg)
|
return save(cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadIntegration(appName string) (*integration, error) {
|
// MarkIntegrationOnboarded marks an integration as onboarded in Ollama's config.
|
||||||
|
func MarkIntegrationOnboarded(appName string) error {
|
||||||
|
cfg, err := load()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
key := strings.ToLower(appName)
|
||||||
|
existing := cfg.Integrations[key]
|
||||||
|
if existing == nil {
|
||||||
|
existing = &integration{}
|
||||||
|
}
|
||||||
|
existing.Onboarded = true
|
||||||
|
cfg.Integrations[key] = existing
|
||||||
|
return save(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
|
||||||
|
func IntegrationModel(appName string) string {
|
||||||
|
integrationConfig, err := LoadIntegration(appName)
|
||||||
|
if err != nil || len(integrationConfig.Models) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return integrationConfig.Models[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntegrationModels returns all configured models for an integration, or nil.
|
||||||
|
func IntegrationModels(appName string) []string {
|
||||||
|
integrationConfig, err := LoadIntegration(appName)
|
||||||
|
if err != nil || len(integrationConfig.Models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return integrationConfig.Models
|
||||||
|
}
|
||||||
|
|
||||||
|
// LastModel returns the last model that was run, or empty string if none.
|
||||||
|
func LastModel() string {
|
||||||
|
cfg, err := load()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return cfg.LastModel
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLastModel saves the last model that was run.
|
||||||
|
func SetLastModel(model string) error {
|
||||||
|
cfg, err := load()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cfg.LastModel = model
|
||||||
|
return save(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LastSelection returns the last menu selection ("run" or integration name), or empty string if none.
|
||||||
|
func LastSelection() string {
|
||||||
|
cfg, err := load()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return cfg.LastSelection
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLastSelection saves the last menu selection ("run" or integration name).
|
||||||
|
func SetLastSelection(selection string) error {
|
||||||
|
cfg, err := load()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cfg.LastSelection = selection
|
||||||
|
return save(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadIntegration returns the saved config for one integration.
|
||||||
|
func LoadIntegration(appName string) (*integration, error) {
|
||||||
cfg, err := load()
|
cfg, err := load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ic, ok := cfg.Integrations[strings.ToLower(appName)]
|
integrationConfig, ok := cfg.Integrations[strings.ToLower(appName)]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, os.ErrNotExist
|
return nil, os.ErrNotExist
|
||||||
}
|
}
|
||||||
|
|
||||||
return ic, nil
|
return integrationConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveAliases replaces the saved aliases for one integration.
|
||||||
|
func SaveAliases(appName string, aliases map[string]string) error {
|
||||||
|
if appName == "" {
|
||||||
|
return errors.New("app name cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := load()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
key := strings.ToLower(appName)
|
||||||
|
existing := cfg.Integrations[key]
|
||||||
|
if existing == nil {
|
||||||
|
existing = &integration{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace aliases entirely (not merge) so deletions are persisted
|
||||||
|
existing.Aliases = aliases
|
||||||
|
|
||||||
|
cfg.Integrations[key] = existing
|
||||||
|
return save(cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func listIntegrations() ([]integration, error) {
|
func listIntegrations() ([]integration, error) {
|
||||||
@@ -107,8 +276,8 @@ func listIntegrations() ([]integration, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
result := make([]integration, 0, len(cfg.Integrations))
|
result := make([]integration, 0, len(cfg.Integrations))
|
||||||
for _, ic := range cfg.Integrations {
|
for _, integrationConfig := range cfg.Integrations {
|
||||||
result = append(result, *ic)
|
result = append(result, *integrationConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
|
|||||||
641
cmd/config/config_cloud_test.go
Normal file
641
cmd/config/config_cloud_test.go
Normal file
@@ -0,0 +1,641 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSetAliases_CloudModel(t *testing.T) {
|
||||||
|
// Test the SetAliases logic by checking the alias map behavior
|
||||||
|
aliases := map[string]string{
|
||||||
|
"primary": "kimi-k2.5:cloud",
|
||||||
|
"fast": "kimi-k2.5:cloud",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify fast is set (cloud model behavior)
|
||||||
|
if aliases["fast"] == "" {
|
||||||
|
t.Error("cloud model should have fast alias set")
|
||||||
|
}
|
||||||
|
if aliases["fast"] != aliases["primary"] {
|
||||||
|
t.Errorf("fast should equal primary for auto-set, got fast=%q primary=%q", aliases["fast"], aliases["primary"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetAliases_LocalModel(t *testing.T) {
|
||||||
|
aliases := map[string]string{
|
||||||
|
"primary": "llama3.2:latest",
|
||||||
|
}
|
||||||
|
// Simulate local model behavior: fast should be empty
|
||||||
|
delete(aliases, "fast")
|
||||||
|
|
||||||
|
if aliases["fast"] != "" {
|
||||||
|
t.Error("local model should have empty fast alias")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
// First save with both primary and fast
|
||||||
|
initial := map[string]string{
|
||||||
|
"primary": "cloud-model",
|
||||||
|
"fast": "cloud-model",
|
||||||
|
}
|
||||||
|
if err := SaveAliases("claude", initial); err != nil {
|
||||||
|
t.Fatalf("failed to save initial aliases: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify both are saved
|
||||||
|
loaded, err := LoadIntegration("claude")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to load: %v", err)
|
||||||
|
}
|
||||||
|
if loaded.Aliases["fast"] != "cloud-model" {
|
||||||
|
t.Errorf("expected fast=cloud-model, got %q", loaded.Aliases["fast"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now save without fast (simulating switch to local model)
|
||||||
|
updated := map[string]string{
|
||||||
|
"primary": "local-model",
|
||||||
|
// fast intentionally missing
|
||||||
|
}
|
||||||
|
if err := SaveAliases("claude", updated); err != nil {
|
||||||
|
t.Fatalf("failed to save updated aliases: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify fast is GONE (not merged/preserved)
|
||||||
|
loaded, err = LoadIntegration("claude")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to load after update: %v", err)
|
||||||
|
}
|
||||||
|
if loaded.Aliases["fast"] != "" {
|
||||||
|
t.Errorf("fast should be removed after saving without it, got %q", loaded.Aliases["fast"])
|
||||||
|
}
|
||||||
|
if loaded.Aliases["primary"] != "local-model" {
|
||||||
|
t.Errorf("primary should be updated to local-model, got %q", loaded.Aliases["primary"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSaveAliases_PreservesModels(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
// First save integration with models
|
||||||
|
if err := SaveIntegration("claude", []string{"model1", "model2"}); err != nil {
|
||||||
|
t.Fatalf("failed to save integration: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then update aliases
|
||||||
|
aliases := map[string]string{"primary": "new-model"}
|
||||||
|
if err := SaveAliases("claude", aliases); err != nil {
|
||||||
|
t.Fatalf("failed to save aliases: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify models are preserved
|
||||||
|
loaded, err := LoadIntegration("claude")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to load: %v", err)
|
||||||
|
}
|
||||||
|
if len(loaded.Models) != 2 || loaded.Models[0] != "model1" {
|
||||||
|
t.Errorf("models should be preserved, got %v", loaded.Models)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSaveAliases_EmptyMap clears all aliases
|
||||||
|
func TestSaveAliases_EmptyMap(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
// Save with aliases
|
||||||
|
if err := SaveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
|
||||||
|
t.Fatalf("failed to save: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save empty map
|
||||||
|
if err := SaveAliases("claude", map[string]string{}); err != nil {
|
||||||
|
t.Fatalf("failed to save empty: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded, err := LoadIntegration("claude")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to load: %v", err)
|
||||||
|
}
|
||||||
|
if len(loaded.Aliases) != 0 {
|
||||||
|
t.Errorf("aliases should be empty, got %v", loaded.Aliases)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSaveAliases_NilMap handles nil gracefully
|
||||||
|
func TestSaveAliases_NilMap(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
// Save with aliases first
|
||||||
|
if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||||
|
t.Fatalf("failed to save: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save nil map - should clear aliases
|
||||||
|
if err := SaveAliases("claude", nil); err != nil {
|
||||||
|
t.Fatalf("failed to save nil: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded, err := LoadIntegration("claude")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to load: %v", err)
|
||||||
|
}
|
||||||
|
if len(loaded.Aliases) > 0 {
|
||||||
|
t.Errorf("aliases should be nil or empty, got %v", loaded.Aliases)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSaveAliases_EmptyAppName returns error
|
||||||
|
func TestSaveAliases_EmptyAppName(t *testing.T) {
|
||||||
|
err := SaveAliases("", map[string]string{"primary": "model"})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for empty app name")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSaveAliases_CaseInsensitive(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
if err := SaveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
|
||||||
|
t.Fatalf("failed to save: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load with different case
|
||||||
|
loaded, err := LoadIntegration("claude")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to load: %v", err)
|
||||||
|
}
|
||||||
|
if loaded.Aliases["primary"] != "model1" {
|
||||||
|
t.Errorf("expected primary=model1, got %q", loaded.Aliases["primary"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update with different case
|
||||||
|
if err := SaveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
|
||||||
|
t.Fatalf("failed to update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded, err = LoadIntegration("claude")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to load after update: %v", err)
|
||||||
|
}
|
||||||
|
if loaded.Aliases["primary"] != "model2" {
|
||||||
|
t.Errorf("expected primary=model2, got %q", loaded.Aliases["primary"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSaveAliases_CreatesIntegration creates integration if it doesn't exist
|
||||||
|
func TestSaveAliases_CreatesIntegration(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
// Save aliases for non-existent integration
|
||||||
|
if err := SaveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
|
||||||
|
t.Fatalf("failed to save: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded, err := LoadIntegration("newintegration")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to load: %v", err)
|
||||||
|
}
|
||||||
|
if loaded.Aliases["primary"] != "model" {
|
||||||
|
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigureAliases_AliasMap(t *testing.T) {
|
||||||
|
t.Run("cloud model auto-sets fast to primary", func(t *testing.T) {
|
||||||
|
aliases := make(map[string]string)
|
||||||
|
aliases["primary"] = "cloud-model"
|
||||||
|
|
||||||
|
// Simulate cloud model behavior
|
||||||
|
isCloud := true
|
||||||
|
if isCloud {
|
||||||
|
if aliases["fast"] == "" {
|
||||||
|
aliases["fast"] = aliases["primary"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if aliases["fast"] != "cloud-model" {
|
||||||
|
t.Errorf("expected fast=cloud-model, got %q", aliases["fast"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cloud model preserves custom fast", func(t *testing.T) {
|
||||||
|
aliases := map[string]string{
|
||||||
|
"primary": "cloud-model",
|
||||||
|
"fast": "custom-fast-model",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate cloud model behavior - should preserve existing fast
|
||||||
|
isCloud := true
|
||||||
|
if isCloud {
|
||||||
|
if aliases["fast"] == "" {
|
||||||
|
aliases["fast"] = aliases["primary"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if aliases["fast"] != "custom-fast-model" {
|
||||||
|
t.Errorf("expected fast=custom-fast-model (preserved), got %q", aliases["fast"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("local model clears fast", func(t *testing.T) {
|
||||||
|
aliases := map[string]string{
|
||||||
|
"primary": "local-model",
|
||||||
|
"fast": "should-be-cleared",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate local model behavior
|
||||||
|
isCloud := false
|
||||||
|
if !isCloud {
|
||||||
|
delete(aliases, "fast")
|
||||||
|
}
|
||||||
|
|
||||||
|
if aliases["fast"] != "" {
|
||||||
|
t.Errorf("expected fast to be cleared, got %q", aliases["fast"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("switching cloud to local clears fast", func(t *testing.T) {
|
||||||
|
// Start with cloud config
|
||||||
|
aliases := map[string]string{
|
||||||
|
"primary": "cloud-model",
|
||||||
|
"fast": "cloud-model",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Switch to local
|
||||||
|
aliases["primary"] = "local-model"
|
||||||
|
isCloud := false
|
||||||
|
if !isCloud {
|
||||||
|
delete(aliases, "fast")
|
||||||
|
}
|
||||||
|
|
||||||
|
if aliases["fast"] != "" {
|
||||||
|
t.Errorf("fast should be cleared when switching to local, got %q", aliases["fast"])
|
||||||
|
}
|
||||||
|
if aliases["primary"] != "local-model" {
|
||||||
|
t.Errorf("primary should be updated, got %q", aliases["primary"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("switching local to cloud sets fast", func(t *testing.T) {
|
||||||
|
// Start with local config (no fast)
|
||||||
|
aliases := map[string]string{
|
||||||
|
"primary": "local-model",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Switch to cloud
|
||||||
|
aliases["primary"] = "cloud-model"
|
||||||
|
isCloud := true
|
||||||
|
if isCloud {
|
||||||
|
if aliases["fast"] == "" {
|
||||||
|
aliases["fast"] = aliases["primary"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if aliases["fast"] != "cloud-model" {
|
||||||
|
t.Errorf("fast should be set when switching to cloud, got %q", aliases["fast"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetAliases_PrefixMapping(t *testing.T) {
|
||||||
|
// This tests the expected mapping without needing a real client
|
||||||
|
aliases := map[string]string{
|
||||||
|
"primary": "my-cloud-model",
|
||||||
|
"fast": "my-fast-model",
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedMappings := map[string]string{
|
||||||
|
"claude-sonnet-": aliases["primary"],
|
||||||
|
"claude-haiku-": aliases["fast"],
|
||||||
|
}
|
||||||
|
|
||||||
|
if expectedMappings["claude-sonnet-"] != "my-cloud-model" {
|
||||||
|
t.Errorf("claude-sonnet- should map to primary")
|
||||||
|
}
|
||||||
|
if expectedMappings["claude-haiku-"] != "my-fast-model" {
|
||||||
|
t.Errorf("claude-haiku- should map to fast")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetAliases_LocalDeletesPrefixes(t *testing.T) {
|
||||||
|
aliases := map[string]string{
|
||||||
|
"primary": "local-model",
|
||||||
|
// fast is empty/missing - indicates local model
|
||||||
|
}
|
||||||
|
|
||||||
|
prefixesToDelete := []string{"claude-sonnet-", "claude-haiku-"}
|
||||||
|
|
||||||
|
// Verify the logic: when fast is empty, we should delete
|
||||||
|
if aliases["fast"] != "" {
|
||||||
|
t.Error("fast should be empty for local model")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify we have the right prefixes to delete
|
||||||
|
if len(prefixesToDelete) != 2 {
|
||||||
|
t.Errorf("expected 2 prefixes to delete, got %d", len(prefixesToDelete))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAtomicUpdate_ServerFailsConfigNotSaved simulates atomic update behavior
|
||||||
|
func TestAtomicUpdate_ServerFailsConfigNotSaved(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
// Simulate: server fails, config should NOT be saved
|
||||||
|
serverErr := errors.New("server unavailable")
|
||||||
|
|
||||||
|
if serverErr == nil {
|
||||||
|
t.Error("config should NOT be saved when server fails")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAtomicUpdate_ServerSucceedsConfigSaved simulates successful atomic update
|
||||||
|
func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
// Simulate: server succeeds, config should be saved
|
||||||
|
var serverErr error
|
||||||
|
if serverErr != nil {
|
||||||
|
t.Fatal("server should succeed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||||
|
t.Fatalf("saveAliases failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it was actually saved
|
||||||
|
loaded, err := LoadIntegration("claude")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to load: %v", err)
|
||||||
|
}
|
||||||
|
if loaded.Aliases["primary"] != "model" {
|
||||||
|
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigFile_PreservesUnknownFields(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
// Write config with extra fields
|
||||||
|
configPath := filepath.Join(tmpDir, ".ollama", "config.json")
|
||||||
|
os.MkdirAll(filepath.Dir(configPath), 0o755)
|
||||||
|
|
||||||
|
// Note: Our config struct only has Integrations, so top-level unknown fields
|
||||||
|
// won't be preserved by our current implementation. This test documents that.
|
||||||
|
initialConfig := `{
|
||||||
|
"integrations": {
|
||||||
|
"claude": {
|
||||||
|
"models": ["model1"],
|
||||||
|
"aliases": {"primary": "model1"},
|
||||||
|
"unknownField": "should be lost"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"topLevelUnknown": "will be lost"
|
||||||
|
}`
|
||||||
|
os.WriteFile(configPath, []byte(initialConfig), 0o644)
|
||||||
|
|
||||||
|
// Update aliases
|
||||||
|
if err := SaveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
|
||||||
|
t.Fatalf("failed to save: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read raw file to check
|
||||||
|
data, _ := os.ReadFile(configPath)
|
||||||
|
content := string(data)
|
||||||
|
|
||||||
|
// models should be preserved
|
||||||
|
if !contains(content, "model1") {
|
||||||
|
t.Error("models should be preserved")
|
||||||
|
}
|
||||||
|
|
||||||
|
// primary should be updated
|
||||||
|
if !contains(content, "model2") {
|
||||||
|
t.Error("primary should be updated to model2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func contains(s, substr string) bool {
|
||||||
|
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsHelper(s, substr string) bool {
|
||||||
|
for i := 0; i <= len(s)-len(substr); i++ {
|
||||||
|
if s[i:i+len(substr)] == substr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelNameEdgeCases(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
model string
|
||||||
|
}{
|
||||||
|
{"simple", "llama3.2"},
|
||||||
|
{"with tag", "llama3.2:latest"},
|
||||||
|
{"with cloud tag", "kimi-k2.5:cloud"},
|
||||||
|
{"with namespace", "library/llama3.2"},
|
||||||
|
{"with dots", "glm-4.7-flash"},
|
||||||
|
{"with numbers", "qwen3:8b"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
aliases := map[string]string{"primary": tc.model}
|
||||||
|
if err := SaveAliases("claude", aliases); err != nil {
|
||||||
|
t.Fatalf("failed to save model %q: %v", tc.model, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded, err := LoadIntegration("claude")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to load: %v", err)
|
||||||
|
}
|
||||||
|
if loaded.Aliases["primary"] != tc.model {
|
||||||
|
t.Errorf("expected primary=%q, got %q", tc.model, loaded.Aliases["primary"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSwitchingScenarios(t *testing.T) {
|
||||||
|
t.Run("cloud to local removes fast", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
// Initial cloud config
|
||||||
|
if err := SaveAliases("claude", map[string]string{
|
||||||
|
"primary": "cloud-model",
|
||||||
|
"fast": "cloud-model",
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Switch to local (no fast)
|
||||||
|
if err := SaveAliases("claude", map[string]string{
|
||||||
|
"primary": "local-model",
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded, _ := LoadIntegration("claude")
|
||||||
|
if loaded.Aliases["fast"] != "" {
|
||||||
|
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
|
||||||
|
}
|
||||||
|
if loaded.Aliases["primary"] != "local-model" {
|
||||||
|
t.Errorf("primary should be local-model, got %q", loaded.Aliases["primary"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("local to cloud adds fast", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
// Initial local config
|
||||||
|
if err := SaveAliases("claude", map[string]string{
|
||||||
|
"primary": "local-model",
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Switch to cloud (with fast)
|
||||||
|
if err := SaveAliases("claude", map[string]string{
|
||||||
|
"primary": "cloud-model",
|
||||||
|
"fast": "cloud-model",
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded, _ := LoadIntegration("claude")
|
||||||
|
if loaded.Aliases["fast"] != "cloud-model" {
|
||||||
|
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cloud to different cloud updates both", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
// Initial cloud config
|
||||||
|
if err := SaveAliases("claude", map[string]string{
|
||||||
|
"primary": "cloud-model-1",
|
||||||
|
"fast": "cloud-model-1",
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Switch to different cloud
|
||||||
|
if err := SaveAliases("claude", map[string]string{
|
||||||
|
"primary": "cloud-model-2",
|
||||||
|
"fast": "cloud-model-2",
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded, _ := LoadIntegration("claude")
|
||||||
|
if loaded.Aliases["primary"] != "cloud-model-2" {
|
||||||
|
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
|
||||||
|
}
|
||||||
|
if loaded.Aliases["fast"] != "cloud-model-2" {
|
||||||
|
t.Errorf("fast should be cloud-model-2, got %q", loaded.Aliases["fast"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||||
|
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
// Save aliases with one model
|
||||||
|
if err := SaveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save integration with same model (this is the pattern we use)
|
||||||
|
if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded, _ := LoadIntegration("claude")
|
||||||
|
if loaded.Aliases["primary"] != loaded.Models[0] {
|
||||||
|
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("out of sync config is detectable", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
// Simulate out-of-sync state (like manual edit or bug)
|
||||||
|
if err := SaveIntegration("claude", []string{"old-model"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := SaveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded, _ := LoadIntegration("claude")
|
||||||
|
|
||||||
|
// They should be different (this is the bug state)
|
||||||
|
if loaded.Models[0] == loaded.Aliases["primary"] {
|
||||||
|
t.Error("expected out-of-sync state for this test")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The fix: when updating aliases, also update models
|
||||||
|
if err := SaveIntegration("claude", []string{loaded.Aliases["primary"]}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded, _ = LoadIntegration("claude")
|
||||||
|
if loaded.Models[0] != loaded.Aliases["primary"] {
|
||||||
|
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
|
||||||
|
loaded.Models[0], loaded.Aliases["primary"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("updating primary alias updates models too", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
// Initial state
|
||||||
|
if err := SaveIntegration("claude", []string{"initial-model"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := SaveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update aliases AND models together
|
||||||
|
newAliases := map[string]string{"primary": "updated-model"}
|
||||||
|
if err := SaveAliases("claude", newAliases); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := SaveIntegration("claude", []string{newAliases["primary"]}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded, _ := LoadIntegration("claude")
|
||||||
|
if loaded.Models[0] != "updated-model" {
|
||||||
|
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
|
||||||
|
}
|
||||||
|
if loaded.Aliases["primary"] != "updated-model" {
|
||||||
|
t.Errorf("aliases.primary should be updated-model, got %q", loaded.Aliases["primary"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -10,28 +10,21 @@ import (
|
|||||||
// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests
|
// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests
|
||||||
func setTestHome(t *testing.T, dir string) {
|
func setTestHome(t *testing.T, dir string) {
|
||||||
t.Setenv("HOME", dir)
|
t.Setenv("HOME", dir)
|
||||||
|
t.Setenv("TMPDIR", dir)
|
||||||
t.Setenv("USERPROFILE", dir)
|
t.Setenv("USERPROFILE", dir)
|
||||||
}
|
}
|
||||||
|
|
||||||
// editorPaths is a test helper that safely calls Paths if the runner implements Editor
|
|
||||||
func editorPaths(r Runner) []string {
|
|
||||||
if editor, ok := r.(Editor); ok {
|
|
||||||
return editor.Paths()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntegrationConfig(t *testing.T) {
|
func TestIntegrationConfig(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
t.Run("save and load round-trip", func(t *testing.T) {
|
t.Run("save and load round-trip", func(t *testing.T) {
|
||||||
models := []string{"llama3.2", "mistral", "qwen2.5"}
|
models := []string{"llama3.2", "mistral", "qwen2.5"}
|
||||||
if err := saveIntegration("claude", models); err != nil {
|
if err := SaveIntegration("claude", models); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := loadIntegration("claude")
|
config, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -46,10 +39,57 @@ func TestIntegrationConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("defaultModel returns first model", func(t *testing.T) {
|
t.Run("save and load aliases", func(t *testing.T) {
|
||||||
saveIntegration("codex", []string{"model-a", "model-b"})
|
models := []string{"llama3.2"}
|
||||||
|
if err := SaveIntegration("claude", models); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
aliases := map[string]string{
|
||||||
|
"primary": "llama3.2:70b",
|
||||||
|
"fast": "llama3.2:8b",
|
||||||
|
}
|
||||||
|
if err := SaveAliases("claude", aliases); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
config, _ := loadIntegration("codex")
|
config, err := LoadIntegration("claude")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if config.Aliases == nil {
|
||||||
|
t.Fatal("expected aliases to be saved")
|
||||||
|
}
|
||||||
|
for k, v := range aliases {
|
||||||
|
if config.Aliases[k] != v {
|
||||||
|
t.Errorf("alias %s: expected %s, got %s", k, v, config.Aliases[k])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("saveIntegration preserves aliases", func(t *testing.T) {
|
||||||
|
if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := SaveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := SaveIntegration("claude", []string{"model-b"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
config, err := LoadIntegration("claude")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if config.Aliases["primary"] != "model-a" {
|
||||||
|
t.Errorf("expected aliases to be preserved, got %v", config.Aliases)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("defaultModel returns first model", func(t *testing.T) {
|
||||||
|
SaveIntegration("codex", []string{"model-a", "model-b"})
|
||||||
|
|
||||||
|
config, _ := LoadIntegration("codex")
|
||||||
defaultModel := ""
|
defaultModel := ""
|
||||||
if len(config.Models) > 0 {
|
if len(config.Models) > 0 {
|
||||||
defaultModel = config.Models[0]
|
defaultModel = config.Models[0]
|
||||||
@@ -71,9 +111,9 @@ func TestIntegrationConfig(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("app name is case-insensitive", func(t *testing.T) {
|
t.Run("app name is case-insensitive", func(t *testing.T) {
|
||||||
saveIntegration("Claude", []string{"model-x"})
|
SaveIntegration("Claude", []string{"model-x"})
|
||||||
|
|
||||||
config, err := loadIntegration("claude")
|
config, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -87,11 +127,11 @@ func TestIntegrationConfig(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("multiple integrations in single file", func(t *testing.T) {
|
t.Run("multiple integrations in single file", func(t *testing.T) {
|
||||||
saveIntegration("app1", []string{"model-1"})
|
SaveIntegration("app1", []string{"model-1"})
|
||||||
saveIntegration("app2", []string{"model-2"})
|
SaveIntegration("app2", []string{"model-2"})
|
||||||
|
|
||||||
config1, _ := loadIntegration("app1")
|
config1, _ := LoadIntegration("app1")
|
||||||
config2, _ := loadIntegration("app2")
|
config2, _ := LoadIntegration("app2")
|
||||||
|
|
||||||
defaultModel1 := ""
|
defaultModel1 := ""
|
||||||
if len(config1.Models) > 0 {
|
if len(config1.Models) > 0 {
|
||||||
@@ -125,8 +165,8 @@ func TestListIntegrations(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("returns all saved integrations", func(t *testing.T) {
|
t.Run("returns all saved integrations", func(t *testing.T) {
|
||||||
saveIntegration("claude", []string{"model-1"})
|
SaveIntegration("claude", []string{"model-1"})
|
||||||
saveIntegration("droid", []string{"model-2"})
|
SaveIntegration("droid", []string{"model-2"})
|
||||||
|
|
||||||
configs, err := listIntegrations()
|
configs, err := listIntegrations()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -138,75 +178,15 @@ func TestListIntegrations(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEditorPaths(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
t.Run("returns empty for claude (no Editor)", func(t *testing.T) {
|
|
||||||
r := integrations["claude"]
|
|
||||||
paths := editorPaths(r)
|
|
||||||
if len(paths) != 0 {
|
|
||||||
t.Errorf("expected no paths for claude, got %v", paths)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns empty for codex (no Editor)", func(t *testing.T) {
|
|
||||||
r := integrations["codex"]
|
|
||||||
paths := editorPaths(r)
|
|
||||||
if len(paths) != 0 {
|
|
||||||
t.Errorf("expected no paths for codex, got %v", paths)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns empty for droid when no config exists", func(t *testing.T) {
|
|
||||||
r := integrations["droid"]
|
|
||||||
paths := editorPaths(r)
|
|
||||||
if len(paths) != 0 {
|
|
||||||
t.Errorf("expected no paths, got %v", paths)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns path for droid when config exists", func(t *testing.T) {
|
|
||||||
settingsDir, _ := os.UserHomeDir()
|
|
||||||
settingsDir = filepath.Join(settingsDir, ".factory")
|
|
||||||
os.MkdirAll(settingsDir, 0o755)
|
|
||||||
os.WriteFile(filepath.Join(settingsDir, "settings.json"), []byte(`{}`), 0o644)
|
|
||||||
|
|
||||||
r := integrations["droid"]
|
|
||||||
paths := editorPaths(r)
|
|
||||||
if len(paths) != 1 {
|
|
||||||
t.Errorf("expected 1 path, got %d", len(paths))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns paths for opencode when configs exist", func(t *testing.T) {
|
|
||||||
home, _ := os.UserHomeDir()
|
|
||||||
configDir := filepath.Join(home, ".config", "opencode")
|
|
||||||
stateDir := filepath.Join(home, ".local", "state", "opencode")
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.MkdirAll(stateDir, 0o755)
|
|
||||||
os.WriteFile(filepath.Join(configDir, "opencode.json"), []byte(`{}`), 0o644)
|
|
||||||
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644)
|
|
||||||
|
|
||||||
r := integrations["opencode"]
|
|
||||||
paths := editorPaths(r)
|
|
||||||
if len(paths) != 2 {
|
|
||||||
t.Errorf("expected 2 paths, got %d: %v", len(paths), paths)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
// Create corrupted config.json file
|
dir := filepath.Join(tmpDir, ".ollama")
|
||||||
dir := filepath.Join(tmpDir, ".ollama", "config")
|
|
||||||
os.MkdirAll(dir, 0o755)
|
os.MkdirAll(dir, 0o755)
|
||||||
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
|
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
|
||||||
|
|
||||||
// Corrupted file is treated as empty, so loadIntegration returns not found
|
_, err := LoadIntegration("test")
|
||||||
_, err := loadIntegration("test")
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error for nonexistent integration in corrupted file")
|
t.Error("expected error for nonexistent integration in corrupted file")
|
||||||
}
|
}
|
||||||
@@ -216,11 +196,11 @@ func TestSaveIntegration_NilModels(t *testing.T) {
|
|||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
if err := saveIntegration("test", nil); err != nil {
|
if err := SaveIntegration("test", nil); err != nil {
|
||||||
t.Fatalf("saveIntegration with nil models failed: %v", err)
|
t.Fatalf("saveIntegration with nil models failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := loadIntegration("test")
|
config, err := LoadIntegration("test")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("loadIntegration failed: %v", err)
|
t.Fatalf("loadIntegration failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -236,7 +216,7 @@ func TestSaveIntegration_EmptyAppName(t *testing.T) {
|
|||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
err := saveIntegration("", []string{"model"})
|
err := SaveIntegration("", []string{"model"})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error for empty app name, got nil")
|
t.Error("expected error for empty app name, got nil")
|
||||||
}
|
}
|
||||||
@@ -249,7 +229,7 @@ func TestLoadIntegration_NonexistentIntegration(t *testing.T) {
|
|||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
_, err := loadIntegration("nonexistent")
|
_, err := LoadIntegration("nonexistent")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error for nonexistent integration, got nil")
|
t.Error("expected error for nonexistent integration, got nil")
|
||||||
}
|
}
|
||||||
@@ -267,7 +247,7 @@ func TestConfigPath(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
expected := filepath.Join(tmpDir, ".ollama", "config", "config.json")
|
expected := filepath.Join(tmpDir, ".ollama", "config.json")
|
||||||
if path != expected {
|
if path != expected {
|
||||||
t.Errorf("expected %s, got %s", expected, path)
|
t.Errorf("expected %s, got %s", expected, path)
|
||||||
}
|
}
|
||||||
@@ -322,6 +302,183 @@ func TestLoad(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMigrateConfig(t *testing.T) {
|
||||||
|
t.Run("migrates legacy file to new location", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||||
|
os.MkdirAll(legacyDir, 0o755)
|
||||||
|
data := []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`)
|
||||||
|
os.WriteFile(filepath.Join(legacyDir, "config.json"), data, 0o644)
|
||||||
|
|
||||||
|
migrated, err := migrateConfig()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !migrated {
|
||||||
|
t.Fatal("expected migration to occur")
|
||||||
|
}
|
||||||
|
|
||||||
|
newPath, _ := configPath()
|
||||||
|
got, err := os.ReadFile(newPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new config not found: %v", err)
|
||||||
|
}
|
||||||
|
if string(got) != string(data) {
|
||||||
|
t.Errorf("content mismatch: got %s", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); !os.IsNotExist(err) {
|
||||||
|
t.Error("legacy file should have been removed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Stat(legacyDir); !os.IsNotExist(err) {
|
||||||
|
t.Error("legacy directory should have been removed")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no-op when no legacy file exists", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
migrated, err := migrateConfig()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if migrated {
|
||||||
|
t.Error("expected no migration")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("skips corrupt legacy file", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||||
|
os.MkdirAll(legacyDir, 0o755)
|
||||||
|
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{corrupt`), 0o644)
|
||||||
|
|
||||||
|
migrated, err := migrateConfig()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if migrated {
|
||||||
|
t.Error("should not migrate corrupt file")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); os.IsNotExist(err) {
|
||||||
|
t.Error("corrupt legacy file should not have been deleted")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("new path takes precedence over legacy", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||||
|
os.MkdirAll(legacyDir, 0o755)
|
||||||
|
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"old":{"models":["old-model"]}}}`), 0o644)
|
||||||
|
|
||||||
|
newDir := filepath.Join(tmpDir, ".ollama")
|
||||||
|
os.WriteFile(filepath.Join(newDir, "config.json"), []byte(`{"integrations":{"new":{"models":["new-model"]}}}`), 0o644)
|
||||||
|
|
||||||
|
cfg, err := load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, ok := cfg.Integrations["new"]; !ok {
|
||||||
|
t.Error("expected new-path integration to be loaded")
|
||||||
|
}
|
||||||
|
if _, ok := cfg.Integrations["old"]; ok {
|
||||||
|
t.Error("legacy integration should not have been loaded")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("idempotent when called twice", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||||
|
os.MkdirAll(legacyDir, 0o755)
|
||||||
|
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{}}`), 0o644)
|
||||||
|
|
||||||
|
if _, err := migrateConfig(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
migrated, err := migrateConfig()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if migrated {
|
||||||
|
t.Error("second migration should be a no-op")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("legacy directory preserved if not empty", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||||
|
os.MkdirAll(legacyDir, 0o755)
|
||||||
|
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{}}`), 0o644)
|
||||||
|
os.WriteFile(filepath.Join(legacyDir, "other-file.txt"), []byte("keep me"), 0o644)
|
||||||
|
|
||||||
|
if _, err := migrateConfig(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Stat(legacyDir); os.IsNotExist(err) {
|
||||||
|
t.Error("directory with other files should not have been removed")
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(filepath.Join(legacyDir, "other-file.txt")); os.IsNotExist(err) {
|
||||||
|
t.Error("other files in legacy directory should be untouched")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("save writes to new path after migration", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||||
|
os.MkdirAll(legacyDir, 0o755)
|
||||||
|
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
|
||||||
|
|
||||||
|
// load triggers migration, then save should write to new path
|
||||||
|
if err := SaveIntegration("codex", []string{"qwen2.5"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newPath := filepath.Join(tmpDir, ".ollama", "config.json")
|
||||||
|
if _, err := os.Stat(newPath); os.IsNotExist(err) {
|
||||||
|
t.Error("save should write to new path")
|
||||||
|
}
|
||||||
|
|
||||||
|
// old path should not be recreated
|
||||||
|
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); !os.IsNotExist(err) {
|
||||||
|
t.Error("save should not recreate legacy path")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("load triggers migration transparently", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||||
|
os.MkdirAll(legacyDir, 0o755)
|
||||||
|
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
|
||||||
|
|
||||||
|
cfg, err := load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if cfg.Integrations["claude"] == nil || cfg.Integrations["claude"].Models[0] != "llama3.2" {
|
||||||
|
t.Error("migration via load() did not preserve data")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestSave(t *testing.T) {
|
func TestSave(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|||||||
@@ -1,355 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"maps"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"runtime"
|
|
||||||
"slices"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Runners execute the launching of a model with the integration - claude, codex
|
|
||||||
// Editors can edit config files (supports multi-model selection) - opencode, droid
|
|
||||||
// They are composable interfaces where in some cases an editor is also a runner - opencode, droid
|
|
||||||
// Runner can run an integration with a model.
|
|
||||||
|
|
||||||
type Runner interface {
|
|
||||||
Run(model string) error
|
|
||||||
// String returns the human-readable name of the integration
|
|
||||||
String() string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Editor can edit config files (supports multi-model selection)
|
|
||||||
type Editor interface {
|
|
||||||
// Paths returns the paths to the config files for the integration
|
|
||||||
Paths() []string
|
|
||||||
// Edit updates the config files for the integration with the given models
|
|
||||||
Edit(models []string) error
|
|
||||||
// Models returns the models currently configured for the integration
|
|
||||||
// TODO(parthsareen): add error return to Models()
|
|
||||||
Models() []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// integrations is the registry of available integrations.
|
|
||||||
var integrations = map[string]Runner{
|
|
||||||
"claude": &Claude{},
|
|
||||||
"clawdbot": &Clawdbot{},
|
|
||||||
"codex": &Codex{},
|
|
||||||
"droid": &Droid{},
|
|
||||||
"opencode": &OpenCode{},
|
|
||||||
}
|
|
||||||
|
|
||||||
func selectIntegration() (string, error) {
|
|
||||||
if len(integrations) == 0 {
|
|
||||||
return "", fmt.Errorf("no integrations available")
|
|
||||||
}
|
|
||||||
|
|
||||||
names := slices.Sorted(maps.Keys(integrations))
|
|
||||||
var items []selectItem
|
|
||||||
for _, name := range names {
|
|
||||||
r := integrations[name]
|
|
||||||
description := r.String()
|
|
||||||
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
|
|
||||||
description = fmt.Sprintf("%s (%s)", r.String(), conn.Models[0])
|
|
||||||
}
|
|
||||||
items = append(items, selectItem{Name: name, Description: description})
|
|
||||||
}
|
|
||||||
|
|
||||||
return selectPrompt("Select integration:", items)
|
|
||||||
}
|
|
||||||
|
|
||||||
// selectModels lets the user select models for an integration
|
|
||||||
func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
|
||||||
r, ok := integrations[name]
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("unknown integration: %s", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := api.ClientFromEnvironment()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
models, err := client.List(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(models.Models) == 0 {
|
|
||||||
return nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
|
||||||
}
|
|
||||||
|
|
||||||
var items []selectItem
|
|
||||||
cloudModels := make(map[string]bool)
|
|
||||||
for _, m := range models.Models {
|
|
||||||
if m.RemoteModel != "" {
|
|
||||||
cloudModels[m.Name] = true
|
|
||||||
}
|
|
||||||
items = append(items, selectItem{Name: m.Name})
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(items) == 0 {
|
|
||||||
return nil, fmt.Errorf("no local models available, run 'ollama pull <model>' first")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get previously configured models (saved config takes precedence)
|
|
||||||
var preChecked []string
|
|
||||||
if saved, err := loadIntegration(name); err == nil {
|
|
||||||
preChecked = saved.Models
|
|
||||||
} else if editor, ok := r.(Editor); ok {
|
|
||||||
preChecked = editor.Models()
|
|
||||||
}
|
|
||||||
checked := make(map[string]bool, len(preChecked))
|
|
||||||
for _, n := range preChecked {
|
|
||||||
checked[n] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest")
|
|
||||||
for _, item := range items {
|
|
||||||
if item.Name == current || strings.HasPrefix(item.Name, current+":") {
|
|
||||||
current = item.Name
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If current model is configured, move to front of preChecked
|
|
||||||
if checked[current] {
|
|
||||||
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sort: checked first, then alphabetical
|
|
||||||
slices.SortFunc(items, func(a, b selectItem) int {
|
|
||||||
ac, bc := checked[a.Name], checked[b.Name]
|
|
||||||
if ac != bc {
|
|
||||||
if ac {
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
|
||||||
})
|
|
||||||
|
|
||||||
var selected []string
|
|
||||||
// only editors support multi-model selection
|
|
||||||
if _, ok := r.(Editor); ok {
|
|
||||||
selected, err = multiSelectPrompt(fmt.Sprintf("Select models for %s:", r), items, preChecked)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
model, err := selectPrompt(fmt.Sprintf("Select model for %s:", r), items)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
selected = []string{model}
|
|
||||||
}
|
|
||||||
|
|
||||||
// if any model in selected is a cloud model, ensure signed in
|
|
||||||
var selectedCloudModels []string
|
|
||||||
for _, m := range selected {
|
|
||||||
if cloudModels[m] {
|
|
||||||
selectedCloudModels = append(selectedCloudModels, m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(selectedCloudModels) > 0 {
|
|
||||||
// ensure user is signed in
|
|
||||||
user, err := client.Whoami(ctx)
|
|
||||||
if err == nil && user != nil && user.Name != "" {
|
|
||||||
return selected, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var aErr api.AuthorizationError
|
|
||||||
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
modelList := strings.Join(selectedCloudModels, ", ")
|
|
||||||
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
|
||||||
if err != nil || !yes {
|
|
||||||
return nil, fmt.Errorf("%s requires sign in", modelList)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
|
||||||
|
|
||||||
// TODO(parthsareen): extract into auth package for cmd
|
|
||||||
// Auto-open browser (best effort, fail silently)
|
|
||||||
switch runtime.GOOS {
|
|
||||||
case "darwin":
|
|
||||||
_ = exec.Command("open", aErr.SigninURL).Start()
|
|
||||||
case "linux":
|
|
||||||
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
|
|
||||||
case "windows":
|
|
||||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
|
|
||||||
}
|
|
||||||
|
|
||||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
|
||||||
frame := 0
|
|
||||||
|
|
||||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
|
||||||
|
|
||||||
ticker := time.NewTicker(200 * time.Millisecond)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
|
||||||
return nil, ctx.Err()
|
|
||||||
case <-ticker.C:
|
|
||||||
frame++
|
|
||||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
|
||||||
|
|
||||||
// poll every 10th frame (~2 seconds)
|
|
||||||
if frame%10 == 0 {
|
|
||||||
u, err := client.Whoami(ctx)
|
|
||||||
if err == nil && u != nil && u.Name != "" {
|
|
||||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
|
||||||
return selected, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return selected, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func runIntegration(name, modelName string) error {
|
|
||||||
r, ok := integrations[name]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("unknown integration: %s", name)
|
|
||||||
}
|
|
||||||
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
|
|
||||||
return r.Run(modelName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// LaunchCmd returns the cobra command for launching integrations.
|
|
||||||
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command {
|
|
||||||
var modelFlag string
|
|
||||||
var configFlag bool
|
|
||||||
|
|
||||||
cmd := &cobra.Command{
|
|
||||||
Use: "launch [INTEGRATION]",
|
|
||||||
Short: "Launch an integration with Ollama",
|
|
||||||
Long: `Launch an integration configured with Ollama models.
|
|
||||||
|
|
||||||
Supported integrations:
|
|
||||||
claude Claude Code
|
|
||||||
clawdbot Clawdbot
|
|
||||||
codex Codex
|
|
||||||
droid Droid
|
|
||||||
opencode OpenCode
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
ollama launch
|
|
||||||
ollama launch claude
|
|
||||||
ollama launch claude --model <model>
|
|
||||||
ollama launch droid --config (does not auto-launch)`,
|
|
||||||
Args: cobra.MaximumNArgs(1),
|
|
||||||
PreRunE: checkServerHeartbeat,
|
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
|
||||||
var name string
|
|
||||||
if len(args) > 0 {
|
|
||||||
name = args[0]
|
|
||||||
} else {
|
|
||||||
var err error
|
|
||||||
name, err = selectIntegration()
|
|
||||||
if errors.Is(err, errCancelled) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
r, ok := integrations[strings.ToLower(name)]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("unknown integration: %s", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If launching without --model, use saved config if available
|
|
||||||
if !configFlag && modelFlag == "" {
|
|
||||||
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
|
|
||||||
return runIntegration(name, config.Models[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var models []string
|
|
||||||
if modelFlag != "" {
|
|
||||||
// When --model is specified, merge with existing models (new model becomes default)
|
|
||||||
models = []string{modelFlag}
|
|
||||||
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
|
|
||||||
for _, m := range existing.Models {
|
|
||||||
if m != modelFlag {
|
|
||||||
models = append(models, m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
var err error
|
|
||||||
models, err = selectModels(cmd.Context(), name, "")
|
|
||||||
if errors.Is(err, errCancelled) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if editor, isEditor := r.(Editor); isEditor {
|
|
||||||
paths := editor.Paths()
|
|
||||||
if len(paths) > 0 {
|
|
||||||
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", r)
|
|
||||||
for _, p := range paths {
|
|
||||||
fmt.Fprintf(os.Stderr, " %s\n", p)
|
|
||||||
}
|
|
||||||
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", backupDir())
|
|
||||||
|
|
||||||
if ok, _ := confirmPrompt("Proceed?"); !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := saveIntegration(name, models); err != nil {
|
|
||||||
return fmt.Errorf("failed to save: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if editor, isEditor := r.(Editor); isEditor {
|
|
||||||
if err := editor.Edit(models); err != nil {
|
|
||||||
return fmt.Errorf("setup failed: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, isEditor := r.(Editor); isEditor {
|
|
||||||
if len(models) == 1 {
|
|
||||||
fmt.Fprintf(os.Stderr, "Added %s to %s\n", models[0], r)
|
|
||||||
} else {
|
|
||||||
fmt.Fprintf(os.Stderr, "Added %d models to %s (default: %s)\n", len(models), r, models[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if configFlag {
|
|
||||||
if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
|
|
||||||
return runIntegration(name, models[0])
|
|
||||||
}
|
|
||||||
fmt.Fprintf(os.Stderr, "Run 'ollama launch %s' to start with %s\n", strings.ToLower(name), models[0])
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return runIntegration(name, models[0])
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd.Flags().StringVar(&modelFlag, "model", "", "Model to use")
|
|
||||||
cmd.Flags().BoolVar(&configFlag, "config", false, "Configure without launching")
|
|
||||||
return cmd
|
|
||||||
}
|
|
||||||
@@ -1,188 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"slices"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestIntegrationLookup(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input string
|
|
||||||
wantFound bool
|
|
||||||
wantName string
|
|
||||||
}{
|
|
||||||
{"claude lowercase", "claude", true, "Claude Code"},
|
|
||||||
{"claude uppercase", "CLAUDE", true, "Claude Code"},
|
|
||||||
{"claude mixed case", "Claude", true, "Claude Code"},
|
|
||||||
{"codex", "codex", true, "Codex"},
|
|
||||||
{"droid", "droid", true, "Droid"},
|
|
||||||
{"opencode", "opencode", true, "OpenCode"},
|
|
||||||
{"unknown integration", "unknown", false, ""},
|
|
||||||
{"empty string", "", false, ""},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
r, found := integrations[strings.ToLower(tt.input)]
|
|
||||||
if found != tt.wantFound {
|
|
||||||
t.Errorf("integrations[%q] found = %v, want %v", tt.input, found, tt.wantFound)
|
|
||||||
}
|
|
||||||
if found && r.String() != tt.wantName {
|
|
||||||
t.Errorf("integrations[%q].String() = %q, want %q", tt.input, r.String(), tt.wantName)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntegrationRegistry(t *testing.T) {
|
|
||||||
expectedIntegrations := []string{"claude", "codex", "droid", "opencode"}
|
|
||||||
|
|
||||||
for _, name := range expectedIntegrations {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
r, ok := integrations[name]
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("integration %q not found in registry", name)
|
|
||||||
}
|
|
||||||
if r.String() == "" {
|
|
||||||
t.Error("integration.String() should not be empty")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHasLocalModel(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
models []string
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{"empty list", []string{}, false},
|
|
||||||
{"single local model", []string{"llama3.2"}, true},
|
|
||||||
{"single cloud model", []string{"cloud-model"}, false},
|
|
||||||
{"mixed models", []string{"cloud-model", "llama3.2"}, true},
|
|
||||||
{"multiple local models", []string{"llama3.2", "qwen2.5"}, true},
|
|
||||||
{"multiple cloud models", []string{"cloud-a", "cloud-b"}, false},
|
|
||||||
{"local model first", []string{"llama3.2", "cloud-model"}, true},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got := slices.ContainsFunc(tt.models, func(m string) bool {
|
|
||||||
return !strings.Contains(m, "cloud")
|
|
||||||
})
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("hasLocalModel(%v) = %v, want %v", tt.models, got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLaunchCmd(t *testing.T) {
|
|
||||||
// Mock checkServerHeartbeat that always succeeds
|
|
||||||
mockCheck := func(cmd *cobra.Command, args []string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := LaunchCmd(mockCheck)
|
|
||||||
|
|
||||||
t.Run("command structure", func(t *testing.T) {
|
|
||||||
if cmd.Use != "launch [INTEGRATION]" {
|
|
||||||
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION]")
|
|
||||||
}
|
|
||||||
if cmd.Short == "" {
|
|
||||||
t.Error("Short description should not be empty")
|
|
||||||
}
|
|
||||||
if cmd.Long == "" {
|
|
||||||
t.Error("Long description should not be empty")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("flags exist", func(t *testing.T) {
|
|
||||||
modelFlag := cmd.Flags().Lookup("model")
|
|
||||||
if modelFlag == nil {
|
|
||||||
t.Error("--model flag should exist")
|
|
||||||
}
|
|
||||||
|
|
||||||
configFlag := cmd.Flags().Lookup("config")
|
|
||||||
if configFlag == nil {
|
|
||||||
t.Error("--config flag should exist")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("PreRunE is set", func(t *testing.T) {
|
|
||||||
if cmd.PreRunE == nil {
|
|
||||||
t.Error("PreRunE should be set to checkServerHeartbeat")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunIntegration_UnknownIntegration(t *testing.T) {
|
|
||||||
err := runIntegration("unknown-integration", "model")
|
|
||||||
if err == nil {
|
|
||||||
t.Error("expected error for unknown integration, got nil")
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), "unknown integration") {
|
|
||||||
t.Errorf("error should mention 'unknown integration', got: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHasLocalModel_DocumentsHeuristic(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
models []string
|
|
||||||
want bool
|
|
||||||
reason string
|
|
||||||
}{
|
|
||||||
{"empty list", []string{}, false, "empty list has no local models"},
|
|
||||||
{"contains-cloud-substring", []string{"deepseek-r1:cloud"}, false, "model with 'cloud' substring is considered cloud"},
|
|
||||||
{"cloud-in-name", []string{"my-cloud-model"}, false, "'cloud' anywhere in name = cloud model"},
|
|
||||||
{"cloudless", []string{"cloudless-model"}, false, "'cloudless' still contains 'cloud'"},
|
|
||||||
{"local-model", []string{"llama3.2"}, true, "no 'cloud' = local"},
|
|
||||||
{"mixed", []string{"cloud-model", "llama3.2"}, true, "one local model = hasLocalModel true"},
|
|
||||||
{"all-cloud", []string{"cloud-a", "cloud-b"}, false, "all contain 'cloud'"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got := slices.ContainsFunc(tt.models, func(m string) bool {
|
|
||||||
return !strings.Contains(m, "cloud")
|
|
||||||
})
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("hasLocalModel(%v) = %v, want %v (%s)", tt.models, got, tt.want, tt.reason)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLaunchCmd_NilHeartbeat(t *testing.T) {
|
|
||||||
// This should not panic - cmd creation should work even with nil
|
|
||||||
cmd := LaunchCmd(nil)
|
|
||||||
if cmd == nil {
|
|
||||||
t.Fatal("LaunchCmd returned nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
// PreRunE should be nil when passed nil
|
|
||||||
if cmd.PreRunE != nil {
|
|
||||||
t.Log("Note: PreRunE is set even when nil is passed (acceptable)")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAllIntegrations_HaveRequiredMethods(t *testing.T) {
|
|
||||||
for name, r := range integrations {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
// Test String() doesn't panic and returns non-empty
|
|
||||||
displayName := r.String()
|
|
||||||
if displayName == "" {
|
|
||||||
t.Error("String() should not return empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test Run() exists (we can't call it without actually running the command)
|
|
||||||
// Just verify the method is available
|
|
||||||
var _ func(string) error = r.Run
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,499 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"golang.org/x/term"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ANSI escape sequences for terminal formatting.
|
|
||||||
const (
|
|
||||||
ansiHideCursor = "\033[?25l"
|
|
||||||
ansiShowCursor = "\033[?25h"
|
|
||||||
ansiBold = "\033[1m"
|
|
||||||
ansiReset = "\033[0m"
|
|
||||||
ansiGray = "\033[37m"
|
|
||||||
ansiClearDown = "\033[J"
|
|
||||||
)
|
|
||||||
|
|
||||||
const maxDisplayedItems = 10
|
|
||||||
|
|
||||||
var errCancelled = errors.New("cancelled")
|
|
||||||
|
|
||||||
type selectItem struct {
|
|
||||||
Name string
|
|
||||||
Description string
|
|
||||||
}
|
|
||||||
|
|
||||||
type inputEvent int
|
|
||||||
|
|
||||||
const (
|
|
||||||
eventNone inputEvent = iota
|
|
||||||
eventEnter
|
|
||||||
eventEscape
|
|
||||||
eventUp
|
|
||||||
eventDown
|
|
||||||
eventTab
|
|
||||||
eventBackspace
|
|
||||||
eventChar
|
|
||||||
)
|
|
||||||
|
|
||||||
type selectState struct {
|
|
||||||
items []selectItem
|
|
||||||
filter string
|
|
||||||
selected int
|
|
||||||
scrollOffset int
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSelectState(items []selectItem) *selectState {
|
|
||||||
return &selectState{items: items}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *selectState) filtered() []selectItem {
|
|
||||||
return filterItems(s.items, s.filter)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *selectState) handleInput(event inputEvent, char byte) (done bool, result string, err error) {
|
|
||||||
filtered := s.filtered()
|
|
||||||
|
|
||||||
switch event {
|
|
||||||
case eventEnter:
|
|
||||||
if len(filtered) > 0 && s.selected < len(filtered) {
|
|
||||||
return true, filtered[s.selected].Name, nil
|
|
||||||
}
|
|
||||||
case eventEscape:
|
|
||||||
return true, "", errCancelled
|
|
||||||
case eventBackspace:
|
|
||||||
if len(s.filter) > 0 {
|
|
||||||
s.filter = s.filter[:len(s.filter)-1]
|
|
||||||
s.selected = 0
|
|
||||||
s.scrollOffset = 0
|
|
||||||
}
|
|
||||||
case eventUp:
|
|
||||||
if s.selected > 0 {
|
|
||||||
s.selected--
|
|
||||||
if s.selected < s.scrollOffset {
|
|
||||||
s.scrollOffset = s.selected
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case eventDown:
|
|
||||||
if s.selected < len(filtered)-1 {
|
|
||||||
s.selected++
|
|
||||||
if s.selected >= s.scrollOffset+maxDisplayedItems {
|
|
||||||
s.scrollOffset = s.selected - maxDisplayedItems + 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case eventChar:
|
|
||||||
s.filter += string(char)
|
|
||||||
s.selected = 0
|
|
||||||
s.scrollOffset = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type multiSelectState struct {
|
|
||||||
items []selectItem
|
|
||||||
itemIndex map[string]int
|
|
||||||
filter string
|
|
||||||
highlighted int
|
|
||||||
scrollOffset int
|
|
||||||
checked map[int]bool
|
|
||||||
checkOrder []int
|
|
||||||
focusOnButton bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func newMultiSelectState(items []selectItem, preChecked []string) *multiSelectState {
|
|
||||||
s := &multiSelectState{
|
|
||||||
items: items,
|
|
||||||
itemIndex: make(map[string]int, len(items)),
|
|
||||||
checked: make(map[int]bool),
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, item := range items {
|
|
||||||
s.itemIndex[item.Name] = i
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, name := range preChecked {
|
|
||||||
if idx, ok := s.itemIndex[name]; ok {
|
|
||||||
s.checked[idx] = true
|
|
||||||
s.checkOrder = append(s.checkOrder, idx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *multiSelectState) filtered() []selectItem {
|
|
||||||
return filterItems(s.items, s.filter)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *multiSelectState) toggleItem() {
|
|
||||||
filtered := s.filtered()
|
|
||||||
if len(filtered) == 0 || s.highlighted >= len(filtered) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
item := filtered[s.highlighted]
|
|
||||||
origIdx := s.itemIndex[item.Name]
|
|
||||||
|
|
||||||
if s.checked[origIdx] {
|
|
||||||
delete(s.checked, origIdx)
|
|
||||||
for i, idx := range s.checkOrder {
|
|
||||||
if idx == origIdx {
|
|
||||||
s.checkOrder = append(s.checkOrder[:i], s.checkOrder[i+1:]...)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
s.checked[origIdx] = true
|
|
||||||
s.checkOrder = append(s.checkOrder, origIdx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *multiSelectState) handleInput(event inputEvent, char byte) (done bool, result []string, err error) {
|
|
||||||
filtered := s.filtered()
|
|
||||||
|
|
||||||
switch event {
|
|
||||||
case eventEnter:
|
|
||||||
if s.focusOnButton && len(s.checkOrder) > 0 {
|
|
||||||
var res []string
|
|
||||||
for _, idx := range s.checkOrder {
|
|
||||||
res = append(res, s.items[idx].Name)
|
|
||||||
}
|
|
||||||
return true, res, nil
|
|
||||||
} else if !s.focusOnButton {
|
|
||||||
s.toggleItem()
|
|
||||||
}
|
|
||||||
case eventTab:
|
|
||||||
if len(s.checkOrder) > 0 {
|
|
||||||
s.focusOnButton = !s.focusOnButton
|
|
||||||
}
|
|
||||||
case eventEscape:
|
|
||||||
return true, nil, errCancelled
|
|
||||||
case eventBackspace:
|
|
||||||
if len(s.filter) > 0 {
|
|
||||||
s.filter = s.filter[:len(s.filter)-1]
|
|
||||||
s.highlighted = 0
|
|
||||||
s.scrollOffset = 0
|
|
||||||
s.focusOnButton = false
|
|
||||||
}
|
|
||||||
case eventUp:
|
|
||||||
if s.focusOnButton {
|
|
||||||
s.focusOnButton = false
|
|
||||||
} else if s.highlighted > 0 {
|
|
||||||
s.highlighted--
|
|
||||||
if s.highlighted < s.scrollOffset {
|
|
||||||
s.scrollOffset = s.highlighted
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case eventDown:
|
|
||||||
if s.focusOnButton {
|
|
||||||
s.focusOnButton = false
|
|
||||||
} else if s.highlighted < len(filtered)-1 {
|
|
||||||
s.highlighted++
|
|
||||||
if s.highlighted >= s.scrollOffset+maxDisplayedItems {
|
|
||||||
s.scrollOffset = s.highlighted - maxDisplayedItems + 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case eventChar:
|
|
||||||
s.filter += string(char)
|
|
||||||
s.highlighted = 0
|
|
||||||
s.scrollOffset = 0
|
|
||||||
s.focusOnButton = false
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *multiSelectState) selectedCount() int {
|
|
||||||
return len(s.checkOrder)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Terminal I/O handling
|
|
||||||
|
|
||||||
type terminalState struct {
|
|
||||||
fd int
|
|
||||||
oldState *term.State
|
|
||||||
}
|
|
||||||
|
|
||||||
func enterRawMode() (*terminalState, error) {
|
|
||||||
fd := int(os.Stdin.Fd())
|
|
||||||
oldState, err := term.MakeRaw(fd)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
fmt.Fprint(os.Stderr, ansiHideCursor)
|
|
||||||
return &terminalState{fd: fd, oldState: oldState}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *terminalState) restore() {
|
|
||||||
fmt.Fprint(os.Stderr, ansiShowCursor)
|
|
||||||
term.Restore(t.fd, t.oldState)
|
|
||||||
}
|
|
||||||
|
|
||||||
func clearLines(n int) {
|
|
||||||
if n > 0 {
|
|
||||||
fmt.Fprintf(os.Stderr, "\033[%dA", n)
|
|
||||||
fmt.Fprint(os.Stderr, ansiClearDown)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseInput(r io.Reader) (inputEvent, byte, error) {
|
|
||||||
buf := make([]byte, 3)
|
|
||||||
n, err := r.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
return 0, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case n == 1 && buf[0] == 13:
|
|
||||||
return eventEnter, 0, nil
|
|
||||||
case n == 1 && (buf[0] == 3 || buf[0] == 27):
|
|
||||||
return eventEscape, 0, nil
|
|
||||||
case n == 1 && buf[0] == 9:
|
|
||||||
return eventTab, 0, nil
|
|
||||||
case n == 1 && buf[0] == 127:
|
|
||||||
return eventBackspace, 0, nil
|
|
||||||
case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 65:
|
|
||||||
return eventUp, 0, nil
|
|
||||||
case n == 3 && buf[0] == 27 && buf[1] == 91 && buf[2] == 66:
|
|
||||||
return eventDown, 0, nil
|
|
||||||
case n == 1 && buf[0] >= 32 && buf[0] < 127:
|
|
||||||
return eventChar, buf[0], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return eventNone, 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rendering
|
|
||||||
|
|
||||||
func renderSelect(w io.Writer, prompt string, s *selectState) int {
|
|
||||||
filtered := s.filtered()
|
|
||||||
|
|
||||||
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
|
|
||||||
lineCount := 1
|
|
||||||
|
|
||||||
if len(filtered) == 0 {
|
|
||||||
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
|
|
||||||
lineCount++
|
|
||||||
} else {
|
|
||||||
displayCount := min(len(filtered), maxDisplayedItems)
|
|
||||||
|
|
||||||
for i := range displayCount {
|
|
||||||
idx := s.scrollOffset + i
|
|
||||||
if idx >= len(filtered) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
item := filtered[idx]
|
|
||||||
prefix := " "
|
|
||||||
if idx == s.selected {
|
|
||||||
prefix = " " + ansiBold + "> "
|
|
||||||
}
|
|
||||||
if item.Description != "" {
|
|
||||||
fmt.Fprintf(w, "%s%s%s %s- %s%s\r\n", prefix, item.Name, ansiReset, ansiGray, item.Description, ansiReset)
|
|
||||||
} else {
|
|
||||||
fmt.Fprintf(w, "%s%s%s\r\n", prefix, item.Name, ansiReset)
|
|
||||||
}
|
|
||||||
lineCount++
|
|
||||||
}
|
|
||||||
|
|
||||||
if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 {
|
|
||||||
fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset)
|
|
||||||
lineCount++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return lineCount
|
|
||||||
}
|
|
||||||
|
|
||||||
func renderMultiSelect(w io.Writer, prompt string, s *multiSelectState) int {
|
|
||||||
filtered := s.filtered()
|
|
||||||
|
|
||||||
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
|
|
||||||
lineCount := 1
|
|
||||||
|
|
||||||
if len(filtered) == 0 {
|
|
||||||
fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset)
|
|
||||||
lineCount++
|
|
||||||
} else {
|
|
||||||
displayCount := min(len(filtered), maxDisplayedItems)
|
|
||||||
|
|
||||||
for i := range displayCount {
|
|
||||||
idx := s.scrollOffset + i
|
|
||||||
if idx >= len(filtered) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
item := filtered[idx]
|
|
||||||
origIdx := s.itemIndex[item.Name]
|
|
||||||
|
|
||||||
checkbox := "[ ]"
|
|
||||||
if s.checked[origIdx] {
|
|
||||||
checkbox = "[x]"
|
|
||||||
}
|
|
||||||
|
|
||||||
prefix := " "
|
|
||||||
suffix := ""
|
|
||||||
if idx == s.highlighted && !s.focusOnButton {
|
|
||||||
prefix = "> "
|
|
||||||
}
|
|
||||||
if len(s.checkOrder) > 0 && s.checkOrder[0] == origIdx {
|
|
||||||
suffix = " " + ansiGray + "(default)" + ansiReset
|
|
||||||
}
|
|
||||||
|
|
||||||
if idx == s.highlighted && !s.focusOnButton {
|
|
||||||
fmt.Fprintf(w, " %s%s %s %s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, suffix)
|
|
||||||
} else {
|
|
||||||
fmt.Fprintf(w, " %s %s %s%s\r\n", prefix, checkbox, item.Name, suffix)
|
|
||||||
}
|
|
||||||
lineCount++
|
|
||||||
}
|
|
||||||
|
|
||||||
if remaining := len(filtered) - s.scrollOffset - displayCount; remaining > 0 {
|
|
||||||
fmt.Fprintf(w, " %s... and %d more%s\r\n", ansiGray, remaining, ansiReset)
|
|
||||||
lineCount++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Fprintf(w, "\r\n")
|
|
||||||
lineCount++
|
|
||||||
count := s.selectedCount()
|
|
||||||
switch {
|
|
||||||
case count == 0:
|
|
||||||
fmt.Fprintf(w, " %sSelect at least one model.%s\r\n", ansiGray, ansiReset)
|
|
||||||
case s.focusOnButton:
|
|
||||||
fmt.Fprintf(w, " %s> [ Continue ]%s %s(%d selected)%s\r\n", ansiBold, ansiReset, ansiGray, count, ansiReset)
|
|
||||||
default:
|
|
||||||
fmt.Fprintf(w, " %s[ Continue ] (%d selected) - press Tab%s\r\n", ansiGray, count, ansiReset)
|
|
||||||
}
|
|
||||||
lineCount++
|
|
||||||
|
|
||||||
return lineCount
|
|
||||||
}
|
|
||||||
|
|
||||||
// selectPrompt prompts the user to select a single item from a list.
|
|
||||||
func selectPrompt(prompt string, items []selectItem) (string, error) {
|
|
||||||
if len(items) == 0 {
|
|
||||||
return "", fmt.Errorf("no items to select from")
|
|
||||||
}
|
|
||||||
|
|
||||||
ts, err := enterRawMode()
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer ts.restore()
|
|
||||||
|
|
||||||
state := newSelectState(items)
|
|
||||||
var lastLineCount int
|
|
||||||
|
|
||||||
render := func() {
|
|
||||||
clearLines(lastLineCount)
|
|
||||||
lastLineCount = renderSelect(os.Stderr, prompt, state)
|
|
||||||
}
|
|
||||||
|
|
||||||
render()
|
|
||||||
|
|
||||||
for {
|
|
||||||
event, char, err := parseInput(os.Stdin)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
done, result, err := state.handleInput(event, char)
|
|
||||||
if done {
|
|
||||||
clearLines(lastLineCount)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
render()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// multiSelectPrompt prompts the user to select multiple items from a list.
|
|
||||||
func multiSelectPrompt(prompt string, items []selectItem, preChecked []string) ([]string, error) {
|
|
||||||
if len(items) == 0 {
|
|
||||||
return nil, fmt.Errorf("no items to select from")
|
|
||||||
}
|
|
||||||
|
|
||||||
ts, err := enterRawMode()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer ts.restore()
|
|
||||||
|
|
||||||
state := newMultiSelectState(items, preChecked)
|
|
||||||
var lastLineCount int
|
|
||||||
|
|
||||||
render := func() {
|
|
||||||
clearLines(lastLineCount)
|
|
||||||
lastLineCount = renderMultiSelect(os.Stderr, prompt, state)
|
|
||||||
}
|
|
||||||
|
|
||||||
render()
|
|
||||||
|
|
||||||
for {
|
|
||||||
event, char, err := parseInput(os.Stdin)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
done, result, err := state.handleInput(event, char)
|
|
||||||
if done {
|
|
||||||
clearLines(lastLineCount)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
render()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func confirmPrompt(prompt string) (bool, error) {
|
|
||||||
fd := int(os.Stdin.Fd())
|
|
||||||
oldState, err := term.MakeRaw(fd)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
defer term.Restore(fd, oldState)
|
|
||||||
|
|
||||||
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
|
|
||||||
|
|
||||||
buf := make([]byte, 1)
|
|
||||||
for {
|
|
||||||
if _, err := os.Stdin.Read(buf); err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch buf[0] {
|
|
||||||
case 'Y', 'y', 13:
|
|
||||||
fmt.Fprintf(os.Stderr, "yes\r\n")
|
|
||||||
return true, nil
|
|
||||||
case 'N', 'n', 27, 3:
|
|
||||||
fmt.Fprintf(os.Stderr, "no\r\n")
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func filterItems(items []selectItem, filter string) []selectItem {
|
|
||||||
if filter == "" {
|
|
||||||
return items
|
|
||||||
}
|
|
||||||
var result []selectItem
|
|
||||||
filterLower := strings.ToLower(filter)
|
|
||||||
for _, item := range items {
|
|
||||||
if strings.Contains(strings.ToLower(item.Name), filterLower) {
|
|
||||||
result = append(result, item)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
@@ -1,913 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestFilterItems(t *testing.T) {
|
|
||||||
items := []selectItem{
|
|
||||||
{Name: "llama3.2:latest"},
|
|
||||||
{Name: "qwen2.5:7b"},
|
|
||||||
{Name: "deepseek-v3:cloud"},
|
|
||||||
{Name: "GPT-OSS:20b"},
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("EmptyFilter_ReturnsAllItems", func(t *testing.T) {
|
|
||||||
result := filterItems(items, "")
|
|
||||||
if len(result) != len(items) {
|
|
||||||
t.Errorf("expected %d items, got %d", len(items), len(result))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("CaseInsensitive_UppercaseFilterMatchesLowercase", func(t *testing.T) {
|
|
||||||
result := filterItems(items, "LLAMA")
|
|
||||||
if len(result) != 1 || result[0].Name != "llama3.2:latest" {
|
|
||||||
t.Errorf("expected llama3.2:latest, got %v", result)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("CaseInsensitive_LowercaseFilterMatchesUppercase", func(t *testing.T) {
|
|
||||||
result := filterItems(items, "gpt")
|
|
||||||
if len(result) != 1 || result[0].Name != "GPT-OSS:20b" {
|
|
||||||
t.Errorf("expected GPT-OSS:20b, got %v", result)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("PartialMatch", func(t *testing.T) {
|
|
||||||
result := filterItems(items, "deep")
|
|
||||||
if len(result) != 1 || result[0].Name != "deepseek-v3:cloud" {
|
|
||||||
t.Errorf("expected deepseek-v3:cloud, got %v", result)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("NoMatch_ReturnsEmpty", func(t *testing.T) {
|
|
||||||
result := filterItems(items, "nonexistent")
|
|
||||||
if len(result) != 0 {
|
|
||||||
t.Errorf("expected 0 items, got %d", len(result))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSelectState(t *testing.T) {
|
|
||||||
items := []selectItem{
|
|
||||||
{Name: "item1"},
|
|
||||||
{Name: "item2"},
|
|
||||||
{Name: "item3"},
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("InitialState", func(t *testing.T) {
|
|
||||||
s := newSelectState(items)
|
|
||||||
if s.selected != 0 {
|
|
||||||
t.Errorf("expected selected=0, got %d", s.selected)
|
|
||||||
}
|
|
||||||
if s.filter != "" {
|
|
||||||
t.Errorf("expected empty filter, got %q", s.filter)
|
|
||||||
}
|
|
||||||
if s.scrollOffset != 0 {
|
|
||||||
t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Enter_SelectsCurrentItem", func(t *testing.T) {
|
|
||||||
s := newSelectState(items)
|
|
||||||
done, result, err := s.handleInput(eventEnter, 0)
|
|
||||||
if !done || result != "item1" || err != nil {
|
|
||||||
t.Errorf("expected (true, item1, nil), got (%v, %v, %v)", done, result, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Enter_WithFilter_SelectsFilteredItem", func(t *testing.T) {
|
|
||||||
s := newSelectState(items)
|
|
||||||
s.filter = "item3"
|
|
||||||
done, result, err := s.handleInput(eventEnter, 0)
|
|
||||||
if !done || result != "item3" || err != nil {
|
|
||||||
t.Errorf("expected (true, item3, nil), got (%v, %v, %v)", done, result, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Enter_EmptyFilteredList_DoesNothing", func(t *testing.T) {
|
|
||||||
s := newSelectState(items)
|
|
||||||
s.filter = "nonexistent"
|
|
||||||
done, result, err := s.handleInput(eventEnter, 0)
|
|
||||||
if done || result != "" || err != nil {
|
|
||||||
t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
|
|
||||||
s := newSelectState(items)
|
|
||||||
done, result, err := s.handleInput(eventEscape, 0)
|
|
||||||
if !done || result != "" || err != errCancelled {
|
|
||||||
t.Errorf("expected (true, '', errCancelled), got (%v, %v, %v)", done, result, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Down_MovesSelection", func(t *testing.T) {
|
|
||||||
s := newSelectState(items)
|
|
||||||
s.handleInput(eventDown, 0)
|
|
||||||
if s.selected != 1 {
|
|
||||||
t.Errorf("expected selected=1, got %d", s.selected)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Down_AtBottom_StaysAtBottom", func(t *testing.T) {
|
|
||||||
s := newSelectState(items)
|
|
||||||
s.selected = 2
|
|
||||||
s.handleInput(eventDown, 0)
|
|
||||||
if s.selected != 2 {
|
|
||||||
t.Errorf("expected selected=2 (stayed at bottom), got %d", s.selected)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Up_MovesSelection", func(t *testing.T) {
|
|
||||||
s := newSelectState(items)
|
|
||||||
s.selected = 2
|
|
||||||
s.handleInput(eventUp, 0)
|
|
||||||
if s.selected != 1 {
|
|
||||||
t.Errorf("expected selected=1, got %d", s.selected)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Up_AtTop_StaysAtTop", func(t *testing.T) {
|
|
||||||
s := newSelectState(items)
|
|
||||||
s.handleInput(eventUp, 0)
|
|
||||||
if s.selected != 0 {
|
|
||||||
t.Errorf("expected selected=0 (stayed at top), got %d", s.selected)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Char_AppendsToFilter", func(t *testing.T) {
|
|
||||||
s := newSelectState(items)
|
|
||||||
s.handleInput(eventChar, 'i')
|
|
||||||
s.handleInput(eventChar, 't')
|
|
||||||
s.handleInput(eventChar, 'e')
|
|
||||||
s.handleInput(eventChar, 'm')
|
|
||||||
s.handleInput(eventChar, '2')
|
|
||||||
if s.filter != "item2" {
|
|
||||||
t.Errorf("expected filter='item2', got %q", s.filter)
|
|
||||||
}
|
|
||||||
filtered := s.filtered()
|
|
||||||
if len(filtered) != 1 || filtered[0].Name != "item2" {
|
|
||||||
t.Errorf("expected [item2], got %v", filtered)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Char_ResetsSelectionToZero", func(t *testing.T) {
|
|
||||||
s := newSelectState(items)
|
|
||||||
s.selected = 2
|
|
||||||
s.handleInput(eventChar, 'x')
|
|
||||||
if s.selected != 0 {
|
|
||||||
t.Errorf("expected selected=0 after typing, got %d", s.selected)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) {
|
|
||||||
s := newSelectState(items)
|
|
||||||
s.filter = "test"
|
|
||||||
s.handleInput(eventBackspace, 0)
|
|
||||||
if s.filter != "tes" {
|
|
||||||
t.Errorf("expected filter='tes', got %q", s.filter)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Backspace_EmptyFilter_DoesNothing", func(t *testing.T) {
|
|
||||||
s := newSelectState(items)
|
|
||||||
s.handleInput(eventBackspace, 0)
|
|
||||||
if s.filter != "" {
|
|
||||||
t.Errorf("expected filter='', got %q", s.filter)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Backspace_ResetsSelectionToZero", func(t *testing.T) {
|
|
||||||
s := newSelectState(items)
|
|
||||||
s.filter = "test"
|
|
||||||
s.selected = 2
|
|
||||||
s.handleInput(eventBackspace, 0)
|
|
||||||
if s.selected != 0 {
|
|
||||||
t.Errorf("expected selected=0 after backspace, got %d", s.selected)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Scroll_DownPastVisibleItems_ScrollsViewport", func(t *testing.T) {
|
|
||||||
// maxDisplayedItems is 10, so with 15 items we need to scroll
|
|
||||||
manyItems := make([]selectItem, 15)
|
|
||||||
for i := range manyItems {
|
|
||||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
|
||||||
}
|
|
||||||
s := newSelectState(manyItems)
|
|
||||||
|
|
||||||
// move down 12 times (past the 10-item viewport)
|
|
||||||
for range 12 {
|
|
||||||
s.handleInput(eventDown, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.selected != 12 {
|
|
||||||
t.Errorf("expected selected=12, got %d", s.selected)
|
|
||||||
}
|
|
||||||
if s.scrollOffset != 3 {
|
|
||||||
t.Errorf("expected scrollOffset=3 (12-10+1), got %d", s.scrollOffset)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Scroll_UpPastScrollOffset_ScrollsViewport", func(t *testing.T) {
|
|
||||||
manyItems := make([]selectItem, 15)
|
|
||||||
for i := range manyItems {
|
|
||||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
|
||||||
}
|
|
||||||
s := newSelectState(manyItems)
|
|
||||||
s.selected = 5
|
|
||||||
s.scrollOffset = 5
|
|
||||||
|
|
||||||
s.handleInput(eventUp, 0)
|
|
||||||
|
|
||||||
if s.selected != 4 {
|
|
||||||
t.Errorf("expected selected=4, got %d", s.selected)
|
|
||||||
}
|
|
||||||
if s.scrollOffset != 4 {
|
|
||||||
t.Errorf("expected scrollOffset=4, got %d", s.scrollOffset)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMultiSelectState(t *testing.T) {
|
|
||||||
items := []selectItem{
|
|
||||||
{Name: "item1"},
|
|
||||||
{Name: "item2"},
|
|
||||||
{Name: "item3"},
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("InitialState_NoPrechecked", func(t *testing.T) {
|
|
||||||
s := newMultiSelectState(items, nil)
|
|
||||||
if s.highlighted != 0 {
|
|
||||||
t.Errorf("expected highlighted=0, got %d", s.highlighted)
|
|
||||||
}
|
|
||||||
if s.selectedCount() != 0 {
|
|
||||||
t.Errorf("expected 0 selected, got %d", s.selectedCount())
|
|
||||||
}
|
|
||||||
if s.focusOnButton {
|
|
||||||
t.Error("expected focusOnButton=false initially")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("InitialState_WithPrechecked", func(t *testing.T) {
|
|
||||||
s := newMultiSelectState(items, []string{"item2", "item3"})
|
|
||||||
if s.selectedCount() != 2 {
|
|
||||||
t.Errorf("expected 2 selected, got %d", s.selectedCount())
|
|
||||||
}
|
|
||||||
if !s.checked[1] || !s.checked[2] {
|
|
||||||
t.Error("expected item2 and item3 to be checked")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Prechecked_PreservesSelectionOrder", func(t *testing.T) {
|
|
||||||
// order matters: first checked = default model
|
|
||||||
s := newMultiSelectState(items, []string{"item3", "item1"})
|
|
||||||
if len(s.checkOrder) != 2 {
|
|
||||||
t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder))
|
|
||||||
}
|
|
||||||
if s.checkOrder[0] != 2 || s.checkOrder[1] != 0 {
|
|
||||||
t.Errorf("expected checkOrder=[2,0] (item3 first), got %v", s.checkOrder)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Prechecked_IgnoresInvalidNames", func(t *testing.T) {
|
|
||||||
s := newMultiSelectState(items, []string{"item1", "nonexistent"})
|
|
||||||
if s.selectedCount() != 1 {
|
|
||||||
t.Errorf("expected 1 selected (nonexistent ignored), got %d", s.selectedCount())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Toggle_ChecksUncheckedItem", func(t *testing.T) {
|
|
||||||
s := newMultiSelectState(items, nil)
|
|
||||||
s.toggleItem()
|
|
||||||
if !s.checked[0] {
|
|
||||||
t.Error("expected item1 to be checked after toggle")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Toggle_UnchecksCheckedItem", func(t *testing.T) {
|
|
||||||
s := newMultiSelectState(items, []string{"item1"})
|
|
||||||
s.toggleItem()
|
|
||||||
if s.checked[0] {
|
|
||||||
t.Error("expected item1 to be unchecked after toggle")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Toggle_RemovesFromCheckOrder", func(t *testing.T) {
|
|
||||||
s := newMultiSelectState(items, []string{"item1", "item2", "item3"})
|
|
||||||
s.highlighted = 1 // toggle item2
|
|
||||||
s.toggleItem()
|
|
||||||
|
|
||||||
if len(s.checkOrder) != 2 {
|
|
||||||
t.Fatalf("expected 2 in checkOrder, got %d", len(s.checkOrder))
|
|
||||||
}
|
|
||||||
// should be [0, 2] (item1, item3) with item2 removed
|
|
||||||
if s.checkOrder[0] != 0 || s.checkOrder[1] != 2 {
|
|
||||||
t.Errorf("expected checkOrder=[0,2], got %v", s.checkOrder)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Enter_TogglesWhenNotOnButton", func(t *testing.T) {
|
|
||||||
s := newMultiSelectState(items, nil)
|
|
||||||
s.handleInput(eventEnter, 0)
|
|
||||||
if !s.checked[0] {
|
|
||||||
t.Error("expected item1 to be checked after enter")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Enter_OnButton_ReturnsSelection", func(t *testing.T) {
|
|
||||||
s := newMultiSelectState(items, []string{"item2", "item1"})
|
|
||||||
s.focusOnButton = true
|
|
||||||
|
|
||||||
done, result, err := s.handleInput(eventEnter, 0)
|
|
||||||
|
|
||||||
if !done || err != nil {
|
|
||||||
t.Errorf("expected done=true, err=nil, got done=%v, err=%v", done, err)
|
|
||||||
}
|
|
||||||
// result should preserve selection order
|
|
||||||
if len(result) != 2 || result[0] != "item2" || result[1] != "item1" {
|
|
||||||
t.Errorf("expected [item2, item1], got %v", result)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Enter_OnButton_EmptySelection_DoesNothing", func(t *testing.T) {
|
|
||||||
s := newMultiSelectState(items, nil)
|
|
||||||
s.focusOnButton = true
|
|
||||||
done, result, err := s.handleInput(eventEnter, 0)
|
|
||||||
if done || result != nil || err != nil {
|
|
||||||
t.Errorf("expected (false, nil, nil), got (%v, %v, %v)", done, result, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Tab_SwitchesToButton_WhenHasSelection", func(t *testing.T) {
|
|
||||||
s := newMultiSelectState(items, []string{"item1"})
|
|
||||||
s.handleInput(eventTab, 0)
|
|
||||||
if !s.focusOnButton {
|
|
||||||
t.Error("expected focus on button after tab")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Tab_DoesNothing_WhenNoSelection", func(t *testing.T) {
|
|
||||||
s := newMultiSelectState(items, nil)
|
|
||||||
s.handleInput(eventTab, 0)
|
|
||||||
if s.focusOnButton {
|
|
||||||
t.Error("tab should not focus button when nothing selected")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Tab_TogglesButtonFocus", func(t *testing.T) {
|
|
||||||
s := newMultiSelectState(items, []string{"item1"})
|
|
||||||
s.handleInput(eventTab, 0)
|
|
||||||
if !s.focusOnButton {
|
|
||||||
t.Error("expected focus on button after first tab")
|
|
||||||
}
|
|
||||||
s.handleInput(eventTab, 0)
|
|
||||||
if s.focusOnButton {
|
|
||||||
t.Error("expected focus back on list after second tab")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
|
|
||||||
s := newMultiSelectState(items, []string{"item1"})
|
|
||||||
done, result, err := s.handleInput(eventEscape, 0)
|
|
||||||
if !done || result != nil || err != errCancelled {
|
|
||||||
t.Errorf("expected (true, nil, errCancelled), got (%v, %v, %v)", done, result, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("IsDefault_TrueForFirstChecked", func(t *testing.T) {
|
|
||||||
s := newMultiSelectState(items, []string{"item2", "item1"})
|
|
||||||
if !(len(s.checkOrder) > 0 && s.checkOrder[0] == 1) {
|
|
||||||
t.Error("expected item2 (idx 1) to be default (first checked)")
|
|
||||||
}
|
|
||||||
if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 {
|
|
||||||
t.Error("expected item1 (idx 0) to NOT be default")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("IsDefault_FalseWhenNothingChecked", func(t *testing.T) {
|
|
||||||
s := newMultiSelectState(items, nil)
|
|
||||||
if len(s.checkOrder) > 0 && s.checkOrder[0] == 0 {
|
|
||||||
t.Error("expected isDefault=false when nothing checked")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Down_MovesHighlight", func(t *testing.T) {
|
|
||||||
s := newMultiSelectState(items, nil)
|
|
||||||
s.handleInput(eventDown, 0)
|
|
||||||
if s.highlighted != 1 {
|
|
||||||
t.Errorf("expected highlighted=1, got %d", s.highlighted)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Up_MovesHighlight", func(t *testing.T) {
|
|
||||||
s := newMultiSelectState(items, nil)
|
|
||||||
s.highlighted = 1
|
|
||||||
s.handleInput(eventUp, 0)
|
|
||||||
if s.highlighted != 0 {
|
|
||||||
t.Errorf("expected highlighted=0, got %d", s.highlighted)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Arrow_ReturnsFocusFromButton", func(t *testing.T) {
|
|
||||||
s := newMultiSelectState(items, []string{"item1"})
|
|
||||||
s.focusOnButton = true
|
|
||||||
s.handleInput(eventDown, 0)
|
|
||||||
if s.focusOnButton {
|
|
||||||
t.Error("expected focus to return to list on arrow key")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Char_AppendsToFilter", func(t *testing.T) {
|
|
||||||
s := newMultiSelectState(items, nil)
|
|
||||||
s.handleInput(eventChar, 'x')
|
|
||||||
if s.filter != "x" {
|
|
||||||
t.Errorf("expected filter='x', got %q", s.filter)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Char_ResetsHighlightAndScroll", func(t *testing.T) {
|
|
||||||
manyItems := make([]selectItem, 15)
|
|
||||||
for i := range manyItems {
|
|
||||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
|
||||||
}
|
|
||||||
s := newMultiSelectState(manyItems, nil)
|
|
||||||
s.highlighted = 10
|
|
||||||
s.scrollOffset = 5
|
|
||||||
|
|
||||||
s.handleInput(eventChar, 'x')
|
|
||||||
|
|
||||||
if s.highlighted != 0 {
|
|
||||||
t.Errorf("expected highlighted=0, got %d", s.highlighted)
|
|
||||||
}
|
|
||||||
if s.scrollOffset != 0 {
|
|
||||||
t.Errorf("expected scrollOffset=0, got %d", s.scrollOffset)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Backspace_RemovesLastFilterChar", func(t *testing.T) {
|
|
||||||
s := newMultiSelectState(items, nil)
|
|
||||||
s.filter = "test"
|
|
||||||
s.handleInput(eventBackspace, 0)
|
|
||||||
if s.filter != "tes" {
|
|
||||||
t.Errorf("expected filter='tes', got %q", s.filter)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Backspace_RemovesFocusFromButton", func(t *testing.T) {
|
|
||||||
s := newMultiSelectState(items, []string{"item1"})
|
|
||||||
s.filter = "x"
|
|
||||||
s.focusOnButton = true
|
|
||||||
s.handleInput(eventBackspace, 0)
|
|
||||||
if s.focusOnButton {
|
|
||||||
t.Error("expected focusOnButton=false after backspace")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseInput(t *testing.T) {
|
|
||||||
t.Run("Enter", func(t *testing.T) {
|
|
||||||
event, char, err := parseInput(bytes.NewReader([]byte{13}))
|
|
||||||
if err != nil || event != eventEnter || char != 0 {
|
|
||||||
t.Errorf("expected (eventEnter, 0, nil), got (%v, %v, %v)", event, char, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Escape", func(t *testing.T) {
|
|
||||||
event, _, err := parseInput(bytes.NewReader([]byte{27}))
|
|
||||||
if err != nil || event != eventEscape {
|
|
||||||
t.Errorf("expected eventEscape, got %v", event)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("CtrlC_TreatedAsEscape", func(t *testing.T) {
|
|
||||||
event, _, err := parseInput(bytes.NewReader([]byte{3}))
|
|
||||||
if err != nil || event != eventEscape {
|
|
||||||
t.Errorf("expected eventEscape for Ctrl+C, got %v", event)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Tab", func(t *testing.T) {
|
|
||||||
event, _, err := parseInput(bytes.NewReader([]byte{9}))
|
|
||||||
if err != nil || event != eventTab {
|
|
||||||
t.Errorf("expected eventTab, got %v", event)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Backspace", func(t *testing.T) {
|
|
||||||
event, _, err := parseInput(bytes.NewReader([]byte{127}))
|
|
||||||
if err != nil || event != eventBackspace {
|
|
||||||
t.Errorf("expected eventBackspace, got %v", event)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("UpArrow", func(t *testing.T) {
|
|
||||||
event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 65}))
|
|
||||||
if err != nil || event != eventUp {
|
|
||||||
t.Errorf("expected eventUp, got %v", event)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("DownArrow", func(t *testing.T) {
|
|
||||||
event, _, err := parseInput(bytes.NewReader([]byte{27, 91, 66}))
|
|
||||||
if err != nil || event != eventDown {
|
|
||||||
t.Errorf("expected eventDown, got %v", event)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("PrintableChars", func(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
char byte
|
|
||||||
}{
|
|
||||||
{"lowercase", 'a'},
|
|
||||||
{"uppercase", 'Z'},
|
|
||||||
{"digit", '5'},
|
|
||||||
{"space", ' '},
|
|
||||||
{"tilde", '~'},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
event, char, err := parseInput(bytes.NewReader([]byte{tt.char}))
|
|
||||||
if err != nil || event != eventChar || char != tt.char {
|
|
||||||
t.Errorf("expected (eventChar, %q), got (%v, %q)", tt.char, event, char)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRenderSelect(t *testing.T) {
|
|
||||||
items := []selectItem{
|
|
||||||
{Name: "item1", Description: "first item"},
|
|
||||||
{Name: "item2"},
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("ShowsPromptAndItems", func(t *testing.T) {
|
|
||||||
s := newSelectState(items)
|
|
||||||
var buf bytes.Buffer
|
|
||||||
lineCount := renderSelect(&buf, "Select:", s)
|
|
||||||
|
|
||||||
output := buf.String()
|
|
||||||
if !strings.Contains(output, "Select:") {
|
|
||||||
t.Error("expected prompt in output")
|
|
||||||
}
|
|
||||||
if !strings.Contains(output, "item1") {
|
|
||||||
t.Error("expected item1 in output")
|
|
||||||
}
|
|
||||||
if !strings.Contains(output, "first item") {
|
|
||||||
t.Error("expected description in output")
|
|
||||||
}
|
|
||||||
if !strings.Contains(output, "item2") {
|
|
||||||
t.Error("expected item2 in output")
|
|
||||||
}
|
|
||||||
if lineCount != 3 { // 1 prompt + 2 items
|
|
||||||
t.Errorf("expected 3 lines, got %d", lineCount)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("EmptyFilteredList_ShowsNoMatches", func(t *testing.T) {
|
|
||||||
s := newSelectState(items)
|
|
||||||
s.filter = "xyz"
|
|
||||||
var buf bytes.Buffer
|
|
||||||
renderSelect(&buf, "Select:", s)
|
|
||||||
|
|
||||||
if !strings.Contains(buf.String(), "no matches") {
|
|
||||||
t.Error("expected 'no matches' message")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("LongList_ShowsRemainingCount", func(t *testing.T) {
|
|
||||||
manyItems := make([]selectItem, 15)
|
|
||||||
for i := range manyItems {
|
|
||||||
manyItems[i] = selectItem{Name: string(rune('a' + i))}
|
|
||||||
}
|
|
||||||
s := newSelectState(manyItems)
|
|
||||||
var buf bytes.Buffer
|
|
||||||
renderSelect(&buf, "Select:", s)
|
|
||||||
|
|
||||||
// 15 items - 10 displayed = 5 more
|
|
||||||
if !strings.Contains(buf.String(), "5 more") {
|
|
||||||
t.Error("expected '5 more' indicator")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRenderMultiSelect(t *testing.T) {
|
|
||||||
items := []selectItem{
|
|
||||||
{Name: "item1"},
|
|
||||||
{Name: "item2"},
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("ShowsCheckboxes", func(t *testing.T) {
|
|
||||||
s := newMultiSelectState(items, []string{"item1"})
|
|
||||||
var buf bytes.Buffer
|
|
||||||
renderMultiSelect(&buf, "Select:", s)
|
|
||||||
|
|
||||||
output := buf.String()
|
|
||||||
if !strings.Contains(output, "[x]") {
|
|
||||||
t.Error("expected checked checkbox [x]")
|
|
||||||
}
|
|
||||||
if !strings.Contains(output, "[ ]") {
|
|
||||||
t.Error("expected unchecked checkbox [ ]")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("ShowsDefaultMarker", func(t *testing.T) {
|
|
||||||
s := newMultiSelectState(items, []string{"item1"})
|
|
||||||
var buf bytes.Buffer
|
|
||||||
renderMultiSelect(&buf, "Select:", s)
|
|
||||||
|
|
||||||
if !strings.Contains(buf.String(), "(default)") {
|
|
||||||
t.Error("expected (default) marker for first checked item")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("ShowsSelectedCount", func(t *testing.T) {
|
|
||||||
s := newMultiSelectState(items, []string{"item1", "item2"})
|
|
||||||
var buf bytes.Buffer
|
|
||||||
renderMultiSelect(&buf, "Select:", s)
|
|
||||||
|
|
||||||
if !strings.Contains(buf.String(), "2 selected") {
|
|
||||||
t.Error("expected '2 selected' in output")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("NoSelection_ShowsHelperText", func(t *testing.T) {
|
|
||||||
s := newMultiSelectState(items, nil)
|
|
||||||
var buf bytes.Buffer
|
|
||||||
renderMultiSelect(&buf, "Select:", s)
|
|
||||||
|
|
||||||
if !strings.Contains(buf.String(), "Select at least one") {
|
|
||||||
t.Error("expected 'Select at least one' helper text")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestErrCancelled(t *testing.T) {
|
|
||||||
t.Run("NotNil", func(t *testing.T) {
|
|
||||||
if errCancelled == nil {
|
|
||||||
t.Error("errCancelled should not be nil")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Message", func(t *testing.T) {
|
|
||||||
if errCancelled.Error() != "cancelled" {
|
|
||||||
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Edge case tests for selector.go
|
|
||||||
|
|
||||||
// TestSelectState_SingleItem verifies that single item list works without crash.
|
|
||||||
// List with only one item should still work.
|
|
||||||
func TestSelectState_SingleItem(t *testing.T) {
|
|
||||||
items := []selectItem{{Name: "only-one"}}
|
|
||||||
|
|
||||||
s := newSelectState(items)
|
|
||||||
|
|
||||||
// Down should do nothing (already at bottom)
|
|
||||||
s.handleInput(eventDown, 0)
|
|
||||||
if s.selected != 0 {
|
|
||||||
t.Errorf("down on single item: expected selected=0, got %d", s.selected)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Up should do nothing (already at top)
|
|
||||||
s.handleInput(eventUp, 0)
|
|
||||||
if s.selected != 0 {
|
|
||||||
t.Errorf("up on single item: expected selected=0, got %d", s.selected)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enter should select the only item
|
|
||||||
done, result, err := s.handleInput(eventEnter, 0)
|
|
||||||
if !done || result != "only-one" || err != nil {
|
|
||||||
t.Errorf("enter on single item: expected (true, 'only-one', nil), got (%v, %q, %v)", done, result, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestSelectState_ExactlyMaxItems verifies boundary condition at maxDisplayedItems.
|
|
||||||
// List with exactly maxDisplayedItems items should not scroll.
|
|
||||||
func TestSelectState_ExactlyMaxItems(t *testing.T) {
|
|
||||||
items := make([]selectItem, maxDisplayedItems)
|
|
||||||
for i := range items {
|
|
||||||
items[i] = selectItem{Name: string(rune('a' + i))}
|
|
||||||
}
|
|
||||||
|
|
||||||
s := newSelectState(items)
|
|
||||||
|
|
||||||
// Move to last item
|
|
||||||
for range maxDisplayedItems - 1 {
|
|
||||||
s.handleInput(eventDown, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.selected != maxDisplayedItems-1 {
|
|
||||||
t.Errorf("expected selected=%d, got %d", maxDisplayedItems-1, s.selected)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should not scroll when exactly at max
|
|
||||||
if s.scrollOffset != 0 {
|
|
||||||
t.Errorf("expected scrollOffset=0 for exactly maxDisplayedItems, got %d", s.scrollOffset)
|
|
||||||
}
|
|
||||||
|
|
||||||
// One more down should do nothing
|
|
||||||
s.handleInput(eventDown, 0)
|
|
||||||
if s.selected != maxDisplayedItems-1 {
|
|
||||||
t.Errorf("down at max: expected selected=%d, got %d", maxDisplayedItems-1, s.selected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestFilterItems_RegexSpecialChars verifies that filter is literal, not regex.
|
|
||||||
// User typing "model.v1" shouldn't match "modelsv1".
|
|
||||||
func TestFilterItems_RegexSpecialChars(t *testing.T) {
|
|
||||||
items := []selectItem{
|
|
||||||
{Name: "model.v1"},
|
|
||||||
{Name: "modelsv1"},
|
|
||||||
{Name: "model-v1"},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Filter with dot should only match literal dot
|
|
||||||
result := filterItems(items, "model.v1")
|
|
||||||
if len(result) != 1 {
|
|
||||||
t.Errorf("expected 1 exact match, got %d", len(result))
|
|
||||||
}
|
|
||||||
if len(result) > 0 && result[0].Name != "model.v1" {
|
|
||||||
t.Errorf("expected 'model.v1', got %s", result[0].Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Other regex special chars should be literal too
|
|
||||||
items2 := []selectItem{
|
|
||||||
{Name: "test[0]"},
|
|
||||||
{Name: "test0"},
|
|
||||||
{Name: "test(1)"},
|
|
||||||
}
|
|
||||||
|
|
||||||
result2 := filterItems(items2, "test[0]")
|
|
||||||
if len(result2) != 1 || result2[0].Name != "test[0]" {
|
|
||||||
t.Errorf("expected only 'test[0]', got %v", result2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestMultiSelectState_DuplicateNames documents handling of duplicate item names.
|
|
||||||
// itemIndex uses name as key - duplicates cause collision. This documents
|
|
||||||
// the current behavior: the last index for a duplicate name is stored
|
|
||||||
func TestMultiSelectState_DuplicateNames(t *testing.T) {
|
|
||||||
// Duplicate names - this is an edge case that shouldn't happen in practice
|
|
||||||
items := []selectItem{
|
|
||||||
{Name: "duplicate"},
|
|
||||||
{Name: "duplicate"},
|
|
||||||
{Name: "unique"},
|
|
||||||
}
|
|
||||||
|
|
||||||
s := newMultiSelectState(items, nil)
|
|
||||||
|
|
||||||
// DOCUMENTED BEHAVIOR: itemIndex maps name to LAST index
|
|
||||||
// When there are duplicates, only the last occurrence's index is stored
|
|
||||||
if s.itemIndex["duplicate"] != 1 {
|
|
||||||
t.Errorf("itemIndex should map 'duplicate' to last index (1), got %d", s.itemIndex["duplicate"])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Toggle item at highlighted=0 (first "duplicate")
|
|
||||||
// Due to name collision, toggleItem uses itemIndex["duplicate"] = 1
|
|
||||||
// So it actually toggles the SECOND duplicate item, not the first
|
|
||||||
s.toggleItem()
|
|
||||||
|
|
||||||
// This documents the potentially surprising behavior:
|
|
||||||
// We toggled at highlighted=0, but itemIndex lookup returned 1
|
|
||||||
if !s.checked[1] {
|
|
||||||
t.Error("toggle should check index 1 (due to name collision in itemIndex)")
|
|
||||||
}
|
|
||||||
if s.checked[0] {
|
|
||||||
t.Log("Note: index 0 is NOT checked, even though highlighted=0 (name collision behavior)")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestSelectState_FilterReducesBelowSelection verifies selection resets when filter reduces list.
|
|
||||||
// Prevents index-out-of-bounds on next keystroke
|
|
||||||
func TestSelectState_FilterReducesBelowSelection(t *testing.T) {
|
|
||||||
items := []selectItem{
|
|
||||||
{Name: "apple"},
|
|
||||||
{Name: "banana"},
|
|
||||||
{Name: "cherry"},
|
|
||||||
}
|
|
||||||
|
|
||||||
s := newSelectState(items)
|
|
||||||
s.selected = 2 // Select "cherry"
|
|
||||||
|
|
||||||
// Type a filter that removes cherry from results
|
|
||||||
s.handleInput(eventChar, 'a') // Filter to "a" - matches "apple" and "banana"
|
|
||||||
|
|
||||||
// Selection should reset to 0
|
|
||||||
if s.selected != 0 {
|
|
||||||
t.Errorf("expected selected=0 after filter, got %d", s.selected)
|
|
||||||
}
|
|
||||||
|
|
||||||
filtered := s.filtered()
|
|
||||||
if len(filtered) != 2 {
|
|
||||||
t.Errorf("expected 2 filtered items, got %d", len(filtered))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestFilterItems_UnicodeCharacters verifies filtering works with UTF-8.
|
|
||||||
// Model names might contain unicode characters
|
|
||||||
func TestFilterItems_UnicodeCharacters(t *testing.T) {
|
|
||||||
items := []selectItem{
|
|
||||||
{Name: "llama-日本語"},
|
|
||||||
{Name: "模型-chinese"},
|
|
||||||
{Name: "émoji-🦙"},
|
|
||||||
{Name: "regular-model"},
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("filter japanese", func(t *testing.T) {
|
|
||||||
result := filterItems(items, "日本")
|
|
||||||
if len(result) != 1 || result[0].Name != "llama-日本語" {
|
|
||||||
t.Errorf("expected llama-日本語, got %v", result)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("filter chinese", func(t *testing.T) {
|
|
||||||
result := filterItems(items, "模型")
|
|
||||||
if len(result) != 1 || result[0].Name != "模型-chinese" {
|
|
||||||
t.Errorf("expected 模型-chinese, got %v", result)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("filter emoji", func(t *testing.T) {
|
|
||||||
result := filterItems(items, "🦙")
|
|
||||||
if len(result) != 1 || result[0].Name != "émoji-🦙" {
|
|
||||||
t.Errorf("expected émoji-🦙, got %v", result)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("filter accented char", func(t *testing.T) {
|
|
||||||
result := filterItems(items, "émoji")
|
|
||||||
if len(result) != 1 || result[0].Name != "émoji-🦙" {
|
|
||||||
t.Errorf("expected émoji-🦙, got %v", result)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestMultiSelectState_FilterReducesBelowHighlight verifies highlight resets when filter reduces list.
|
|
||||||
func TestMultiSelectState_FilterReducesBelowHighlight(t *testing.T) {
|
|
||||||
items := []selectItem{
|
|
||||||
{Name: "apple"},
|
|
||||||
{Name: "banana"},
|
|
||||||
{Name: "cherry"},
|
|
||||||
}
|
|
||||||
|
|
||||||
s := newMultiSelectState(items, nil)
|
|
||||||
s.highlighted = 2 // Highlight "cherry"
|
|
||||||
|
|
||||||
// Type a filter that removes cherry
|
|
||||||
s.handleInput(eventChar, 'a')
|
|
||||||
|
|
||||||
if s.highlighted != 0 {
|
|
||||||
t.Errorf("expected highlighted=0 after filter, got %d", s.highlighted)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestMultiSelectState_EmptyItems verifies handling of empty item list.
|
|
||||||
// Empty list should be handled gracefully.
|
|
||||||
func TestMultiSelectState_EmptyItems(t *testing.T) {
|
|
||||||
s := newMultiSelectState([]selectItem{}, nil)
|
|
||||||
|
|
||||||
// Toggle should not panic on empty list
|
|
||||||
s.toggleItem()
|
|
||||||
|
|
||||||
if s.selectedCount() != 0 {
|
|
||||||
t.Errorf("expected 0 selected for empty list, got %d", s.selectedCount())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Render should handle empty list
|
|
||||||
var buf bytes.Buffer
|
|
||||||
lineCount := renderMultiSelect(&buf, "Select:", s)
|
|
||||||
if lineCount == 0 {
|
|
||||||
t.Error("renderMultiSelect should produce output even for empty list")
|
|
||||||
}
|
|
||||||
if !strings.Contains(buf.String(), "no matches") {
|
|
||||||
t.Error("expected 'no matches' for empty list")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestSelectState_RenderWithDescriptions verifies rendering items with descriptions.
|
|
||||||
func TestSelectState_RenderWithDescriptions(t *testing.T) {
|
|
||||||
items := []selectItem{
|
|
||||||
{Name: "item1", Description: "First item description"},
|
|
||||||
{Name: "item2", Description: ""},
|
|
||||||
{Name: "item3", Description: "Third item"},
|
|
||||||
}
|
|
||||||
|
|
||||||
s := newSelectState(items)
|
|
||||||
var buf bytes.Buffer
|
|
||||||
renderSelect(&buf, "Select:", s)
|
|
||||||
|
|
||||||
output := buf.String()
|
|
||||||
if !strings.Contains(output, "First item description") {
|
|
||||||
t.Error("expected description to be rendered")
|
|
||||||
}
|
|
||||||
if !strings.Contains(output, "item2") {
|
|
||||||
t.Error("expected item without description to be rendered")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
5
cmd/editor_unix.go
Normal file
5
cmd/editor_unix.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
const defaultEditor = "vi"
|
||||||
5
cmd/editor_windows.go
Normal file
5
cmd/editor_windows.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
const defaultEditor = "edit"
|
||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"regexp"
|
"regexp"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -16,6 +17,7 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/internal/modelref"
|
||||||
"github.com/ollama/ollama/readline"
|
"github.com/ollama/ollama/readline"
|
||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
@@ -79,6 +81,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
fmt.Fprintln(os.Stderr, " Ctrl + w Delete the word before the cursor")
|
fmt.Fprintln(os.Stderr, " Ctrl + w Delete the word before the cursor")
|
||||||
fmt.Fprintln(os.Stderr, "")
|
fmt.Fprintln(os.Stderr, "")
|
||||||
fmt.Fprintln(os.Stderr, " Ctrl + l Clear the screen")
|
fmt.Fprintln(os.Stderr, " Ctrl + l Clear the screen")
|
||||||
|
fmt.Fprintln(os.Stderr, " Ctrl + g Open default editor to compose a prompt")
|
||||||
fmt.Fprintln(os.Stderr, " Ctrl + c Stop the model from responding")
|
fmt.Fprintln(os.Stderr, " Ctrl + c Stop the model from responding")
|
||||||
fmt.Fprintln(os.Stderr, " Ctrl + d Exit ollama (/bye)")
|
fmt.Fprintln(os.Stderr, " Ctrl + d Exit ollama (/bye)")
|
||||||
fmt.Fprintln(os.Stderr, "")
|
fmt.Fprintln(os.Stderr, "")
|
||||||
@@ -147,6 +150,18 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
scanner.Prompt.UseAlt = false
|
scanner.Prompt.UseAlt = false
|
||||||
sb.Reset()
|
sb.Reset()
|
||||||
|
|
||||||
|
continue
|
||||||
|
case errors.Is(err, readline.ErrEditPrompt):
|
||||||
|
sb.Reset()
|
||||||
|
content, err := editInExternalEditor(line)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(content) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
scanner.Prefill = content
|
||||||
continue
|
continue
|
||||||
case err != nil:
|
case err != nil:
|
||||||
return err
|
return err
|
||||||
@@ -526,6 +541,13 @@ func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
|
|||||||
parentModel = ""
|
parentModel = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Preserve explicit cloud intent for sessions started with `:cloud`.
|
||||||
|
// Cloud model metadata can return a source-less parent_model (for example
|
||||||
|
// "qwen3.5"), which would otherwise make `/save` create a local derivative.
|
||||||
|
if modelref.HasExplicitCloudSource(opts.Model) && !modelref.HasExplicitCloudSource(parentModel) {
|
||||||
|
parentModel = ""
|
||||||
|
}
|
||||||
|
|
||||||
req := &api.CreateRequest{
|
req := &api.CreateRequest{
|
||||||
Model: name,
|
Model: name,
|
||||||
From: cmp.Or(parentModel, opts.Model),
|
From: cmp.Or(parentModel, opts.Model),
|
||||||
@@ -598,6 +620,57 @@ func extractFileData(input string) (string, []api.ImageData, error) {
|
|||||||
return strings.TrimSpace(input), imgs, nil
|
return strings.TrimSpace(input), imgs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func editInExternalEditor(content string) (string, error) {
|
||||||
|
editor := envconfig.Editor()
|
||||||
|
if editor == "" {
|
||||||
|
editor = os.Getenv("VISUAL")
|
||||||
|
}
|
||||||
|
if editor == "" {
|
||||||
|
editor = os.Getenv("EDITOR")
|
||||||
|
}
|
||||||
|
if editor == "" {
|
||||||
|
editor = defaultEditor
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the editor binary exists
|
||||||
|
name := strings.Fields(editor)[0]
|
||||||
|
if _, err := exec.LookPath(name); err != nil {
|
||||||
|
return "", fmt.Errorf("editor %q not found, set OLLAMA_EDITOR to the path of your preferred editor", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpFile, err := os.CreateTemp("", "ollama-prompt-*.txt")
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("creating temp file: %w", err)
|
||||||
|
}
|
||||||
|
defer os.Remove(tmpFile.Name())
|
||||||
|
|
||||||
|
if content != "" {
|
||||||
|
if _, err := tmpFile.WriteString(content); err != nil {
|
||||||
|
tmpFile.Close()
|
||||||
|
return "", fmt.Errorf("writing to temp file: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tmpFile.Close()
|
||||||
|
|
||||||
|
args := strings.Fields(editor)
|
||||||
|
args = append(args, tmpFile.Name())
|
||||||
|
cmd := exec.Command(args[0], args[1:]...)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
return "", fmt.Errorf("editor exited with error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(tmpFile.Name())
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("reading temp file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.TrimRight(string(data), "\n"), nil
|
||||||
|
}
|
||||||
|
|
||||||
func getImageData(filePath string) ([]byte, error) {
|
func getImageData(filePath string) ([]byte, error) {
|
||||||
file, err := os.Open(filePath)
|
file, err := os.Open(filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
package config
|
// Package fileutil provides small shared helpers for reading JSON files
|
||||||
|
// and writing config files with backup-on-overwrite semantics.
|
||||||
|
package fileutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@@ -9,7 +11,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func readJSONFile(path string) (map[string]any, error) {
|
// ReadJSON reads a JSON object file into a generic map.
|
||||||
|
func ReadJSON(path string) (map[string]any, error) {
|
||||||
data, err := os.ReadFile(path)
|
data, err := os.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -33,12 +36,13 @@ func copyFile(src, dst string) error {
|
|||||||
return os.WriteFile(dst, data, info.Mode().Perm())
|
return os.WriteFile(dst, data, info.Mode().Perm())
|
||||||
}
|
}
|
||||||
|
|
||||||
func backupDir() string {
|
// BackupDir returns the shared backup directory used before overwriting files.
|
||||||
|
func BackupDir() string {
|
||||||
return filepath.Join(os.TempDir(), "ollama-backups")
|
return filepath.Join(os.TempDir(), "ollama-backups")
|
||||||
}
|
}
|
||||||
|
|
||||||
func backupToTmp(srcPath string) (string, error) {
|
func backupToTmp(srcPath string) (string, error) {
|
||||||
dir := backupDir()
|
dir := BackupDir()
|
||||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -50,8 +54,8 @@ func backupToTmp(srcPath string) (string, error) {
|
|||||||
return backupPath, nil
|
return backupPath, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeWithBackup writes data to path via temp file + rename, backing up any existing file first
|
// WriteWithBackup writes data to path via temp file + rename, backing up any existing file first.
|
||||||
func writeWithBackup(path string, data []byte) error {
|
func WriteWithBackup(path string, data []byte) error {
|
||||||
var backupPath string
|
var backupPath string
|
||||||
// backup must be created before any writes to the target file
|
// backup must be created before any writes to the target file
|
||||||
if existingContent, err := os.ReadFile(path); err == nil {
|
if existingContent, err := os.ReadFile(path); err == nil {
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package fileutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -9,6 +9,21 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
tmpRoot, err := os.MkdirTemp("", "fileutil-test-*")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.Setenv("TMPDIR", tmpRoot); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
code := m.Run()
|
||||||
|
_ = os.RemoveAll(tmpRoot)
|
||||||
|
os.Exit(code)
|
||||||
|
}
|
||||||
|
|
||||||
func mustMarshal(t *testing.T, v any) []byte {
|
func mustMarshal(t *testing.T, v any) []byte {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
data, err := json.MarshalIndent(v, "", " ")
|
data, err := json.MarshalIndent(v, "", " ")
|
||||||
@@ -18,14 +33,19 @@ func mustMarshal(t *testing.T, v any) []byte {
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isolatedTempDir(t *testing.T) string {
|
||||||
|
t.Helper()
|
||||||
|
return t.TempDir()
|
||||||
|
}
|
||||||
|
|
||||||
func TestWriteWithBackup(t *testing.T) {
|
func TestWriteWithBackup(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
|
|
||||||
t.Run("creates file", func(t *testing.T) {
|
t.Run("creates file", func(t *testing.T) {
|
||||||
path := filepath.Join(tmpDir, "new.json")
|
path := filepath.Join(tmpDir, "new.json")
|
||||||
data := mustMarshal(t, map[string]string{"key": "value"})
|
data := mustMarshal(t, map[string]string{"key": "value"})
|
||||||
|
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -43,17 +63,17 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("creates backup in /tmp/ollama-backups", func(t *testing.T) {
|
t.Run("creates backup in the temp backup directory", func(t *testing.T) {
|
||||||
path := filepath.Join(tmpDir, "backup.json")
|
path := filepath.Join(tmpDir, "backup.json")
|
||||||
|
|
||||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||||
|
|
||||||
data := mustMarshal(t, map[string]bool{"updated": true})
|
data := mustMarshal(t, map[string]bool{"updated": true})
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
entries, err := os.ReadDir(backupDir())
|
entries, err := os.ReadDir(BackupDir())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("backup directory not created")
|
t.Fatal("backup directory not created")
|
||||||
}
|
}
|
||||||
@@ -63,7 +83,7 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
if filepath.Ext(entry.Name()) != ".json" {
|
if filepath.Ext(entry.Name()) != ".json" {
|
||||||
name := entry.Name()
|
name := entry.Name()
|
||||||
if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." {
|
if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." {
|
||||||
backupPath := filepath.Join(backupDir(), name)
|
backupPath := filepath.Join(BackupDir(), name)
|
||||||
backup, err := os.ReadFile(backupPath)
|
backup, err := os.ReadFile(backupPath)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
var backupData map[string]bool
|
var backupData map[string]bool
|
||||||
@@ -79,7 +99,7 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !foundBackup {
|
if !foundBackup {
|
||||||
t.Error("backup file not created in /tmp/ollama-backups")
|
t.Error("backup file not created in backup directory")
|
||||||
}
|
}
|
||||||
|
|
||||||
current, _ := os.ReadFile(path)
|
current, _ := os.ReadFile(path)
|
||||||
@@ -94,11 +114,11 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
path := filepath.Join(tmpDir, "nobak.json")
|
path := filepath.Join(tmpDir, "nobak.json")
|
||||||
|
|
||||||
data := mustMarshal(t, map[string]string{"new": "file"})
|
data := mustMarshal(t, map[string]string{"new": "file"})
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
entries, _ := os.ReadDir(backupDir())
|
entries, _ := os.ReadDir(BackupDir())
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." {
|
if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." {
|
||||||
t.Error("backup should not exist for new file")
|
t.Error("backup should not exist for new file")
|
||||||
@@ -111,11 +131,11 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
|
|
||||||
data := mustMarshal(t, map[string]string{"key": "value"})
|
data := mustMarshal(t, map[string]string{"key": "value"})
|
||||||
|
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
entries1, _ := os.ReadDir(backupDir())
|
entries1, _ := os.ReadDir(BackupDir())
|
||||||
countBefore := 0
|
countBefore := 0
|
||||||
for _, e := range entries1 {
|
for _, e := range entries1 {
|
||||||
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
||||||
@@ -123,11 +143,11 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
entries2, _ := os.ReadDir(backupDir())
|
entries2, _ := os.ReadDir(BackupDir())
|
||||||
countAfter := 0
|
countAfter := 0
|
||||||
for _, e := range entries2 {
|
for _, e := range entries2 {
|
||||||
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
||||||
@@ -145,11 +165,11 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
|
|
||||||
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
|
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
|
||||||
data := mustMarshal(t, map[string]int{"v": 2})
|
data := mustMarshal(t, map[string]int{"v": 2})
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
entries, _ := os.ReadDir(backupDir())
|
entries, _ := os.ReadDir(BackupDir())
|
||||||
var found bool
|
var found bool
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
name := entry.Name()
|
name := entry.Name()
|
||||||
@@ -161,7 +181,7 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
found = true
|
found = true
|
||||||
os.Remove(filepath.Join(backupDir(), name))
|
os.Remove(filepath.Join(BackupDir(), name))
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -180,7 +200,7 @@ func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
|
|||||||
t.Skip("permission tests unreliable on Windows")
|
t.Skip("permission tests unreliable on Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
path := filepath.Join(tmpDir, "config.json")
|
path := filepath.Join(tmpDir, "config.json")
|
||||||
|
|
||||||
// Create original file
|
// Create original file
|
||||||
@@ -188,13 +208,13 @@ func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
|
|||||||
os.WriteFile(path, originalContent, 0o644)
|
os.WriteFile(path, originalContent, 0o644)
|
||||||
|
|
||||||
// Make backup directory read-only to force backup failure
|
// Make backup directory read-only to force backup failure
|
||||||
backupDir := backupDir()
|
backupDir := BackupDir()
|
||||||
os.MkdirAll(backupDir, 0o755)
|
os.MkdirAll(backupDir, 0o755)
|
||||||
os.Chmod(backupDir, 0o444) // Read-only
|
os.Chmod(backupDir, 0o444) // Read-only
|
||||||
defer os.Chmod(backupDir, 0o755)
|
defer os.Chmod(backupDir, 0o755)
|
||||||
|
|
||||||
newContent := []byte(`{"updated": true}`)
|
newContent := []byte(`{"updated": true}`)
|
||||||
err := writeWithBackup(path, newContent)
|
err := WriteWithBackup(path, newContent)
|
||||||
|
|
||||||
// Should fail because backup couldn't be created
|
// Should fail because backup couldn't be created
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -215,7 +235,7 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
|
|||||||
t.Skip("permission tests unreliable on Windows")
|
t.Skip("permission tests unreliable on Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
|
|
||||||
// Create a read-only directory
|
// Create a read-only directory
|
||||||
readOnlyDir := filepath.Join(tmpDir, "readonly")
|
readOnlyDir := filepath.Join(tmpDir, "readonly")
|
||||||
@@ -224,7 +244,7 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
|
|||||||
defer os.Chmod(readOnlyDir, 0o755)
|
defer os.Chmod(readOnlyDir, 0o755)
|
||||||
|
|
||||||
path := filepath.Join(readOnlyDir, "config.json")
|
path := filepath.Join(readOnlyDir, "config.json")
|
||||||
err := writeWithBackup(path, []byte(`{"test": true}`))
|
err := WriteWithBackup(path, []byte(`{"test": true}`))
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected permission error, got nil")
|
t.Error("expected permission error, got nil")
|
||||||
@@ -234,10 +254,10 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
|
|||||||
// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist.
|
// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist.
|
||||||
// writeWithBackup doesn't create directories - caller is responsible.
|
// writeWithBackup doesn't create directories - caller is responsible.
|
||||||
func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) {
|
func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json")
|
path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json")
|
||||||
|
|
||||||
err := writeWithBackup(path, []byte(`{"test": true}`))
|
err := WriteWithBackup(path, []byte(`{"test": true}`))
|
||||||
|
|
||||||
// Should fail because directory doesn't exist
|
// Should fail because directory doesn't exist
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -252,7 +272,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
|
|||||||
t.Skip("symlink tests may require admin on Windows")
|
t.Skip("symlink tests may require admin on Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
realFile := filepath.Join(tmpDir, "real.json")
|
realFile := filepath.Join(tmpDir, "real.json")
|
||||||
symlink := filepath.Join(tmpDir, "link.json")
|
symlink := filepath.Join(tmpDir, "link.json")
|
||||||
|
|
||||||
@@ -261,7 +281,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
|
|||||||
os.Symlink(realFile, symlink)
|
os.Symlink(realFile, symlink)
|
||||||
|
|
||||||
// Write through symlink
|
// Write through symlink
|
||||||
err := writeWithBackup(symlink, []byte(`{"v": 2}`))
|
err := WriteWithBackup(symlink, []byte(`{"v": 2}`))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("writeWithBackup through symlink failed: %v", err)
|
t.Fatalf("writeWithBackup through symlink failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -276,7 +296,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
|
|||||||
// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters.
|
// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters.
|
||||||
// User may have config files with unusual names.
|
// User may have config files with unusual names.
|
||||||
func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) {
|
func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
|
|
||||||
// File with spaces and special chars
|
// File with spaces and special chars
|
||||||
path := filepath.Join(tmpDir, "my config (backup).json")
|
path := filepath.Join(tmpDir, "my config (backup).json")
|
||||||
@@ -305,7 +325,7 @@ func TestCopyFile_PreservesPermissions(t *testing.T) {
|
|||||||
t.Skip("permission preservation tests unreliable on Windows")
|
t.Skip("permission preservation tests unreliable on Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
src := filepath.Join(tmpDir, "src.json")
|
src := filepath.Join(tmpDir, "src.json")
|
||||||
dst := filepath.Join(tmpDir, "dst.json")
|
dst := filepath.Join(tmpDir, "dst.json")
|
||||||
|
|
||||||
@@ -327,7 +347,7 @@ func TestCopyFile_PreservesPermissions(t *testing.T) {
|
|||||||
|
|
||||||
// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist.
|
// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist.
|
||||||
func TestCopyFile_SourceNotFound(t *testing.T) {
|
func TestCopyFile_SourceNotFound(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
src := filepath.Join(tmpDir, "nonexistent.json")
|
src := filepath.Join(tmpDir, "nonexistent.json")
|
||||||
dst := filepath.Join(tmpDir, "dst.json")
|
dst := filepath.Join(tmpDir, "dst.json")
|
||||||
|
|
||||||
@@ -339,11 +359,11 @@ func TestCopyFile_SourceNotFound(t *testing.T) {
|
|||||||
|
|
||||||
// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory.
|
// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory.
|
||||||
func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
|
func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
dirPath := filepath.Join(tmpDir, "actualdir")
|
dirPath := filepath.Join(tmpDir, "actualdir")
|
||||||
os.MkdirAll(dirPath, 0o755)
|
os.MkdirAll(dirPath, 0o755)
|
||||||
|
|
||||||
err := writeWithBackup(dirPath, []byte(`{"test": true}`))
|
err := WriteWithBackup(dirPath, []byte(`{"test": true}`))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error when target is a directory, got nil")
|
t.Error("expected error when target is a directory, got nil")
|
||||||
}
|
}
|
||||||
@@ -351,10 +371,10 @@ func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
|
|||||||
|
|
||||||
// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly.
|
// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly.
|
||||||
func TestWriteWithBackup_EmptyData(t *testing.T) {
|
func TestWriteWithBackup_EmptyData(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
path := filepath.Join(tmpDir, "empty.json")
|
path := filepath.Join(tmpDir, "empty.json")
|
||||||
|
|
||||||
err := writeWithBackup(path, []byte{})
|
err := WriteWithBackup(path, []byte{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("writeWithBackup with empty data failed: %v", err)
|
t.Fatalf("writeWithBackup with empty data failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -375,7 +395,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
|
|||||||
t.Skip("permission tests unreliable on Windows")
|
t.Skip("permission tests unreliable on Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
path := filepath.Join(tmpDir, "unreadable.json")
|
path := filepath.Join(tmpDir, "unreadable.json")
|
||||||
|
|
||||||
// Create file and make it unreadable
|
// Create file and make it unreadable
|
||||||
@@ -384,7 +404,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
|
|||||||
defer os.Chmod(path, 0o644)
|
defer os.Chmod(path, 0o644)
|
||||||
|
|
||||||
// Should fail because we can't read the file to compare/backup
|
// Should fail because we can't read the file to compare/backup
|
||||||
err := writeWithBackup(path, []byte(`{"updated": true}`))
|
err := WriteWithBackup(path, []byte(`{"updated": true}`))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error when file is unreadable, got nil")
|
t.Error("expected error when file is unreadable, got nil")
|
||||||
}
|
}
|
||||||
@@ -393,7 +413,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
|
|||||||
// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes
|
// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes
|
||||||
// within the same second (timestamp collision scenario).
|
// within the same second (timestamp collision scenario).
|
||||||
func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
path := filepath.Join(tmpDir, "rapid.json")
|
path := filepath.Join(tmpDir, "rapid.json")
|
||||||
|
|
||||||
// Create initial file
|
// Create initial file
|
||||||
@@ -402,7 +422,7 @@ func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
|||||||
// Rapid successive writes
|
// Rapid successive writes
|
||||||
for i := 1; i <= 3; i++ {
|
for i := 1; i <= 3; i++ {
|
||||||
data := []byte(fmt.Sprintf(`{"v": %d}`, i))
|
data := []byte(fmt.Sprintf(`{"v": %d}`, i))
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatalf("write %d failed: %v", i, err)
|
t.Fatalf("write %d failed: %v", i, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -414,7 +434,7 @@ func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Verify at least one backup exists
|
// Verify at least one backup exists
|
||||||
entries, _ := os.ReadDir(backupDir())
|
entries, _ := os.ReadDir(BackupDir())
|
||||||
var backupCount int
|
var backupCount int
|
||||||
for _, e := range entries {
|
for _, e := range entries {
|
||||||
if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." {
|
if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." {
|
||||||
@@ -432,8 +452,9 @@ func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
|
|||||||
t.Skip("test modifies system temp directory")
|
t.Skip("test modifies system temp directory")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tmpDir := isolatedTempDir(t)
|
||||||
// Create a file at the backup directory path
|
// Create a file at the backup directory path
|
||||||
backupPath := backupDir()
|
backupPath := BackupDir()
|
||||||
// Clean up any existing directory first
|
// Clean up any existing directory first
|
||||||
os.RemoveAll(backupPath)
|
os.RemoveAll(backupPath)
|
||||||
// Create a file instead of directory
|
// Create a file instead of directory
|
||||||
@@ -443,11 +464,10 @@ func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
|
|||||||
os.MkdirAll(backupPath, 0o755)
|
os.MkdirAll(backupPath, 0o755)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
path := filepath.Join(tmpDir, "test.json")
|
path := filepath.Join(tmpDir, "test.json")
|
||||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||||
|
|
||||||
err := writeWithBackup(path, []byte(`{"updated": true}`))
|
err := WriteWithBackup(path, []byte(`{"updated": true}`))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error when backup dir is a file, got nil")
|
t.Error("expected error when backup dir is a file, got nil")
|
||||||
}
|
}
|
||||||
@@ -459,7 +479,7 @@ func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
|
|||||||
t.Skip("permission tests unreliable on Windows")
|
t.Skip("permission tests unreliable on Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
|
|
||||||
// Count existing temp files
|
// Count existing temp files
|
||||||
countTempFiles := func() int {
|
countTempFiles := func() int {
|
||||||
@@ -493,7 +513,7 @@ func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
|
|||||||
badPath := filepath.Join(tmpDir, "isdir")
|
badPath := filepath.Join(tmpDir, "isdir")
|
||||||
os.MkdirAll(badPath, 0o755)
|
os.MkdirAll(badPath, 0o755)
|
||||||
|
|
||||||
_ = writeWithBackup(badPath, []byte(`{"test": true}`))
|
_ = WriteWithBackup(badPath, []byte(`{"test": true}`))
|
||||||
|
|
||||||
after := countTempFiles()
|
after := countTempFiles()
|
||||||
if after > before {
|
if after > before {
|
||||||
87
cmd/launch/claude.go
Normal file
87
cmd/launch/claude.go
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Claude implements Runner for Claude Code integration.
|
||||||
|
type Claude struct{}
|
||||||
|
|
||||||
|
func (c *Claude) String() string { return "Claude Code" }
|
||||||
|
|
||||||
|
func (c *Claude) args(model string, extra []string) []string {
|
||||||
|
var args []string
|
||||||
|
if model != "" {
|
||||||
|
args = append(args, "--model", model)
|
||||||
|
}
|
||||||
|
args = append(args, extra...)
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Claude) findPath() (string, error) {
|
||||||
|
if p, err := exec.LookPath("claude"); err == nil {
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
name := "claude"
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
name = "claude.exe"
|
||||||
|
}
|
||||||
|
fallback := filepath.Join(home, ".claude", "local", name)
|
||||||
|
if _, err := os.Stat(fallback); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return fallback, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Claude) Run(model string, args []string) error {
|
||||||
|
claudePath, err := c.findPath()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command(claudePath, c.args(model, args)...)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
|
||||||
|
env := append(os.Environ(),
|
||||||
|
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
|
||||||
|
"ANTHROPIC_API_KEY=",
|
||||||
|
"ANTHROPIC_AUTH_TOKEN=ollama",
|
||||||
|
"CLAUDE_CODE_ATTRIBUTION_HEADER=0",
|
||||||
|
)
|
||||||
|
|
||||||
|
env = append(env, c.modelEnvVars(model)...)
|
||||||
|
|
||||||
|
cmd.Env = env
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
|
||||||
|
func (c *Claude) modelEnvVars(model string) []string {
|
||||||
|
env := []string{
|
||||||
|
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + model,
|
||||||
|
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + model,
|
||||||
|
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + model,
|
||||||
|
"CLAUDE_CODE_SUBAGENT_MODEL=" + model,
|
||||||
|
}
|
||||||
|
|
||||||
|
if isCloudModelName(model) {
|
||||||
|
if l, ok := lookupCloudModelLimit(model); ok {
|
||||||
|
env = append(env, "CLAUDE_CODE_AUTO_COMPACT_WINDOW="+strconv.Itoa(l.Context))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return env
|
||||||
|
}
|
||||||
171
cmd/launch/claude_test.go
Normal file
171
cmd/launch/claude_test.go
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClaudeIntegration(t *testing.T) {
|
||||||
|
c := &Claude{}
|
||||||
|
|
||||||
|
t.Run("String", func(t *testing.T) {
|
||||||
|
if got := c.String(); got != "Claude Code" {
|
||||||
|
t.Errorf("String() = %q, want %q", got, "Claude Code")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("implements Runner", func(t *testing.T) {
|
||||||
|
var _ Runner = c
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeFindPath(t *testing.T) {
|
||||||
|
c := &Claude{}
|
||||||
|
|
||||||
|
t.Run("finds claude in PATH", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
name := "claude"
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
name = "claude.exe"
|
||||||
|
}
|
||||||
|
fakeBin := filepath.Join(tmpDir, name)
|
||||||
|
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
|
||||||
|
t.Setenv("PATH", tmpDir)
|
||||||
|
|
||||||
|
got, err := c.findPath()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if got != fakeBin {
|
||||||
|
t.Errorf("findPath() = %q, want %q", got, fakeBin)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("falls back to ~/.claude/local/claude", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("PATH", t.TempDir()) // empty dir, no claude binary
|
||||||
|
|
||||||
|
name := "claude"
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
name = "claude.exe"
|
||||||
|
}
|
||||||
|
fallback := filepath.Join(tmpDir, ".claude", "local", name)
|
||||||
|
os.MkdirAll(filepath.Dir(fallback), 0o755)
|
||||||
|
os.WriteFile(fallback, []byte("#!/bin/sh\n"), 0o755)
|
||||||
|
|
||||||
|
got, err := c.findPath()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if got != fallback {
|
||||||
|
t.Errorf("findPath() = %q, want %q", got, fallback)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns error when neither PATH nor fallback exists", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("PATH", t.TempDir()) // empty dir, no claude binary
|
||||||
|
|
||||||
|
_, err := c.findPath()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeArgs(t *testing.T) {
|
||||||
|
c := &Claude{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model string
|
||||||
|
args []string
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
{"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
|
||||||
|
{"empty model", "", nil, nil},
|
||||||
|
{"with model and verbose", "llama3.2", []string{"--verbose"}, []string{"--model", "llama3.2", "--verbose"}},
|
||||||
|
{"empty model with help", "", []string{"--help"}, []string{"--help"}},
|
||||||
|
{"with allowed tools", "llama3.2", []string{"--allowedTools", "Read,Write,Bash"}, []string{"--model", "llama3.2", "--allowedTools", "Read,Write,Bash"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := c.args(tt.model, tt.args)
|
||||||
|
if !slices.Equal(got, tt.want) {
|
||||||
|
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeModelEnvVars(t *testing.T) {
|
||||||
|
c := &Claude{}
|
||||||
|
|
||||||
|
envMap := func(envs []string) map[string]string {
|
||||||
|
m := make(map[string]string)
|
||||||
|
for _, e := range envs {
|
||||||
|
k, v, _ := strings.Cut(e, "=")
|
||||||
|
m[k] = v
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("maps all Claude model env vars to the provided model", func(t *testing.T) {
|
||||||
|
got := envMap(c.modelEnvVars("llama3.2"))
|
||||||
|
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2" {
|
||||||
|
t.Errorf("OPUS = %q, want llama3.2", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||||
|
}
|
||||||
|
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2" {
|
||||||
|
t.Errorf("SONNET = %q, want llama3.2", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||||
|
}
|
||||||
|
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2" {
|
||||||
|
t.Errorf("HAIKU = %q, want llama3.2", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||||
|
}
|
||||||
|
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2" {
|
||||||
|
t.Errorf("SUBAGENT = %q, want llama3.2", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||||
|
}
|
||||||
|
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
|
||||||
|
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty for local models", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("supports empty model", func(t *testing.T) {
|
||||||
|
got := envMap(c.modelEnvVars(""))
|
||||||
|
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "" {
|
||||||
|
t.Errorf("OPUS = %q, want empty", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||||
|
}
|
||||||
|
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "" {
|
||||||
|
t.Errorf("SONNET = %q, want empty", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||||
|
}
|
||||||
|
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "" {
|
||||||
|
t.Errorf("HAIKU = %q, want empty", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||||
|
}
|
||||||
|
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "" {
|
||||||
|
t.Errorf("SUBAGENT = %q, want empty", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||||
|
}
|
||||||
|
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
|
||||||
|
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("sets auto compact window for known cloud models", func(t *testing.T) {
|
||||||
|
got := envMap(c.modelEnvVars("glm-5:cloud"))
|
||||||
|
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "202752" {
|
||||||
|
t.Errorf("AUTO_COMPACT_WINDOW = %q, want 202752", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("does not set auto compact window for unknown cloud models", func(t *testing.T) {
|
||||||
|
got := envMap(c.modelEnvVars("unknown-model:cloud"))
|
||||||
|
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
|
||||||
|
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
104
cmd/launch/cline.go
Normal file
104
cmd/launch/cline.go
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Cline implements Runner and Editor for the Cline CLI integration
|
||||||
|
type Cline struct{}
|
||||||
|
|
||||||
|
func (c *Cline) String() string { return "Cline" }
|
||||||
|
|
||||||
|
func (c *Cline) Run(model string, args []string) error {
|
||||||
|
if _, err := exec.LookPath("cline"); err != nil {
|
||||||
|
return fmt.Errorf("cline is not installed, install with: npm install -g cline")
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command("cline", args...)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cline) Paths() []string {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
p := filepath.Join(home, ".cline", "data", "globalState.json")
|
||||||
|
if _, err := os.Stat(p); err == nil {
|
||||||
|
return []string{p}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cline) Edit(models []string) error {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
configPath := filepath.Join(home, ".cline", "data", "globalState.json")
|
||||||
|
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
config := make(map[string]any)
|
||||||
|
if data, err := os.ReadFile(configPath); err == nil {
|
||||||
|
if err := json.Unmarshal(data, &config); err != nil {
|
||||||
|
return fmt.Errorf("failed to parse config: %w, at: %s", err, configPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set Ollama as the provider for both act and plan modes
|
||||||
|
baseURL := envconfig.Host().String()
|
||||||
|
config["ollamaBaseUrl"] = baseURL
|
||||||
|
config["actModeApiProvider"] = "ollama"
|
||||||
|
config["actModeOllamaModelId"] = models[0]
|
||||||
|
config["actModeOllamaBaseUrl"] = baseURL
|
||||||
|
config["planModeApiProvider"] = "ollama"
|
||||||
|
config["planModeOllamaModelId"] = models[0]
|
||||||
|
config["planModeOllamaBaseUrl"] = baseURL
|
||||||
|
|
||||||
|
config["welcomeViewCompleted"] = true
|
||||||
|
|
||||||
|
data, err := json.MarshalIndent(config, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return fileutil.WriteWithBackup(configPath, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cline) Models() []string {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := fileutil.ReadJSON(filepath.Join(home, ".cline", "data", "globalState.json"))
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if config["actModeApiProvider"] != "ollama" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
modelID, _ := config["actModeOllamaModelId"].(string)
|
||||||
|
if modelID == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return []string{modelID}
|
||||||
|
}
|
||||||
204
cmd/launch/cline_test.go
Normal file
204
cmd/launch/cline_test.go
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClineIntegration(t *testing.T) {
|
||||||
|
c := &Cline{}
|
||||||
|
|
||||||
|
t.Run("String", func(t *testing.T) {
|
||||||
|
if got := c.String(); got != "Cline" {
|
||||||
|
t.Errorf("String() = %q, want %q", got, "Cline")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("implements Runner", func(t *testing.T) {
|
||||||
|
var _ Runner = c
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("implements Editor", func(t *testing.T) {
|
||||||
|
var _ Editor = c
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClineEdit(t *testing.T) {
|
||||||
|
c := &Cline{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
configDir := filepath.Join(tmpDir, ".cline", "data")
|
||||||
|
configPath := filepath.Join(configDir, "globalState.json")
|
||||||
|
|
||||||
|
readConfig := func() map[string]any {
|
||||||
|
data, _ := os.ReadFile(configPath)
|
||||||
|
var config map[string]any
|
||||||
|
json.Unmarshal(data, &config)
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("creates config from scratch", func(t *testing.T) {
|
||||||
|
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||||
|
|
||||||
|
if err := c.Edit([]string{"kimi-k2.5:cloud"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config := readConfig()
|
||||||
|
if config["actModeApiProvider"] != "ollama" {
|
||||||
|
t.Errorf("actModeApiProvider = %v, want ollama", config["actModeApiProvider"])
|
||||||
|
}
|
||||||
|
if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
|
||||||
|
t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud", config["actModeOllamaModelId"])
|
||||||
|
}
|
||||||
|
if config["planModeApiProvider"] != "ollama" {
|
||||||
|
t.Errorf("planModeApiProvider = %v, want ollama", config["planModeApiProvider"])
|
||||||
|
}
|
||||||
|
if config["planModeOllamaModelId"] != "kimi-k2.5:cloud" {
|
||||||
|
t.Errorf("planModeOllamaModelId = %v, want kimi-k2.5:cloud", config["planModeOllamaModelId"])
|
||||||
|
}
|
||||||
|
if config["welcomeViewCompleted"] != true {
|
||||||
|
t.Errorf("welcomeViewCompleted = %v, want true", config["welcomeViewCompleted"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserves existing fields", func(t *testing.T) {
|
||||||
|
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
|
||||||
|
existing := map[string]any{
|
||||||
|
"remoteRulesToggles": map[string]any{},
|
||||||
|
"remoteWorkflowToggles": map[string]any{},
|
||||||
|
"customSetting": "keep-me",
|
||||||
|
}
|
||||||
|
data, _ := json.Marshal(existing)
|
||||||
|
os.WriteFile(configPath, data, 0o644)
|
||||||
|
|
||||||
|
if err := c.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config := readConfig()
|
||||||
|
if config["customSetting"] != "keep-me" {
|
||||||
|
t.Errorf("customSetting was not preserved")
|
||||||
|
}
|
||||||
|
if config["actModeOllamaModelId"] != "glm-5:cloud" {
|
||||||
|
t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("updates model on re-edit", func(t *testing.T) {
|
||||||
|
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||||
|
|
||||||
|
if err := c.Edit([]string{"kimi-k2.5:cloud"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := c.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config := readConfig()
|
||||||
|
if config["actModeOllamaModelId"] != "glm-5:cloud" {
|
||||||
|
t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
|
||||||
|
}
|
||||||
|
if config["planModeOllamaModelId"] != "glm-5:cloud" {
|
||||||
|
t.Errorf("planModeOllamaModelId = %v, want glm-5:cloud", config["planModeOllamaModelId"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty models is no-op", func(t *testing.T) {
|
||||||
|
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||||
|
|
||||||
|
if err := c.Edit(nil); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Stat(configPath); !os.IsNotExist(err) {
|
||||||
|
t.Error("expected no config file to be created for empty models")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("uses first model as primary", func(t *testing.T) {
|
||||||
|
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||||
|
|
||||||
|
if err := c.Edit([]string{"kimi-k2.5:cloud", "glm-5:cloud"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config := readConfig()
|
||||||
|
if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
|
||||||
|
t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud (first model)", config["actModeOllamaModelId"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClineModels(t *testing.T) {
|
||||||
|
c := &Cline{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
configDir := filepath.Join(tmpDir, ".cline", "data")
|
||||||
|
configPath := filepath.Join(configDir, "globalState.json")
|
||||||
|
|
||||||
|
t.Run("returns nil when no config", func(t *testing.T) {
|
||||||
|
if models := c.Models(); models != nil {
|
||||||
|
t.Errorf("Models() = %v, want nil", models)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns nil when provider is not ollama", func(t *testing.T) {
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
config := map[string]any{
|
||||||
|
"actModeApiProvider": "anthropic",
|
||||||
|
"actModeOllamaModelId": "some-model",
|
||||||
|
}
|
||||||
|
data, _ := json.Marshal(config)
|
||||||
|
os.WriteFile(configPath, data, 0o644)
|
||||||
|
|
||||||
|
if models := c.Models(); models != nil {
|
||||||
|
t.Errorf("Models() = %v, want nil", models)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns model when ollama is configured", func(t *testing.T) {
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
config := map[string]any{
|
||||||
|
"actModeApiProvider": "ollama",
|
||||||
|
"actModeOllamaModelId": "kimi-k2.5:cloud",
|
||||||
|
}
|
||||||
|
data, _ := json.Marshal(config)
|
||||||
|
os.WriteFile(configPath, data, 0o644)
|
||||||
|
|
||||||
|
models := c.Models()
|
||||||
|
if len(models) != 1 || models[0] != "kimi-k2.5:cloud" {
|
||||||
|
t.Errorf("Models() = %v, want [kimi-k2.5:cloud]", models)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClinePaths(t *testing.T) {
|
||||||
|
c := &Cline{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
t.Run("returns nil when no config exists", func(t *testing.T) {
|
||||||
|
if paths := c.Paths(); paths != nil {
|
||||||
|
t.Errorf("Paths() = %v, want nil", paths)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns path when config exists", func(t *testing.T) {
|
||||||
|
configDir := filepath.Join(tmpDir, ".cline", "data")
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
configPath := filepath.Join(configDir, "globalState.json")
|
||||||
|
os.WriteFile(configPath, []byte("{}"), 0o644)
|
||||||
|
|
||||||
|
paths := c.Paths()
|
||||||
|
if len(paths) != 1 || paths[0] != configPath {
|
||||||
|
t.Errorf("Paths() = %v, want [%s]", paths, configPath)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"os/exec"
|
"os/exec"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
"golang.org/x/mod/semver"
|
"golang.org/x/mod/semver"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -14,23 +15,28 @@ type Codex struct{}
|
|||||||
|
|
||||||
func (c *Codex) String() string { return "Codex" }
|
func (c *Codex) String() string { return "Codex" }
|
||||||
|
|
||||||
func (c *Codex) args(model string) []string {
|
func (c *Codex) args(model string, extra []string) []string {
|
||||||
args := []string{"--oss"}
|
args := []string{"--oss"}
|
||||||
if model != "" {
|
if model != "" {
|
||||||
args = append(args, "-m", model)
|
args = append(args, "-m", model)
|
||||||
}
|
}
|
||||||
|
args = append(args, extra...)
|
||||||
return args
|
return args
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Codex) Run(model string) error {
|
func (c *Codex) Run(model string, args []string) error {
|
||||||
if err := checkCodexVersion(); err != nil {
|
if err := checkCodexVersion(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := exec.Command("codex", c.args(model)...)
|
cmd := exec.Command("codex", c.args(model, args)...)
|
||||||
cmd.Stdin = os.Stdin
|
cmd.Stdin = os.Stdin
|
||||||
cmd.Stdout = os.Stdout
|
cmd.Stdout = os.Stdout
|
||||||
cmd.Stderr = os.Stderr
|
cmd.Stderr = os.Stderr
|
||||||
|
cmd.Env = append(os.Environ(),
|
||||||
|
"OPENAI_BASE_URL="+envconfig.Host().String()+"/v1/",
|
||||||
|
"OPENAI_API_KEY=ollama",
|
||||||
|
)
|
||||||
return cmd.Run()
|
return cmd.Run()
|
||||||
}
|
}
|
||||||
|
|
||||||
31
cmd/launch/codex_test.go
Normal file
31
cmd/launch/codex_test.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCodexArgs(t *testing.T) {
|
||||||
|
c := &Codex{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model string
|
||||||
|
args []string
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
{"with model", "llama3.2", nil, []string{"--oss", "-m", "llama3.2"}},
|
||||||
|
{"empty model", "", nil, []string{"--oss"}},
|
||||||
|
{"with model and profile", "qwen3.5", []string{"-p", "myprofile"}, []string{"--oss", "-m", "qwen3.5", "-p", "myprofile"}},
|
||||||
|
{"with sandbox flag", "llama3.2", []string{"--sandbox", "workspace-write"}, []string{"--oss", "-m", "llama3.2", "--sandbox", "workspace-write"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := c.args(tt.model, tt.args)
|
||||||
|
if !slices.Equal(got, tt.want) {
|
||||||
|
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
598
cmd/launch/command_test.go
Normal file
598
cmd/launch/command_test.go
Normal file
@@ -0,0 +1,598 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/ollama/ollama/cmd/config"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
func captureStderr(t *testing.T, fn func()) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
oldStderr := os.Stderr
|
||||||
|
r, w, err := os.Pipe()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create stderr pipe: %v", err)
|
||||||
|
}
|
||||||
|
os.Stderr = w
|
||||||
|
defer func() {
|
||||||
|
os.Stderr = oldStderr
|
||||||
|
}()
|
||||||
|
|
||||||
|
done := make(chan string, 1)
|
||||||
|
go func() {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
_, _ = io.Copy(&buf, r)
|
||||||
|
done <- buf.String()
|
||||||
|
}()
|
||||||
|
|
||||||
|
fn()
|
||||||
|
|
||||||
|
_ = w.Close()
|
||||||
|
return <-done
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmd(t *testing.T) {
|
||||||
|
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
mockTUI := func(cmd *cobra.Command) {}
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
|
||||||
|
t.Run("command structure", func(t *testing.T) {
|
||||||
|
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
|
||||||
|
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]")
|
||||||
|
}
|
||||||
|
if cmd.Short == "" {
|
||||||
|
t.Error("Short description should not be empty")
|
||||||
|
}
|
||||||
|
if cmd.Long == "" {
|
||||||
|
t.Error("Long description should not be empty")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("flags exist", func(t *testing.T) {
|
||||||
|
if cmd.Flags().Lookup("model") == nil {
|
||||||
|
t.Error("--model flag should exist")
|
||||||
|
}
|
||||||
|
if cmd.Flags().Lookup("config") == nil {
|
||||||
|
t.Error("--config flag should exist")
|
||||||
|
}
|
||||||
|
if cmd.Flags().Lookup("yes") == nil {
|
||||||
|
t.Error("--yes flag should exist")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("PreRunE is set", func(t *testing.T) {
|
||||||
|
if cmd.PreRunE == nil {
|
||||||
|
t.Error("PreRunE should be set to checkServerHeartbeat")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdTUICallback(t *testing.T) {
|
||||||
|
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("no args calls TUI", func(t *testing.T) {
|
||||||
|
tuiCalled := false
|
||||||
|
mockTUI := func(cmd *cobra.Command) {
|
||||||
|
tuiCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
cmd.SetArgs([]string{})
|
||||||
|
_ = cmd.Execute()
|
||||||
|
|
||||||
|
if !tuiCalled {
|
||||||
|
t.Error("TUI callback should be called when no args provided")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("integration arg bypasses TUI", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.NotFoundHandler())
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
tuiCalled := false
|
||||||
|
mockTUI := func(cmd *cobra.Command) {
|
||||||
|
tuiCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
cmd.SetArgs([]string{"claude"})
|
||||||
|
_ = cmd.Execute()
|
||||||
|
|
||||||
|
if tuiCalled {
|
||||||
|
t.Error("TUI callback should NOT be called when integration arg provided")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("--model flag without integration returns error", func(t *testing.T) {
|
||||||
|
tuiCalled := false
|
||||||
|
mockTUI := func(cmd *cobra.Command) {
|
||||||
|
tuiCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
cmd.SetArgs([]string{"--model", "test-model"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected --model without an integration to fail")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "require an integration name") {
|
||||||
|
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if tuiCalled {
|
||||||
|
t.Error("TUI callback should NOT be called when --model is provided without an integration")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("--config flag without integration returns error", func(t *testing.T) {
|
||||||
|
tuiCalled := false
|
||||||
|
mockTUI := func(cmd *cobra.Command) {
|
||||||
|
tuiCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
cmd.SetArgs([]string{"--config"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected --config without an integration to fail")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "require an integration name") {
|
||||||
|
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if tuiCalled {
|
||||||
|
t.Error("TUI callback should NOT be called when --config is provided without an integration")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("--yes flag without integration returns error", func(t *testing.T) {
|
||||||
|
tuiCalled := false
|
||||||
|
mockTUI := func(cmd *cobra.Command) {
|
||||||
|
tuiCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
cmd.SetArgs([]string{"--yes"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected --yes without an integration to fail")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "require an integration name") {
|
||||||
|
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if tuiCalled {
|
||||||
|
t.Error("TUI callback should NOT be called when --yes is provided without an integration")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("extra args without integration return error", func(t *testing.T) {
|
||||||
|
tuiCalled := false
|
||||||
|
mockTUI := func(cmd *cobra.Command) {
|
||||||
|
tuiCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
cmd.SetArgs([]string{"--model", "test-model", "--", "--sandbox", "workspace-write"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected flags and extra args without an integration to fail")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "require an integration name") {
|
||||||
|
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if tuiCalled {
|
||||||
|
t.Error("TUI callback should NOT be called when flags or extra args are provided without an integration")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdNilHeartbeat(t *testing.T) {
|
||||||
|
cmd := LaunchCmd(nil, nil)
|
||||||
|
if cmd == nil {
|
||||||
|
t.Fatal("LaunchCmd returned nil")
|
||||||
|
}
|
||||||
|
if cmd.PreRunE != nil {
|
||||||
|
t.Log("Note: PreRunE is set even when nil is passed (acceptable)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdModelFlagFiltersDisabledCloudFromSavedConfig(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
if err := config.SaveIntegration("stubeditor", []string{"glm-5:cloud"}); err != nil {
|
||||||
|
t.Fatalf("failed to seed saved config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/status":
|
||||||
|
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprintf(w, `{"model":"llama3.2"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherEditorRunner{}
|
||||||
|
restore := OverrideIntegration("stubeditor", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
|
||||||
|
if err := cmd.Execute(); err != nil {
|
||||||
|
t.Fatalf("launch command failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
saved, err := config.LoadIntegration("stubeditor")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to reload integration config: %v", err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
|
||||||
|
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
|
||||||
|
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "llama3.2" {
|
||||||
|
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdModelFlagClearsDisabledCloudOverride(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/status":
|
||||||
|
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||||
|
case "/api/tags":
|
||||||
|
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherSingleRunner{}
|
||||||
|
restore := OverrideIntegration("stubapp", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
oldSelector := DefaultSingleSelector
|
||||||
|
defer func() { DefaultSingleSelector = oldSelector }()
|
||||||
|
|
||||||
|
var selectorCalls int
|
||||||
|
var gotCurrent string
|
||||||
|
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||||
|
selectorCalls++
|
||||||
|
gotCurrent = current
|
||||||
|
return "llama3.2", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubapp", "--model", "glm-5:cloud"})
|
||||||
|
stderr := captureStderr(t, func() {
|
||||||
|
if err := cmd.Execute(); err != nil {
|
||||||
|
t.Fatalf("launch command failed: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if selectorCalls != 1 {
|
||||||
|
t.Fatalf("expected disabled cloud override to fall back to selector, got %d calls", selectorCalls)
|
||||||
|
}
|
||||||
|
if gotCurrent != "" {
|
||||||
|
t.Fatalf("expected disabled override to be cleared before selection, got current %q", gotCurrent)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "llama3.2" {
|
||||||
|
t.Fatalf("expected launch to run with replacement local model, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
if !strings.Contains(stderr, "Warning: ignoring --model glm-5:cloud because cloud is disabled") {
|
||||||
|
t.Fatalf("expected disabled-cloud warning, got stderr: %q", stderr)
|
||||||
|
}
|
||||||
|
|
||||||
|
saved, err := config.LoadIntegration("stubapp")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to reload integration config: %v", err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
|
||||||
|
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdYes_AutoConfirmsLaunchPromptPath(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
withInteractiveSession(t, false)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||||
|
case "/api/status":
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprint(w, `{"error":"not found"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherEditorRunner{paths: []string{"/tmp/stubeditor.json"}}
|
||||||
|
restore := OverrideIntegration("stubeditor", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||||
|
t.Fatalf("unexpected prompt with --yes: %q", prompt)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2", "--yes"})
|
||||||
|
if err := cmd.Execute(); err != nil {
|
||||||
|
t.Fatalf("launch command with --yes failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
|
||||||
|
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "llama3.2" {
|
||||||
|
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdHeadlessWithYes_AutoPullsMissingLocalModel(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
withInteractiveSession(t, false)
|
||||||
|
|
||||||
|
var pullCalled bool
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprint(w, `{"error":"model not found"}`)
|
||||||
|
case "/api/pull":
|
||||||
|
pullCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprint(w, `{"status":"success"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherSingleRunner{}
|
||||||
|
restore := OverrideIntegration("stubapp", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||||
|
t.Fatalf("unexpected prompt with --yes in headless autopull path: %q", prompt)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubapp", "--model", "missing-model", "--yes"})
|
||||||
|
if err := cmd.Execute(); err != nil {
|
||||||
|
t.Fatalf("launch command with --yes failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !pullCalled {
|
||||||
|
t.Fatal("expected missing local model to be auto-pulled with --yes in headless mode")
|
||||||
|
}
|
||||||
|
if stub.ranModel != "missing-model" {
|
||||||
|
t.Fatalf("expected launch to run with pulled model, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdHeadlessWithoutYes_ReturnsActionableConfirmError(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
withInteractiveSession(t, false)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||||
|
case "/api/status":
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprint(w, `{"error":"not found"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherEditorRunner{paths: []string{"/tmp/stubeditor.json"}}
|
||||||
|
restore := OverrideIntegration("stubeditor", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||||
|
t.Fatalf("unexpected prompt in headless non-yes mode: %q", prompt)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected launch command to fail without --yes in headless mode")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "re-run with --yes") {
|
||||||
|
t.Fatalf("expected actionable --yes guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if len(stub.edited) != 0 {
|
||||||
|
t.Fatalf("expected no editor writes when confirmation is blocked, got %v", stub.edited)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "" {
|
||||||
|
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdIntegrationArgPromptsForModelWithSavedSelection(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
if err := config.SaveIntegration("stubapp", []string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatalf("failed to seed saved config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/tags":
|
||||||
|
fmt.Fprint(w, `{"models":[{"name":"llama3.2"},{"name":"qwen3:8b"}]}`)
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model":"qwen3:8b"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherSingleRunner{}
|
||||||
|
restore := OverrideIntegration("stubapp", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
oldSelector := DefaultSingleSelector
|
||||||
|
defer func() { DefaultSingleSelector = oldSelector }()
|
||||||
|
|
||||||
|
var gotCurrent string
|
||||||
|
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||||
|
gotCurrent = current
|
||||||
|
return "qwen3:8b", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubapp"})
|
||||||
|
if err := cmd.Execute(); err != nil {
|
||||||
|
t.Fatalf("launch command failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotCurrent != "llama3.2" {
|
||||||
|
t.Fatalf("expected selector current model to be saved model llama3.2, got %q", gotCurrent)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "qwen3:8b" {
|
||||||
|
t.Fatalf("expected launch to run selected model qwen3:8b, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
saved, err := config.LoadIntegration("stubapp")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to reload integration config: %v", err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff([]string{"qwen3:8b"}, saved.Models); diff != "" {
|
||||||
|
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdHeadlessYes_IntegrationRequiresModelEvenWhenSaved(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
withInteractiveSession(t, false)
|
||||||
|
|
||||||
|
if err := config.SaveIntegration("stubapp", []string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatalf("failed to seed saved config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherSingleRunner{}
|
||||||
|
restore := OverrideIntegration("stubapp", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
oldSelector := DefaultSingleSelector
|
||||||
|
defer func() { DefaultSingleSelector = oldSelector }()
|
||||||
|
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||||
|
t.Fatal("selector should not be called for headless --yes saved-model launch")
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubapp", "--yes"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected launch command to fail when --yes is used headlessly without --model")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "requires --model <model>") {
|
||||||
|
t.Fatalf("expected actionable --model guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "" {
|
||||||
|
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdHeadlessYes_IntegrationWithoutSavedModelReturnsError(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
withInteractiveSession(t, false)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherSingleRunner{}
|
||||||
|
restore := OverrideIntegration("stubapp", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
oldSelector := DefaultSingleSelector
|
||||||
|
defer func() { DefaultSingleSelector = oldSelector }()
|
||||||
|
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||||
|
t.Fatal("selector should not be called for headless --yes without saved model")
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubapp", "--yes"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected launch command to fail when --yes is used headlessly without --model")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "requires --model <model>") {
|
||||||
|
t.Fatalf("expected actionable --model guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "" {
|
||||||
|
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -39,21 +40,12 @@ type modelEntry struct {
|
|||||||
|
|
||||||
func (d *Droid) String() string { return "Droid" }
|
func (d *Droid) String() string { return "Droid" }
|
||||||
|
|
||||||
func (d *Droid) Run(model string) error {
|
func (d *Droid) Run(model string, args []string) error {
|
||||||
if _, err := exec.LookPath("droid"); err != nil {
|
if _, err := exec.LookPath("droid"); err != nil {
|
||||||
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
|
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call Edit() to ensure config is up-to-date before launch
|
cmd := exec.Command("droid", args...)
|
||||||
models := []string{model}
|
|
||||||
if config, err := loadIntegration("droid"); err == nil && len(config.Models) > 0 {
|
|
||||||
models = config.Models
|
|
||||||
}
|
|
||||||
if err := d.Edit(models); err != nil {
|
|
||||||
return fmt.Errorf("setup failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command("droid")
|
|
||||||
cmd.Stdin = os.Stdin
|
cmd.Stdin = os.Stdin
|
||||||
cmd.Stdout = os.Stdout
|
cmd.Stdout = os.Stdout
|
||||||
cmd.Stderr = os.Stderr
|
cmd.Stderr = os.Stderr
|
||||||
@@ -98,6 +90,16 @@ func (d *Droid) Edit(models []string) error {
|
|||||||
json.Unmarshal(data, &settings) // ignore error, zero values are fine
|
json.Unmarshal(data, &settings) // ignore error, zero values are fine
|
||||||
}
|
}
|
||||||
|
|
||||||
|
settingsMap = updateDroidSettings(settingsMap, settings, models)
|
||||||
|
|
||||||
|
data, err := json.MarshalIndent(settingsMap, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return fileutil.WriteWithBackup(settingsPath, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateDroidSettings(settingsMap map[string]any, settings droidSettings, models []string) map[string]any {
|
||||||
// Keep only non-Ollama models from the raw map (preserves extra fields)
|
// Keep only non-Ollama models from the raw map (preserves extra fields)
|
||||||
// Rebuild Ollama models
|
// Rebuild Ollama models
|
||||||
var nonOllamaModels []any
|
var nonOllamaModels []any
|
||||||
@@ -112,9 +114,16 @@ func (d *Droid) Edit(models []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
|
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
|
||||||
|
|
||||||
var newModels []any
|
var newModels []any
|
||||||
var defaultModelID string
|
var defaultModelID string
|
||||||
for i, model := range models {
|
for i, model := range models {
|
||||||
|
maxOutput := 64000
|
||||||
|
if isCloudModelName(model) {
|
||||||
|
if l, ok := lookupCloudModelLimit(model); ok {
|
||||||
|
maxOutput = l.Output
|
||||||
|
}
|
||||||
|
}
|
||||||
modelID := fmt.Sprintf("custom:%s-%d", model, i)
|
modelID := fmt.Sprintf("custom:%s-%d", model, i)
|
||||||
newModels = append(newModels, modelEntry{
|
newModels = append(newModels, modelEntry{
|
||||||
Model: model,
|
Model: model,
|
||||||
@@ -122,7 +131,7 @@ func (d *Droid) Edit(models []string) error {
|
|||||||
BaseURL: envconfig.Host().String() + "/v1",
|
BaseURL: envconfig.Host().String() + "/v1",
|
||||||
APIKey: "ollama",
|
APIKey: "ollama",
|
||||||
Provider: "generic-chat-completion-api",
|
Provider: "generic-chat-completion-api",
|
||||||
MaxOutputTokens: 64000,
|
MaxOutputTokens: maxOutput,
|
||||||
SupportsImages: false,
|
SupportsImages: false,
|
||||||
ID: modelID,
|
ID: modelID,
|
||||||
Index: i,
|
Index: i,
|
||||||
@@ -146,12 +155,7 @@ func (d *Droid) Edit(models []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
settingsMap["sessionDefaultSettings"] = sessionSettings
|
settingsMap["sessionDefaultSettings"] = sessionSettings
|
||||||
|
return settingsMap
|
||||||
data, err := json.MarshalIndent(settingsMap, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return writeWithBackup(settingsPath, data)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Droid) Models() []string {
|
func (d *Droid) Models() []string {
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -6,6 +6,8 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDroidIntegration(t *testing.T) {
|
func TestDroidIntegration(t *testing.T) {
|
||||||
@@ -362,7 +364,7 @@ func TestDroidEdit_DuplicateModels(t *testing.T) {
|
|||||||
t.Fatalf("Edit with duplicates failed: %v", err)
|
t.Fatalf("Edit with duplicates failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := readJSONFile(settingsPath)
|
settings, err := fileutil.ReadJSON(settingsPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("readJSONFile failed: %v", err)
|
t.Fatalf("readJSONFile failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -392,7 +394,7 @@ func TestDroidEdit_MalformedModelEntry(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Malformed entries (non-object) are dropped - only valid model objects are preserved
|
// Malformed entries (non-object) are dropped - only valid model objects are preserved
|
||||||
settings, _ := readJSONFile(settingsPath)
|
settings, _ := fileutil.ReadJSON(settingsPath)
|
||||||
customModels, _ := settings["customModels"].([]any)
|
customModels, _ := settings["customModels"].([]any)
|
||||||
|
|
||||||
// Should have: 1 new Ollama model only (malformed entries dropped)
|
// Should have: 1 new Ollama model only (malformed entries dropped)
|
||||||
@@ -419,7 +421,7 @@ func TestDroidEdit_WrongTypeSessionSettings(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Should create proper sessionDefaultSettings
|
// Should create proper sessionDefaultSettings
|
||||||
settings, _ := readJSONFile(settingsPath)
|
settings, _ := fileutil.ReadJSON(settingsPath)
|
||||||
session, ok := settings["sessionDefaultSettings"].(map[string]any)
|
session, ok := settings["sessionDefaultSettings"].(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("sessionDefaultSettings should be map after setup, got %T", settings["sessionDefaultSettings"])
|
t.Fatalf("sessionDefaultSettings should be map after setup, got %T", settings["sessionDefaultSettings"])
|
||||||
@@ -1008,34 +1010,34 @@ func TestDroidEdit_ModelNamesWithSpecialCharacters(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDroidEdit_MissingCustomModelsKey(t *testing.T) {
|
func TestDroidEdit_MissingCustomModelsKey(t *testing.T) {
|
||||||
d := &Droid{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
settingsDir := filepath.Join(tmpDir, ".factory")
|
|
||||||
settingsPath := filepath.Join(settingsDir, "settings.json")
|
|
||||||
|
|
||||||
os.MkdirAll(settingsDir, 0o755)
|
|
||||||
|
|
||||||
// No customModels key at all
|
// No customModels key at all
|
||||||
original := `{
|
original := `{
|
||||||
"diffMode": "github",
|
"diffMode": "github",
|
||||||
"sessionDefaultSettings": {"autonomyMode": "auto-high"}
|
"sessionDefaultSettings": {"autonomyMode": "auto-high"}
|
||||||
}`
|
}`
|
||||||
os.WriteFile(settingsPath, []byte(original), 0o644)
|
|
||||||
|
|
||||||
if err := d.Edit([]string{"model-a"}); err != nil {
|
var settingsStruct droidSettings
|
||||||
|
var settings map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(original), &settings); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(original), &settingsStruct); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
data, _ := os.ReadFile(settingsPath)
|
settings = updateDroidSettings(settings, settingsStruct, []string{"model-a"})
|
||||||
var settings map[string]any
|
|
||||||
json.Unmarshal(data, &settings)
|
|
||||||
|
|
||||||
// Original fields preserved
|
// Original fields preserved
|
||||||
if settings["diffMode"] != "github" {
|
if settings["diffMode"] != "github" {
|
||||||
t.Error("diffMode not preserved")
|
t.Error("diffMode not preserved")
|
||||||
}
|
}
|
||||||
|
session, ok := settings["sessionDefaultSettings"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("sessionDefaultSettings not preserved")
|
||||||
|
}
|
||||||
|
if session["autonomyMode"] != "auto-high" {
|
||||||
|
t.Error("sessionDefaultSettings.autonomyMode not preserved")
|
||||||
|
}
|
||||||
|
|
||||||
// customModels created
|
// customModels created
|
||||||
models, ok := settings["customModels"].([]any)
|
models, ok := settings["customModels"].([]any)
|
||||||
@@ -1251,6 +1253,47 @@ func TestDroidEdit_LargeNumberOfModels(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDroidEdit_LocalModelDefaultMaxOutput(t *testing.T) {
|
||||||
|
d := &Droid{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
settingsDir := filepath.Join(tmpDir, ".factory")
|
||||||
|
settingsPath := filepath.Join(settingsDir, "settings.json")
|
||||||
|
|
||||||
|
if err := d.Edit([]string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := os.ReadFile(settingsPath)
|
||||||
|
var settings map[string]any
|
||||||
|
json.Unmarshal(data, &settings)
|
||||||
|
|
||||||
|
models := settings["customModels"].([]any)
|
||||||
|
entry := models[0].(map[string]any)
|
||||||
|
if entry["maxOutputTokens"] != float64(64000) {
|
||||||
|
t.Errorf("local model maxOutputTokens = %v, want 64000", entry["maxOutputTokens"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
|
||||||
|
// Verify that every cloud model in cloudModelLimits has a valid output
|
||||||
|
// value that would be used for maxOutputTokens when the selected model uses
|
||||||
|
// the explicit :cloud source tag.
|
||||||
|
for name, expected := range cloudModelLimits {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
cloudName := name + ":cloud"
|
||||||
|
l, ok := lookupCloudModelLimit(cloudName)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName)
|
||||||
|
}
|
||||||
|
if l.Output != expected.Output {
|
||||||
|
t.Errorf("output = %d, want %d", l.Output, expected.Output)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDroidEdit_ArraysWithMixedTypes(t *testing.T) {
|
func TestDroidEdit_ArraysWithMixedTypes(t *testing.T) {
|
||||||
d := &Droid{}
|
d := &Droid{}
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
1574
cmd/launch/integrations_test.go
Normal file
1574
cmd/launch/integrations_test.go
Normal file
File diff suppressed because it is too large
Load Diff
840
cmd/launch/launch.go
Normal file
840
cmd/launch/launch.go
Normal file
@@ -0,0 +1,840 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/cmd/config"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"golang.org/x/term"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LauncherState is the launch-owned snapshot used to render the root launcher menu.
|
||||||
|
type LauncherState struct {
|
||||||
|
LastSelection string
|
||||||
|
RunModel string
|
||||||
|
RunModelUsable bool
|
||||||
|
Integrations map[string]LauncherIntegrationState
|
||||||
|
}
|
||||||
|
|
||||||
|
// LauncherIntegrationState is the launch-owned status for one launcher integration.
|
||||||
|
type LauncherIntegrationState struct {
|
||||||
|
Name string
|
||||||
|
DisplayName string
|
||||||
|
Description string
|
||||||
|
Installed bool
|
||||||
|
AutoInstallable bool
|
||||||
|
Selectable bool
|
||||||
|
Changeable bool
|
||||||
|
CurrentModel string
|
||||||
|
ModelUsable bool
|
||||||
|
InstallHint string
|
||||||
|
Editor bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunModelRequest controls how the root launcher resolves the chat model.
|
||||||
|
type RunModelRequest struct {
|
||||||
|
ForcePicker bool
|
||||||
|
Policy *LaunchPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
// LaunchConfirmMode controls confirmation behavior across launch flows.
|
||||||
|
type LaunchConfirmMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// LaunchConfirmPrompt prompts the user for confirmation.
|
||||||
|
LaunchConfirmPrompt LaunchConfirmMode = iota
|
||||||
|
// LaunchConfirmAutoApprove skips prompts and treats confirmation as accepted.
|
||||||
|
LaunchConfirmAutoApprove
|
||||||
|
// LaunchConfirmRequireYes rejects confirmation requests with a --yes hint.
|
||||||
|
LaunchConfirmRequireYes
|
||||||
|
)
|
||||||
|
|
||||||
|
// LaunchMissingModelMode controls local missing-model handling in launch flows.
|
||||||
|
type LaunchMissingModelMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// LaunchMissingModelPromptToPull prompts to pull a missing local model.
|
||||||
|
LaunchMissingModelPromptToPull LaunchMissingModelMode = iota
|
||||||
|
// LaunchMissingModelAutoPull pulls a missing local model without prompting.
|
||||||
|
LaunchMissingModelAutoPull
|
||||||
|
// LaunchMissingModelFail fails immediately when a local model is missing.
|
||||||
|
LaunchMissingModelFail
|
||||||
|
)
|
||||||
|
|
||||||
|
// LaunchPolicy controls launch behavior that may vary by caller context.
|
||||||
|
type LaunchPolicy struct {
|
||||||
|
Confirm LaunchConfirmMode
|
||||||
|
MissingModel LaunchMissingModelMode
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultLaunchPolicy(interactive bool, yes bool) LaunchPolicy {
|
||||||
|
policy := LaunchPolicy{
|
||||||
|
Confirm: LaunchConfirmPrompt,
|
||||||
|
MissingModel: LaunchMissingModelPromptToPull,
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case yes:
|
||||||
|
// if yes flag is set, auto approve and auto pull
|
||||||
|
policy.Confirm = LaunchConfirmAutoApprove
|
||||||
|
policy.MissingModel = LaunchMissingModelAutoPull
|
||||||
|
case !interactive:
|
||||||
|
// otherwise make sure to stop when needed
|
||||||
|
policy.Confirm = LaunchConfirmRequireYes
|
||||||
|
policy.MissingModel = LaunchMissingModelFail
|
||||||
|
}
|
||||||
|
return policy
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p LaunchPolicy) confirmPolicy() launchConfirmPolicy {
|
||||||
|
switch p.Confirm {
|
||||||
|
case LaunchConfirmAutoApprove:
|
||||||
|
return launchConfirmPolicy{yes: true}
|
||||||
|
case LaunchConfirmRequireYes:
|
||||||
|
return launchConfirmPolicy{requireYesMessage: true}
|
||||||
|
default:
|
||||||
|
return launchConfirmPolicy{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p LaunchPolicy) missingModelPolicy() missingModelPolicy {
|
||||||
|
switch p.MissingModel {
|
||||||
|
case LaunchMissingModelAutoPull:
|
||||||
|
return missingModelAutoPull
|
||||||
|
case LaunchMissingModelFail:
|
||||||
|
return missingModelFail
|
||||||
|
default:
|
||||||
|
return missingModelPromptPull
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntegrationLaunchRequest controls the canonical integration launcher flow.
|
||||||
|
type IntegrationLaunchRequest struct {
|
||||||
|
Name string
|
||||||
|
ModelOverride string
|
||||||
|
ForceConfigure bool
|
||||||
|
ConfigureOnly bool
|
||||||
|
ExtraArgs []string
|
||||||
|
Policy *LaunchPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
var isInteractiveSession = func() bool {
|
||||||
|
return term.IsTerminal(int(os.Stdin.Fd())) && term.IsTerminal(int(os.Stdout.Fd()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Runner executes a model with an integration.
|
||||||
|
type Runner interface {
|
||||||
|
Run(model string, args []string) error
|
||||||
|
String() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Editor can edit config files for integrations that support model configuration.
|
||||||
|
type Editor interface {
|
||||||
|
Paths() []string
|
||||||
|
Edit(models []string) error
|
||||||
|
Models() []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type modelInfo struct {
|
||||||
|
Name string
|
||||||
|
Remote bool
|
||||||
|
ToolCapable bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelInfo re-exports launcher model inventory details for callers.
|
||||||
|
type ModelInfo = modelInfo
|
||||||
|
|
||||||
|
// ModelItem represents a model for selection UIs.
|
||||||
|
type ModelItem struct {
|
||||||
|
Name string
|
||||||
|
Description string
|
||||||
|
Recommended bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// LaunchCmd returns the cobra command for launching integrations.
|
||||||
|
// The runTUI callback is called when the root launcher UI should be shown.
|
||||||
|
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error, runTUI func(cmd *cobra.Command)) *cobra.Command {
|
||||||
|
var modelFlag string
|
||||||
|
var configFlag bool
|
||||||
|
var yesFlag bool
|
||||||
|
|
||||||
|
cmd := &cobra.Command{
|
||||||
|
Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
|
||||||
|
Short: "Launch the Ollama menu or an integration",
|
||||||
|
Long: `Launch the Ollama interactive menu, or directly launch a specific integration.
|
||||||
|
|
||||||
|
Without arguments, this is equivalent to running 'ollama' directly.
|
||||||
|
Flags and extra arguments require an integration name.
|
||||||
|
|
||||||
|
Supported integrations:
|
||||||
|
claude Claude Code
|
||||||
|
cline Cline
|
||||||
|
codex Codex
|
||||||
|
droid Droid
|
||||||
|
opencode OpenCode
|
||||||
|
openclaw OpenClaw (aliases: clawdbot, moltbot)
|
||||||
|
pi Pi
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
ollama launch
|
||||||
|
ollama launch claude
|
||||||
|
ollama launch claude --model <model>
|
||||||
|
ollama launch droid --config (does not auto-launch)
|
||||||
|
ollama launch codex -- -p myprofile (pass extra args to integration)
|
||||||
|
ollama launch codex -- --sandbox workspace-write`,
|
||||||
|
Args: cobra.ArbitraryArgs,
|
||||||
|
PreRunE: checkServerHeartbeat,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
policy := defaultLaunchPolicy(isInteractiveSession(), yesFlag)
|
||||||
|
// reset when done to make sure state doens't leak between launches
|
||||||
|
restoreConfirmPolicy := withLaunchConfirmPolicy(policy.confirmPolicy())
|
||||||
|
defer restoreConfirmPolicy()
|
||||||
|
|
||||||
|
var name string
|
||||||
|
var passArgs []string
|
||||||
|
dashIdx := cmd.ArgsLenAtDash()
|
||||||
|
|
||||||
|
if dashIdx == -1 {
|
||||||
|
if len(args) > 1 {
|
||||||
|
return fmt.Errorf("unexpected arguments: %v\nUse '--' to pass extra arguments to the integration", args[1:])
|
||||||
|
}
|
||||||
|
if len(args) == 1 {
|
||||||
|
name = args[0]
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if dashIdx > 1 {
|
||||||
|
return fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx)
|
||||||
|
}
|
||||||
|
if dashIdx == 1 {
|
||||||
|
name = args[0]
|
||||||
|
}
|
||||||
|
passArgs = args[dashIdx:]
|
||||||
|
}
|
||||||
|
|
||||||
|
if name == "" {
|
||||||
|
if cmd.Flags().Changed("model") || cmd.Flags().Changed("config") || cmd.Flags().Changed("yes") || len(passArgs) > 0 {
|
||||||
|
return fmt.Errorf("flags and extra args require an integration name, for example: 'ollama launch claude --model qwen3.5'")
|
||||||
|
}
|
||||||
|
runTUI(cmd)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelFlag != "" && isCloudModelName(modelFlag) {
|
||||||
|
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||||
|
if disabled, _ := cloudStatusDisabled(cmd.Context(), client); disabled {
|
||||||
|
fmt.Fprintf(os.Stderr, "Warning: ignoring --model %s because cloud is disabled\n", modelFlag)
|
||||||
|
modelFlag = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
headlessYes := yesFlag && !isInteractiveSession()
|
||||||
|
err := LaunchIntegration(cmd.Context(), IntegrationLaunchRequest{
|
||||||
|
Name: name,
|
||||||
|
ModelOverride: modelFlag,
|
||||||
|
ForceConfigure: configFlag || (modelFlag == "" && !headlessYes),
|
||||||
|
ConfigureOnly: configFlag,
|
||||||
|
ExtraArgs: passArgs,
|
||||||
|
Policy: &policy,
|
||||||
|
})
|
||||||
|
if errors.Is(err, ErrCancelled) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Flags().StringVar(&modelFlag, "model", "", "Model to use")
|
||||||
|
cmd.Flags().BoolVar(&configFlag, "config", false, "Configure without launching")
|
||||||
|
cmd.Flags().BoolVarP(&yesFlag, "yes", "y", false, "Automatically answer yes to confirmation prompts")
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
type launcherClient struct {
|
||||||
|
apiClient *api.Client
|
||||||
|
modelInventory []ModelInfo
|
||||||
|
inventoryLoaded bool
|
||||||
|
policy LaunchPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLauncherClient(policy LaunchPolicy) (*launcherClient, error) {
|
||||||
|
apiClient, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &launcherClient{
|
||||||
|
apiClient: apiClient,
|
||||||
|
policy: policy,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildLauncherState returns the launch-owned root launcher menu snapshot.
|
||||||
|
func BuildLauncherState(ctx context.Context) (*LauncherState, error) {
|
||||||
|
launchClient, err := newLauncherClient(defaultLaunchPolicy(isInteractiveSession(), false))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return launchClient.buildLauncherState(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveRunModel returns the model that should be used for interactive chat.
|
||||||
|
func ResolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
|
||||||
|
// Called by the launcher TUI "Run a model" action (cmd/runLauncherAction),
|
||||||
|
// which resolves models separately from LaunchIntegration. Callers can pass
|
||||||
|
// Policy directly; otherwise we fall back to ambient --yes/session defaults.
|
||||||
|
policy := defaultLaunchPolicy(isInteractiveSession(), currentLaunchConfirmPolicy.yes)
|
||||||
|
if req.Policy != nil {
|
||||||
|
policy = *req.Policy
|
||||||
|
}
|
||||||
|
|
||||||
|
launchClient, err := newLauncherClient(policy)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return launchClient.resolveRunModel(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LaunchIntegration runs the canonical launcher flow for one integration.
|
||||||
|
func LaunchIntegration(ctx context.Context, req IntegrationLaunchRequest) error {
|
||||||
|
name, runner, err := LookupIntegration(req.Name)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !req.ConfigureOnly {
|
||||||
|
if err := EnsureIntegrationInstalled(name, runner); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var policy LaunchPolicy
|
||||||
|
// TUI does not set a policy, whereas ollama launch <app> does as it can have flags which change the behavior
|
||||||
|
if req.Policy == nil {
|
||||||
|
policy = defaultLaunchPolicy(isInteractiveSession(), false)
|
||||||
|
} else {
|
||||||
|
policy = *req.Policy
|
||||||
|
}
|
||||||
|
|
||||||
|
launchClient, err := newLauncherClient(policy)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
saved, _ := loadStoredIntegrationConfig(name)
|
||||||
|
// In headless --yes mode we cannot prompt, so require an explicit --model.
|
||||||
|
if policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() && req.ModelOverride == "" {
|
||||||
|
return fmt.Errorf("headless --yes launch for %s requires --model <model>", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if editor, ok := runner.(Editor); ok {
|
||||||
|
return launchClient.launchEditorIntegration(ctx, name, runner, editor, saved, req)
|
||||||
|
}
|
||||||
|
return launchClient.launchSingleIntegration(ctx, name, runner, saved, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) buildLauncherState(ctx context.Context) (*LauncherState, error) {
|
||||||
|
_ = c.loadModelInventoryOnce(ctx)
|
||||||
|
|
||||||
|
state := &LauncherState{
|
||||||
|
LastSelection: config.LastSelection(),
|
||||||
|
RunModel: config.LastModel(),
|
||||||
|
Integrations: make(map[string]LauncherIntegrationState),
|
||||||
|
}
|
||||||
|
runModelUsable, err := c.savedModelUsable(ctx, state.RunModel)
|
||||||
|
if err != nil {
|
||||||
|
runModelUsable = false
|
||||||
|
}
|
||||||
|
state.RunModelUsable = runModelUsable
|
||||||
|
|
||||||
|
for _, info := range ListIntegrationInfos() {
|
||||||
|
integrationState, err := c.buildLauncherIntegrationState(ctx, info)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
state.Integrations[info.Name] = integrationState
|
||||||
|
}
|
||||||
|
|
||||||
|
return state, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) buildLauncherIntegrationState(ctx context.Context, info IntegrationInfo) (LauncherIntegrationState, error) {
|
||||||
|
integration, err := integrationFor(info.Name)
|
||||||
|
if err != nil {
|
||||||
|
return LauncherIntegrationState{}, err
|
||||||
|
}
|
||||||
|
currentModel, usable, err := c.launcherModelState(ctx, info.Name, integration.editor)
|
||||||
|
if err != nil {
|
||||||
|
return LauncherIntegrationState{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return LauncherIntegrationState{
|
||||||
|
Name: info.Name,
|
||||||
|
DisplayName: info.DisplayName,
|
||||||
|
Description: info.Description,
|
||||||
|
Installed: integration.installed,
|
||||||
|
AutoInstallable: integration.autoInstallable,
|
||||||
|
Selectable: integration.installed || integration.autoInstallable,
|
||||||
|
Changeable: integration.installed || integration.autoInstallable,
|
||||||
|
CurrentModel: currentModel,
|
||||||
|
ModelUsable: usable,
|
||||||
|
InstallHint: integration.installHint,
|
||||||
|
Editor: integration.editor,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) launcherModelState(ctx context.Context, name string, isEditor bool) (string, bool, error) {
|
||||||
|
cfg, loadErr := loadStoredIntegrationConfig(name)
|
||||||
|
hasModels := loadErr == nil && len(cfg.Models) > 0
|
||||||
|
if !hasModels {
|
||||||
|
return "", false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if isEditor {
|
||||||
|
filtered := c.filterDisabledCloudModels(ctx, cfg.Models)
|
||||||
|
if len(filtered) > 0 {
|
||||||
|
return filtered[0], true, nil
|
||||||
|
}
|
||||||
|
return cfg.Models[0], false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
model := cfg.Models[0]
|
||||||
|
usable, usableErr := c.savedModelUsable(ctx, model)
|
||||||
|
return model, usableErr == nil && usable, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
|
||||||
|
current := config.LastModel()
|
||||||
|
if !req.ForcePicker && current != "" && c.policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() {
|
||||||
|
if err := c.ensureModelsReady(ctx, []string{current}); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "Headless mode: auto-selected last used model %q\n", current)
|
||||||
|
return current, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !req.ForcePicker {
|
||||||
|
usable, err := c.savedModelUsable(ctx, current)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if usable {
|
||||||
|
if err := c.ensureModelsReady(ctx, []string{current}); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return current, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
model, err := c.selectSingleModelWithSelector(ctx, "Select model to run:", current, DefaultSingleSelector)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if model != current {
|
||||||
|
if err := config.SetLastModel(model); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return model, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) launchSingleIntegration(ctx context.Context, name string, runner Runner, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
|
||||||
|
current := primaryModelFromConfig(saved)
|
||||||
|
target := req.ModelOverride
|
||||||
|
needsConfigure := req.ForceConfigure
|
||||||
|
|
||||||
|
if target == "" {
|
||||||
|
target = current
|
||||||
|
usable, err := c.savedModelUsable(ctx, target)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !usable {
|
||||||
|
needsConfigure = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if needsConfigure {
|
||||||
|
selected, err := c.selectSingleModelWithSelector(ctx, fmt.Sprintf("Select model for %s:", runner), target, DefaultSingleSelector)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
target = selected
|
||||||
|
} else if err := c.ensureModelsReady(ctx, []string{target}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if target == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if target != current {
|
||||||
|
if err := config.SaveIntegration(name, []string{target}); err != nil {
|
||||||
|
return fmt.Errorf("failed to save: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return launchAfterConfiguration(name, runner, target, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) launchEditorIntegration(ctx context.Context, name string, runner Runner, editor Editor, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
|
||||||
|
models, needsConfigure := c.resolveEditorLaunchModels(ctx, saved, req)
|
||||||
|
|
||||||
|
if needsConfigure {
|
||||||
|
selected, err := c.selectMultiModelsForIntegration(ctx, runner, models)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
models = selected
|
||||||
|
} else if err := c.ensureModelsReady(ctx, models); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if needsConfigure || req.ModelOverride != "" {
|
||||||
|
if err := prepareEditorIntegration(name, runner, editor, models); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return launchAfterConfiguration(name, runner, models[0], req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) selectSingleModelWithSelector(ctx context.Context, title, current string, selector SingleSelector) (string, error) {
|
||||||
|
if selector == nil {
|
||||||
|
return "", fmt.Errorf("no selector configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
items, _, err := c.loadSelectableModels(ctx, nil, current, "no models available, run 'ollama pull <model>' first")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
selected, err := selector(title, items, current)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := c.ensureModelsReady(ctx, []string{selected}); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return selected, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) selectMultiModelsForIntegration(ctx context.Context, runner Runner, preChecked []string) ([]string, error) {
|
||||||
|
if DefaultMultiSelector == nil {
|
||||||
|
return nil, fmt.Errorf("no selector configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
current := firstModel(preChecked)
|
||||||
|
|
||||||
|
items, orderedChecked, err := c.loadSelectableModels(ctx, preChecked, current, "no models available")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(preChecked) > 0 {
|
||||||
|
// Keep list order stable in multi-select even when there are existing checks.
|
||||||
|
// checked/default state still comes from orderedChecked.
|
||||||
|
stableItems, _, stableErr := c.loadSelectableModels(ctx, nil, current, "no models available")
|
||||||
|
if stableErr != nil {
|
||||||
|
return nil, stableErr
|
||||||
|
}
|
||||||
|
items = stableItems
|
||||||
|
}
|
||||||
|
|
||||||
|
selected, err := DefaultMultiSelector(fmt.Sprintf("Select models for %s:", runner), items, orderedChecked)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := c.ensureModelsReady(ctx, selected); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return selected, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) loadSelectableModels(ctx context.Context, preChecked []string, current, emptyMessage string) ([]ModelItem, []string, error) {
|
||||||
|
if err := c.loadModelInventoryOnce(ctx); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||||
|
items, orderedChecked, _, _ := buildModelList(c.modelInventory, preChecked, current)
|
||||||
|
if cloudDisabled {
|
||||||
|
items = filterCloudItems(items)
|
||||||
|
orderedChecked = c.filterDisabledCloudModels(ctx, orderedChecked)
|
||||||
|
}
|
||||||
|
if len(items) == 0 {
|
||||||
|
return nil, nil, errors.New(emptyMessage)
|
||||||
|
}
|
||||||
|
return items, orderedChecked, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) ensureModelsReady(ctx context.Context, models []string) error {
|
||||||
|
var deduped []string
|
||||||
|
seen := make(map[string]bool, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
if model == "" || seen[model] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[model] = true
|
||||||
|
deduped = append(deduped, model)
|
||||||
|
}
|
||||||
|
models = deduped
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cloudModels := make(map[string]bool, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
isCloudModel := isCloudModelName(model)
|
||||||
|
if isCloudModel {
|
||||||
|
cloudModels[model] = true
|
||||||
|
}
|
||||||
|
if err := showOrPullWithPolicy(ctx, c.apiClient, model, c.policy.missingModelPolicy(), isCloudModel); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ensureAuth(ctx, c.apiClient, cloudModels, models)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) resolveEditorLaunchModels(ctx context.Context, saved *config.IntegrationConfig, req IntegrationLaunchRequest) ([]string, bool) {
|
||||||
|
if req.ForceConfigure {
|
||||||
|
return editorPreCheckedModels(saved, req.ModelOverride), true
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.ModelOverride != "" {
|
||||||
|
models := append([]string{req.ModelOverride}, additionalSavedModels(saved, req.ModelOverride)...)
|
||||||
|
models = c.filterDisabledCloudModels(ctx, models)
|
||||||
|
return models, len(models) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if saved == nil || len(saved.Models) == 0 {
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
|
||||||
|
models := c.filterDisabledCloudModels(ctx, saved.Models)
|
||||||
|
return models, len(models) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) filterDisabledCloudModels(ctx context.Context, models []string) []string {
|
||||||
|
// if connection cannot be established or there is a 404, cloud models will continue to be displayed
|
||||||
|
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||||
|
if !cloudDisabled {
|
||||||
|
return append([]string(nil), models...)
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := make([]string, 0, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
if !isCloudModelName(model) {
|
||||||
|
filtered = append(filtered, model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) savedModelUsable(ctx context.Context, name string) (bool, error) {
|
||||||
|
if err := c.loadModelInventoryOnce(ctx); err != nil {
|
||||||
|
return c.showBasedModelUsable(ctx, name)
|
||||||
|
}
|
||||||
|
return c.singleModelUsable(ctx, name), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) showBasedModelUsable(ctx context.Context, name string) (bool, error) {
|
||||||
|
if name == "" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := c.apiClient.Show(ctx, &api.ShowRequest{Model: name})
|
||||||
|
if err != nil {
|
||||||
|
var statusErr api.StatusError
|
||||||
|
if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if isCloudModelName(name) || info.RemoteModel != "" {
|
||||||
|
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||||
|
|
||||||
|
return !cloudDisabled, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) singleModelUsable(ctx context.Context, name string) bool {
|
||||||
|
if name == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if isCloudModelName(name) {
|
||||||
|
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||||
|
return !cloudDisabled
|
||||||
|
}
|
||||||
|
return c.hasLocalModel(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) hasLocalModel(name string) bool {
|
||||||
|
for _, model := range c.modelInventory {
|
||||||
|
if model.Remote {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if model.Name == name || strings.HasPrefix(model.Name, name+":") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) loadModelInventoryOnce(ctx context.Context) error {
|
||||||
|
if c.inventoryLoaded {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.apiClient.List(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.modelInventory = c.modelInventory[:0]
|
||||||
|
for _, model := range resp.Models {
|
||||||
|
c.modelInventory = append(c.modelInventory, ModelInfo{
|
||||||
|
Name: model.Name,
|
||||||
|
Remote: model.RemoteModel != "",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||||
|
if cloudDisabled {
|
||||||
|
c.modelInventory = filterCloudModels(c.modelInventory)
|
||||||
|
}
|
||||||
|
c.inventoryLoaded = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runIntegration(runner Runner, modelName string, args []string) error {
|
||||||
|
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", runner, modelName)
|
||||||
|
return runner.Run(modelName, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
func launchAfterConfiguration(name string, runner Runner, model string, req IntegrationLaunchRequest) error {
|
||||||
|
if req.ConfigureOnly {
|
||||||
|
launch, err := ConfirmPrompt(fmt.Sprintf("Launch %s now?", runner))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !launch {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := EnsureIntegrationInstalled(name, runner); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return runIntegration(runner, model, req.ExtraArgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadStoredIntegrationConfig(name string) (*config.IntegrationConfig, error) {
|
||||||
|
cfg, err := config.LoadIntegration(name)
|
||||||
|
if err == nil {
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
if !errors.Is(err, os.ErrNotExist) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
spec, specErr := LookupIntegrationSpec(name)
|
||||||
|
if specErr != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, alias := range spec.Aliases {
|
||||||
|
legacy, legacyErr := config.LoadIntegration(alias)
|
||||||
|
if legacyErr == nil {
|
||||||
|
migrateLegacyIntegrationConfig(spec.Name, legacy)
|
||||||
|
if migrated, migratedErr := config.LoadIntegration(spec.Name); migratedErr == nil {
|
||||||
|
return migrated, nil
|
||||||
|
}
|
||||||
|
return legacy, nil
|
||||||
|
}
|
||||||
|
if legacyErr != nil && !errors.Is(legacyErr, os.ErrNotExist) {
|
||||||
|
return nil, legacyErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func migrateLegacyIntegrationConfig(canonical string, legacy *config.IntegrationConfig) {
|
||||||
|
if legacy == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = config.SaveIntegration(canonical, append([]string(nil), legacy.Models...))
|
||||||
|
if len(legacy.Aliases) > 0 {
|
||||||
|
_ = config.SaveAliases(canonical, cloneAliases(legacy.Aliases))
|
||||||
|
}
|
||||||
|
if legacy.Onboarded {
|
||||||
|
_ = config.MarkIntegrationOnboarded(canonical)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func primaryModelFromConfig(cfg *config.IntegrationConfig) string {
|
||||||
|
if cfg == nil || len(cfg.Models) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return cfg.Models[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneAliases(aliases map[string]string) map[string]string {
|
||||||
|
if len(aliases) == 0 {
|
||||||
|
return make(map[string]string)
|
||||||
|
}
|
||||||
|
|
||||||
|
cloned := make(map[string]string, len(aliases))
|
||||||
|
for key, value := range aliases {
|
||||||
|
cloned[key] = value
|
||||||
|
}
|
||||||
|
return cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
func singleModelPrechecked(current string) []string {
|
||||||
|
if current == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return []string{current}
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstModel(models []string) string {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return models[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func editorPreCheckedModels(saved *config.IntegrationConfig, override string) []string {
|
||||||
|
if override == "" {
|
||||||
|
if saved == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return append([]string(nil), saved.Models...)
|
||||||
|
}
|
||||||
|
return append([]string{override}, additionalSavedModels(saved, override)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func additionalSavedModels(saved *config.IntegrationConfig, exclude string) []string {
|
||||||
|
if saved == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var models []string
|
||||||
|
for _, model := range saved.Models {
|
||||||
|
if model != exclude {
|
||||||
|
models = append(models, model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return models
|
||||||
|
}
|
||||||
1498
cmd/launch/launch_test.go
Normal file
1498
cmd/launch/launch_test.go
Normal file
File diff suppressed because it is too large
Load Diff
494
cmd/launch/models.go
Normal file
494
cmd/launch/models.go
Normal file
@@ -0,0 +1,494 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"runtime"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/cmd/config"
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
|
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||||
|
"github.com/ollama/ollama/internal/modelref"
|
||||||
|
"github.com/ollama/ollama/progress"
|
||||||
|
)
|
||||||
|
|
||||||
|
var recommendedModels = []ModelItem{
|
||||||
|
{Name: "kimi-k2.5:cloud", Description: "Multimodal reasoning with subagents", Recommended: true},
|
||||||
|
{Name: "qwen3.5:cloud", Description: "Reasoning, coding, and agentic tool use with vision", Recommended: true},
|
||||||
|
{Name: "glm-5:cloud", Description: "Reasoning and code generation", Recommended: true},
|
||||||
|
{Name: "minimax-m2.7:cloud", Description: "Fast, efficient coding and real-world productivity", Recommended: true},
|
||||||
|
{Name: "glm-4.7-flash", Description: "Reasoning and code generation locally", Recommended: true},
|
||||||
|
{Name: "qwen3.5", Description: "Reasoning, coding, and visual understanding locally", Recommended: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
var recommendedVRAM = map[string]string{
|
||||||
|
"glm-4.7-flash": "~25GB",
|
||||||
|
"qwen3.5": "~11GB",
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloudModelLimit holds context and output token limits for a cloud model.
|
||||||
|
type cloudModelLimit struct {
|
||||||
|
Context int
|
||||||
|
Output int
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloudModelLimits maps cloud model base names to their token limits.
|
||||||
|
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
|
||||||
|
var cloudModelLimits = map[string]cloudModelLimit{
|
||||||
|
"minimax-m2.7": {Context: 204_800, Output: 128_000},
|
||||||
|
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
|
||||||
|
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
|
||||||
|
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
|
||||||
|
"glm-4.6": {Context: 202_752, Output: 131_072},
|
||||||
|
"glm-4.7": {Context: 202_752, Output: 131_072},
|
||||||
|
"glm-5": {Context: 202_752, Output: 131_072},
|
||||||
|
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
|
||||||
|
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
|
||||||
|
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
|
||||||
|
"kimi-k2.5": {Context: 262_144, Output: 262_144},
|
||||||
|
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
|
||||||
|
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
|
||||||
|
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
|
||||||
|
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
|
||||||
|
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
|
||||||
|
"qwen3.5": {Context: 262_144, Output: 32_768},
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupCloudModelLimit returns the token limits for a cloud model.
|
||||||
|
// It normalizes explicit cloud source suffixes before checking the shared limit map.
|
||||||
|
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
||||||
|
base, stripped := modelref.StripCloudSourceTag(name)
|
||||||
|
if stripped {
|
||||||
|
if l, ok := cloudModelLimits[base]; ok {
|
||||||
|
return l, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cloudModelLimit{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// missingModelPolicy controls how model-not-found errors should be handled.
|
||||||
|
type missingModelPolicy int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// missingModelPromptPull prompts the user to download missing local models.
|
||||||
|
missingModelPromptPull missingModelPolicy = iota
|
||||||
|
// missingModelAutoPull downloads missing local models without prompting.
|
||||||
|
missingModelAutoPull
|
||||||
|
// missingModelFail returns an error for missing local models without prompting.
|
||||||
|
missingModelFail
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenBrowser opens the URL in the user's browser.
|
||||||
|
func OpenBrowser(url string) {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
_ = exec.Command("open", url).Start()
|
||||||
|
case "linux":
|
||||||
|
// Skip on headless systems where no display server is available
|
||||||
|
if os.Getenv("DISPLAY") == "" && os.Getenv("WAYLAND_DISPLAY") == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = exec.Command("xdg-open", url).Start()
|
||||||
|
case "windows":
|
||||||
|
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureAuth ensures the user is signed in before cloud-backed models run.
|
||||||
|
func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error {
|
||||||
|
var selectedCloudModels []string
|
||||||
|
for _, m := range selected {
|
||||||
|
if cloudModels[m] {
|
||||||
|
selectedCloudModels = append(selectedCloudModels, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(selectedCloudModels) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
|
||||||
|
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := client.Whoami(ctx)
|
||||||
|
if err == nil && user != nil && user.Name != "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var aErr api.AuthorizationError
|
||||||
|
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
modelList := strings.Join(selectedCloudModels, ", ")
|
||||||
|
|
||||||
|
if DefaultSignIn != nil {
|
||||||
|
_, err := DefaultSignIn(modelList, aErr.SigninURL)
|
||||||
|
if errors.Is(err, ErrCancelled) {
|
||||||
|
return ErrCancelled
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%s requires sign in", modelList)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
yes, err := ConfirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||||
|
if errors.Is(err, ErrCancelled) {
|
||||||
|
return ErrCancelled
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !yes {
|
||||||
|
return ErrCancelled
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||||
|
OpenBrowser(aErr.SigninURL)
|
||||||
|
|
||||||
|
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||||
|
frame := 0
|
||||||
|
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||||
|
|
||||||
|
ticker := time.NewTicker(200 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||||
|
return ctx.Err()
|
||||||
|
case <-ticker.C:
|
||||||
|
frame++
|
||||||
|
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||||
|
|
||||||
|
if frame%10 == 0 {
|
||||||
|
u, err := client.Whoami(ctx)
|
||||||
|
if err == nil && u != nil && u.Name != "" {
|
||||||
|
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// showOrPullWithPolicy checks if a model exists and applies the provided missing-model policy.
|
||||||
|
func showOrPullWithPolicy(ctx context.Context, client *api.Client, model string, policy missingModelPolicy, isCloudModel bool) error {
|
||||||
|
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
var statusErr api.StatusError
|
||||||
|
if !errors.As(err, &statusErr) || statusErr.StatusCode != http.StatusNotFound {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isCloudModel {
|
||||||
|
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
|
||||||
|
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
|
||||||
|
}
|
||||||
|
return fmt.Errorf("model %q not found", model)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch policy {
|
||||||
|
case missingModelAutoPull:
|
||||||
|
return pullMissingModel(ctx, client, model)
|
||||||
|
case missingModelFail:
|
||||||
|
return fmt.Errorf("model %q not found; run 'ollama pull %s' first, or use --yes to auto-pull", model, model)
|
||||||
|
default:
|
||||||
|
return confirmAndPull(ctx, client, model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func confirmAndPull(ctx context.Context, client *api.Client, model string) error {
|
||||||
|
if ok, err := ConfirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
|
||||||
|
return err
|
||||||
|
} else if !ok {
|
||||||
|
return errCancelled
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "\n")
|
||||||
|
return pullMissingModel(ctx, client, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func pullMissingModel(ctx context.Context, client *api.Client, model string) error {
|
||||||
|
if err := pullModel(ctx, client, model, false); err != nil {
|
||||||
|
return fmt.Errorf("failed to pull %s: %w", model, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepareEditorIntegration persists models and applies editor-managed config files.
|
||||||
|
func prepareEditorIntegration(name string, runner Runner, editor Editor, models []string) error {
|
||||||
|
if ok, err := confirmEditorEdit(runner, editor); err != nil {
|
||||||
|
return err
|
||||||
|
} else if !ok {
|
||||||
|
return errCancelled
|
||||||
|
}
|
||||||
|
if err := editor.Edit(models); err != nil {
|
||||||
|
return fmt.Errorf("setup failed: %w", err)
|
||||||
|
}
|
||||||
|
if err := config.SaveIntegration(name, models); err != nil {
|
||||||
|
return fmt.Errorf("failed to save: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func confirmEditorEdit(runner Runner, editor Editor) (bool, error) {
|
||||||
|
paths := editor.Paths()
|
||||||
|
if len(paths) == 0 {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", runner)
|
||||||
|
for _, path := range paths {
|
||||||
|
fmt.Fprintf(os.Stderr, " %s\n", path)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", fileutil.BackupDir())
|
||||||
|
|
||||||
|
return ConfirmPrompt("Proceed?")
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildModelList merges existing models with recommendations for selection UIs.
|
||||||
|
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
|
||||||
|
existingModels = make(map[string]bool)
|
||||||
|
cloudModels = make(map[string]bool)
|
||||||
|
recommended := make(map[string]bool)
|
||||||
|
var hasLocalModel, hasCloudModel bool
|
||||||
|
|
||||||
|
recDesc := make(map[string]string)
|
||||||
|
for _, rec := range recommendedModels {
|
||||||
|
recommended[rec.Name] = true
|
||||||
|
recDesc[rec.Name] = rec.Description
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range existing {
|
||||||
|
existingModels[m.Name] = true
|
||||||
|
if m.Remote {
|
||||||
|
cloudModels[m.Name] = true
|
||||||
|
hasCloudModel = true
|
||||||
|
} else {
|
||||||
|
hasLocalModel = true
|
||||||
|
}
|
||||||
|
displayName := strings.TrimSuffix(m.Name, ":latest")
|
||||||
|
existingModels[displayName] = true
|
||||||
|
item := ModelItem{Name: displayName, Recommended: recommended[displayName], Description: recDesc[displayName]}
|
||||||
|
items = append(items, item)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rec := range recommendedModels {
|
||||||
|
if existingModels[rec.Name] || existingModels[rec.Name+":latest"] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
items = append(items, rec)
|
||||||
|
if isCloudModelName(rec.Name) {
|
||||||
|
cloudModels[rec.Name] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
checked := make(map[string]bool, len(preChecked))
|
||||||
|
for _, n := range preChecked {
|
||||||
|
checked[n] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if current != "" {
|
||||||
|
matchedCurrent := false
|
||||||
|
for _, item := range items {
|
||||||
|
if item.Name == current {
|
||||||
|
current = item.Name
|
||||||
|
matchedCurrent = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !matchedCurrent {
|
||||||
|
for _, item := range items {
|
||||||
|
if strings.HasPrefix(item.Name, current+":") {
|
||||||
|
current = item.Name
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if checked[current] {
|
||||||
|
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
|
||||||
|
}
|
||||||
|
|
||||||
|
notInstalled := make(map[string]bool)
|
||||||
|
for i := range items {
|
||||||
|
if !existingModels[items[i].Name] && !cloudModels[items[i].Name] {
|
||||||
|
notInstalled[items[i].Name] = true
|
||||||
|
var parts []string
|
||||||
|
if items[i].Description != "" {
|
||||||
|
parts = append(parts, items[i].Description)
|
||||||
|
}
|
||||||
|
if vram := recommendedVRAM[items[i].Name]; vram != "" {
|
||||||
|
parts = append(parts, vram)
|
||||||
|
}
|
||||||
|
parts = append(parts, "(not downloaded)")
|
||||||
|
items[i].Description = strings.Join(parts, ", ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
recRank := make(map[string]int)
|
||||||
|
for i, rec := range recommendedModels {
|
||||||
|
recRank[rec.Name] = i + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
onlyLocal := hasLocalModel && !hasCloudModel
|
||||||
|
|
||||||
|
if hasLocalModel || hasCloudModel {
|
||||||
|
slices.SortStableFunc(items, func(a, b ModelItem) int {
|
||||||
|
ac, bc := checked[a.Name], checked[b.Name]
|
||||||
|
aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
|
||||||
|
aRec, bRec := recRank[a.Name] > 0, recRank[b.Name] > 0
|
||||||
|
aCloud, bCloud := cloudModels[a.Name], cloudModels[b.Name]
|
||||||
|
|
||||||
|
if ac != bc {
|
||||||
|
if ac {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if aRec != bRec {
|
||||||
|
if aRec {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if aRec && bRec {
|
||||||
|
if aCloud != bCloud {
|
||||||
|
if onlyLocal {
|
||||||
|
if aCloud {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if aCloud {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return recRank[a.Name] - recRank[b.Name]
|
||||||
|
}
|
||||||
|
if aNew != bNew {
|
||||||
|
if aNew {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return items, preChecked, existingModels, cloudModels
|
||||||
|
}
|
||||||
|
|
||||||
|
// isCloudModelName reports whether the model name has an explicit cloud source.
|
||||||
|
func isCloudModelName(name string) bool {
|
||||||
|
return modelref.HasExplicitCloudSource(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterCloudModels drops remote-only models from the given inventory.
|
||||||
|
func filterCloudModels(existing []modelInfo) []modelInfo {
|
||||||
|
filtered := existing[:0]
|
||||||
|
for _, m := range existing {
|
||||||
|
if !m.Remote {
|
||||||
|
filtered = append(filtered, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterCloudItems removes cloud models from selection items.
|
||||||
|
func filterCloudItems(items []ModelItem) []ModelItem {
|
||||||
|
filtered := items[:0]
|
||||||
|
for _, item := range items {
|
||||||
|
if !isCloudModelName(item.Name) {
|
||||||
|
filtered = append(filtered, item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
|
||||||
|
if client == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return resp.RemoteModel != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloudStatusDisabled returns whether cloud usage is currently disabled.
|
||||||
|
func cloudStatusDisabled(ctx context.Context, client *api.Client) (disabled bool, known bool) {
|
||||||
|
status, err := client.CloudStatusExperimental(ctx)
|
||||||
|
if err != nil {
|
||||||
|
var statusErr api.StatusError
|
||||||
|
if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
return status.Cloud.Disabled, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(parthsareen): this duplicates the pull progress UI in cmd.PullHandler.
|
||||||
|
// Move the shared pull rendering to a small utility once the package boundary settles.
|
||||||
|
func pullModel(ctx context.Context, client *api.Client, model string, insecure bool) error {
|
||||||
|
p := progress.NewProgress(os.Stderr)
|
||||||
|
defer p.Stop()
|
||||||
|
|
||||||
|
bars := make(map[string]*progress.Bar)
|
||||||
|
var status string
|
||||||
|
var spinner *progress.Spinner
|
||||||
|
|
||||||
|
fn := func(resp api.ProgressResponse) error {
|
||||||
|
if resp.Digest != "" {
|
||||||
|
if resp.Completed == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if spinner != nil {
|
||||||
|
spinner.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
bar, ok := bars[resp.Digest]
|
||||||
|
if !ok {
|
||||||
|
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
if isDigest {
|
||||||
|
name = name[:min(12, len(name))]
|
||||||
|
}
|
||||||
|
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
|
||||||
|
bars[resp.Digest] = bar
|
||||||
|
p.Add(resp.Digest, bar)
|
||||||
|
}
|
||||||
|
|
||||||
|
bar.Set(resp.Completed)
|
||||||
|
} else if status != resp.Status {
|
||||||
|
if spinner != nil {
|
||||||
|
spinner.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
status = resp.Status
|
||||||
|
spinner = progress.NewSpinner(status)
|
||||||
|
p.Add(status, spinner)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
request := api.PullRequest{Name: model, Insecure: insecure}
|
||||||
|
return client.Pull(ctx, &request, fn)
|
||||||
|
}
|
||||||
896
cmd/launch/openclaw.go
Normal file
896
cmd/launch/openclaw.go
Normal file
@@ -0,0 +1,896 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/mod/semver"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultGatewayPort = 18789
|
||||||
|
|
||||||
|
// Bound model capability probing so launch/config cannot hang on slow/unreachable API calls.
|
||||||
|
var openclawModelShowTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
// openclawFreshInstall is set to true when ensureOpenclawInstalled performs an install
|
||||||
|
var openclawFreshInstall bool
|
||||||
|
|
||||||
|
type Openclaw struct{}
|
||||||
|
|
||||||
|
func (c *Openclaw) String() string { return "OpenClaw" }
|
||||||
|
|
||||||
|
func (c *Openclaw) Run(model string, args []string) error {
|
||||||
|
bin, err := ensureOpenclawInstalled()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
firstLaunch := !c.onboarded()
|
||||||
|
|
||||||
|
if firstLaunch {
|
||||||
|
fmt.Fprintf(os.Stderr, "\n%sSecurity%s\n\n", ansiBold, ansiReset)
|
||||||
|
fmt.Fprintf(os.Stderr, " OpenClaw can read files and run actions when tools are enabled.\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " A bad prompt can trick it into doing unsafe things.\n\n")
|
||||||
|
fmt.Fprintf(os.Stderr, "%s Learn more: https://docs.openclaw.ai/gateway/security%s\n\n", ansiGray, ansiReset)
|
||||||
|
|
||||||
|
ok, err := ConfirmPrompt("I understand the risks. Continue?")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the latest version is installed before onboarding so we get
|
||||||
|
// the newest wizard flags (e.g. --auth-choice ollama).
|
||||||
|
if !openclawFreshInstall {
|
||||||
|
update := exec.Command(bin, "update")
|
||||||
|
update.Stdout = os.Stdout
|
||||||
|
update.Stderr = os.Stderr
|
||||||
|
_ = update.Run() // best-effort; continue even if update fails
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "\n%sSetting up OpenClaw with Ollama...%s\n", ansiGreen, ansiReset)
|
||||||
|
fmt.Fprintf(os.Stderr, "%s Model: %s%s\n\n", ansiGray, model, ansiReset)
|
||||||
|
|
||||||
|
onboardArgs := []string{
|
||||||
|
"onboard",
|
||||||
|
"--non-interactive",
|
||||||
|
"--accept-risk",
|
||||||
|
"--auth-choice", "ollama",
|
||||||
|
"--custom-base-url", envconfig.Host().String(),
|
||||||
|
"--custom-model-id", model,
|
||||||
|
"--skip-channels",
|
||||||
|
"--skip-skills",
|
||||||
|
}
|
||||||
|
if canInstallDaemon() {
|
||||||
|
onboardArgs = append(onboardArgs, "--install-daemon")
|
||||||
|
}
|
||||||
|
cmd := exec.Command(bin, onboardArgs...)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
return windowsHint(fmt.Errorf("openclaw onboarding failed: %w\n\nTry running: openclaw onboard", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
patchDeviceScopes()
|
||||||
|
}
|
||||||
|
|
||||||
|
if ensureWebSearchPlugin() {
|
||||||
|
registerWebSearchPlugin()
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "\n%sStarting your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
|
||||||
|
|
||||||
|
// When extra args are passed through, run exactly what the user asked for
|
||||||
|
// after setup and skip the built-in gateway+TUI convenience flow.
|
||||||
|
if len(args) > 0 {
|
||||||
|
cmd := exec.Command(bin, args...)
|
||||||
|
cmd.Env = openclawEnv()
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
return windowsHint(err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
token, port := c.gatewayInfo()
|
||||||
|
addr := fmt.Sprintf("localhost:%d", port)
|
||||||
|
|
||||||
|
// If the gateway is already running (e.g. via the daemon), restart it
|
||||||
|
// so it picks up any config changes (model, provider, etc.).
|
||||||
|
if portOpen(addr) {
|
||||||
|
restart := exec.Command(bin, "daemon", "restart")
|
||||||
|
restart.Env = openclawEnv()
|
||||||
|
if err := restart.Run(); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s Warning: daemon restart failed: %v%s\n", ansiYellow, err, ansiReset)
|
||||||
|
}
|
||||||
|
if !waitForPort(addr, 10*time.Second) {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s Warning: gateway did not come back after restart%s\n", ansiYellow, ansiReset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the gateway isn't running, start it as a background child process.
|
||||||
|
if !portOpen(addr) {
|
||||||
|
gw := exec.Command(bin, "gateway", "run", "--force")
|
||||||
|
gw.Env = openclawEnv()
|
||||||
|
if err := gw.Start(); err != nil {
|
||||||
|
return windowsHint(fmt.Errorf("failed to start gateway: %w", err))
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if gw.Process != nil {
|
||||||
|
_ = gw.Process.Kill()
|
||||||
|
_ = gw.Wait()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "%sStarting gateway...%s\n", ansiGray, ansiReset)
|
||||||
|
if !waitForPort(addr, 30*time.Second) {
|
||||||
|
return windowsHint(fmt.Errorf("gateway did not start on %s", addr))
|
||||||
|
}
|
||||||
|
|
||||||
|
printOpenclawReady(bin, token, port, firstLaunch)
|
||||||
|
|
||||||
|
tuiArgs := []string{"tui"}
|
||||||
|
if firstLaunch {
|
||||||
|
tuiArgs = append(tuiArgs, "--message", "Wake up, my friend!")
|
||||||
|
}
|
||||||
|
tui := exec.Command(bin, tuiArgs...)
|
||||||
|
tui.Env = openclawEnv()
|
||||||
|
tui.Stdin = os.Stdin
|
||||||
|
tui.Stdout = os.Stdout
|
||||||
|
tui.Stderr = os.Stderr
|
||||||
|
if err := tui.Run(); err != nil {
|
||||||
|
return windowsHint(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// gatewayInfo reads the gateway auth token and port from the OpenClaw config.
|
||||||
|
func (c *Openclaw) gatewayInfo() (token string, port int) {
|
||||||
|
port = defaultGatewayPort
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", port
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, path := range []string{
|
||||||
|
filepath.Join(home, ".openclaw", "openclaw.json"),
|
||||||
|
filepath.Join(home, ".clawdbot", "clawdbot.json"),
|
||||||
|
} {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var config map[string]any
|
||||||
|
if json.Unmarshal(data, &config) != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
gw, _ := config["gateway"].(map[string]any)
|
||||||
|
if p, ok := gw["port"].(float64); ok && p > 0 {
|
||||||
|
port = int(p)
|
||||||
|
}
|
||||||
|
auth, _ := gw["auth"].(map[string]any)
|
||||||
|
if t, _ := auth["token"].(string); t != "" {
|
||||||
|
token = t
|
||||||
|
}
|
||||||
|
return token, port
|
||||||
|
}
|
||||||
|
return "", port
|
||||||
|
}
|
||||||
|
|
||||||
|
func printOpenclawReady(bin, token string, port int, firstLaunch bool) {
|
||||||
|
u := fmt.Sprintf("http://localhost:%d", port)
|
||||||
|
if token != "" {
|
||||||
|
u += "/#token=" + url.QueryEscape(token)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "\n%s✓ OpenClaw is running%s\n\n", ansiGreen, ansiReset)
|
||||||
|
fmt.Fprintf(os.Stderr, " Open the Web UI:\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " %s\n\n", hyperlink(u, u))
|
||||||
|
|
||||||
|
if firstLaunch {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s Quick start:%s\n", ansiBold, ansiReset)
|
||||||
|
fmt.Fprintf(os.Stderr, "%s /help see all commands%s\n", ansiGray, ansiReset)
|
||||||
|
fmt.Fprintf(os.Stderr, "%s %s configure --section channels connect WhatsApp, Telegram, etc.%s\n", ansiGray, bin, ansiReset)
|
||||||
|
fmt.Fprintf(os.Stderr, "%s %s skills browse and install skills%s\n\n", ansiGray, bin, ansiReset)
|
||||||
|
fmt.Fprintf(os.Stderr, "%s The OpenClaw gateway is running in the background.%s\n", ansiYellow, ansiReset)
|
||||||
|
fmt.Fprintf(os.Stderr, "%s Stop it with: %s gateway stop%s\n\n", ansiYellow, bin, ansiReset)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "%sTip: connect WhatsApp, Telegram, and more with: %s configure --section channels%s\n", ansiGray, bin, ansiReset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// openclawEnv returns the current environment with provider API keys cleared
|
||||||
|
// so openclaw only uses the Ollama gateway, not keys from the user's shell.
|
||||||
|
func openclawEnv() []string {
|
||||||
|
clear := map[string]bool{
|
||||||
|
"ANTHROPIC_API_KEY": true,
|
||||||
|
"ANTHROPIC_OAUTH_TOKEN": true,
|
||||||
|
"OPENAI_API_KEY": true,
|
||||||
|
"GEMINI_API_KEY": true,
|
||||||
|
"MISTRAL_API_KEY": true,
|
||||||
|
"GROQ_API_KEY": true,
|
||||||
|
"XAI_API_KEY": true,
|
||||||
|
"OPENROUTER_API_KEY": true,
|
||||||
|
}
|
||||||
|
var env []string
|
||||||
|
for _, e := range os.Environ() {
|
||||||
|
key, _, _ := strings.Cut(e, "=")
|
||||||
|
if !clear[key] {
|
||||||
|
env = append(env, e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return env
|
||||||
|
}
|
||||||
|
|
||||||
|
// portOpen checks if a TCP port is currently accepting connections.
|
||||||
|
func portOpen(addr string) bool {
|
||||||
|
conn, err := net.DialTimeout("tcp", addr, 500*time.Millisecond)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
conn.Close()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForPort(addr string, timeout time.Duration) bool {
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
conn, err := net.DialTimeout("tcp", addr, 500*time.Millisecond)
|
||||||
|
if err == nil {
|
||||||
|
conn.Close()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
time.Sleep(250 * time.Millisecond)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func windowsHint(err error) error {
|
||||||
|
if runtime.GOOS != "windows" {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return fmt.Errorf("%w\n\n"+
|
||||||
|
"OpenClaw runs best on WSL2.\n"+
|
||||||
|
"Quick setup: wsl --install\n"+
|
||||||
|
"Guide: https://docs.openclaw.ai/windows", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// onboarded checks if OpenClaw onboarding wizard was completed
|
||||||
|
// by looking for the wizard.lastRunAt marker in the config
|
||||||
|
func (c *Openclaw) onboarded() bool {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||||
|
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||||
|
|
||||||
|
config := make(map[string]any)
|
||||||
|
if data, err := os.ReadFile(configPath); err == nil {
|
||||||
|
_ = json.Unmarshal(data, &config)
|
||||||
|
} else if data, err := os.ReadFile(legacyPath); err == nil {
|
||||||
|
_ = json.Unmarshal(data, &config)
|
||||||
|
} else {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for wizard.lastRunAt marker (set when onboarding completes)
|
||||||
|
wizard, _ := config["wizard"].(map[string]any)
|
||||||
|
if wizard == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
lastRunAt, _ := wizard["lastRunAt"].(string)
|
||||||
|
return lastRunAt != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// patchDeviceScopes upgrades the local CLI device's paired scopes to include
|
||||||
|
// operator.admin. Only patches the local device, not remote ones.
|
||||||
|
// Best-effort: silently returns on any error.
|
||||||
|
func patchDeviceScopes() {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceID := readLocalDeviceID(home)
|
||||||
|
if deviceID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
path := filepath.Join(home, ".openclaw", "devices", "paired.json")
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var devices map[string]map[string]any
|
||||||
|
if err := json.Unmarshal(data, &devices); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dev, ok := devices[deviceID]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
required := []string{
|
||||||
|
"operator.read",
|
||||||
|
"operator.admin",
|
||||||
|
"operator.approvals",
|
||||||
|
"operator.pairing",
|
||||||
|
}
|
||||||
|
|
||||||
|
changed := patchScopes(dev, "scopes", required)
|
||||||
|
if tokens, ok := dev["tokens"].(map[string]any); ok {
|
||||||
|
for _, tok := range tokens {
|
||||||
|
if tokenMap, ok := tok.(map[string]any); ok {
|
||||||
|
if patchScopes(tokenMap, "scopes", required) {
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !changed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := json.MarshalIndent(devices, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = os.WriteFile(path, out, 0o600)
|
||||||
|
}
|
||||||
|
|
||||||
|
// readLocalDeviceID reads the local device ID from openclaw's identity file.
|
||||||
|
func readLocalDeviceID(home string) string {
|
||||||
|
data, err := os.ReadFile(filepath.Join(home, ".openclaw", "identity", "device-auth.json"))
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var auth map[string]any
|
||||||
|
if err := json.Unmarshal(data, &auth); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
id, _ := auth["deviceId"].(string)
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
// patchScopes ensures obj[key] contains all required scopes. Returns true if
|
||||||
|
// any scopes were added.
|
||||||
|
func patchScopes(obj map[string]any, key string, required []string) bool {
|
||||||
|
existing, _ := obj[key].([]any)
|
||||||
|
have := make(map[string]bool, len(existing))
|
||||||
|
for _, s := range existing {
|
||||||
|
if str, ok := s.(string); ok {
|
||||||
|
have[str] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
added := false
|
||||||
|
for _, s := range required {
|
||||||
|
if !have[s] {
|
||||||
|
existing = append(existing, s)
|
||||||
|
added = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if added {
|
||||||
|
obj[key] = existing
|
||||||
|
}
|
||||||
|
return added
|
||||||
|
}
|
||||||
|
|
||||||
|
// canInstallDaemon reports whether the openclaw daemon can be installed as a
|
||||||
|
// background service. Returns false on Linux when systemd is absent (e.g.
|
||||||
|
// containers) so that --install-daemon is omitted and the gateway is started
|
||||||
|
// as a foreground child process instead. Returns true in all other cases.
|
||||||
|
func canInstallDaemon() bool {
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// /run/systemd/system exists as a directory when systemd is the init system.
|
||||||
|
// This is absent in most containers.
|
||||||
|
fi, err := os.Stat("/run/systemd/system")
|
||||||
|
if err != nil || !fi.IsDir() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Even when systemd is the init system, user services require a user
|
||||||
|
// manager instance. XDG_RUNTIME_DIR being set is a prerequisite.
|
||||||
|
return os.Getenv("XDG_RUNTIME_DIR") != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureOpenclawInstalled() (string, error) {
|
||||||
|
if _, err := exec.LookPath("openclaw"); err == nil {
|
||||||
|
return "openclaw", nil
|
||||||
|
}
|
||||||
|
if _, err := exec.LookPath("clawdbot"); err == nil {
|
||||||
|
return "clawdbot", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, npmErr := exec.LookPath("npm")
|
||||||
|
_, gitErr := exec.LookPath("git")
|
||||||
|
if npmErr != nil || gitErr != nil {
|
||||||
|
var missing []string
|
||||||
|
if npmErr != nil {
|
||||||
|
missing = append(missing, "npm (Node.js): https://nodejs.org/")
|
||||||
|
}
|
||||||
|
if gitErr != nil {
|
||||||
|
missing = append(missing, "git: https://git-scm.com/")
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("openclaw is not installed and required dependencies are missing\n\nInstall the following first:\n %s", strings.Join(missing, "\n "))
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := ConfirmPrompt("OpenClaw is not installed. Install with npm?")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("openclaw installation cancelled")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "\nInstalling OpenClaw...\n")
|
||||||
|
cmd := exec.Command("npm", "install", "-g", "openclaw@latest")
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to install openclaw: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := exec.LookPath("openclaw"); err != nil {
|
||||||
|
return "", fmt.Errorf("openclaw was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "%sOpenClaw installed successfully%s\n\n", ansiGreen, ansiReset)
|
||||||
|
openclawFreshInstall = true
|
||||||
|
return "openclaw", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Openclaw) Paths() []string {
|
||||||
|
home, _ := os.UserHomeDir()
|
||||||
|
p := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||||
|
if _, err := os.Stat(p); err == nil {
|
||||||
|
return []string{p}
|
||||||
|
}
|
||||||
|
legacy := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||||
|
if _, err := os.Stat(legacy); err == nil {
|
||||||
|
return []string{legacy}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Openclaw) Edit(models []string) error {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||||
|
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||||
|
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read into map[string]any to preserve unknown fields
|
||||||
|
config := make(map[string]any)
|
||||||
|
if data, err := os.ReadFile(configPath); err == nil {
|
||||||
|
_ = json.Unmarshal(data, &config)
|
||||||
|
} else if data, err := os.ReadFile(legacyPath); err == nil {
|
||||||
|
_ = json.Unmarshal(data, &config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Navigate/create: models.providers.ollama (preserving other providers)
|
||||||
|
modelsSection, _ := config["models"].(map[string]any)
|
||||||
|
if modelsSection == nil {
|
||||||
|
modelsSection = make(map[string]any)
|
||||||
|
}
|
||||||
|
providers, _ := modelsSection["providers"].(map[string]any)
|
||||||
|
if providers == nil {
|
||||||
|
providers = make(map[string]any)
|
||||||
|
}
|
||||||
|
ollama, _ := providers["ollama"].(map[string]any)
|
||||||
|
if ollama == nil {
|
||||||
|
ollama = make(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
|
ollama["baseUrl"] = envconfig.Host().String()
|
||||||
|
// needed to register provider
|
||||||
|
ollama["apiKey"] = "ollama-local"
|
||||||
|
ollama["api"] = "ollama"
|
||||||
|
|
||||||
|
// Build map of existing models to preserve user customizations
|
||||||
|
existingModels, _ := ollama["models"].([]any)
|
||||||
|
existingByID := make(map[string]map[string]any)
|
||||||
|
for _, m := range existingModels {
|
||||||
|
if entry, ok := m.(map[string]any); ok {
|
||||||
|
if id, ok := entry["id"].(string); ok {
|
||||||
|
existingByID[id] = entry
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
client, _ := api.ClientFromEnvironment()
|
||||||
|
|
||||||
|
var newModels []any
|
||||||
|
for _, m := range models {
|
||||||
|
entry, _ := openclawModelConfig(context.Background(), client, m)
|
||||||
|
// Merge existing fields (user customizations)
|
||||||
|
if existing, ok := existingByID[m]; ok {
|
||||||
|
for k, v := range existing {
|
||||||
|
if _, isNew := entry[k]; !isNew {
|
||||||
|
entry[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
newModels = append(newModels, entry)
|
||||||
|
}
|
||||||
|
ollama["models"] = newModels
|
||||||
|
|
||||||
|
providers["ollama"] = ollama
|
||||||
|
modelsSection["providers"] = providers
|
||||||
|
config["models"] = modelsSection
|
||||||
|
|
||||||
|
// Update agents.defaults.model.primary (preserving other agent settings)
|
||||||
|
agents, _ := config["agents"].(map[string]any)
|
||||||
|
if agents == nil {
|
||||||
|
agents = make(map[string]any)
|
||||||
|
}
|
||||||
|
defaults, _ := agents["defaults"].(map[string]any)
|
||||||
|
if defaults == nil {
|
||||||
|
defaults = make(map[string]any)
|
||||||
|
}
|
||||||
|
modelConfig, _ := defaults["model"].(map[string]any)
|
||||||
|
if modelConfig == nil {
|
||||||
|
modelConfig = make(map[string]any)
|
||||||
|
}
|
||||||
|
modelConfig["primary"] = "ollama/" + models[0]
|
||||||
|
defaults["model"] = modelConfig
|
||||||
|
agents["defaults"] = defaults
|
||||||
|
config["agents"] = agents
|
||||||
|
|
||||||
|
data, err := json.MarshalIndent(config, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := fileutil.WriteWithBackup(configPath, data); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear any per-session model overrides so the new primary takes effect
|
||||||
|
// immediately rather than being shadowed by a cached modelOverride.
|
||||||
|
clearSessionModelOverride(models[0])
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// clearSessionModelOverride removes per-session model overrides from the main
|
||||||
|
// agent session so the global primary model takes effect on the next TUI launch.
|
||||||
|
func clearSessionModelOverride(primary string) {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
path := filepath.Join(home, ".openclaw", "agents", "main", "sessions", "sessions.json")
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var sessions map[string]map[string]any
|
||||||
|
if json.Unmarshal(data, &sessions) != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
changed := false
|
||||||
|
for _, sess := range sessions {
|
||||||
|
if override, _ := sess["modelOverride"].(string); override != "" && override != primary {
|
||||||
|
delete(sess, "modelOverride")
|
||||||
|
delete(sess, "providerOverride")
|
||||||
|
}
|
||||||
|
if model, _ := sess["model"].(string); model != "" && model != primary {
|
||||||
|
sess["model"] = primary
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !changed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out, err := json.MarshalIndent(sessions, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = os.WriteFile(path, out, 0o600)
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
webSearchNpmPackage = "@ollama/openclaw-web-search"
|
||||||
|
webSearchMinVersion = "0.2.1"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ensureWebSearchPlugin installs the openclaw-web-search extension into the
|
||||||
|
// user-level extensions directory (~/.openclaw/extensions/) if it isn't already
|
||||||
|
// present, or re-installs if the installed version is older than webSearchMinVersion.
|
||||||
|
// Returns true if the extension is available.
|
||||||
|
func ensureWebSearchPlugin() bool {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
pluginDir := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
|
||||||
|
if webSearchPluginUpToDate(pluginDir) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
npmBin, err := exec.LookPath("npm")
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.MkdirAll(pluginDir, 0o755); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Download the tarball via `npm pack`, extract it flat into the plugin dir.
|
||||||
|
pack := exec.Command(npmBin, "pack", webSearchNpmPackage, "--pack-destination", pluginDir)
|
||||||
|
out, err := pack.Output()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s Warning: could not download web search plugin: %v%s\n", ansiYellow, err, ansiReset)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
tgzName := strings.TrimSpace(string(out))
|
||||||
|
tgzPath := filepath.Join(pluginDir, tgzName)
|
||||||
|
defer os.Remove(tgzPath)
|
||||||
|
|
||||||
|
tar := exec.Command("tar", "xzf", tgzPath, "--strip-components=1", "-C", pluginDir)
|
||||||
|
if err := tar.Run(); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s Warning: could not extract web search plugin: %v%s\n", ansiYellow, err, ansiReset)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "%s ✓ Installed web search plugin%s\n", ansiGreen, ansiReset)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// webSearchPluginUpToDate returns true if the plugin is installed and its
|
||||||
|
// package.json version is >= webSearchMinVersion.
|
||||||
|
func webSearchPluginUpToDate(pluginDir string) bool {
|
||||||
|
data, err := os.ReadFile(filepath.Join(pluginDir, "package.json"))
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
var pkg struct {
|
||||||
|
Version string `json:"version"`
|
||||||
|
}
|
||||||
|
if json.Unmarshal(data, &pkg) != nil || pkg.Version == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return !versionLessThan(pkg.Version, webSearchMinVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
// versionLessThan compares two semver version strings (major.minor.patch).
|
||||||
|
// Inputs may omit the "v" prefix; it is added automatically for semver.Compare.
|
||||||
|
func versionLessThan(a, b string) bool {
|
||||||
|
if !strings.HasPrefix(a, "v") {
|
||||||
|
a = "v" + a
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(b, "v") {
|
||||||
|
b = "v" + b
|
||||||
|
}
|
||||||
|
return semver.Compare(a, b) < 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerWebSearchPlugin adds plugins.entries.openclaw-web-search to the OpenClaw
|
||||||
|
// config so the gateway activates it on next start. Best-effort; silently returns
|
||||||
|
// on any error.
|
||||||
|
func registerWebSearchPlugin() {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var config map[string]any
|
||||||
|
if json.Unmarshal(data, &config) != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
plugins, _ := config["plugins"].(map[string]any)
|
||||||
|
if plugins == nil {
|
||||||
|
plugins = make(map[string]any)
|
||||||
|
}
|
||||||
|
entries, _ := plugins["entries"].(map[string]any)
|
||||||
|
if entries == nil {
|
||||||
|
entries = make(map[string]any)
|
||||||
|
}
|
||||||
|
entries["openclaw-web-search"] = map[string]any{"enabled": true}
|
||||||
|
plugins["entries"] = entries
|
||||||
|
|
||||||
|
// Pin trust so the gateway doesn't warn about untracked plugins.
|
||||||
|
allow, _ := plugins["allow"].([]any)
|
||||||
|
hasAllow := false
|
||||||
|
for _, v := range allow {
|
||||||
|
if s, ok := v.(string); ok && s == "openclaw-web-search" {
|
||||||
|
hasAllow = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasAllow {
|
||||||
|
allow = append(allow, "openclaw-web-search")
|
||||||
|
}
|
||||||
|
plugins["allow"] = allow
|
||||||
|
|
||||||
|
// Record install provenance so the loader can verify the plugin origin.
|
||||||
|
installs, _ := plugins["installs"].(map[string]any)
|
||||||
|
if installs == nil {
|
||||||
|
installs = make(map[string]any)
|
||||||
|
}
|
||||||
|
pluginDir := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
|
||||||
|
installs["openclaw-web-search"] = map[string]any{
|
||||||
|
"source": "npm",
|
||||||
|
"spec": webSearchNpmPackage,
|
||||||
|
"installPath": pluginDir,
|
||||||
|
}
|
||||||
|
plugins["installs"] = installs
|
||||||
|
|
||||||
|
config["plugins"] = plugins
|
||||||
|
|
||||||
|
// Add plugin tools to tools.alsoAllow so they survive the coding profile's
|
||||||
|
// policy pipeline (which has an explicit allow list of core tools only).
|
||||||
|
tools, _ := config["tools"].(map[string]any)
|
||||||
|
if tools == nil {
|
||||||
|
tools = make(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
|
alsoAllow, _ := tools["alsoAllow"].([]any)
|
||||||
|
needed := []string{"ollama_web_search", "ollama_web_fetch"}
|
||||||
|
have := make(map[string]bool, len(alsoAllow))
|
||||||
|
for _, v := range alsoAllow {
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
have[s] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, name := range needed {
|
||||||
|
if !have[name] {
|
||||||
|
alsoAllow = append(alsoAllow, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tools["alsoAllow"] = alsoAllow
|
||||||
|
|
||||||
|
// Disable built-in web search/fetch since our plugin replaces them.
|
||||||
|
web, _ := tools["web"].(map[string]any)
|
||||||
|
if web == nil {
|
||||||
|
web = make(map[string]any)
|
||||||
|
}
|
||||||
|
web["search"] = map[string]any{"enabled": false}
|
||||||
|
web["fetch"] = map[string]any{"enabled": false}
|
||||||
|
tools["web"] = web
|
||||||
|
config["tools"] = tools
|
||||||
|
|
||||||
|
out, err := json.MarshalIndent(config, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = os.WriteFile(configPath, out, 0o600)
|
||||||
|
}
|
||||||
|
|
||||||
|
// openclawModelConfig builds an OpenClaw model config entry with capability detection.
|
||||||
|
// The second return value indicates whether the model is a cloud (remote) model.
|
||||||
|
func openclawModelConfig(ctx context.Context, client *api.Client, modelID string) (map[string]any, bool) {
|
||||||
|
entry := map[string]any{
|
||||||
|
"id": modelID,
|
||||||
|
"name": modelID,
|
||||||
|
"input": []any{"text"},
|
||||||
|
"cost": map[string]any{
|
||||||
|
"input": 0,
|
||||||
|
"output": 0,
|
||||||
|
"cacheRead": 0,
|
||||||
|
"cacheWrite": 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if client == nil {
|
||||||
|
return entry, false
|
||||||
|
}
|
||||||
|
|
||||||
|
showCtx := ctx
|
||||||
|
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
showCtx, cancel = context.WithTimeout(ctx, openclawModelShowTimeout)
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Show(showCtx, &api.ShowRequest{Model: modelID})
|
||||||
|
if err != nil {
|
||||||
|
return entry, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set input types based on vision capability
|
||||||
|
if slices.Contains(resp.Capabilities, model.CapabilityVision) {
|
||||||
|
entry["input"] = []any{"text", "image"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set reasoning based on thinking capability
|
||||||
|
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||||
|
entry["reasoning"] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cloud models: use hardcoded limits for context/output tokens.
|
||||||
|
// Capability detection above still applies (vision, thinking).
|
||||||
|
if resp.RemoteModel != "" {
|
||||||
|
if l, ok := lookupCloudModelLimit(modelID); ok {
|
||||||
|
entry["contextWindow"] = l.Context
|
||||||
|
entry["maxTokens"] = l.Output
|
||||||
|
}
|
||||||
|
return entry, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract context window from ModelInfo (local models only)
|
||||||
|
for key, val := range resp.ModelInfo {
|
||||||
|
if strings.HasSuffix(key, ".context_length") {
|
||||||
|
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
|
||||||
|
entry["contextWindow"] = int(ctxLen)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return entry, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Openclaw) Models() []string {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := fileutil.ReadJSON(filepath.Join(home, ".openclaw", "openclaw.json"))
|
||||||
|
if err != nil {
|
||||||
|
config, err = fileutil.ReadJSON(filepath.Join(home, ".clawdbot", "clawdbot.json"))
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
modelsSection, _ := config["models"].(map[string]any)
|
||||||
|
providers, _ := modelsSection["providers"].(map[string]any)
|
||||||
|
ollama, _ := providers["ollama"].(map[string]any)
|
||||||
|
modelList, _ := ollama["models"].([]any)
|
||||||
|
|
||||||
|
var result []string
|
||||||
|
for _, m := range modelList {
|
||||||
|
if entry, ok := m.(map[string]any); ok {
|
||||||
|
if id, ok := entry["id"].(string); ok {
|
||||||
|
result = append(result, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
1891
cmd/launch/openclaw_test.go
Normal file
1891
cmd/launch/openclaw_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -18,21 +19,12 @@ type OpenCode struct{}
|
|||||||
|
|
||||||
func (o *OpenCode) String() string { return "OpenCode" }
|
func (o *OpenCode) String() string { return "OpenCode" }
|
||||||
|
|
||||||
func (o *OpenCode) Run(model string) error {
|
func (o *OpenCode) Run(model string, args []string) error {
|
||||||
if _, err := exec.LookPath("opencode"); err != nil {
|
if _, err := exec.LookPath("opencode"); err != nil {
|
||||||
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
|
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call Edit() to ensure config is up-to-date before launch
|
cmd := exec.Command("opencode", args...)
|
||||||
models := []string{model}
|
|
||||||
if config, err := loadIntegration("opencode"); err == nil && len(config.Models) > 0 {
|
|
||||||
models = config.Models
|
|
||||||
}
|
|
||||||
if err := o.Edit(models); err != nil {
|
|
||||||
return fmt.Errorf("setup failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command("opencode")
|
|
||||||
cmd.Stdin = os.Stdin
|
cmd.Stdin = os.Stdin
|
||||||
cmd.Stdout = os.Stdout
|
cmd.Stdout = os.Stdout
|
||||||
cmd.Stderr = os.Stderr
|
cmd.Stderr = os.Stderr
|
||||||
@@ -88,13 +80,18 @@ func (o *OpenCode) Edit(modelList []string) error {
|
|||||||
if !ok {
|
if !ok {
|
||||||
ollama = map[string]any{
|
ollama = map[string]any{
|
||||||
"npm": "@ai-sdk/openai-compatible",
|
"npm": "@ai-sdk/openai-compatible",
|
||||||
"name": "Ollama (local)",
|
"name": "Ollama",
|
||||||
"options": map[string]any{
|
"options": map[string]any{
|
||||||
"baseURL": envconfig.Host().String() + "/v1",
|
"baseURL": envconfig.Host().String() + "/v1",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Migrate legacy provider name
|
||||||
|
if name, _ := ollama["name"].(string); name == "Ollama (local)" {
|
||||||
|
ollama["name"] = "Ollama"
|
||||||
|
}
|
||||||
|
|
||||||
models, ok := ollama["models"].(map[string]any)
|
models, ok := ollama["models"].(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
models = make(map[string]any)
|
models = make(map[string]any)
|
||||||
@@ -122,12 +119,29 @@ func (o *OpenCode) Edit(modelList []string) error {
|
|||||||
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if isCloudModelName(model) {
|
||||||
|
if l, ok := lookupCloudModelLimit(model); ok {
|
||||||
|
existing["limit"] = map[string]any{
|
||||||
|
"context": l.Context,
|
||||||
|
"output": l.Output,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
models[model] = map[string]any{
|
entry := map[string]any{
|
||||||
"name": model,
|
"name": model,
|
||||||
"_launch": true,
|
"_launch": true,
|
||||||
}
|
}
|
||||||
|
if isCloudModelName(model) {
|
||||||
|
if l, ok := lookupCloudModelLimit(model); ok {
|
||||||
|
entry["limit"] = map[string]any{
|
||||||
|
"context": l.Context,
|
||||||
|
"output": l.Output,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
models[model] = entry
|
||||||
}
|
}
|
||||||
|
|
||||||
ollama["models"] = models
|
ollama["models"] = models
|
||||||
@@ -138,7 +152,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := writeWithBackup(configPath, configData); err != nil {
|
if err := fileutil.WriteWithBackup(configPath, configData); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -190,7 +204,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return writeWithBackup(statePath, stateData)
|
return fileutil.WriteWithBackup(statePath, stateData)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *OpenCode) Models() []string {
|
func (o *OpenCode) Models() []string {
|
||||||
@@ -198,7 +212,7 @@ func (o *OpenCode) Models() []string {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
config, err := readJSONFile(filepath.Join(home, ".config", "opencode", "opencode.json"))
|
config, err := fileutil.ReadJSON(filepath.Join(home, ".config", "opencode", "opencode.json"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -1,7 +1,10 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -231,6 +234,44 @@ func TestOpenCodeEdit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("migrate Ollama (local) provider name", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"name":"Ollama (local)","npm":"@ai-sdk/openai-compatible","options":{"baseURL":"http://localhost:11434/v1"}}}}`), 0o644)
|
||||||
|
|
||||||
|
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := os.ReadFile(configPath)
|
||||||
|
var cfg map[string]any
|
||||||
|
json.Unmarshal(data, &cfg)
|
||||||
|
provider := cfg["provider"].(map[string]any)
|
||||||
|
ollama := provider["ollama"].(map[string]any)
|
||||||
|
if ollama["name"] != "Ollama" {
|
||||||
|
t.Errorf("provider name not migrated: got %q, want %q", ollama["name"], "Ollama")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserve custom provider name", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"name":"My Custom Ollama","npm":"@ai-sdk/openai-compatible","options":{"baseURL":"http://localhost:11434/v1"}}}}`), 0o644)
|
||||||
|
|
||||||
|
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := os.ReadFile(configPath)
|
||||||
|
var cfg map[string]any
|
||||||
|
json.Unmarshal(data, &cfg)
|
||||||
|
provider := cfg["provider"].(map[string]any)
|
||||||
|
ollama := provider["ollama"].(map[string]any)
|
||||||
|
if ollama["name"] != "My Custom Ollama" {
|
||||||
|
t.Errorf("custom provider name was changed: got %q, want %q", ollama["name"], "My Custom Ollama")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
|
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
|
||||||
cleanup()
|
cleanup()
|
||||||
os.MkdirAll(configDir, 0o755)
|
os.MkdirAll(configDir, 0o755)
|
||||||
@@ -495,6 +536,220 @@ func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func readOpenCodeModel(t *testing.T, configPath, model string) map[string]any {
|
||||||
|
t.Helper()
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var cfg map[string]any
|
||||||
|
json.Unmarshal(data, &cfg)
|
||||||
|
provider := cfg["provider"].(map[string]any)
|
||||||
|
ollama := provider["ollama"].(map[string]any)
|
||||||
|
models := ollama["models"].(map[string]any)
|
||||||
|
entry, ok := models[model].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("model %s not found in config", model)
|
||||||
|
}
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenCodeEdit_LocalModelNoLimit(t *testing.T) {
|
||||||
|
o := &OpenCode{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
configPath := filepath.Join(tmpDir, ".config", "opencode", "opencode.json")
|
||||||
|
|
||||||
|
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := readOpenCodeModel(t, configPath, "llama3.2")
|
||||||
|
if entry["limit"] != nil {
|
||||||
|
t.Errorf("local model should not have limit set, got %v", entry["limit"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenCodeEdit_PreservesUserLimit(t *testing.T) {
|
||||||
|
o := &OpenCode{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||||
|
configPath := filepath.Join(configDir, "opencode.json")
|
||||||
|
|
||||||
|
// Set up a model with a user-configured limit
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
os.WriteFile(configPath, []byte(`{
|
||||||
|
"provider": {
|
||||||
|
"ollama": {
|
||||||
|
"models": {
|
||||||
|
"llama3.2": {
|
||||||
|
"name": "llama3.2",
|
||||||
|
"_launch": true,
|
||||||
|
"limit": {"context": 8192, "output": 4096}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`), 0o644)
|
||||||
|
|
||||||
|
// Re-edit should preserve the user's limit (not delete it)
|
||||||
|
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := readOpenCodeModel(t, configPath, "llama3.2")
|
||||||
|
limit, ok := entry["limit"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("user-configured limit was removed")
|
||||||
|
}
|
||||||
|
if limit["context"] != float64(8192) {
|
||||||
|
t.Errorf("context limit changed: got %v, want 8192", limit["context"])
|
||||||
|
}
|
||||||
|
if limit["output"] != float64(4096) {
|
||||||
|
t.Errorf("output limit changed: got %v, want 4096", limit["output"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
|
||||||
|
// Verify that when a cloud model entry has limits set (as Edit would do),
|
||||||
|
// the structure matches what opencode expects and re-edit preserves them.
|
||||||
|
o := &OpenCode{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||||
|
configPath := filepath.Join(configDir, "opencode.json")
|
||||||
|
|
||||||
|
expected := cloudModelLimits["glm-4.7"]
|
||||||
|
|
||||||
|
// Simulate a cloud model that already has the limit set by a previous Edit
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
os.WriteFile(configPath, []byte(fmt.Sprintf(`{
|
||||||
|
"provider": {
|
||||||
|
"ollama": {
|
||||||
|
"models": {
|
||||||
|
"glm-4.7:cloud": {
|
||||||
|
"name": "glm-4.7:cloud",
|
||||||
|
"_launch": true,
|
||||||
|
"limit": {"context": %d, "output": %d}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`, expected.Context, expected.Output)), 0o644)
|
||||||
|
|
||||||
|
// Re-edit should preserve the cloud model limit
|
||||||
|
if err := o.Edit([]string{"glm-4.7:cloud"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := readOpenCodeModel(t, configPath, "glm-4.7:cloud")
|
||||||
|
limit, ok := entry["limit"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("cloud model limit was removed on re-edit")
|
||||||
|
}
|
||||||
|
if limit["context"] != float64(expected.Context) {
|
||||||
|
t.Errorf("context = %v, want %d", limit["context"], expected.Context)
|
||||||
|
}
|
||||||
|
if limit["output"] != float64(expected.Output) {
|
||||||
|
t.Errorf("output = %v, want %d", limit["output"], expected.Output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenCodeEdit_BackfillsCloudModelLimitOnExistingEntry(t *testing.T) {
|
||||||
|
o := &OpenCode{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/api/show" {
|
||||||
|
fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"glm-5"}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||||
|
configPath := filepath.Join(configDir, "opencode.json")
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
os.WriteFile(configPath, []byte(`{
|
||||||
|
"provider": {
|
||||||
|
"ollama": {
|
||||||
|
"models": {
|
||||||
|
"glm-5:cloud": {
|
||||||
|
"name": "glm-5:cloud",
|
||||||
|
"_launch": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`), 0o644)
|
||||||
|
|
||||||
|
if err := o.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := readOpenCodeModel(t, configPath, "glm-5:cloud")
|
||||||
|
limit, ok := entry["limit"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("cloud model limit was not added on re-edit")
|
||||||
|
}
|
||||||
|
if limit["context"] != float64(202_752) {
|
||||||
|
t.Errorf("context = %v, want 202752", limit["context"])
|
||||||
|
}
|
||||||
|
if limit["output"] != float64(131_072) {
|
||||||
|
t.Errorf("output = %v, want 131072", limit["output"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLookupCloudModelLimit(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
wantOK bool
|
||||||
|
wantContext int
|
||||||
|
wantOutput int
|
||||||
|
}{
|
||||||
|
{"glm-4.7", false, 0, 0},
|
||||||
|
{"glm-4.7:cloud", true, 202_752, 131_072},
|
||||||
|
{"glm-5:cloud", true, 202_752, 131_072},
|
||||||
|
{"gpt-oss:120b-cloud", true, 131_072, 131_072},
|
||||||
|
{"gpt-oss:20b-cloud", true, 131_072, 131_072},
|
||||||
|
{"kimi-k2.5", false, 0, 0},
|
||||||
|
{"kimi-k2.5:cloud", true, 262_144, 262_144},
|
||||||
|
{"deepseek-v3.2", false, 0, 0},
|
||||||
|
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
|
||||||
|
{"qwen3.5", false, 0, 0},
|
||||||
|
{"qwen3.5:cloud", true, 262_144, 32_768},
|
||||||
|
{"qwen3-coder:480b", false, 0, 0},
|
||||||
|
{"qwen3-coder:480b:cloud", true, 262_144, 65_536},
|
||||||
|
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
|
||||||
|
{"llama3.2", false, 0, 0},
|
||||||
|
{"unknown-model:cloud", false, 0, 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
l, ok := lookupCloudModelLimit(tt.name)
|
||||||
|
if ok != tt.wantOK {
|
||||||
|
t.Errorf("lookupCloudModelLimit(%q) ok = %v, want %v", tt.name, ok, tt.wantOK)
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
if l.Context != tt.wantContext {
|
||||||
|
t.Errorf("context = %d, want %d", l.Context, tt.wantContext)
|
||||||
|
}
|
||||||
|
if l.Output != tt.wantOutput {
|
||||||
|
t.Errorf("output = %d, want %d", l.Output, tt.wantOutput)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenCodeModels_NoConfig(t *testing.T) {
|
func TestOpenCodeModels_NoConfig(t *testing.T) {
|
||||||
o := &OpenCode{}
|
o := &OpenCode{}
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
265
cmd/launch/pi.go
Normal file
265
cmd/launch/pi.go
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Pi implements Runner and Editor for Pi (Pi Coding Agent) integration
|
||||||
|
type Pi struct{}
|
||||||
|
|
||||||
|
func (p *Pi) String() string { return "Pi" }
|
||||||
|
|
||||||
|
func (p *Pi) Run(model string, args []string) error {
|
||||||
|
if _, err := exec.LookPath("pi"); err != nil {
|
||||||
|
return fmt.Errorf("pi is not installed, install with: npm install -g @mariozechner/pi-coding-agent")
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command("pi", args...)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pi) Paths() []string {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var paths []string
|
||||||
|
modelsPath := filepath.Join(home, ".pi", "agent", "models.json")
|
||||||
|
if _, err := os.Stat(modelsPath); err == nil {
|
||||||
|
paths = append(paths, modelsPath)
|
||||||
|
}
|
||||||
|
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
|
||||||
|
if _, err := os.Stat(settingsPath); err == nil {
|
||||||
|
paths = append(paths, settingsPath)
|
||||||
|
}
|
||||||
|
return paths
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pi) Edit(models []string) error {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
configPath := filepath.Join(home, ".pi", "agent", "models.json")
|
||||||
|
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
config := make(map[string]any)
|
||||||
|
if data, err := os.ReadFile(configPath); err == nil {
|
||||||
|
_ = json.Unmarshal(data, &config)
|
||||||
|
}
|
||||||
|
|
||||||
|
providers, ok := config["providers"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
providers = make(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
|
ollama, ok := providers["ollama"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
ollama = map[string]any{
|
||||||
|
"baseUrl": envconfig.Host().String() + "/v1",
|
||||||
|
"api": "openai-completions",
|
||||||
|
"apiKey": "ollama",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
existingModels, ok := ollama["models"].([]any)
|
||||||
|
if !ok {
|
||||||
|
existingModels = make([]any, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build set of selected models to track which need to be added
|
||||||
|
selectedSet := make(map[string]bool, len(models))
|
||||||
|
for _, m := range models {
|
||||||
|
selectedSet[m] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build new models list:
|
||||||
|
// 1. Keep user-managed models (no _launch marker) - untouched
|
||||||
|
// 2. Keep ollama-managed models (_launch marker) that are still selected,
|
||||||
|
// except stale cloud entries that should be rebuilt below
|
||||||
|
// 3. Add new ollama-managed models
|
||||||
|
var newModels []any
|
||||||
|
for _, m := range existingModels {
|
||||||
|
if modelObj, ok := m.(map[string]any); ok {
|
||||||
|
if id, ok := modelObj["id"].(string); ok {
|
||||||
|
// User-managed model (no _launch marker) - always preserve
|
||||||
|
if !isPiOllamaModel(modelObj) {
|
||||||
|
newModels = append(newModels, m)
|
||||||
|
} else if selectedSet[id] {
|
||||||
|
// Rebuild stale managed cloud entries so createConfig refreshes
|
||||||
|
// the whole entry instead of patching it in place.
|
||||||
|
if !hasContextWindow(modelObj) {
|
||||||
|
if _, ok := lookupCloudModelLimit(id); ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
newModels = append(newModels, m)
|
||||||
|
selectedSet[id] = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add newly selected models that weren't already in the list
|
||||||
|
client := api.NewClient(envconfig.Host(), http.DefaultClient)
|
||||||
|
ctx := context.Background()
|
||||||
|
for _, model := range models {
|
||||||
|
if selectedSet[model] {
|
||||||
|
newModels = append(newModels, createConfig(ctx, client, model))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ollama["models"] = newModels
|
||||||
|
providers["ollama"] = ollama
|
||||||
|
config["providers"] = providers
|
||||||
|
|
||||||
|
configData, err := json.MarshalIndent(config, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := fileutil.WriteWithBackup(configPath, configData); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update settings.json with default provider and model
|
||||||
|
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
|
||||||
|
settings := make(map[string]any)
|
||||||
|
if data, err := os.ReadFile(settingsPath); err == nil {
|
||||||
|
_ = json.Unmarshal(data, &settings)
|
||||||
|
}
|
||||||
|
|
||||||
|
settings["defaultProvider"] = "ollama"
|
||||||
|
settings["defaultModel"] = models[0]
|
||||||
|
|
||||||
|
settingsData, err := json.MarshalIndent(settings, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return fileutil.WriteWithBackup(settingsPath, settingsData)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pi) Models() []string {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
configPath := filepath.Join(home, ".pi", "agent", "models.json")
|
||||||
|
config, err := fileutil.ReadJSON(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
providers, _ := config["providers"].(map[string]any)
|
||||||
|
ollama, _ := providers["ollama"].(map[string]any)
|
||||||
|
models, _ := ollama["models"].([]any)
|
||||||
|
|
||||||
|
var result []string
|
||||||
|
for _, m := range models {
|
||||||
|
if modelObj, ok := m.(map[string]any); ok {
|
||||||
|
if id, ok := modelObj["id"].(string); ok {
|
||||||
|
result = append(result, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
slices.Sort(result)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPiOllamaModel reports whether a model config entry is managed by ollama launch
|
||||||
|
func isPiOllamaModel(cfg map[string]any) bool {
|
||||||
|
if v, ok := cfg["_launch"].(bool); ok && v {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasContextWindow(cfg map[string]any) bool {
|
||||||
|
switch v := cfg["contextWindow"].(type) {
|
||||||
|
case float64:
|
||||||
|
return v > 0
|
||||||
|
case int:
|
||||||
|
return v > 0
|
||||||
|
case int64:
|
||||||
|
return v > 0
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// createConfig builds Pi model config with capability detection
|
||||||
|
func createConfig(ctx context.Context, client *api.Client, modelID string) map[string]any {
|
||||||
|
cfg := map[string]any{
|
||||||
|
"id": modelID,
|
||||||
|
"_launch": true,
|
||||||
|
}
|
||||||
|
if l, ok := lookupCloudModelLimit(modelID); ok {
|
||||||
|
cfg["contextWindow"] = l.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
applyCloudContextFallback := func() {
|
||||||
|
if l, ok := lookupCloudModelLimit(modelID); ok {
|
||||||
|
cfg["contextWindow"] = l.Context
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID})
|
||||||
|
if err != nil {
|
||||||
|
applyCloudContextFallback()
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set input types based on vision capability
|
||||||
|
if slices.Contains(resp.Capabilities, model.CapabilityVision) {
|
||||||
|
cfg["input"] = []string{"text", "image"}
|
||||||
|
} else {
|
||||||
|
cfg["input"] = []string{"text"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set reasoning based on thinking capability
|
||||||
|
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||||
|
cfg["reasoning"] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract context window from ModelInfo. For known cloud models, the
|
||||||
|
// pre-filled shared limit remains unless the server provides a positive value.
|
||||||
|
hasContextWindow := false
|
||||||
|
for key, val := range resp.ModelInfo {
|
||||||
|
if strings.HasSuffix(key, ".context_length") {
|
||||||
|
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
|
||||||
|
cfg["contextWindow"] = int(ctxLen)
|
||||||
|
hasContextWindow = true
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasContextWindow {
|
||||||
|
applyCloudContextFallback()
|
||||||
|
}
|
||||||
|
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
925
cmd/launch/pi_test.go
Normal file
925
cmd/launch/pi_test.go
Normal file
@@ -0,0 +1,925 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPiIntegration(t *testing.T) {
|
||||||
|
pi := &Pi{}
|
||||||
|
|
||||||
|
t.Run("String", func(t *testing.T) {
|
||||||
|
if got := pi.String(); got != "Pi" {
|
||||||
|
t.Errorf("String() = %q, want %q", got, "Pi")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("implements Runner", func(t *testing.T) {
|
||||||
|
var _ Runner = pi
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("implements Editor", func(t *testing.T) {
|
||||||
|
var _ Editor = pi
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPiPaths(t *testing.T) {
|
||||||
|
pi := &Pi{}
|
||||||
|
|
||||||
|
t.Run("returns empty when no config exists", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
paths := pi.Paths()
|
||||||
|
if len(paths) != 0 {
|
||||||
|
t.Errorf("Paths() = %v, want empty", paths)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns path when config exists", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||||
|
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
configPath := filepath.Join(configDir, "models.json")
|
||||||
|
if err := os.WriteFile(configPath, []byte("{}"), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
paths := pi.Paths()
|
||||||
|
if len(paths) != 1 || paths[0] != configPath {
|
||||||
|
t.Errorf("Paths() = %v, want [%s]", paths, configPath)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPiEdit(t *testing.T) {
|
||||||
|
// Mock Ollama server for createConfig calls during Edit
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/api/show" {
|
||||||
|
fmt.Fprintf(w, `{"capabilities":[],"model_info":{}}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
pi := &Pi{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||||
|
configPath := filepath.Join(configDir, "models.json")
|
||||||
|
|
||||||
|
cleanup := func() {
|
||||||
|
os.RemoveAll(configDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
readConfig := func() map[string]any {
|
||||||
|
data, _ := os.ReadFile(configPath)
|
||||||
|
var cfg map[string]any
|
||||||
|
json.Unmarshal(data, &cfg)
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("returns nil for empty models", func(t *testing.T) {
|
||||||
|
if err := pi.Edit([]string{}); err != nil {
|
||||||
|
t.Errorf("Edit([]) error = %v, want nil", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("creates config with models", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
models := []string{"llama3.2", "qwen3:8b"}
|
||||||
|
if err := pi.Edit(models); err != nil {
|
||||||
|
t.Fatalf("Edit() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := readConfig()
|
||||||
|
|
||||||
|
providers, ok := cfg["providers"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Error("Config missing providers")
|
||||||
|
}
|
||||||
|
|
||||||
|
ollama, ok := providers["ollama"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Error("Providers missing ollama")
|
||||||
|
}
|
||||||
|
|
||||||
|
modelsArray, ok := ollama["models"].([]any)
|
||||||
|
if !ok || len(modelsArray) != 2 {
|
||||||
|
t.Errorf("Expected 2 models, got %v", modelsArray)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ollama["baseUrl"] == nil {
|
||||||
|
t.Error("Missing baseUrl")
|
||||||
|
}
|
||||||
|
if ollama["api"] != "openai-completions" {
|
||||||
|
t.Errorf("Expected api=openai-completions, got %v", ollama["api"])
|
||||||
|
}
|
||||||
|
if ollama["apiKey"] != "ollama" {
|
||||||
|
t.Errorf("Expected apiKey=ollama, got %v", ollama["apiKey"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("updates existing config preserving ollama provider settings", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
|
||||||
|
existingConfig := `{
|
||||||
|
"providers": {
|
||||||
|
"ollama": {
|
||||||
|
"baseUrl": "http://custom:8080/v1",
|
||||||
|
"api": "custom-api",
|
||||||
|
"apiKey": "custom-key",
|
||||||
|
"models": [
|
||||||
|
{"id": "old-model", "_launch": true}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
models := []string{"new-model"}
|
||||||
|
if err := pi.Edit(models); err != nil {
|
||||||
|
t.Fatalf("Edit() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := readConfig()
|
||||||
|
providers := cfg["providers"].(map[string]any)
|
||||||
|
ollama := providers["ollama"].(map[string]any)
|
||||||
|
|
||||||
|
if ollama["baseUrl"] != "http://custom:8080/v1" {
|
||||||
|
t.Errorf("Custom baseUrl not preserved, got %v", ollama["baseUrl"])
|
||||||
|
}
|
||||||
|
if ollama["api"] != "custom-api" {
|
||||||
|
t.Errorf("Custom api not preserved, got %v", ollama["api"])
|
||||||
|
}
|
||||||
|
if ollama["apiKey"] != "custom-key" {
|
||||||
|
t.Errorf("Custom apiKey not preserved, got %v", ollama["apiKey"])
|
||||||
|
}
|
||||||
|
|
||||||
|
modelsArray := ollama["models"].([]any)
|
||||||
|
if len(modelsArray) != 1 {
|
||||||
|
t.Errorf("Expected 1 model after update, got %d", len(modelsArray))
|
||||||
|
} else {
|
||||||
|
modelEntry := modelsArray[0].(map[string]any)
|
||||||
|
if modelEntry["id"] != "new-model" {
|
||||||
|
t.Errorf("Expected new-model, got %v", modelEntry["id"])
|
||||||
|
}
|
||||||
|
// Verify _launch marker is present
|
||||||
|
if modelEntry["_launch"] != true {
|
||||||
|
t.Errorf("Expected _launch marker to be true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("rebuilds stale existing managed cloud model", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
|
||||||
|
existingConfig := `{
|
||||||
|
"providers": {
|
||||||
|
"ollama": {
|
||||||
|
"baseUrl": "http://localhost:11434/v1",
|
||||||
|
"api": "openai-completions",
|
||||||
|
"apiKey": "ollama",
|
||||||
|
"models": [
|
||||||
|
{"id": "glm-5:cloud", "_launch": true, "legacyField": "stale"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := pi.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||||
|
t.Fatalf("Edit() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := readConfig()
|
||||||
|
providers := cfg["providers"].(map[string]any)
|
||||||
|
ollama := providers["ollama"].(map[string]any)
|
||||||
|
modelsArray := ollama["models"].([]any)
|
||||||
|
modelEntry := modelsArray[0].(map[string]any)
|
||||||
|
|
||||||
|
if modelEntry["contextWindow"] != float64(202_752) {
|
||||||
|
t.Errorf("contextWindow = %v, want 202752", modelEntry["contextWindow"])
|
||||||
|
}
|
||||||
|
input, ok := modelEntry["input"].([]any)
|
||||||
|
if !ok || len(input) != 1 || input[0] != "text" {
|
||||||
|
t.Errorf("input = %v, want [text]", modelEntry["input"])
|
||||||
|
}
|
||||||
|
if _, ok := modelEntry["legacyField"]; ok {
|
||||||
|
t.Error("legacyField should be removed when stale managed cloud entry is rebuilt")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("replaces old models with new ones", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
|
||||||
|
// Old models must have _launch marker to be managed by us
|
||||||
|
existingConfig := `{
|
||||||
|
"providers": {
|
||||||
|
"ollama": {
|
||||||
|
"baseUrl": "http://localhost:11434/v1",
|
||||||
|
"api": "openai-completions",
|
||||||
|
"apiKey": "ollama",
|
||||||
|
"models": [
|
||||||
|
{"id": "old-model-1", "_launch": true},
|
||||||
|
{"id": "old-model-2", "_launch": true}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newModels := []string{"new-model-1", "new-model-2"}
|
||||||
|
if err := pi.Edit(newModels); err != nil {
|
||||||
|
t.Fatalf("Edit() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := readConfig()
|
||||||
|
providers := cfg["providers"].(map[string]any)
|
||||||
|
ollama := providers["ollama"].(map[string]any)
|
||||||
|
modelsArray := ollama["models"].([]any)
|
||||||
|
|
||||||
|
if len(modelsArray) != 2 {
|
||||||
|
t.Errorf("Expected 2 models, got %d", len(modelsArray))
|
||||||
|
}
|
||||||
|
|
||||||
|
modelIDs := make(map[string]bool)
|
||||||
|
for _, m := range modelsArray {
|
||||||
|
modelObj := m.(map[string]any)
|
||||||
|
id := modelObj["id"].(string)
|
||||||
|
modelIDs[id] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if !modelIDs["new-model-1"] || !modelIDs["new-model-2"] {
|
||||||
|
t.Errorf("Expected new models, got %v", modelIDs)
|
||||||
|
}
|
||||||
|
if modelIDs["old-model-1"] || modelIDs["old-model-2"] {
|
||||||
|
t.Errorf("Old models should have been removed, got %v", modelIDs)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles partial overlap in model list", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
|
||||||
|
// Models must have _launch marker to be managed
|
||||||
|
existingConfig := `{
|
||||||
|
"providers": {
|
||||||
|
"ollama": {
|
||||||
|
"baseUrl": "http://localhost:11434/v1",
|
||||||
|
"api": "openai-completions",
|
||||||
|
"apiKey": "ollama",
|
||||||
|
"models": [
|
||||||
|
{"id": "keep-model", "_launch": true},
|
||||||
|
{"id": "remove-model", "_launch": true}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newModels := []string{"keep-model", "add-model"}
|
||||||
|
if err := pi.Edit(newModels); err != nil {
|
||||||
|
t.Fatalf("Edit() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := readConfig()
|
||||||
|
providers := cfg["providers"].(map[string]any)
|
||||||
|
ollama := providers["ollama"].(map[string]any)
|
||||||
|
modelsArray := ollama["models"].([]any)
|
||||||
|
|
||||||
|
if len(modelsArray) != 2 {
|
||||||
|
t.Errorf("Expected 2 models, got %d", len(modelsArray))
|
||||||
|
}
|
||||||
|
|
||||||
|
modelIDs := make(map[string]bool)
|
||||||
|
for _, m := range modelsArray {
|
||||||
|
modelObj := m.(map[string]any)
|
||||||
|
id := modelObj["id"].(string)
|
||||||
|
modelIDs[id] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if !modelIDs["keep-model"] || !modelIDs["add-model"] {
|
||||||
|
t.Errorf("Expected keep-model and add-model, got %v", modelIDs)
|
||||||
|
}
|
||||||
|
if modelIDs["remove-model"] {
|
||||||
|
t.Errorf("remove-model should have been removed")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles corrupt config gracefully", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
|
||||||
|
if err := os.WriteFile(configPath, []byte("{invalid json}"), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
models := []string{"test-model"}
|
||||||
|
if err := pi.Edit(models); err != nil {
|
||||||
|
t.Fatalf("Edit() should not fail with corrupt config, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg map[string]any
|
||||||
|
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||||
|
t.Fatalf("Config should be valid after Edit, got parse error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
providers := cfg["providers"].(map[string]any)
|
||||||
|
ollama := providers["ollama"].(map[string]any)
|
||||||
|
modelsArray := ollama["models"].([]any)
|
||||||
|
|
||||||
|
if len(modelsArray) != 1 {
|
||||||
|
t.Errorf("Expected 1 model, got %d", len(modelsArray))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// CRITICAL SAFETY TEST: verifies we don't stomp on user configs
|
||||||
|
t.Run("preserves user-managed models without _launch marker", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
|
||||||
|
// User has manually configured models in ollama provider (no _launch marker)
|
||||||
|
existingConfig := `{
|
||||||
|
"providers": {
|
||||||
|
"ollama": {
|
||||||
|
"baseUrl": "http://localhost:11434/v1",
|
||||||
|
"api": "openai-completions",
|
||||||
|
"apiKey": "ollama",
|
||||||
|
"models": [
|
||||||
|
{"id": "user-model-1"},
|
||||||
|
{"id": "user-model-2", "customField": "preserved"},
|
||||||
|
{"id": "ollama-managed", "_launch": true}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a new ollama-managed model
|
||||||
|
newModels := []string{"new-ollama-model"}
|
||||||
|
if err := pi.Edit(newModels); err != nil {
|
||||||
|
t.Fatalf("Edit() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := readConfig()
|
||||||
|
providers := cfg["providers"].(map[string]any)
|
||||||
|
ollama := providers["ollama"].(map[string]any)
|
||||||
|
modelsArray := ollama["models"].([]any)
|
||||||
|
|
||||||
|
// Should have: new-ollama-model (managed) + 2 user models (preserved)
|
||||||
|
if len(modelsArray) != 3 {
|
||||||
|
t.Errorf("Expected 3 models (1 new managed + 2 preserved user models), got %d", len(modelsArray))
|
||||||
|
}
|
||||||
|
|
||||||
|
modelIDs := make(map[string]map[string]any)
|
||||||
|
for _, m := range modelsArray {
|
||||||
|
modelObj := m.(map[string]any)
|
||||||
|
id := modelObj["id"].(string)
|
||||||
|
modelIDs[id] = modelObj
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify new model has _launch marker
|
||||||
|
if m, ok := modelIDs["new-ollama-model"]; !ok {
|
||||||
|
t.Errorf("new-ollama-model should be present")
|
||||||
|
} else if m["_launch"] != true {
|
||||||
|
t.Errorf("new-ollama-model should have _launch marker")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify user models are preserved
|
||||||
|
if _, ok := modelIDs["user-model-1"]; !ok {
|
||||||
|
t.Errorf("user-model-1 should be preserved")
|
||||||
|
}
|
||||||
|
if _, ok := modelIDs["user-model-2"]; !ok {
|
||||||
|
t.Errorf("user-model-2 should be preserved")
|
||||||
|
} else if modelIDs["user-model-2"]["customField"] != "preserved" {
|
||||||
|
t.Errorf("user-model-2 customField should be preserved")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify old ollama-managed model is removed (not in new list)
|
||||||
|
if _, ok := modelIDs["ollama-managed"]; ok {
|
||||||
|
t.Errorf("ollama-managed should be removed (old ollama model not in new selection)")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("updates settings.json with default provider and model", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
|
||||||
|
// Create existing settings with other fields
|
||||||
|
settingsPath := filepath.Join(configDir, "settings.json")
|
||||||
|
existingSettings := `{
|
||||||
|
"theme": "dark",
|
||||||
|
"customSetting": "value",
|
||||||
|
"defaultProvider": "anthropic",
|
||||||
|
"defaultModel": "claude-3"
|
||||||
|
}`
|
||||||
|
if err := os.WriteFile(settingsPath, []byte(existingSettings), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
models := []string{"llama3.2"}
|
||||||
|
if err := pi.Edit(models); err != nil {
|
||||||
|
t.Fatalf("Edit() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(settingsPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read settings: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var settings map[string]any
|
||||||
|
if err := json.Unmarshal(data, &settings); err != nil {
|
||||||
|
t.Fatalf("Failed to parse settings: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify defaultProvider is set to ollama
|
||||||
|
if settings["defaultProvider"] != "ollama" {
|
||||||
|
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify defaultModel is set to first model
|
||||||
|
if settings["defaultModel"] != "llama3.2" {
|
||||||
|
t.Errorf("defaultModel = %v, want llama3.2", settings["defaultModel"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify other fields are preserved
|
||||||
|
if settings["theme"] != "dark" {
|
||||||
|
t.Errorf("theme = %v, want dark (preserved)", settings["theme"])
|
||||||
|
}
|
||||||
|
if settings["customSetting"] != "value" {
|
||||||
|
t.Errorf("customSetting = %v, want value (preserved)", settings["customSetting"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("creates settings.json if it does not exist", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
|
||||||
|
models := []string{"qwen3:8b"}
|
||||||
|
if err := pi.Edit(models); err != nil {
|
||||||
|
t.Fatalf("Edit() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
settingsPath := filepath.Join(configDir, "settings.json")
|
||||||
|
data, err := os.ReadFile(settingsPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("settings.json should be created: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var settings map[string]any
|
||||||
|
if err := json.Unmarshal(data, &settings); err != nil {
|
||||||
|
t.Fatalf("Failed to parse settings: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if settings["defaultProvider"] != "ollama" {
|
||||||
|
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
|
||||||
|
}
|
||||||
|
if settings["defaultModel"] != "qwen3:8b" {
|
||||||
|
t.Errorf("defaultModel = %v, want qwen3:8b", settings["defaultModel"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles corrupt settings.json gracefully", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
|
||||||
|
// Create corrupt settings
|
||||||
|
settingsPath := filepath.Join(configDir, "settings.json")
|
||||||
|
if err := os.WriteFile(settingsPath, []byte("{invalid"), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
models := []string{"test-model"}
|
||||||
|
if err := pi.Edit(models); err != nil {
|
||||||
|
t.Fatalf("Edit() should not fail with corrupt settings, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(settingsPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read settings: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var settings map[string]any
|
||||||
|
if err := json.Unmarshal(data, &settings); err != nil {
|
||||||
|
t.Fatalf("settings.json should be valid after Edit, got parse error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if settings["defaultProvider"] != "ollama" {
|
||||||
|
t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
|
||||||
|
}
|
||||||
|
if settings["defaultModel"] != "test-model" {
|
||||||
|
t.Errorf("defaultModel = %v, want test-model", settings["defaultModel"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPiModels(t *testing.T) {
|
||||||
|
pi := &Pi{}
|
||||||
|
|
||||||
|
t.Run("returns nil when no config exists", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
models := pi.Models()
|
||||||
|
if models != nil {
|
||||||
|
t.Errorf("Models() = %v, want nil", models)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns models from config", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||||
|
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
config := `{
|
||||||
|
"providers": {
|
||||||
|
"ollama": {
|
||||||
|
"models": [
|
||||||
|
{"id": "llama3.2"},
|
||||||
|
{"id": "qwen3:8b"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
configPath := filepath.Join(configDir, "models.json")
|
||||||
|
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
models := pi.Models()
|
||||||
|
if len(models) != 2 {
|
||||||
|
t.Errorf("Models() returned %d models, want 2", len(models))
|
||||||
|
}
|
||||||
|
if models[0] != "llama3.2" || models[1] != "qwen3:8b" {
|
||||||
|
t.Errorf("Models() = %v, want [llama3.2 qwen3:8b] (sorted)", models)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns sorted models", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||||
|
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
config := `{
|
||||||
|
"providers": {
|
||||||
|
"ollama": {
|
||||||
|
"models": [
|
||||||
|
{"id": "z-model"},
|
||||||
|
{"id": "a-model"},
|
||||||
|
{"id": "m-model"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
configPath := filepath.Join(configDir, "models.json")
|
||||||
|
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
models := pi.Models()
|
||||||
|
if models[0] != "a-model" || models[1] != "m-model" || models[2] != "z-model" {
|
||||||
|
t.Errorf("Models() = %v, want [a-model m-model z-model] (sorted)", models)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns nil when models array is missing", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||||
|
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
config := `{
|
||||||
|
"providers": {
|
||||||
|
"ollama": {}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
configPath := filepath.Join(configDir, "models.json")
|
||||||
|
if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
models := pi.Models()
|
||||||
|
if models != nil {
|
||||||
|
t.Errorf("Models() = %v, want nil when models array is missing", models)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles corrupt config gracefully", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
configDir := filepath.Join(tmpDir, ".pi", "agent")
|
||||||
|
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
configPath := filepath.Join(configDir, "models.json")
|
||||||
|
if err := os.WriteFile(configPath, []byte("{invalid json}"), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
models := pi.Models()
|
||||||
|
if models != nil {
|
||||||
|
t.Errorf("Models() = %v, want nil for corrupt config", models)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsPiOllamaModel(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cfg map[string]any
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"with _launch true", map[string]any{"id": "m", "_launch": true}, true},
|
||||||
|
{"with _launch false", map[string]any{"id": "m", "_launch": false}, false},
|
||||||
|
{"without _launch", map[string]any{"id": "m"}, false},
|
||||||
|
{"with _launch non-bool", map[string]any{"id": "m", "_launch": "yes"}, false},
|
||||||
|
{"empty map", map[string]any{}, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := isPiOllamaModel(tt.cfg); got != tt.want {
|
||||||
|
t.Errorf("isPiOllamaModel(%v) = %v, want %v", tt.cfg, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateConfig(t *testing.T) {
|
||||||
|
t.Run("sets vision input when model has vision capability", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/api/show" {
|
||||||
|
fmt.Fprintf(w, `{"capabilities":["vision"],"model_info":{}}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, _ := url.Parse(srv.URL)
|
||||||
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
|
cfg := createConfig(context.Background(), client, "llava:7b")
|
||||||
|
|
||||||
|
if cfg["id"] != "llava:7b" {
|
||||||
|
t.Errorf("id = %v, want llava:7b", cfg["id"])
|
||||||
|
}
|
||||||
|
if cfg["_launch"] != true {
|
||||||
|
t.Error("expected _launch = true")
|
||||||
|
}
|
||||||
|
input, ok := cfg["input"].([]string)
|
||||||
|
if !ok || len(input) != 2 || input[0] != "text" || input[1] != "image" {
|
||||||
|
t.Errorf("input = %v, want [text image]", cfg["input"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("sets text-only input when model lacks vision", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/api/show" {
|
||||||
|
fmt.Fprintf(w, `{"capabilities":["completion"],"model_info":{}}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, _ := url.Parse(srv.URL)
|
||||||
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
|
cfg := createConfig(context.Background(), client, "llama3.2")
|
||||||
|
|
||||||
|
input, ok := cfg["input"].([]string)
|
||||||
|
if !ok || len(input) != 1 || input[0] != "text" {
|
||||||
|
t.Errorf("input = %v, want [text]", cfg["input"])
|
||||||
|
}
|
||||||
|
if _, ok := cfg["reasoning"]; ok {
|
||||||
|
t.Error("reasoning should not be set for non-thinking model")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("sets reasoning when model has thinking capability", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/api/show" {
|
||||||
|
fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{}}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, _ := url.Parse(srv.URL)
|
||||||
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
|
cfg := createConfig(context.Background(), client, "qwq")
|
||||||
|
|
||||||
|
if cfg["reasoning"] != true {
|
||||||
|
t.Error("expected reasoning = true for thinking model")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("extracts context window from model info", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/api/show" {
|
||||||
|
fmt.Fprintf(w, `{"capabilities":[],"model_info":{"llama.context_length":131072}}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, _ := url.Parse(srv.URL)
|
||||||
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
|
cfg := createConfig(context.Background(), client, "llama3.2")
|
||||||
|
|
||||||
|
if cfg["contextWindow"] != 131072 {
|
||||||
|
t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles all capabilities together", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/api/show" {
|
||||||
|
fmt.Fprintf(w, `{"capabilities":["vision","thinking"],"model_info":{"qwen3.context_length":32768}}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, _ := url.Parse(srv.URL)
|
||||||
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
|
cfg := createConfig(context.Background(), client, "qwen3-vision")
|
||||||
|
|
||||||
|
input := cfg["input"].([]string)
|
||||||
|
if len(input) != 2 || input[0] != "text" || input[1] != "image" {
|
||||||
|
t.Errorf("input = %v, want [text image]", input)
|
||||||
|
}
|
||||||
|
if cfg["reasoning"] != true {
|
||||||
|
t.Error("expected reasoning = true")
|
||||||
|
}
|
||||||
|
if cfg["contextWindow"] != 32768 {
|
||||||
|
t.Errorf("contextWindow = %v, want 32768", cfg["contextWindow"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns minimal config when show fails", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, _ := url.Parse(srv.URL)
|
||||||
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
|
cfg := createConfig(context.Background(), client, "missing-model")
|
||||||
|
|
||||||
|
if cfg["id"] != "missing-model" {
|
||||||
|
t.Errorf("id = %v, want missing-model", cfg["id"])
|
||||||
|
}
|
||||||
|
if cfg["_launch"] != true {
|
||||||
|
t.Error("expected _launch = true")
|
||||||
|
}
|
||||||
|
// Should not have capability fields
|
||||||
|
if _, ok := cfg["input"]; ok {
|
||||||
|
t.Error("input should not be set when show fails")
|
||||||
|
}
|
||||||
|
if _, ok := cfg["reasoning"]; ok {
|
||||||
|
t.Error("reasoning should not be set when show fails")
|
||||||
|
}
|
||||||
|
if _, ok := cfg["contextWindow"]; ok {
|
||||||
|
t.Error("contextWindow should not be set when show fails")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cloud model falls back to hardcoded context when show fails", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, _ := url.Parse(srv.URL)
|
||||||
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
|
cfg := createConfig(context.Background(), client, "kimi-k2.5:cloud")
|
||||||
|
|
||||||
|
if cfg["contextWindow"] != 262_144 {
|
||||||
|
t.Errorf("contextWindow = %v, want 262144", cfg["contextWindow"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cloud model falls back to hardcoded context when show omits model info", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/api/show" {
|
||||||
|
fmt.Fprintf(w, `{"capabilities":[],"model_info":{}}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, _ := url.Parse(srv.URL)
|
||||||
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
|
cfg := createConfig(context.Background(), client, "glm-5:cloud")
|
||||||
|
|
||||||
|
if cfg["contextWindow"] != 202_752 {
|
||||||
|
t.Errorf("contextWindow = %v, want 202752", cfg["contextWindow"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cloud model with dash suffix falls back to hardcoded context", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, _ := url.Parse(srv.URL)
|
||||||
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
|
cfg := createConfig(context.Background(), client, "gpt-oss:120b-cloud")
|
||||||
|
|
||||||
|
if cfg["contextWindow"] != 131_072 {
|
||||||
|
t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("skips zero context length", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/api/show" {
|
||||||
|
fmt.Fprintf(w, `{"capabilities":[],"model_info":{"llama.context_length":0}}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, _ := url.Parse(srv.URL)
|
||||||
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
|
cfg := createConfig(context.Background(), client, "test-model")
|
||||||
|
|
||||||
|
if _, ok := cfg["contextWindow"]; ok {
|
||||||
|
t.Error("contextWindow should not be set for zero value")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure Capability constants used in createConfig match expected values
|
||||||
|
func TestPiCapabilityConstants(t *testing.T) {
|
||||||
|
if model.CapabilityVision != "vision" {
|
||||||
|
t.Errorf("CapabilityVision = %q, want %q", model.CapabilityVision, "vision")
|
||||||
|
}
|
||||||
|
if model.CapabilityThinking != "thinking" {
|
||||||
|
t.Errorf("CapabilityThinking = %q, want %q", model.CapabilityThinking, "thinking")
|
||||||
|
}
|
||||||
|
}
|
||||||
355
cmd/launch/registry.go
Normal file
355
cmd/launch/registry.go
Normal file
@@ -0,0 +1,355 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IntegrationInstallSpec describes how launcher should detect and guide installation.
|
||||||
|
type IntegrationInstallSpec struct {
|
||||||
|
CheckInstalled func() bool
|
||||||
|
EnsureInstalled func() error
|
||||||
|
URL string
|
||||||
|
Command []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntegrationSpec is the canonical registry entry for one integration.
|
||||||
|
type IntegrationSpec struct {
|
||||||
|
Name string
|
||||||
|
Runner Runner
|
||||||
|
Aliases []string
|
||||||
|
Hidden bool
|
||||||
|
Description string
|
||||||
|
Install IntegrationInstallSpec
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntegrationInfo contains display information about a registered integration.
|
||||||
|
type IntegrationInfo struct {
|
||||||
|
Name string
|
||||||
|
DisplayName string
|
||||||
|
Description string
|
||||||
|
}
|
||||||
|
|
||||||
|
var launcherIntegrationOrder = []string{"opencode", "droid", "pi", "cline"}
|
||||||
|
|
||||||
|
var integrationSpecs = []*IntegrationSpec{
|
||||||
|
{
|
||||||
|
Name: "claude",
|
||||||
|
Runner: &Claude{},
|
||||||
|
Description: "Anthropic's coding tool with subagents",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := (&Claude{}).findPath()
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
URL: "https://code.claude.com/docs/en/quickstart",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "cline",
|
||||||
|
Runner: &Cline{},
|
||||||
|
Description: "Autonomous coding agent with parallel execution",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := exec.LookPath("cline")
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
Command: []string{"npm", "install", "-g", "cline"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "codex",
|
||||||
|
Runner: &Codex{},
|
||||||
|
Description: "OpenAI's open-source coding agent",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := exec.LookPath("codex")
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
URL: "https://developers.openai.com/codex/cli/",
|
||||||
|
Command: []string{"npm", "install", "-g", "@openai/codex"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "droid",
|
||||||
|
Runner: &Droid{},
|
||||||
|
Description: "Factory's coding agent across terminal and IDEs",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := exec.LookPath("droid")
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
URL: "https://docs.factory.ai/cli/getting-started/quickstart",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "opencode",
|
||||||
|
Runner: &OpenCode{},
|
||||||
|
Description: "Anomaly's open-source coding agent",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := exec.LookPath("opencode")
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
URL: "https://opencode.ai",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "openclaw",
|
||||||
|
Runner: &Openclaw{},
|
||||||
|
Aliases: []string{"clawdbot", "moltbot"},
|
||||||
|
Description: "Personal AI with 100+ skills",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
if _, err := exec.LookPath("openclaw"); err == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if _, err := exec.LookPath("clawdbot"); err == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
},
|
||||||
|
EnsureInstalled: func() error {
|
||||||
|
_, err := ensureOpenclawInstalled()
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
URL: "https://docs.openclaw.ai",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "pi",
|
||||||
|
Runner: &Pi{},
|
||||||
|
Description: "Minimal AI agent toolkit with plugin support",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := exec.LookPath("pi")
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
Command: []string{"npm", "install", "-g", "@mariozechner/pi-coding-agent"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var integrationSpecsByName map[string]*IntegrationSpec
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rebuildIntegrationSpecIndexes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func hyperlink(url, text string) string {
|
||||||
|
return fmt.Sprintf("\033]8;;%s\033\\%s\033]8;;\033\\", url, text)
|
||||||
|
}
|
||||||
|
|
||||||
|
func rebuildIntegrationSpecIndexes() {
|
||||||
|
integrationSpecsByName = make(map[string]*IntegrationSpec, len(integrationSpecs))
|
||||||
|
|
||||||
|
canonical := make(map[string]bool, len(integrationSpecs))
|
||||||
|
for _, spec := range integrationSpecs {
|
||||||
|
key := strings.ToLower(spec.Name)
|
||||||
|
if key == "" {
|
||||||
|
panic("launch: integration spec missing name")
|
||||||
|
}
|
||||||
|
if canonical[key] {
|
||||||
|
panic(fmt.Sprintf("launch: duplicate integration name %q", key))
|
||||||
|
}
|
||||||
|
canonical[key] = true
|
||||||
|
integrationSpecsByName[key] = spec
|
||||||
|
}
|
||||||
|
|
||||||
|
seenAliases := make(map[string]string)
|
||||||
|
for _, spec := range integrationSpecs {
|
||||||
|
for _, alias := range spec.Aliases {
|
||||||
|
key := strings.ToLower(alias)
|
||||||
|
if key == "" {
|
||||||
|
panic(fmt.Sprintf("launch: integration %q has empty alias", spec.Name))
|
||||||
|
}
|
||||||
|
if canonical[key] {
|
||||||
|
panic(fmt.Sprintf("launch: alias %q collides with canonical integration name", key))
|
||||||
|
}
|
||||||
|
if owner, exists := seenAliases[key]; exists {
|
||||||
|
panic(fmt.Sprintf("launch: alias %q collides between %q and %q", key, owner, spec.Name))
|
||||||
|
}
|
||||||
|
seenAliases[key] = spec.Name
|
||||||
|
integrationSpecsByName[key] = spec
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
orderSeen := make(map[string]bool, len(launcherIntegrationOrder))
|
||||||
|
for _, name := range launcherIntegrationOrder {
|
||||||
|
key := strings.ToLower(name)
|
||||||
|
if orderSeen[key] {
|
||||||
|
panic(fmt.Sprintf("launch: duplicate launcher order entry %q", key))
|
||||||
|
}
|
||||||
|
orderSeen[key] = true
|
||||||
|
|
||||||
|
spec, ok := integrationSpecsByName[key]
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Sprintf("launch: unknown launcher order entry %q", key))
|
||||||
|
}
|
||||||
|
if spec.Name != key {
|
||||||
|
panic(fmt.Sprintf("launch: launcher order entry %q must use canonical name, not alias", key))
|
||||||
|
}
|
||||||
|
if spec.Hidden {
|
||||||
|
panic(fmt.Sprintf("launch: hidden integration %q cannot appear in launcher order", key))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LookupIntegrationSpec resolves either a canonical integration name or alias to its spec.
|
||||||
|
func LookupIntegrationSpec(name string) (*IntegrationSpec, error) {
|
||||||
|
spec, ok := integrationSpecsByName[strings.ToLower(name)]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unknown integration: %s", name)
|
||||||
|
}
|
||||||
|
return spec, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LookupIntegration resolves a registry name to the canonical key and runner.
|
||||||
|
func LookupIntegration(name string) (string, Runner, error) {
|
||||||
|
spec, err := LookupIntegrationSpec(name)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
return spec.Name, spec.Runner, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListVisibleIntegrationSpecs returns the canonical integrations that should appear in interactive UIs.
|
||||||
|
func ListVisibleIntegrationSpecs() []IntegrationSpec {
|
||||||
|
visible := make([]IntegrationSpec, 0, len(integrationSpecs))
|
||||||
|
for _, spec := range integrationSpecs {
|
||||||
|
if spec.Hidden {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
visible = append(visible, *spec)
|
||||||
|
}
|
||||||
|
|
||||||
|
orderRank := make(map[string]int, len(launcherIntegrationOrder))
|
||||||
|
for i, name := range launcherIntegrationOrder {
|
||||||
|
orderRank[name] = i + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
slices.SortFunc(visible, func(a, b IntegrationSpec) int {
|
||||||
|
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
|
||||||
|
if aRank > 0 && bRank > 0 {
|
||||||
|
return aRank - bRank
|
||||||
|
}
|
||||||
|
if aRank > 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if bRank > 0 {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return strings.Compare(a.Name, b.Name)
|
||||||
|
})
|
||||||
|
|
||||||
|
return visible
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListIntegrationInfos returns the registered integrations in launcher display order.
|
||||||
|
func ListIntegrationInfos() []IntegrationInfo {
|
||||||
|
visible := ListVisibleIntegrationSpecs()
|
||||||
|
infos := make([]IntegrationInfo, 0, len(visible))
|
||||||
|
for _, spec := range visible {
|
||||||
|
infos = append(infos, IntegrationInfo{
|
||||||
|
Name: spec.Name,
|
||||||
|
DisplayName: spec.Runner.String(),
|
||||||
|
Description: spec.Description,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return infos
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntegrationSelectionItems returns the sorted integration items shown by launcher selection UIs.
|
||||||
|
func IntegrationSelectionItems() ([]ModelItem, error) {
|
||||||
|
visible := ListVisibleIntegrationSpecs()
|
||||||
|
if len(visible) == 0 {
|
||||||
|
return nil, fmt.Errorf("no integrations available")
|
||||||
|
}
|
||||||
|
|
||||||
|
items := make([]ModelItem, 0, len(visible))
|
||||||
|
for _, spec := range visible {
|
||||||
|
description := spec.Runner.String()
|
||||||
|
if conn, err := loadStoredIntegrationConfig(spec.Name); err == nil && len(conn.Models) > 0 {
|
||||||
|
description = fmt.Sprintf("%s (%s)", spec.Runner.String(), conn.Models[0])
|
||||||
|
}
|
||||||
|
items = append(items, ModelItem{Name: spec.Name, Description: description})
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsIntegrationInstalled checks if an integration binary is installed.
|
||||||
|
func IsIntegrationInstalled(name string) bool {
|
||||||
|
integration, err := integrationFor(name)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Ollama couldn't find integration %q, so it'll show up as not installed.\n", name)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return integration.installed
|
||||||
|
}
|
||||||
|
|
||||||
|
// integration is resolved registry metadata used by launcher state and install checks.
|
||||||
|
// It combines immutable registry spec data with computed runtime traits.
|
||||||
|
type integration struct {
|
||||||
|
spec *IntegrationSpec
|
||||||
|
installed bool
|
||||||
|
autoInstallable bool
|
||||||
|
editor bool
|
||||||
|
installHint string
|
||||||
|
}
|
||||||
|
|
||||||
|
// integrationFor resolves an integration name into the canonical spec plus
|
||||||
|
// derived launcher/install traits used across registry and launch flows.
|
||||||
|
func integrationFor(name string) (integration, error) {
|
||||||
|
spec, err := LookupIntegrationSpec(name)
|
||||||
|
if err != nil {
|
||||||
|
return integration{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
installed := true
|
||||||
|
if spec.Install.CheckInstalled != nil {
|
||||||
|
installed = spec.Install.CheckInstalled()
|
||||||
|
}
|
||||||
|
|
||||||
|
_, editor := spec.Runner.(Editor)
|
||||||
|
hint := ""
|
||||||
|
if spec.Install.URL != "" {
|
||||||
|
hint = "Install from " + hyperlink(spec.Install.URL, spec.Install.URL)
|
||||||
|
} else if len(spec.Install.Command) > 0 {
|
||||||
|
hint = "Install with: " + strings.Join(spec.Install.Command, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
return integration{
|
||||||
|
spec: spec,
|
||||||
|
installed: installed,
|
||||||
|
autoInstallable: spec.Install.EnsureInstalled != nil,
|
||||||
|
editor: editor,
|
||||||
|
installHint: hint,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnsureIntegrationInstalled installs auto-installable integrations when missing.
|
||||||
|
func EnsureIntegrationInstalled(name string, runner Runner) error {
|
||||||
|
integration, err := integrationFor(name)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%s is not installed", runner)
|
||||||
|
}
|
||||||
|
|
||||||
|
if integration.installed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if integration.autoInstallable {
|
||||||
|
return integration.spec.Install.EnsureInstalled()
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case integration.spec.Install.URL != "":
|
||||||
|
return fmt.Errorf("%s is not installed, install from %s", integration.spec.Name, integration.spec.Install.URL)
|
||||||
|
case len(integration.spec.Install.Command) > 0:
|
||||||
|
return fmt.Errorf("%s is not installed, install with: %s", integration.spec.Name, strings.Join(integration.spec.Install.Command, " "))
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%s is not installed", runner)
|
||||||
|
}
|
||||||
|
}
|
||||||
21
cmd/launch/registry_test_helpers_test.go
Normal file
21
cmd/launch/registry_test_helpers_test.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// OverrideIntegration replaces one registry entry's runner for tests and returns a restore function.
|
||||||
|
func OverrideIntegration(name string, runner Runner) func() {
|
||||||
|
spec, err := LookupIntegrationSpec(name)
|
||||||
|
if err != nil {
|
||||||
|
key := strings.ToLower(name)
|
||||||
|
integrationSpecsByName[key] = &IntegrationSpec{Name: key, Runner: runner}
|
||||||
|
return func() {
|
||||||
|
delete(integrationSpecsByName, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
original := spec.Runner
|
||||||
|
spec.Runner = runner
|
||||||
|
return func() {
|
||||||
|
spec.Runner = original
|
||||||
|
}
|
||||||
|
}
|
||||||
68
cmd/launch/runner_exec_only_test.go
Normal file
68
cmd/launch/runner_exec_only_test.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
binary string
|
||||||
|
runner Runner
|
||||||
|
checkPath func(home string) string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "droid",
|
||||||
|
binary: "droid",
|
||||||
|
runner: &Droid{},
|
||||||
|
checkPath: func(home string) string {
|
||||||
|
return filepath.Join(home, ".factory", "settings.json")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "opencode",
|
||||||
|
binary: "opencode",
|
||||||
|
runner: &OpenCode{},
|
||||||
|
checkPath: func(home string) string {
|
||||||
|
return filepath.Join(home, ".config", "opencode", "opencode.json")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cline",
|
||||||
|
binary: "cline",
|
||||||
|
runner: &Cline{},
|
||||||
|
checkPath: func(home string) string {
|
||||||
|
return filepath.Join(home, ".cline", "data", "globalState.json")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "pi",
|
||||||
|
binary: "pi",
|
||||||
|
runner: &Pi{},
|
||||||
|
checkPath: func(home string) string {
|
||||||
|
return filepath.Join(home, ".pi", "agent", "models.json")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
home := t.TempDir()
|
||||||
|
setTestHome(t, home)
|
||||||
|
|
||||||
|
binDir := t.TempDir()
|
||||||
|
writeFakeBinary(t, binDir, tt.binary)
|
||||||
|
t.Setenv("PATH", binDir)
|
||||||
|
|
||||||
|
configPath := tt.checkPath(home)
|
||||||
|
if err := tt.runner.Run("llama3.2", nil); err != nil {
|
||||||
|
t.Fatalf("Run returned error: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(configPath); !os.IsNotExist(err) {
|
||||||
|
t.Fatalf("expected Run to leave %s untouched, got err=%v", configPath, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
103
cmd/launch/selector_hooks.go
Normal file
103
cmd/launch/selector_hooks.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"golang.org/x/term"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ANSI escape sequences for terminal formatting.
|
||||||
|
const (
|
||||||
|
ansiBold = "\033[1m"
|
||||||
|
ansiReset = "\033[0m"
|
||||||
|
ansiGray = "\033[37m"
|
||||||
|
ansiGreen = "\033[32m"
|
||||||
|
ansiYellow = "\033[33m"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrCancelled is returned when the user cancels a selection.
|
||||||
|
var ErrCancelled = errors.New("cancelled")
|
||||||
|
|
||||||
|
// errCancelled is kept as an internal alias for existing call sites.
|
||||||
|
var errCancelled = ErrCancelled
|
||||||
|
|
||||||
|
// DefaultConfirmPrompt provides a TUI-based confirmation prompt.
|
||||||
|
// When set, ConfirmPrompt delegates to it instead of using raw terminal I/O.
|
||||||
|
var DefaultConfirmPrompt func(prompt string) (bool, error)
|
||||||
|
|
||||||
|
// SingleSelector is a function type for single item selection.
|
||||||
|
// current is the name of the previously selected item to highlight; empty means no pre-selection.
|
||||||
|
type SingleSelector func(title string, items []ModelItem, current string) (string, error)
|
||||||
|
|
||||||
|
// MultiSelector is a function type for multi item selection.
|
||||||
|
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
|
||||||
|
|
||||||
|
// DefaultSingleSelector is the default single-select implementation.
|
||||||
|
var DefaultSingleSelector SingleSelector
|
||||||
|
|
||||||
|
// DefaultMultiSelector is the default multi-select implementation.
|
||||||
|
var DefaultMultiSelector MultiSelector
|
||||||
|
|
||||||
|
// DefaultSignIn provides a TUI-based sign-in flow.
|
||||||
|
// When set, ensureAuth uses it instead of plain text prompts.
|
||||||
|
// Returns the signed-in username or an error.
|
||||||
|
var DefaultSignIn func(modelName, signInURL string) (string, error)
|
||||||
|
|
||||||
|
type launchConfirmPolicy struct {
|
||||||
|
yes bool
|
||||||
|
requireYesMessage bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var currentLaunchConfirmPolicy launchConfirmPolicy
|
||||||
|
|
||||||
|
func withLaunchConfirmPolicy(policy launchConfirmPolicy) func() {
|
||||||
|
old := currentLaunchConfirmPolicy
|
||||||
|
currentLaunchConfirmPolicy = policy
|
||||||
|
return func() {
|
||||||
|
currentLaunchConfirmPolicy = old
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfirmPrompt is the shared confirmation gate for launch flows (integration
|
||||||
|
// edits, missing-model pulls, sign-in prompts, OpenClaw install/security, etc).
|
||||||
|
// Behavior is controlled by currentLaunchConfirmPolicy, typically scoped by
|
||||||
|
// withLaunchConfirmPolicy in LaunchCmd (e.g. auto-approve with --yes).
|
||||||
|
func ConfirmPrompt(prompt string) (bool, error) {
|
||||||
|
if currentLaunchConfirmPolicy.yes {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
if currentLaunchConfirmPolicy.requireYesMessage {
|
||||||
|
return false, fmt.Errorf("%s requires confirmation; re-run with --yes to continue", prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
if DefaultConfirmPrompt != nil {
|
||||||
|
return DefaultConfirmPrompt(prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
fd := int(os.Stdin.Fd())
|
||||||
|
oldState, err := term.MakeRaw(fd)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
defer term.Restore(fd, oldState)
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
|
||||||
|
|
||||||
|
buf := make([]byte, 1)
|
||||||
|
for {
|
||||||
|
if _, err := os.Stdin.Read(buf); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch buf[0] {
|
||||||
|
case 'Y', 'y', 13:
|
||||||
|
fmt.Fprintf(os.Stderr, "yes\r\n")
|
||||||
|
return true, nil
|
||||||
|
case 'N', 'n', 27, 3:
|
||||||
|
fmt.Fprintf(os.Stderr, "no\r\n")
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user