Compare commits
193 Commits
pdevine/sa
...
hoyyeva/op
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7a3ed0a1b4 | ||
|
|
03f9e57274 | ||
|
|
30d9100fff | ||
|
|
698e04a14b | ||
|
|
1d9537bc33 | ||
|
|
120424d832 | ||
|
|
5818001610 | ||
|
|
2cba7756c5 | ||
|
|
bf2a421727 | ||
|
|
f3cf6b75fb | ||
|
|
5dfac387a6 | ||
|
|
a99e5d9c22 | ||
|
|
0abf3aca36 | ||
|
|
ee0266462a | ||
|
|
c88fb286ec | ||
|
|
d3da29cbfc | ||
|
|
1b70bb8a10 | ||
|
|
ec29ce4ce3 | ||
|
|
4d75f5da03 | ||
|
|
798fd09bfe | ||
|
|
9330bb9120 | ||
|
|
40a1317dfd | ||
|
|
fdfe9cec98 | ||
|
|
9517864603 | ||
|
|
8e6d86dbe3 | ||
|
|
80d3744c5d | ||
|
|
2a94f03823 | ||
|
|
eb97274e5c | ||
|
|
6b5db12aa2 | ||
|
|
612f0a17d3 | ||
|
|
673726fa0e | ||
|
|
b5918f9785 | ||
|
|
d17f482d50 | ||
|
|
4e16f562c0 | ||
|
|
55308f1421 | ||
|
|
d64812eb5d | ||
|
|
f86a969f27 | ||
|
|
9fa80a1660 | ||
|
|
dde09129d1 | ||
|
|
780556c4d0 | ||
|
|
dfae363b5b | ||
|
|
30fdd229a4 | ||
|
|
e823bff873 | ||
|
|
8968740836 | ||
|
|
8c8f8f3450 | ||
|
|
82f0139587 | ||
|
|
26a58b294c | ||
|
|
34a790a2e6 | ||
|
|
4589fa2cf5 | ||
|
|
4bc2728047 | ||
|
|
49d5fd5a3e | ||
|
|
3cd2b03a5e | ||
|
|
c8e0878814 | ||
|
|
bb0c58e134 | ||
|
|
036ed1b9b5 | ||
|
|
3536ef58f6 | ||
|
|
de9673ac3f | ||
|
|
96b202d34b | ||
|
|
79865e6c5a | ||
|
|
5ab10d347a | ||
|
|
a8292dd85f | ||
|
|
cb0033598e | ||
|
|
4d14b0ff92 | ||
|
|
d9cb70c270 | ||
|
|
31f968fe1f | ||
|
|
b7bda92d52 | ||
|
|
8e54823fd3 | ||
|
|
7c8da5679e | ||
|
|
6214103e66 | ||
|
|
9e7cb9697e | ||
|
|
3824e380a8 | ||
|
|
c9b2dcfc52 | ||
|
|
b00bd1dfd4 | ||
|
|
ac83ac20c4 | ||
|
|
e7ccc129ea | ||
|
|
69ed0c2729 | ||
|
|
1cefa749aa | ||
|
|
aec2fef95d | ||
|
|
366625a831 | ||
|
|
516ebd8548 | ||
|
|
f567abc63f | ||
|
|
1adfc27f04 | ||
|
|
4a2b9f9dbc | ||
|
|
e46b67a6cc | ||
|
|
c000afe76c | ||
|
|
9d7b18f81e | ||
|
|
4f5999fd3f | ||
|
|
ac5f0dbb6a | ||
|
|
d1151e18a1 | ||
|
|
ebbce136c7 | ||
|
|
26b9f53f8e | ||
|
|
7575438366 | ||
|
|
7d7c90d702 | ||
|
|
4fda69809a | ||
|
|
c9b5da6b0c | ||
|
|
de5cb7311f | ||
|
|
95ee7fbd29 | ||
|
|
ec55536734 | ||
|
|
77491439c2 | ||
|
|
b166b36cd2 | ||
|
|
c2b0bb7a52 | ||
|
|
22c2bdbd8a | ||
|
|
6df6d097d9 | ||
|
|
d7c176ab91 | ||
|
|
0ff7d724ff | ||
|
|
46cb7795e1 | ||
|
|
126d8db7f3 | ||
|
|
3f3a24b418 | ||
|
|
96e36c0d90 | ||
|
|
6f8ddbb26b | ||
|
|
b5e7888414 | ||
|
|
eab4d22269 | ||
|
|
5759c2d2d2 | ||
|
|
42b1c2642b | ||
|
|
727d69ddf3 | ||
|
|
f622b0c5fc | ||
|
|
5d0000634c | ||
|
|
676d9845ba | ||
|
|
e37a9b4c01 | ||
|
|
d727aacd04 | ||
|
|
fa69b833cd | ||
|
|
bbbad97686 | ||
|
|
bcf6d55b54 | ||
|
|
810d4f9c22 | ||
|
|
856c047a6c | ||
|
|
79c1e93c00 | ||
|
|
f8b657c967 | ||
|
|
10fefe0d57 | ||
|
|
2f9a68f9e9 | ||
|
|
3980c0217d | ||
|
|
870599f5da | ||
|
|
abf8e8e9c8 | ||
|
|
f3f31a8192 | ||
|
|
9e7ba835da | ||
|
|
347f17b8d1 | ||
|
|
081b9eb423 | ||
|
|
bb867c6fdb | ||
|
|
81f4506a61 | ||
|
|
76925f1284 | ||
|
|
f676231de9 | ||
|
|
af5f7c0a9e | ||
|
|
a6b27d776b | ||
|
|
539741199e | ||
|
|
8f45236d09 | ||
|
|
97013a190c | ||
|
|
c222735c02 | ||
|
|
87d21c7fc0 | ||
|
|
54e05172a0 | ||
|
|
464186e995 | ||
|
|
8c4d5d6c2f | ||
|
|
bc72b14016 | ||
|
|
61086083eb | ||
|
|
62d1f01ab4 | ||
|
|
10e51c5177 | ||
|
|
3e06bde643 | ||
|
|
6be2de8214 | ||
|
|
ebb1b9ec14 | ||
|
|
d126467d5d | ||
|
|
afb4c62fbf | ||
|
|
e790dc435b | ||
|
|
288077c3a3 | ||
|
|
4425c54eda | ||
|
|
778899a5d2 | ||
|
|
4eab60c1e2 | ||
|
|
1af850e6e3 | ||
|
|
9b0c7cc7b9 | ||
|
|
6928630601 | ||
|
|
9896e3627f | ||
|
|
15732f0ea7 | ||
|
|
562c76d7cc | ||
|
|
122c68c151 | ||
|
|
82848a7806 | ||
|
|
39982a954e | ||
|
|
e9f6ea232f | ||
|
|
110eff01a9 | ||
|
|
799e51d419 | ||
|
|
e8fcb29586 | ||
|
|
97d2f05a6d | ||
|
|
8207e55ec7 | ||
|
|
ad16bffc7d | ||
|
|
c1e3ef4bcc | ||
|
|
a3093cd5e5 | ||
|
|
23d4cad1a2 | ||
|
|
86513cb697 | ||
|
|
3490e9590b | ||
|
|
8da09b1e7e | ||
|
|
a60b9adcce | ||
|
|
a16f96658b | ||
|
|
18ab09b431 | ||
|
|
638faeac54 | ||
|
|
dd5eb6337d | ||
|
|
79917cf80b | ||
|
|
cc90a035a0 |
69
.github/workflows/release.yaml
vendored
@@ -27,7 +27,7 @@ jobs:
|
|||||||
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
|
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
|
||||||
|
|
||||||
darwin-build:
|
darwin-build:
|
||||||
runs-on: macos-14-xlarge
|
runs-on: macos-26-xlarge
|
||||||
environment: release
|
environment: release
|
||||||
needs: setup-environment
|
needs: setup-environment
|
||||||
env:
|
env:
|
||||||
@@ -117,6 +117,25 @@ jobs:
|
|||||||
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
||||||
flags: ''
|
flags: ''
|
||||||
runner_dir: 'vulkan'
|
runner_dir: 'vulkan'
|
||||||
|
- os: windows
|
||||||
|
arch: amd64
|
||||||
|
preset: 'MLX CUDA 13'
|
||||||
|
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
|
||||||
|
cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip
|
||||||
|
cuda-components:
|
||||||
|
- '"cudart"'
|
||||||
|
- '"nvcc"'
|
||||||
|
- '"cublas"'
|
||||||
|
- '"cublas_dev"'
|
||||||
|
- '"cufft"'
|
||||||
|
- '"cufft_dev"'
|
||||||
|
- '"nvrtc"'
|
||||||
|
- '"nvrtc_dev"'
|
||||||
|
- '"crt"'
|
||||||
|
- '"nvvm"'
|
||||||
|
- '"nvptxcompiler"'
|
||||||
|
cuda-version: '13.0'
|
||||||
|
flags: ''
|
||||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||||
environment: release
|
environment: release
|
||||||
env:
|
env:
|
||||||
@@ -125,8 +144,10 @@ jobs:
|
|||||||
- name: Install system dependencies
|
- name: Install system dependencies
|
||||||
run: |
|
run: |
|
||||||
choco install -y --no-progress ccache ninja
|
choco install -y --no-progress ccache ninja
|
||||||
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
if (Get-Command ccache -ErrorAction SilentlyContinue) {
|
||||||
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan')
|
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
||||||
|
}
|
||||||
|
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan') || startsWith(matrix.preset, 'MLX ')
|
||||||
id: cache-install
|
id: cache-install
|
||||||
uses: actions/cache/restore@v4
|
uses: actions/cache/restore@v4
|
||||||
with:
|
with:
|
||||||
@@ -134,8 +155,9 @@ jobs:
|
|||||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||||
C:\Program Files\AMD\ROCm
|
C:\Program Files\AMD\ROCm
|
||||||
C:\VulkanSDK
|
C:\VulkanSDK
|
||||||
key: ${{ matrix.install }}
|
C:\Program Files\NVIDIA\CUDNN
|
||||||
- if: startsWith(matrix.preset, 'CUDA ')
|
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||||
|
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'MLX ')
|
||||||
name: Install CUDA ${{ matrix.cuda-version }}
|
name: Install CUDA ${{ matrix.cuda-version }}
|
||||||
run: |
|
run: |
|
||||||
$ErrorActionPreference = "Stop"
|
$ErrorActionPreference = "Stop"
|
||||||
@@ -179,6 +201,23 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
- if: startsWith(matrix.preset, 'MLX ')
|
||||||
|
name: Install cuDNN for MLX
|
||||||
|
run: |
|
||||||
|
$ErrorActionPreference = "Stop"
|
||||||
|
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
|
||||||
|
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||||
|
Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip"
|
||||||
|
Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted
|
||||||
|
$cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName
|
||||||
|
New-Item -ItemType Directory -Force -Path $cudnnRoot
|
||||||
|
Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse
|
||||||
|
}
|
||||||
|
|
||||||
|
echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||||
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
||||||
uses: actions/cache/save@v4
|
uses: actions/cache/save@v4
|
||||||
with:
|
with:
|
||||||
@@ -186,7 +225,8 @@ jobs:
|
|||||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||||
C:\Program Files\AMD\ROCm
|
C:\Program Files\AMD\ROCm
|
||||||
C:\VulkanSDK
|
C:\VulkanSDK
|
||||||
key: ${{ matrix.install }}
|
C:\Program Files\NVIDIA\CUDNN
|
||||||
|
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/cache@v4
|
- uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
@@ -198,7 +238,7 @@ jobs:
|
|||||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}"
|
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}"
|
||||||
cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}"
|
cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}"
|
||||||
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
|
cmake --install build --component "${{ startsWith(matrix.preset, 'MLX ') && 'MLX' || startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
|
||||||
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
|
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
|
||||||
env:
|
env:
|
||||||
CMAKE_GENERATOR: Ninja
|
CMAKE_GENERATOR: Ninja
|
||||||
@@ -384,6 +424,7 @@ jobs:
|
|||||||
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
|
lib/ollama/include*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||||
@@ -543,11 +584,19 @@ jobs:
|
|||||||
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg dist/*.ps1 dist/*.sh ; do
|
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg dist/*.ps1 dist/*.sh ; do
|
||||||
echo "Uploading $payload"
|
echo "Uploading $payload"
|
||||||
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
|
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
|
||||||
pids[$!]=$!
|
pids+=($!)
|
||||||
sleep 1
|
sleep 1
|
||||||
done
|
done
|
||||||
echo "Waiting for uploads to complete"
|
echo "Waiting for uploads to complete"
|
||||||
for pid in "${pids[*]}"; do
|
failed=0
|
||||||
wait $pid
|
for pid in "${pids[@]}"; do
|
||||||
|
if ! wait $pid; then
|
||||||
|
echo "::error::Upload failed (pid $pid)"
|
||||||
|
failed=1
|
||||||
|
fi
|
||||||
done
|
done
|
||||||
|
if [ $failed -ne 0 ]; then
|
||||||
|
echo "One or more uploads failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
echo "done"
|
echo "done"
|
||||||
|
|||||||
71
.github/workflows/test.yaml
vendored
@@ -37,7 +37,7 @@ jobs:
|
|||||||
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
|
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
|
||||||
}
|
}
|
||||||
|
|
||||||
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
|
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*' '.github/**/*') | tee -a $GITHUB_OUTPUT
|
||||||
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
|
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
|
||||||
|
|
||||||
linux:
|
linux:
|
||||||
@@ -51,7 +51,7 @@ jobs:
|
|||||||
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
|
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
|
||||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
||||||
- preset: ROCm
|
- preset: ROCm
|
||||||
container: rocm/dev-ubuntu-22.04:6.1.2
|
container: rocm/dev-ubuntu-22.04:7.2.1
|
||||||
extra-packages: rocm-libs
|
extra-packages: rocm-libs
|
||||||
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm'
|
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm'
|
||||||
- preset: Vulkan
|
- preset: Vulkan
|
||||||
@@ -60,6 +60,11 @@ jobs:
|
|||||||
mesa-vulkan-drivers vulkan-tools
|
mesa-vulkan-drivers vulkan-tools
|
||||||
libvulkan1 libvulkan-dev
|
libvulkan1 libvulkan-dev
|
||||||
vulkan-sdk cmake ccache g++ make
|
vulkan-sdk cmake ccache g++ make
|
||||||
|
- preset: 'MLX CUDA 13'
|
||||||
|
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
|
||||||
|
extra-packages: libcudnn9-dev-cuda-13 libopenblas-dev liblapack-dev liblapacke-dev git curl
|
||||||
|
flags: '-DCMAKE_CUDA_ARCHITECTURES=87 -DBLAS_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu -DLAPACK_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu'
|
||||||
|
install-go: true
|
||||||
runs-on: linux
|
runs-on: linux
|
||||||
container: ${{ matrix.container }}
|
container: ${{ matrix.container }}
|
||||||
steps:
|
steps:
|
||||||
@@ -76,19 +81,29 @@ jobs:
|
|||||||
$sudo apt-get update
|
$sudo apt-get update
|
||||||
fi
|
fi
|
||||||
$sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }}
|
$sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }}
|
||||||
|
# MLX requires CMake 3.25+, install from official releases
|
||||||
|
if [ "${{ matrix.preset }}" = "MLX CUDA 13" ]; then
|
||||||
|
curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.31.2/cmake-3.31.2-linux-$(uname -m).tar.gz | $sudo tar xz -C /usr/local --strip-components 1
|
||||||
|
fi
|
||||||
# Export VULKAN_SDK if provided by LunarG package (defensive)
|
# Export VULKAN_SDK if provided by LunarG package (defensive)
|
||||||
if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then
|
if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then
|
||||||
echo "VULKAN_SDK=/usr" >> $GITHUB_ENV
|
echo "VULKAN_SDK=/usr" >> $GITHUB_ENV
|
||||||
fi
|
fi
|
||||||
env:
|
env:
|
||||||
DEBIAN_FRONTEND: noninteractive
|
DEBIAN_FRONTEND: noninteractive
|
||||||
|
- if: matrix.install-go
|
||||||
|
name: Install Go
|
||||||
|
run: |
|
||||||
|
GO_VERSION=$(awk '/^go / { print $2 }' go.mod)
|
||||||
|
curl -fsSL "https://golang.org/dl/go${GO_VERSION}.linux-$(dpkg --print-architecture).tar.gz" | tar xz -C /usr/local
|
||||||
|
echo "/usr/local/go/bin" >> $GITHUB_PATH
|
||||||
- uses: actions/cache@v4
|
- uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: /github/home/.cache/ccache
|
path: /github/home/.cache/ccache
|
||||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
|
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
|
||||||
- run: |
|
- run: |
|
||||||
cmake --preset ${{ matrix.preset }} ${{ matrix.flags }}
|
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
|
||||||
cmake --build --preset ${{ matrix.preset }} --parallel
|
cmake --build --preset "${{ matrix.preset }}" --parallel
|
||||||
|
|
||||||
windows:
|
windows:
|
||||||
needs: [changes]
|
needs: [changes]
|
||||||
@@ -114,12 +129,31 @@ jobs:
|
|||||||
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||||
- preset: Vulkan
|
- preset: Vulkan
|
||||||
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
||||||
|
- preset: 'MLX CUDA 13'
|
||||||
|
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
|
||||||
|
cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip
|
||||||
|
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
|
||||||
|
cuda-components:
|
||||||
|
- '"cudart"'
|
||||||
|
- '"nvcc"'
|
||||||
|
- '"cublas"'
|
||||||
|
- '"cublas_dev"'
|
||||||
|
- '"cufft"'
|
||||||
|
- '"cufft_dev"'
|
||||||
|
- '"nvrtc"'
|
||||||
|
- '"nvrtc_dev"'
|
||||||
|
- '"crt"'
|
||||||
|
- '"nvvm"'
|
||||||
|
- '"nvptxcompiler"'
|
||||||
|
cuda-version: '13.0'
|
||||||
runs-on: windows
|
runs-on: windows
|
||||||
steps:
|
steps:
|
||||||
- run: |
|
- run: |
|
||||||
choco install -y --no-progress ccache ninja
|
choco install -y --no-progress ccache ninja
|
||||||
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
if (Get-Command ccache -ErrorAction SilentlyContinue) {
|
||||||
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan'
|
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
||||||
|
}
|
||||||
|
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan' || matrix.preset == 'MLX CUDA 13'
|
||||||
id: cache-install
|
id: cache-install
|
||||||
uses: actions/cache/restore@v4
|
uses: actions/cache/restore@v4
|
||||||
with:
|
with:
|
||||||
@@ -127,8 +161,9 @@ jobs:
|
|||||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||||
C:\Program Files\AMD\ROCm
|
C:\Program Files\AMD\ROCm
|
||||||
C:\VulkanSDK
|
C:\VulkanSDK
|
||||||
key: ${{ matrix.install }}
|
C:\Program Files\NVIDIA\CUDNN
|
||||||
- if: matrix.preset == 'CUDA'
|
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||||
|
- if: matrix.preset == 'CUDA' || matrix.preset == 'MLX CUDA 13'
|
||||||
name: Install CUDA ${{ matrix.cuda-version }}
|
name: Install CUDA ${{ matrix.cuda-version }}
|
||||||
run: |
|
run: |
|
||||||
$ErrorActionPreference = "Stop"
|
$ErrorActionPreference = "Stop"
|
||||||
@@ -168,6 +203,23 @@ jobs:
|
|||||||
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
|
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
|
||||||
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||||
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
|
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
|
||||||
|
- if: matrix.preset == 'MLX CUDA 13'
|
||||||
|
name: Install cuDNN for MLX
|
||||||
|
run: |
|
||||||
|
$ErrorActionPreference = "Stop"
|
||||||
|
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
|
||||||
|
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||||
|
Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip"
|
||||||
|
Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted
|
||||||
|
$cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName
|
||||||
|
New-Item -ItemType Directory -Force -Path $cudnnRoot
|
||||||
|
Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse
|
||||||
|
}
|
||||||
|
|
||||||
|
echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||||
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
||||||
uses: actions/cache/save@v4
|
uses: actions/cache/save@v4
|
||||||
with:
|
with:
|
||||||
@@ -175,7 +227,8 @@ jobs:
|
|||||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||||
C:\Program Files\AMD\ROCm
|
C:\Program Files\AMD\ROCm
|
||||||
C:\VulkanSDK
|
C:\VulkanSDK
|
||||||
key: ${{ matrix.install }}
|
C:\Program Files\NVIDIA\CUDNN
|
||||||
|
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/cache@v4
|
- uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
|
|||||||
1
.gitignore
vendored
@@ -15,3 +15,4 @@ __debug_bin*
|
|||||||
llama/build
|
llama/build
|
||||||
llama/vendor
|
llama/vendor
|
||||||
/ollama
|
/ollama
|
||||||
|
integration/testdata/models/
|
||||||
|
|||||||
171
CMakeLists.txt
@@ -64,10 +64,15 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR})
|
|||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
|
||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${OLLAMA_BUILD_DIR})
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${OLLAMA_BUILD_DIR})
|
||||||
|
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
# Store ggml include paths for use with target_include_directories later.
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include)
|
# We avoid global include_directories() to prevent polluting the include path
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
|
# for other projects like MLX (whose openblas dependency has its own common.h).
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
|
set(GGML_INCLUDE_DIRS
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx
|
||||||
|
)
|
||||||
|
|
||||||
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
|
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
|
||||||
|
|
||||||
@@ -87,6 +92,14 @@ if(NOT CPU_VARIANTS)
|
|||||||
set(CPU_VARIANTS "ggml-cpu")
|
set(CPU_VARIANTS "ggml-cpu")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# Apply ggml include directories to ggml targets only (not globally)
|
||||||
|
target_include_directories(ggml-base PRIVATE ${GGML_INCLUDE_DIRS})
|
||||||
|
foreach(variant ${CPU_VARIANTS})
|
||||||
|
if(TARGET ${variant})
|
||||||
|
target_include_directories(${variant} PRIVATE ${GGML_INCLUDE_DIRS})
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
|
||||||
install(TARGETS ggml-base ${CPU_VARIANTS}
|
install(TARGETS ggml-base ${CPU_VARIANTS}
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
PRE_EXCLUDE_REGEXES ".*"
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
@@ -103,6 +116,7 @@ if(CMAKE_CUDA_COMPILER)
|
|||||||
|
|
||||||
find_package(CUDAToolkit)
|
find_package(CUDAToolkit)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
||||||
|
target_include_directories(ggml-cuda PRIVATE ${GGML_INCLUDE_DIRS})
|
||||||
install(TARGETS ggml-cuda
|
install(TARGETS ggml-cuda
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||||
@@ -134,6 +148,7 @@ if(CMAKE_HIP_COMPILER)
|
|||||||
if(AMDGPU_TARGETS)
|
if(AMDGPU_TARGETS)
|
||||||
find_package(hip REQUIRED)
|
find_package(hip REQUIRED)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
||||||
|
target_include_directories(ggml-hip PRIVATE ${GGML_INCLUDE_DIRS})
|
||||||
|
|
||||||
if (WIN32)
|
if (WIN32)
|
||||||
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
|
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
|
||||||
@@ -148,7 +163,7 @@ if(CMAKE_HIP_COMPILER)
|
|||||||
)
|
)
|
||||||
install(RUNTIME_DEPENDENCY_SET rocm
|
install(RUNTIME_DEPENDENCY_SET rocm
|
||||||
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
|
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
|
||||||
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf
|
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register roctx64 rocroller drm drm_amdgpu numa elf
|
||||||
PRE_EXCLUDE_REGEXES ".*"
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
POST_EXCLUDE_REGEXES "system32"
|
POST_EXCLUDE_REGEXES "system32"
|
||||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
|
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
|
||||||
@@ -168,6 +183,7 @@ if(NOT APPLE)
|
|||||||
find_package(Vulkan)
|
find_package(Vulkan)
|
||||||
if(Vulkan_FOUND)
|
if(Vulkan_FOUND)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
||||||
|
target_include_directories(ggml-vulkan PRIVATE ${GGML_INCLUDE_DIRS})
|
||||||
install(TARGETS ggml-vulkan
|
install(TARGETS ggml-vulkan
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
PRE_INCLUDE_REGEXES vulkan
|
PRE_INCLUDE_REGEXES vulkan
|
||||||
@@ -179,7 +195,6 @@ if(NOT APPLE)
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
option(MLX_ENGINE "Enable MLX backend" OFF)
|
option(MLX_ENGINE "Enable MLX backend" OFF)
|
||||||
|
|
||||||
if(MLX_ENGINE)
|
if(MLX_ENGINE)
|
||||||
message(STATUS "Setting up MLX (this takes a while...)")
|
message(STATUS "Setting up MLX (this takes a while...)")
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx)
|
||||||
@@ -187,10 +202,36 @@ if(MLX_ENGINE)
|
|||||||
# Find CUDA toolkit if MLX is built with CUDA support
|
# Find CUDA toolkit if MLX is built with CUDA support
|
||||||
find_package(CUDAToolkit)
|
find_package(CUDAToolkit)
|
||||||
|
|
||||||
|
# Build list of directories for runtime dependency resolution
|
||||||
|
set(MLX_RUNTIME_DIRS ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR})
|
||||||
|
# Add cuDNN bin paths for DLLs (Windows MLX CUDA builds)
|
||||||
|
# CUDNN_ROOT_DIR is the standard CMake variable for cuDNN location
|
||||||
|
if(DEFINED ENV{CUDNN_ROOT_DIR})
|
||||||
|
# cuDNN 9.x has versioned subdirectories under bin/ (e.g., bin/13.0/)
|
||||||
|
file(GLOB CUDNN_BIN_SUBDIRS "$ENV{CUDNN_ROOT_DIR}/bin/*")
|
||||||
|
list(APPEND MLX_RUNTIME_DIRS ${CUDNN_BIN_SUBDIRS})
|
||||||
|
endif()
|
||||||
|
# Add build output directory and MLX dependency build directories
|
||||||
|
list(APPEND MLX_RUNTIME_DIRS ${OLLAMA_BUILD_DIR})
|
||||||
|
# OpenBLAS DLL location (pre-built zip extracts into openblas-src/bin/)
|
||||||
|
list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/openblas-src/bin)
|
||||||
|
# NCCL: on Linux, if real NCCL is found, cmake bundles libnccl.so via the
|
||||||
|
# regex below. If NCCL is not found, MLX links a static stub (OBJECT lib)
|
||||||
|
# so there is no runtime dependency. This path covers the stub build dir
|
||||||
|
# for windows so we include the DLL in our dependencies.
|
||||||
|
list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/distributed/nccl/nccl_stub-prefix/src/nccl_stub-build/Release)
|
||||||
|
|
||||||
|
# Base regexes for runtime dependencies (cross-platform)
|
||||||
|
set(MLX_INCLUDE_REGEXES cublas cublasLt cudart cufft nvrtc nvrtc-builtins cudnn nccl openblas gfortran)
|
||||||
|
# On Windows, also include dl.dll (dlfcn-win32 POSIX emulation layer)
|
||||||
|
if(WIN32)
|
||||||
|
list(APPEND MLX_INCLUDE_REGEXES "^dl\\.dll$")
|
||||||
|
endif()
|
||||||
|
|
||||||
install(TARGETS mlx mlxc
|
install(TARGETS mlx mlxc
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
DIRECTORIES ${MLX_RUNTIME_DIRS}
|
||||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
|
PRE_INCLUDE_REGEXES ${MLX_INCLUDE_REGEXES}
|
||||||
PRE_EXCLUDE_REGEXES ".*"
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||||
@@ -205,13 +246,117 @@ if(MLX_ENGINE)
|
|||||||
COMPONENT MLX)
|
COMPONENT MLX)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Manually install cudart and cublas since they might not be picked up as direct dependencies
|
# Install headers for NVRTC JIT compilation at runtime.
|
||||||
|
# MLX's own install rules use the default component so they get skipped by
|
||||||
|
# --component MLX. Headers are installed alongside libmlx in OLLAMA_INSTALL_DIR.
|
||||||
|
#
|
||||||
|
# Layout:
|
||||||
|
# ${OLLAMA_INSTALL_DIR}/include/cccl/{cuda,nv}/ — CCCL headers
|
||||||
|
# ${OLLAMA_INSTALL_DIR}/include/*.h — CUDA toolkit headers
|
||||||
|
#
|
||||||
|
# MLX's jit_module.cpp resolves CCCL via
|
||||||
|
# current_binary_dir()[.parent_path()] / "include" / "cccl"
|
||||||
|
# On Linux, MLX's jit_module.cpp resolves CCCL via
|
||||||
|
# current_binary_dir().parent_path() / "include" / "cccl", so we create a
|
||||||
|
# symlink from lib/ollama/include -> ${OLLAMA_RUNNER_DIR}/include
|
||||||
|
# This will need refinement if we add multiple CUDA versions for MLX in the future.
|
||||||
|
# CUDA runtime headers are found via CUDA_PATH env var (set by mlxrunner).
|
||||||
|
if(EXISTS ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda)
|
||||||
|
install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda
|
||||||
|
DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl
|
||||||
|
COMPONENT MLX)
|
||||||
|
install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/nv
|
||||||
|
DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl
|
||||||
|
COMPONENT MLX)
|
||||||
|
if(NOT WIN32 AND NOT APPLE)
|
||||||
|
install(CODE "
|
||||||
|
set(_link \"${CMAKE_INSTALL_PREFIX}/lib/ollama/include\")
|
||||||
|
set(_target \"${OLLAMA_RUNNER_DIR}/include\")
|
||||||
|
if(NOT EXISTS \${_link})
|
||||||
|
execute_process(COMMAND \${CMAKE_COMMAND} -E create_symlink \${_target} \${_link})
|
||||||
|
endif()
|
||||||
|
" COMPONENT MLX)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Install minimal CUDA toolkit headers needed by MLX JIT kernels.
|
||||||
|
# These are the transitive closure of includes from mlx/backend/cuda/device/*.cuh.
|
||||||
|
# The Go mlxrunner sets CUDA_PATH to OLLAMA_INSTALL_DIR so MLX finds them at
|
||||||
|
# $CUDA_PATH/include/*.h via NVRTC --include-path.
|
||||||
if(CUDAToolkit_FOUND)
|
if(CUDAToolkit_FOUND)
|
||||||
file(GLOB CUDART_LIBS
|
# CUDAToolkit_INCLUDE_DIRS may be a semicolon-separated list
|
||||||
|
# (e.g. ".../include;.../include/cccl"). Find the entry that
|
||||||
|
# contains the CUDA runtime headers we need.
|
||||||
|
set(_cuda_inc "")
|
||||||
|
foreach(_dir ${CUDAToolkit_INCLUDE_DIRS})
|
||||||
|
if(EXISTS "${_dir}/cuda_runtime_api.h")
|
||||||
|
set(_cuda_inc "${_dir}")
|
||||||
|
break()
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
if(NOT _cuda_inc)
|
||||||
|
message(WARNING "Could not find cuda_runtime_api.h in CUDAToolkit_INCLUDE_DIRS: ${CUDAToolkit_INCLUDE_DIRS}")
|
||||||
|
else()
|
||||||
|
set(_dst "${OLLAMA_INSTALL_DIR}/include")
|
||||||
|
set(_MLX_JIT_CUDA_HEADERS
|
||||||
|
builtin_types.h
|
||||||
|
cooperative_groups.h
|
||||||
|
cuda_bf16.h
|
||||||
|
cuda_bf16.hpp
|
||||||
|
cuda_device_runtime_api.h
|
||||||
|
cuda_fp16.h
|
||||||
|
cuda_fp16.hpp
|
||||||
|
cuda_fp8.h
|
||||||
|
cuda_fp8.hpp
|
||||||
|
cuda_runtime_api.h
|
||||||
|
device_types.h
|
||||||
|
driver_types.h
|
||||||
|
math_constants.h
|
||||||
|
surface_types.h
|
||||||
|
texture_types.h
|
||||||
|
vector_functions.h
|
||||||
|
vector_functions.hpp
|
||||||
|
vector_types.h
|
||||||
|
)
|
||||||
|
foreach(_hdr ${_MLX_JIT_CUDA_HEADERS})
|
||||||
|
install(FILES "${_cuda_inc}/${_hdr}"
|
||||||
|
DESTINATION ${_dst}
|
||||||
|
COMPONENT MLX)
|
||||||
|
endforeach()
|
||||||
|
# Subdirectory headers
|
||||||
|
install(DIRECTORY "${_cuda_inc}/cooperative_groups"
|
||||||
|
DESTINATION ${_dst}
|
||||||
|
COMPONENT MLX
|
||||||
|
FILES_MATCHING PATTERN "*.h")
|
||||||
|
install(FILES "${_cuda_inc}/crt/host_defines.h"
|
||||||
|
DESTINATION "${_dst}/crt"
|
||||||
|
COMPONENT MLX)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# On Windows, explicitly install dl.dll (dlfcn-win32 POSIX dlopen emulation)
|
||||||
|
# RUNTIME_DEPENDENCIES auto-excludes it via POST_EXCLUDE_FILES_STRICT because
|
||||||
|
# dlfcn-win32 is a known CMake target with its own install rules (which install
|
||||||
|
# to the wrong destination). We must install it explicitly here.
|
||||||
|
if(WIN32)
|
||||||
|
install(FILES ${OLLAMA_BUILD_DIR}/dl.dll
|
||||||
|
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||||
|
COMPONENT MLX)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Manually install CUDA runtime libraries that MLX loads via dlopen
|
||||||
|
# (not detected by RUNTIME_DEPENDENCIES since they aren't link-time deps)
|
||||||
|
if(CUDAToolkit_FOUND)
|
||||||
|
file(GLOB MLX_CUDA_LIBS
|
||||||
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
|
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
|
||||||
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
|
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*"
|
||||||
if(CUDART_LIBS)
|
"${CUDAToolkit_LIBRARY_DIR}/libcublasLt.so*"
|
||||||
install(FILES ${CUDART_LIBS}
|
"${CUDAToolkit_LIBRARY_DIR}/libnvrtc.so*"
|
||||||
|
"${CUDAToolkit_LIBRARY_DIR}/libnvrtc-builtins.so*"
|
||||||
|
"${CUDAToolkit_LIBRARY_DIR}/libcufft.so*"
|
||||||
|
"${CUDAToolkit_LIBRARY_DIR}/libcudnn.so*")
|
||||||
|
if(MLX_CUDA_LIBS)
|
||||||
|
install(FILES ${MLX_CUDA_LIBS}
|
||||||
DESTINATION ${OLLAMA_INSTALL_DIR}
|
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||||
COMPONENT MLX)
|
COMPONENT MLX)
|
||||||
endif()
|
endif()
|
||||||
|
|||||||
@@ -77,6 +77,15 @@
|
|||||||
"OLLAMA_RUNNER_DIR": "rocm"
|
"OLLAMA_RUNNER_DIR": "rocm"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "ROCm 7",
|
||||||
|
"inherits": [ "ROCm" ],
|
||||||
|
"cacheVariables": {
|
||||||
|
"CMAKE_HIP_FLAGS": "-parallel-jobs=4",
|
||||||
|
"AMDGPU_TARGETS": "gfx942;gfx950;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1151;gfx1200;gfx1201;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-",
|
||||||
|
"OLLAMA_RUNNER_DIR": "rocm"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "Vulkan",
|
"name": "Vulkan",
|
||||||
"inherits": [ "Default" ],
|
"inherits": [ "Default" ],
|
||||||
@@ -103,6 +112,7 @@
|
|||||||
"name": "MLX CUDA 13",
|
"name": "MLX CUDA 13",
|
||||||
"inherits": [ "MLX", "CUDA 13" ],
|
"inherits": [ "MLX", "CUDA 13" ],
|
||||||
"cacheVariables": {
|
"cacheVariables": {
|
||||||
|
"MLX_CUDA_ARCHITECTURES": "86;89;90;90a;100;103;75-virtual;80-virtual;110-virtual;120-virtual;121-virtual",
|
||||||
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
|
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -158,6 +168,11 @@
|
|||||||
"inherits": [ "ROCm" ],
|
"inherits": [ "ROCm" ],
|
||||||
"configurePreset": "ROCm 6"
|
"configurePreset": "ROCm 6"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "ROCm 7",
|
||||||
|
"inherits": [ "ROCm" ],
|
||||||
|
"configurePreset": "ROCm 7"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "Vulkan",
|
"name": "Vulkan",
|
||||||
"targets": [ "ggml-vulkan" ],
|
"targets": [ "ggml-vulkan" ],
|
||||||
|
|||||||
122
Dockerfile
@@ -1,28 +1,23 @@
|
|||||||
# vim: filetype=dockerfile
|
# vim: filetype=dockerfile
|
||||||
|
|
||||||
ARG FLAVOR=${TARGETARCH}
|
ARG FLAVOR=${TARGETARCH}
|
||||||
ARG PARALLEL=8
|
|
||||||
|
|
||||||
ARG ROCMVERSION=6.3.3
|
ARG ROCMVERSION=7.2.1
|
||||||
ARG JETPACK5VERSION=r35.4.1
|
ARG JETPACK5VERSION=r35.4.1
|
||||||
ARG JETPACK6VERSION=r36.4.0
|
ARG JETPACK6VERSION=r36.4.0
|
||||||
ARG CMAKEVERSION=3.31.2
|
ARG CMAKEVERSION=3.31.2
|
||||||
|
ARG NINJAVERSION=1.12.1
|
||||||
ARG VULKANVERSION=1.4.321.1
|
ARG VULKANVERSION=1.4.321.1
|
||||||
|
|
||||||
|
# Default empty stages for local MLX source overrides.
|
||||||
|
# Override with: docker build --build-context local-mlx=../mlx --build-context local-mlx-c=../mlx-c
|
||||||
|
FROM scratch AS local-mlx
|
||||||
|
FROM scratch AS local-mlx-c
|
||||||
|
|
||||||
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
|
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
|
||||||
RUN dnf install -y yum-utils ccache gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ gcc-toolset-11-binutils \
|
RUN dnf install -y yum-utils ccache gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ gcc-toolset-11-binutils \
|
||||||
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
||||||
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
||||||
ARG VULKANVERSION
|
|
||||||
RUN wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
|
|
||||||
&& tar xvf /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
|
|
||||||
&& dnf -y install ninja-build \
|
|
||||||
&& ln -s /usr/bin/python3 /usr/bin/python \
|
|
||||||
&& /${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \
|
|
||||||
&& /${VULKANVERSION}/vulkansdk -j 8 shaderc
|
|
||||||
RUN cp -r /${VULKANVERSION}/x86_64/include/* /usr/local/include/ \
|
|
||||||
&& cp -r /${VULKANVERSION}/x86_64/lib/* /usr/local/lib
|
|
||||||
ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH
|
|
||||||
|
|
||||||
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
|
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
|
||||||
# install epel-release for ccache
|
# install epel-release for ccache
|
||||||
@@ -33,100 +28,119 @@ ENV CC=clang CXX=clang++
|
|||||||
|
|
||||||
FROM base-${TARGETARCH} AS base
|
FROM base-${TARGETARCH} AS base
|
||||||
ARG CMAKEVERSION
|
ARG CMAKEVERSION
|
||||||
|
ARG NINJAVERSION
|
||||||
RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||||
|
RUN dnf install -y unzip \
|
||||||
|
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux$([ "$(uname -m)" = "aarch64" ] && echo "-aarch64").zip \
|
||||||
|
&& unzip /tmp/ninja.zip -d /usr/local/bin \
|
||||||
|
&& rm /tmp/ninja.zip
|
||||||
|
ENV CMAKE_GENERATOR=Ninja
|
||||||
ENV LDFLAGS=-s
|
ENV LDFLAGS=-s
|
||||||
|
|
||||||
FROM base AS cpu
|
FROM base AS cpu
|
||||||
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
|
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
|
||||||
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
||||||
ARG PARALLEL
|
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CPU' \
|
cmake --preset 'CPU' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'CPU' \
|
&& cmake --build --preset 'CPU' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CPU --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CPU --strip
|
||||||
|
|
||||||
FROM base AS cuda-11
|
FROM base AS cuda-11
|
||||||
ARG CUDA11VERSION=11.8
|
ARG CUDA11VERSION=11.8
|
||||||
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
||||||
ARG PARALLEL
|
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 11' \
|
cmake --preset 'CUDA 11' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
|
&& cmake --build --preset 'CUDA 11' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CUDA --strip
|
||||||
|
|
||||||
FROM base AS cuda-12
|
FROM base AS cuda-12
|
||||||
ARG CUDA12VERSION=12.8
|
ARG CUDA12VERSION=12.8
|
||||||
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
||||||
ARG PARALLEL
|
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 12' \
|
cmake --preset 'CUDA 12' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
|
&& cmake --build --preset 'CUDA 12' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CUDA --strip
|
||||||
|
|
||||||
|
|
||||||
FROM base AS cuda-13
|
FROM base AS cuda-13
|
||||||
ARG CUDA13VERSION=13.0
|
ARG CUDA13VERSION=13.0
|
||||||
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
||||||
ARG PARALLEL
|
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 13' \
|
cmake --preset 'CUDA 13' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
|
&& cmake --build --preset 'CUDA 13' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CUDA --strip
|
||||||
|
|
||||||
|
|
||||||
FROM base AS rocm-6
|
FROM base AS rocm-7
|
||||||
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
||||||
ARG PARALLEL
|
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'ROCm 6' \
|
cmake --preset 'ROCm 7' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
|
&& cmake --build --preset 'ROCm 7' -- -l $(nproc) \
|
||||||
&& cmake --install build --component HIP --strip --parallel ${PARALLEL}
|
&& cmake --install build --component HIP --strip
|
||||||
RUN rm -f dist/lib/ollama/rocm/rocblas/library/*gfx90[06]*
|
RUN rm -f dist/lib/ollama/rocm/rocblas/library/*gfx90[06]*
|
||||||
|
|
||||||
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
|
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
|
||||||
ARG CMAKEVERSION
|
ARG CMAKEVERSION
|
||||||
RUN apt-get update && apt-get install -y curl ccache \
|
ARG NINJAVERSION
|
||||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
RUN apt-get update && apt-get install -y curl ccache unzip \
|
||||||
|
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 \
|
||||||
|
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux-aarch64.zip \
|
||||||
|
&& unzip /tmp/ninja.zip -d /usr/local/bin \
|
||||||
|
&& rm /tmp/ninja.zip
|
||||||
|
ENV CMAKE_GENERATOR=Ninja
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
ARG PARALLEL
|
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'JetPack 5' \
|
cmake --preset 'JetPack 5' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 5' \
|
&& cmake --build --preset 'JetPack 5' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CUDA --strip
|
||||||
|
|
||||||
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
|
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
|
||||||
ARG CMAKEVERSION
|
ARG CMAKEVERSION
|
||||||
RUN apt-get update && apt-get install -y curl ccache \
|
ARG NINJAVERSION
|
||||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
RUN apt-get update && apt-get install -y curl ccache unzip \
|
||||||
|
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 \
|
||||||
|
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux-aarch64.zip \
|
||||||
|
&& unzip /tmp/ninja.zip -d /usr/local/bin \
|
||||||
|
&& rm /tmp/ninja.zip
|
||||||
|
ENV CMAKE_GENERATOR=Ninja
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
ARG PARALLEL
|
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'JetPack 6' \
|
cmake --preset 'JetPack 6' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \
|
&& cmake --build --preset 'JetPack 6' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CUDA --strip
|
||||||
|
|
||||||
FROM base AS vulkan
|
FROM base AS vulkan
|
||||||
|
ARG VULKANVERSION
|
||||||
|
RUN ln -s /usr/bin/python3 /usr/bin/python \
|
||||||
|
&& wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk.tar.xz \
|
||||||
|
&& tar xvf /tmp/vulkansdk.tar.xz -C /tmp \
|
||||||
|
&& /tmp/${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \
|
||||||
|
&& /tmp/${VULKANVERSION}/vulkansdk -j 8 shaderc \
|
||||||
|
&& cp -r /tmp/${VULKANVERSION}/x86_64/include/* /usr/local/include/ \
|
||||||
|
&& cp -r /tmp/${VULKANVERSION}/x86_64/lib/* /usr/local/lib \
|
||||||
|
&& cp -r /tmp/${VULKANVERSION}/x86_64/bin/* /usr/local/bin/ \
|
||||||
|
&& rm -rf /tmp/${VULKANVERSION} /tmp/vulkansdk.tar.xz
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'Vulkan' \
|
cmake --preset 'Vulkan' \
|
||||||
&& cmake --build --parallel --preset 'Vulkan' \
|
&& cmake --build --preset 'Vulkan' -- -l $(nproc) \
|
||||||
&& cmake --install build --component Vulkan --strip --parallel 8
|
&& cmake --install build --component Vulkan --strip
|
||||||
|
|
||||||
FROM base AS mlx
|
FROM base AS mlx
|
||||||
ARG CUDA13VERSION=13.0
|
ARG CUDA13VERSION=13.0
|
||||||
@@ -138,20 +152,27 @@ ENV PATH=/usr/local/cuda-13/bin:$PATH
|
|||||||
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
|
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
|
||||||
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
|
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
|
||||||
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
|
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
|
||||||
ARG PARALLEL
|
|
||||||
WORKDIR /go/src/github.com/ollama/ollama
|
WORKDIR /go/src/github.com/ollama/ollama
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
COPY x/imagegen/mlx x/imagegen/mlx
|
COPY x/imagegen/mlx x/imagegen/mlx
|
||||||
COPY go.mod go.sum .
|
COPY go.mod go.sum .
|
||||||
COPY MLX_VERSION .
|
COPY MLX_VERSION MLX_C_VERSION .
|
||||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||||
ENV PATH=/usr/local/go/bin:$PATH
|
ENV PATH=/usr/local/go/bin:$PATH
|
||||||
RUN go mod download
|
RUN go mod download
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
--mount=type=bind,from=local-mlx,target=/tmp/local-mlx \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
|
--mount=type=bind,from=local-mlx-c,target=/tmp/local-mlx-c \
|
||||||
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
|
if [ -f /tmp/local-mlx/CMakeLists.txt ]; then \
|
||||||
|
export OLLAMA_MLX_SOURCE=/tmp/local-mlx; \
|
||||||
|
fi \
|
||||||
|
&& if [ -f /tmp/local-mlx-c/CMakeLists.txt ]; then \
|
||||||
|
export OLLAMA_MLX_C_SOURCE=/tmp/local-mlx-c; \
|
||||||
|
fi \
|
||||||
|
&& cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
||||||
|
&& cmake --build --preset 'MLX CUDA 13' -- -l $(nproc) \
|
||||||
|
&& cmake --install build --component MLX --strip
|
||||||
|
|
||||||
FROM base AS build
|
FROM base AS build
|
||||||
WORKDIR /go/src/github.com/ollama/ollama
|
WORKDIR /go/src/github.com/ollama/ollama
|
||||||
@@ -160,16 +181,14 @@ RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-
|
|||||||
ENV PATH=/usr/local/go/bin:$PATH
|
ENV PATH=/usr/local/go/bin:$PATH
|
||||||
RUN go mod download
|
RUN go mod download
|
||||||
COPY . .
|
COPY . .
|
||||||
# Clone mlx-c headers for CGO (version from MLX_VERSION file)
|
|
||||||
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
|
|
||||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||||
ENV CGO_ENABLED=1
|
ENV CGO_ENABLED=1
|
||||||
ARG CGO_CFLAGS
|
ARG CGO_CFLAGS
|
||||||
ARG CGO_CXXFLAGS
|
ARG CGO_CXXFLAGS
|
||||||
ENV CGO_CFLAGS="${CGO_CFLAGS} -I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
|
ENV CGO_CFLAGS="${CGO_CFLAGS}"
|
||||||
ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}"
|
ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}"
|
||||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||||
go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .
|
go build -trimpath -buildmode=pie -o /bin/ollama .
|
||||||
|
|
||||||
FROM --platform=linux/amd64 scratch AS amd64
|
FROM --platform=linux/amd64 scratch AS amd64
|
||||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||||
@@ -186,10 +205,9 @@ COPY --from=jetpack-5 dist/lib/ollama/ /lib/ollama/
|
|||||||
COPY --from=jetpack-6 dist/lib/ollama/ /lib/ollama/
|
COPY --from=jetpack-6 dist/lib/ollama/ /lib/ollama/
|
||||||
|
|
||||||
FROM scratch AS rocm
|
FROM scratch AS rocm
|
||||||
COPY --from=rocm-6 dist/lib/ollama /lib/ollama
|
COPY --from=rocm-7 dist/lib/ollama /lib/ollama
|
||||||
|
|
||||||
FROM ${FLAVOR} AS archive
|
FROM ${FLAVOR} AS archive
|
||||||
ARG VULKANVERSION
|
|
||||||
COPY --from=cpu dist/lib/ollama /lib/ollama
|
COPY --from=cpu dist/lib/ollama /lib/ollama
|
||||||
COPY --from=build /bin/ollama /bin/ollama
|
COPY --from=build /bin/ollama /bin/ollama
|
||||||
|
|
||||||
|
|||||||
1
MLX_C_VERSION
Normal file
@@ -0,0 +1 @@
|
|||||||
|
0726ca922fc902c4c61ef9c27d94132be418e945
|
||||||
@@ -1 +1 @@
|
|||||||
v0.5.0
|
38ad257088fb2193ad47e527cf6534a689f30943
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ type MessagesRequest struct {
|
|||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
MaxTokens int `json:"max_tokens"`
|
MaxTokens int `json:"max_tokens"`
|
||||||
Messages []MessageParam `json:"messages"`
|
Messages []MessageParam `json:"messages"`
|
||||||
System any `json:"system,omitempty"` // string or []ContentBlock
|
System any `json:"system,omitempty"` // string or []map[string]any (JSON-decoded ContentBlock)
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
Temperature *float64 `json:"temperature,omitempty"`
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
TopP *float64 `json:"top_p,omitempty"`
|
TopP *float64 `json:"top_p,omitempty"`
|
||||||
@@ -82,8 +82,27 @@ type MessagesRequest struct {
|
|||||||
|
|
||||||
// MessageParam represents a message in the request
|
// MessageParam represents a message in the request
|
||||||
type MessageParam struct {
|
type MessageParam struct {
|
||||||
Role string `json:"role"` // "user" or "assistant"
|
Role string `json:"role"` // "user" or "assistant"
|
||||||
Content any `json:"content"` // string or []ContentBlock
|
Content []ContentBlock `json:"content"` // always []ContentBlock; plain strings are normalized on unmarshal
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MessageParam) UnmarshalJSON(data []byte) error {
|
||||||
|
var raw struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content json.RawMessage `json:"content"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &raw); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.Role = raw.Role
|
||||||
|
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(raw.Content, &s); err == nil {
|
||||||
|
m.Content = []ContentBlock{{Type: "text", Text: &s}}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Unmarshal(raw.Content, &m.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ContentBlock represents a content block in a message.
|
// ContentBlock represents a content block in a message.
|
||||||
@@ -102,9 +121,9 @@ type ContentBlock struct {
|
|||||||
Source *ImageSource `json:"source,omitempty"`
|
Source *ImageSource `json:"source,omitempty"`
|
||||||
|
|
||||||
// For tool_use and server_tool_use blocks
|
// For tool_use and server_tool_use blocks
|
||||||
ID string `json:"id,omitempty"`
|
ID string `json:"id,omitempty"`
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
Input any `json:"input,omitempty"`
|
Input api.ToolCallFunctionArguments `json:"input,omitzero"`
|
||||||
|
|
||||||
// For tool_result and web_search_tool_result blocks
|
// For tool_result and web_search_tool_result blocks
|
||||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||||
@@ -377,178 +396,145 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
|||||||
var messages []api.Message
|
var messages []api.Message
|
||||||
role := strings.ToLower(msg.Role)
|
role := strings.ToLower(msg.Role)
|
||||||
|
|
||||||
switch content := msg.Content.(type) {
|
var textContent strings.Builder
|
||||||
case string:
|
var images []api.ImageData
|
||||||
messages = append(messages, api.Message{Role: role, Content: content})
|
var toolCalls []api.ToolCall
|
||||||
|
var thinking string
|
||||||
|
var toolResults []api.Message
|
||||||
|
textBlocks := 0
|
||||||
|
imageBlocks := 0
|
||||||
|
toolUseBlocks := 0
|
||||||
|
toolResultBlocks := 0
|
||||||
|
serverToolUseBlocks := 0
|
||||||
|
webSearchToolResultBlocks := 0
|
||||||
|
thinkingBlocks := 0
|
||||||
|
unknownBlocks := 0
|
||||||
|
|
||||||
case []any:
|
for _, block := range msg.Content {
|
||||||
var textContent strings.Builder
|
switch block.Type {
|
||||||
var images []api.ImageData
|
case "text":
|
||||||
var toolCalls []api.ToolCall
|
textBlocks++
|
||||||
var thinking string
|
if block.Text != nil {
|
||||||
var toolResults []api.Message
|
textContent.WriteString(*block.Text)
|
||||||
textBlocks := 0
|
|
||||||
imageBlocks := 0
|
|
||||||
toolUseBlocks := 0
|
|
||||||
toolResultBlocks := 0
|
|
||||||
serverToolUseBlocks := 0
|
|
||||||
webSearchToolResultBlocks := 0
|
|
||||||
thinkingBlocks := 0
|
|
||||||
unknownBlocks := 0
|
|
||||||
|
|
||||||
for _, block := range content {
|
|
||||||
blockMap, ok := block.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
logutil.Trace("anthropic: invalid content block format", "role", role)
|
|
||||||
return nil, errors.New("invalid content block format")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
blockType, _ := blockMap["type"].(string)
|
case "image":
|
||||||
|
imageBlocks++
|
||||||
|
if block.Source == nil {
|
||||||
|
logutil.Trace("anthropic: invalid image source", "role", role)
|
||||||
|
return nil, errors.New("invalid image source")
|
||||||
|
}
|
||||||
|
|
||||||
switch blockType {
|
if block.Source.Type == "base64" {
|
||||||
case "text":
|
decoded, err := base64.StdEncoding.DecodeString(block.Source.Data)
|
||||||
textBlocks++
|
if err != nil {
|
||||||
if text, ok := blockMap["text"].(string); ok {
|
logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err)
|
||||||
textContent.WriteString(text)
|
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
||||||
}
|
}
|
||||||
|
images = append(images, decoded)
|
||||||
|
} else {
|
||||||
|
logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", block.Source.Type)
|
||||||
|
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", block.Source.Type)
|
||||||
|
}
|
||||||
|
|
||||||
case "image":
|
case "tool_use":
|
||||||
imageBlocks++
|
toolUseBlocks++
|
||||||
source, ok := blockMap["source"].(map[string]any)
|
if block.ID == "" {
|
||||||
if !ok {
|
logutil.Trace("anthropic: tool_use block missing id", "role", role)
|
||||||
logutil.Trace("anthropic: invalid image source", "role", role)
|
return nil, errors.New("tool_use block missing required 'id' field")
|
||||||
return nil, errors.New("invalid image source")
|
}
|
||||||
}
|
if block.Name == "" {
|
||||||
|
logutil.Trace("anthropic: tool_use block missing name", "role", role)
|
||||||
|
return nil, errors.New("tool_use block missing required 'name' field")
|
||||||
|
}
|
||||||
|
toolCalls = append(toolCalls, api.ToolCall{
|
||||||
|
ID: block.ID,
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: block.Name,
|
||||||
|
Arguments: block.Input,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
sourceType, _ := source["type"].(string)
|
case "tool_result":
|
||||||
if sourceType == "base64" {
|
toolResultBlocks++
|
||||||
data, _ := source["data"].(string)
|
var resultContent string
|
||||||
decoded, err := base64.StdEncoding.DecodeString(data)
|
|
||||||
if err != nil {
|
|
||||||
logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err)
|
|
||||||
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
|
||||||
}
|
|
||||||
images = append(images, decoded)
|
|
||||||
} else {
|
|
||||||
logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", sourceType)
|
|
||||||
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
|
|
||||||
}
|
|
||||||
// URL images would need to be fetched - skip for now
|
|
||||||
|
|
||||||
case "tool_use":
|
switch c := block.Content.(type) {
|
||||||
toolUseBlocks++
|
case string:
|
||||||
id, ok := blockMap["id"].(string)
|
resultContent = c
|
||||||
if !ok {
|
case []any:
|
||||||
logutil.Trace("anthropic: tool_use block missing id", "role", role)
|
for _, cb := range c {
|
||||||
return nil, errors.New("tool_use block missing required 'id' field")
|
if cbMap, ok := cb.(map[string]any); ok {
|
||||||
}
|
if cbMap["type"] == "text" {
|
||||||
name, ok := blockMap["name"].(string)
|
if text, ok := cbMap["text"].(string); ok {
|
||||||
if !ok {
|
resultContent += text
|
||||||
logutil.Trace("anthropic: tool_use block missing name", "role", role)
|
|
||||||
return nil, errors.New("tool_use block missing required 'name' field")
|
|
||||||
}
|
|
||||||
tc := api.ToolCall{
|
|
||||||
ID: id,
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Name: name,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if input, ok := blockMap["input"].(map[string]any); ok {
|
|
||||||
tc.Function.Arguments = mapToArgs(input)
|
|
||||||
}
|
|
||||||
toolCalls = append(toolCalls, tc)
|
|
||||||
|
|
||||||
case "tool_result":
|
|
||||||
toolResultBlocks++
|
|
||||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
|
||||||
var resultContent string
|
|
||||||
|
|
||||||
switch c := blockMap["content"].(type) {
|
|
||||||
case string:
|
|
||||||
resultContent = c
|
|
||||||
case []any:
|
|
||||||
for _, cb := range c {
|
|
||||||
if cbMap, ok := cb.(map[string]any); ok {
|
|
||||||
if cbMap["type"] == "text" {
|
|
||||||
if text, ok := cbMap["text"].(string); ok {
|
|
||||||
resultContent += text
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
toolResults = append(toolResults, api.Message{
|
|
||||||
Role: "tool",
|
|
||||||
Content: resultContent,
|
|
||||||
ToolCallID: toolUseID,
|
|
||||||
})
|
|
||||||
|
|
||||||
case "thinking":
|
|
||||||
thinkingBlocks++
|
|
||||||
if t, ok := blockMap["thinking"].(string); ok {
|
|
||||||
thinking = t
|
|
||||||
}
|
|
||||||
|
|
||||||
case "server_tool_use":
|
|
||||||
serverToolUseBlocks++
|
|
||||||
id, _ := blockMap["id"].(string)
|
|
||||||
name, _ := blockMap["name"].(string)
|
|
||||||
tc := api.ToolCall{
|
|
||||||
ID: id,
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Name: name,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if input, ok := blockMap["input"].(map[string]any); ok {
|
|
||||||
tc.Function.Arguments = mapToArgs(input)
|
|
||||||
}
|
|
||||||
toolCalls = append(toolCalls, tc)
|
|
||||||
|
|
||||||
case "web_search_tool_result":
|
|
||||||
webSearchToolResultBlocks++
|
|
||||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
|
||||||
toolResults = append(toolResults, api.Message{
|
|
||||||
Role: "tool",
|
|
||||||
Content: formatWebSearchToolResultContent(blockMap["content"]),
|
|
||||||
ToolCallID: toolUseID,
|
|
||||||
})
|
|
||||||
default:
|
|
||||||
unknownBlocks++
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
|
toolResults = append(toolResults, api.Message{
|
||||||
m := api.Message{
|
Role: "tool",
|
||||||
Role: role,
|
Content: resultContent,
|
||||||
Content: textContent.String(),
|
ToolCallID: block.ToolUseID,
|
||||||
Images: images,
|
})
|
||||||
ToolCalls: toolCalls,
|
|
||||||
Thinking: thinking,
|
case "thinking":
|
||||||
|
thinkingBlocks++
|
||||||
|
if block.Thinking != nil {
|
||||||
|
thinking = *block.Thinking
|
||||||
}
|
}
|
||||||
messages = append(messages, m)
|
|
||||||
|
case "server_tool_use":
|
||||||
|
serverToolUseBlocks++
|
||||||
|
toolCalls = append(toolCalls, api.ToolCall{
|
||||||
|
ID: block.ID,
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: block.Name,
|
||||||
|
Arguments: block.Input,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
case "web_search_tool_result":
|
||||||
|
webSearchToolResultBlocks++
|
||||||
|
toolResults = append(toolResults, api.Message{
|
||||||
|
Role: "tool",
|
||||||
|
Content: formatWebSearchToolResultContent(block.Content),
|
||||||
|
ToolCallID: block.ToolUseID,
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
unknownBlocks++
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add tool results as separate messages
|
|
||||||
messages = append(messages, toolResults...)
|
|
||||||
logutil.Trace("anthropic: converted block message",
|
|
||||||
"role", role,
|
|
||||||
"blocks", len(content),
|
|
||||||
"text", textBlocks,
|
|
||||||
"image", imageBlocks,
|
|
||||||
"tool_use", toolUseBlocks,
|
|
||||||
"tool_result", toolResultBlocks,
|
|
||||||
"server_tool_use", serverToolUseBlocks,
|
|
||||||
"web_search_result", webSearchToolResultBlocks,
|
|
||||||
"thinking", thinkingBlocks,
|
|
||||||
"unknown", unknownBlocks,
|
|
||||||
"messages", TraceAPIMessages(messages),
|
|
||||||
)
|
|
||||||
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("invalid message content type: %T", content)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
|
||||||
|
m := api.Message{
|
||||||
|
Role: role,
|
||||||
|
Content: textContent.String(),
|
||||||
|
Images: images,
|
||||||
|
ToolCalls: toolCalls,
|
||||||
|
Thinking: thinking,
|
||||||
|
}
|
||||||
|
messages = append(messages, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add tool results as separate messages
|
||||||
|
messages = append(messages, toolResults...)
|
||||||
|
logutil.Trace("anthropic: converted block message",
|
||||||
|
"role", role,
|
||||||
|
"blocks", len(msg.Content),
|
||||||
|
"text", textBlocks,
|
||||||
|
"image", imageBlocks,
|
||||||
|
"tool_use", toolUseBlocks,
|
||||||
|
"tool_result", toolResultBlocks,
|
||||||
|
"server_tool_use", serverToolUseBlocks,
|
||||||
|
"web_search_result", webSearchToolResultBlocks,
|
||||||
|
"thinking", thinkingBlocks,
|
||||||
|
"unknown", unknownBlocks,
|
||||||
|
"messages", TraceAPIMessages(messages),
|
||||||
|
)
|
||||||
|
|
||||||
return messages, nil
|
return messages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -852,6 +838,19 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close thinking block if still open (thinking → tool_use without text in between)
|
||||||
|
if c.thinkingStarted && !c.thinkingDone {
|
||||||
|
c.thinkingDone = true
|
||||||
|
events = append(events, StreamEvent{
|
||||||
|
Event: "content_block_stop",
|
||||||
|
Data: ContentBlockStopEvent{
|
||||||
|
Type: "content_block_stop",
|
||||||
|
Index: c.contentIndex,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
c.contentIndex++
|
||||||
|
}
|
||||||
|
|
||||||
if c.textStarted {
|
if c.textStarted {
|
||||||
events = append(events, StreamEvent{
|
events = append(events, StreamEvent{
|
||||||
Event: "content_block_stop",
|
Event: "content_block_stop",
|
||||||
@@ -869,7 +868,6 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
|||||||
slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID)
|
slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
events = append(events, StreamEvent{
|
events = append(events, StreamEvent{
|
||||||
Event: "content_block_start",
|
Event: "content_block_start",
|
||||||
Data: ContentBlockStartEvent{
|
Data: ContentBlockStartEvent{
|
||||||
@@ -879,7 +877,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
|||||||
Type: "tool_use",
|
Type: "tool_use",
|
||||||
ID: tc.ID,
|
ID: tc.ID,
|
||||||
Name: tc.Function.Name,
|
Name: tc.Function.Name,
|
||||||
Input: map[string]any{},
|
Input: api.NewToolCallFunctionArguments(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -976,15 +974,6 @@ func ptr(s string) *string {
|
|||||||
return &s
|
return &s
|
||||||
}
|
}
|
||||||
|
|
||||||
// mapToArgs converts a map to ToolCallFunctionArguments
|
|
||||||
func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
|
|
||||||
args := api.NewToolCallFunctionArguments()
|
|
||||||
for k, v := range m {
|
|
||||||
args.Set(k, v)
|
|
||||||
}
|
|
||||||
return args
|
|
||||||
}
|
|
||||||
|
|
||||||
// CountTokensRequest represents an Anthropic count_tokens request
|
// CountTokensRequest represents an Anthropic count_tokens request
|
||||||
type CountTokensRequest struct {
|
type CountTokensRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
@@ -1017,17 +1006,13 @@ func estimateTokens(req CountTokensRequest) int {
|
|||||||
var totalLen int
|
var totalLen int
|
||||||
|
|
||||||
// Count system prompt
|
// Count system prompt
|
||||||
if req.System != nil {
|
totalLen += countAnyContent(req.System)
|
||||||
totalLen += countAnyContent(req.System)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Count messages
|
|
||||||
for _, msg := range req.Messages {
|
for _, msg := range req.Messages {
|
||||||
// Count role (always present)
|
// Count role (always present)
|
||||||
totalLen += len(msg.Role)
|
totalLen += len(msg.Role)
|
||||||
// Count content
|
// Count content
|
||||||
contentLen := countAnyContent(msg.Content)
|
totalLen += countAnyContent(msg.Content)
|
||||||
totalLen += contentLen
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tool := range req.Tools {
|
for _, tool := range req.Tools {
|
||||||
@@ -1050,12 +1035,25 @@ func countAnyContent(content any) int {
|
|||||||
switch c := content.(type) {
|
switch c := content.(type) {
|
||||||
case string:
|
case string:
|
||||||
return len(c)
|
return len(c)
|
||||||
case []any:
|
case []ContentBlock:
|
||||||
total := 0
|
total := 0
|
||||||
for _, block := range c {
|
for _, block := range c {
|
||||||
total += countContentBlock(block)
|
total += countContentBlock(block)
|
||||||
}
|
}
|
||||||
return total
|
return total
|
||||||
|
case []any:
|
||||||
|
total := 0
|
||||||
|
for _, item := range c {
|
||||||
|
data, err := json.Marshal(item)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var block ContentBlock
|
||||||
|
if err := json.Unmarshal(data, &block); err == nil {
|
||||||
|
total += countContentBlock(block)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return total
|
||||||
default:
|
default:
|
||||||
if data, err := json.Marshal(content); err == nil {
|
if data, err := json.Marshal(content); err == nil {
|
||||||
return len(data)
|
return len(data)
|
||||||
@@ -1064,38 +1062,19 @@ func countAnyContent(content any) int {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func countContentBlock(block any) int {
|
func countContentBlock(block ContentBlock) int {
|
||||||
blockMap, ok := block.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
if s, ok := block.(string); ok {
|
|
||||||
return len(s)
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
total := 0
|
total := 0
|
||||||
blockType, _ := blockMap["type"].(string)
|
if block.Text != nil {
|
||||||
|
total += len(*block.Text)
|
||||||
if text, ok := blockMap["text"].(string); ok {
|
|
||||||
total += len(text)
|
|
||||||
}
|
}
|
||||||
|
if block.Thinking != nil {
|
||||||
if thinking, ok := blockMap["thinking"].(string); ok {
|
total += len(*block.Thinking)
|
||||||
total += len(thinking)
|
|
||||||
}
|
}
|
||||||
|
if block.Type == "tool_use" || block.Type == "tool_result" {
|
||||||
if blockType == "tool_use" {
|
if data, err := json.Marshal(block); err == nil {
|
||||||
if data, err := json.Marshal(blockMap); err == nil {
|
|
||||||
total += len(data)
|
total += len(data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if blockType == "tool_result" {
|
|
||||||
if data, err := json.Marshal(blockMap); err == nil {
|
|
||||||
total += len(data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return total
|
return total
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,11 +15,16 @@ const (
|
|||||||
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||||
)
|
)
|
||||||
|
|
||||||
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
|
// textContent is a convenience for constructing []ContentBlock with a single text block in tests.
|
||||||
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
func textContent(s string) []ContentBlock {
|
||||||
|
return []ContentBlock{{Type: "text", Text: &s}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeArgs creates ToolCallFunctionArguments from key-value pairs (convenience function for tests)
|
||||||
|
func makeArgs(kvs ...any) api.ToolCallFunctionArguments {
|
||||||
args := api.NewToolCallFunctionArguments()
|
args := api.NewToolCallFunctionArguments()
|
||||||
for k, v := range m {
|
for i := 0; i < len(kvs)-1; i += 2 {
|
||||||
args.Set(k, v)
|
args.Set(kvs[i].(string), kvs[i+1])
|
||||||
}
|
}
|
||||||
return args
|
return args
|
||||||
}
|
}
|
||||||
@@ -29,7 +34,7 @@ func TestFromMessagesRequest_Basic(t *testing.T) {
|
|||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
MaxTokens: 1024,
|
MaxTokens: 1024,
|
||||||
Messages: []MessageParam{
|
Messages: []MessageParam{
|
||||||
{Role: "user", Content: "Hello"},
|
{Role: "user", Content: textContent("Hello")},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -61,7 +66,7 @@ func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
|
|||||||
MaxTokens: 1024,
|
MaxTokens: 1024,
|
||||||
System: "You are a helpful assistant.",
|
System: "You are a helpful assistant.",
|
||||||
Messages: []MessageParam{
|
Messages: []MessageParam{
|
||||||
{Role: "user", Content: "Hello"},
|
{Role: "user", Content: textContent("Hello")},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -88,7 +93,7 @@ func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
|
|||||||
map[string]any{"type": "text", "text": " Be concise."},
|
map[string]any{"type": "text", "text": " Be concise."},
|
||||||
},
|
},
|
||||||
Messages: []MessageParam{
|
Messages: []MessageParam{
|
||||||
{Role: "user", Content: "Hello"},
|
{Role: "user", Content: textContent("Hello")},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -113,7 +118,7 @@ func TestFromMessagesRequest_WithOptions(t *testing.T) {
|
|||||||
req := MessagesRequest{
|
req := MessagesRequest{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
MaxTokens: 2048,
|
MaxTokens: 2048,
|
||||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
|
||||||
Temperature: &temp,
|
Temperature: &temp,
|
||||||
TopP: &topP,
|
TopP: &topP,
|
||||||
TopK: &topK,
|
TopK: &topK,
|
||||||
@@ -148,14 +153,14 @@ func TestFromMessagesRequest_WithImage(t *testing.T) {
|
|||||||
Messages: []MessageParam{
|
Messages: []MessageParam{
|
||||||
{
|
{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: []any{
|
Content: []ContentBlock{
|
||||||
map[string]any{"type": "text", "text": "What's in this image?"},
|
{Type: "text", Text: ptr("What's in this image?")},
|
||||||
map[string]any{
|
{
|
||||||
"type": "image",
|
Type: "image",
|
||||||
"source": map[string]any{
|
Source: &ImageSource{
|
||||||
"type": "base64",
|
Type: "base64",
|
||||||
"media_type": "image/png",
|
MediaType: "image/png",
|
||||||
"data": testImage,
|
Data: testImage,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -190,15 +195,15 @@ func TestFromMessagesRequest_WithToolUse(t *testing.T) {
|
|||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
MaxTokens: 1024,
|
MaxTokens: 1024,
|
||||||
Messages: []MessageParam{
|
Messages: []MessageParam{
|
||||||
{Role: "user", Content: "What's the weather in Paris?"},
|
{Role: "user", Content: textContent("What's the weather in Paris?")},
|
||||||
{
|
{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: []any{
|
Content: []ContentBlock{
|
||||||
map[string]any{
|
{
|
||||||
"type": "tool_use",
|
Type: "tool_use",
|
||||||
"id": "call_123",
|
ID: "call_123",
|
||||||
"name": "get_weather",
|
Name: "get_weather",
|
||||||
"input": map[string]any{"location": "Paris"},
|
Input: makeArgs("location", "Paris"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -234,11 +239,11 @@ func TestFromMessagesRequest_WithToolResult(t *testing.T) {
|
|||||||
Messages: []MessageParam{
|
Messages: []MessageParam{
|
||||||
{
|
{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: []any{
|
Content: []ContentBlock{
|
||||||
map[string]any{
|
{
|
||||||
"type": "tool_result",
|
Type: "tool_result",
|
||||||
"tool_use_id": "call_123",
|
ToolUseID: "call_123",
|
||||||
"content": "The weather in Paris is sunny, 22°C",
|
Content: "The weather in Paris is sunny, 22°C",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -270,7 +275,7 @@ func TestFromMessagesRequest_WithTools(t *testing.T) {
|
|||||||
req := MessagesRequest{
|
req := MessagesRequest{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
MaxTokens: 1024,
|
MaxTokens: 1024,
|
||||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
|
||||||
Tools: []Tool{
|
Tools: []Tool{
|
||||||
{
|
{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
@@ -305,7 +310,7 @@ func TestFromMessagesRequest_DropsCustomWebSearchWhenBuiltinPresent(t *testing.T
|
|||||||
req := MessagesRequest{
|
req := MessagesRequest{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
MaxTokens: 1024,
|
MaxTokens: 1024,
|
||||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
|
||||||
Tools: []Tool{
|
Tools: []Tool{
|
||||||
{
|
{
|
||||||
Type: "web_search_20250305",
|
Type: "web_search_20250305",
|
||||||
@@ -346,7 +351,7 @@ func TestFromMessagesRequest_KeepsCustomWebSearchWhenBuiltinAbsent(t *testing.T)
|
|||||||
req := MessagesRequest{
|
req := MessagesRequest{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
MaxTokens: 1024,
|
MaxTokens: 1024,
|
||||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
|
||||||
Tools: []Tool{
|
Tools: []Tool{
|
||||||
{
|
{
|
||||||
Type: "custom",
|
Type: "custom",
|
||||||
@@ -377,7 +382,7 @@ func TestFromMessagesRequest_WithThinking(t *testing.T) {
|
|||||||
req := MessagesRequest{
|
req := MessagesRequest{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
MaxTokens: 1024,
|
MaxTokens: 1024,
|
||||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
|
||||||
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
|
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -399,13 +404,13 @@ func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
|
|||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
MaxTokens: 1024,
|
MaxTokens: 1024,
|
||||||
Messages: []MessageParam{
|
Messages: []MessageParam{
|
||||||
{Role: "user", Content: "Hello"},
|
{Role: "user", Content: textContent("Hello")},
|
||||||
{
|
{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: []any{
|
Content: []ContentBlock{
|
||||||
map[string]any{
|
{
|
||||||
"type": "thinking",
|
Type: "thinking",
|
||||||
"thinking": "Let me think about this...",
|
Thinking: ptr("Let me think about this..."),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -434,10 +439,10 @@ func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
|
|||||||
Messages: []MessageParam{
|
Messages: []MessageParam{
|
||||||
{
|
{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: []any{
|
Content: []ContentBlock{
|
||||||
map[string]any{
|
{
|
||||||
"type": "tool_use",
|
Type: "tool_use",
|
||||||
"name": "get_weather",
|
Name: "get_weather",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -460,10 +465,10 @@ func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
|
|||||||
Messages: []MessageParam{
|
Messages: []MessageParam{
|
||||||
{
|
{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: []any{
|
Content: []ContentBlock{
|
||||||
map[string]any{
|
{
|
||||||
"type": "tool_use",
|
Type: "tool_use",
|
||||||
"id": "call_123",
|
ID: "call_123",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -483,7 +488,7 @@ func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
|
|||||||
req := MessagesRequest{
|
req := MessagesRequest{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
MaxTokens: 1024,
|
MaxTokens: 1024,
|
||||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}},
|
||||||
Tools: []Tool{
|
Tools: []Tool{
|
||||||
{
|
{
|
||||||
Name: "bad_tool",
|
Name: "bad_tool",
|
||||||
@@ -548,7 +553,7 @@ func TestToMessagesResponse_WithToolCalls(t *testing.T) {
|
|||||||
ID: "call_123",
|
ID: "call_123",
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
Arguments: makeArgs("location", "Paris"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -760,7 +765,7 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
|
|||||||
ID: "call_123",
|
ID: "call_123",
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
Arguments: makeArgs("location", "Paris"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -799,6 +804,107 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestStreamConverter_ThinkingDirectlyFollowedByToolCall verifies that when a
|
||||||
|
// model emits a thinking block followed directly by a tool_use block (with no
|
||||||
|
// text block in between), the streaming converter correctly closes the thinking
|
||||||
|
// block and increments the content index before opening the tool_use block.
|
||||||
|
// Previously, the converter reused contentIndex=0 for the tool_use block,
|
||||||
|
// which caused "Content block not found" errors in clients. See #14816.
|
||||||
|
func TestStreamConverter_ThinkingDirectlyFollowedByToolCall(t *testing.T) {
|
||||||
|
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||||
|
|
||||||
|
// First chunk: thinking content (no text)
|
||||||
|
resp1 := api.ChatResponse{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Thinking: "I should call the tool.",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
events1 := conv.Process(resp1)
|
||||||
|
|
||||||
|
// Should have: message_start, content_block_start(thinking), content_block_delta(thinking)
|
||||||
|
if len(events1) < 3 {
|
||||||
|
t.Fatalf("expected at least 3 events for thinking chunk, got %d", len(events1))
|
||||||
|
}
|
||||||
|
if events1[0].Event != "message_start" {
|
||||||
|
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
|
||||||
|
}
|
||||||
|
thinkingStart, ok := events1[1].Data.(ContentBlockStartEvent)
|
||||||
|
if !ok || thinkingStart.ContentBlock.Type != "thinking" {
|
||||||
|
t.Errorf("expected content_block_start(thinking) as second event, got %+v", events1[1])
|
||||||
|
}
|
||||||
|
if thinkingStart.Index != 0 {
|
||||||
|
t.Errorf("expected thinking block at index 0, got %d", thinkingStart.Index)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second chunk: tool call (no text between thinking and tool)
|
||||||
|
resp2 := api.ChatResponse{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_abc",
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "ask_user",
|
||||||
|
Arguments: makeArgs("question", "cats or dogs?"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "stop",
|
||||||
|
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
|
||||||
|
}
|
||||||
|
events2 := conv.Process(resp2)
|
||||||
|
|
||||||
|
// Expect: content_block_stop(index=0), content_block_start(tool_use, index=1),
|
||||||
|
// content_block_delta(input_json_delta, index=1), content_block_stop(index=1),
|
||||||
|
// message_delta, message_stop
|
||||||
|
var thinkingStop, toolStart, toolDelta, toolStop *StreamEvent
|
||||||
|
for i := range events2 {
|
||||||
|
e := &events2[i]
|
||||||
|
switch e.Event {
|
||||||
|
case "content_block_stop":
|
||||||
|
if stop, ok := e.Data.(ContentBlockStopEvent); ok {
|
||||||
|
if stop.Index == 0 && thinkingStop == nil {
|
||||||
|
thinkingStop = e
|
||||||
|
} else if stop.Index == 1 {
|
||||||
|
toolStop = e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "content_block_start":
|
||||||
|
if start, ok := e.Data.(ContentBlockStartEvent); ok && start.ContentBlock.Type == "tool_use" {
|
||||||
|
toolStart = e
|
||||||
|
}
|
||||||
|
case "content_block_delta":
|
||||||
|
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok && delta.Delta.Type == "input_json_delta" {
|
||||||
|
toolDelta = e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if thinkingStop == nil {
|
||||||
|
t.Error("expected content_block_stop for thinking block (index 0)")
|
||||||
|
}
|
||||||
|
if toolStart == nil {
|
||||||
|
t.Fatal("expected content_block_start for tool_use block")
|
||||||
|
}
|
||||||
|
if start, ok := toolStart.Data.(ContentBlockStartEvent); !ok || start.Index != 1 {
|
||||||
|
t.Errorf("expected tool_use block at index 1, got %+v", toolStart.Data)
|
||||||
|
}
|
||||||
|
if toolDelta == nil {
|
||||||
|
t.Fatal("expected input_json_delta event for tool call")
|
||||||
|
}
|
||||||
|
if delta, ok := toolDelta.Data.(ContentBlockDeltaEvent); !ok || delta.Index != 1 {
|
||||||
|
t.Errorf("expected tool delta at index 1, got %+v", toolDelta.Data)
|
||||||
|
}
|
||||||
|
if toolStop == nil {
|
||||||
|
t.Error("expected content_block_stop for tool_use block (index 1)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||||
// Test that unmarshalable arguments (like channels) are handled gracefully
|
// Test that unmarshalable arguments (like channels) are handled gracefully
|
||||||
// and don't cause a panic or corrupt stream
|
// and don't cause a panic or corrupt stream
|
||||||
@@ -864,7 +970,7 @@ func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
|||||||
ID: "call_good",
|
ID: "call_good",
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "good_function",
|
Name: "good_function",
|
||||||
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
Arguments: makeArgs("location", "Paris"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -966,6 +1072,57 @@ func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestContentBlockJSON_NonToolBlocksDoNotIncludeInput(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
block ContentBlock
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "text block",
|
||||||
|
block: ContentBlock{
|
||||||
|
Type: "text",
|
||||||
|
Text: ptr("hello"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "thinking block",
|
||||||
|
block: ContentBlock{
|
||||||
|
Type: "thinking",
|
||||||
|
Thinking: ptr("let me think"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "image block",
|
||||||
|
block: ContentBlock{
|
||||||
|
Type: "image",
|
||||||
|
Source: &ImageSource{
|
||||||
|
Type: "base64",
|
||||||
|
MediaType: "image/png",
|
||||||
|
Data: testImage,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
data, err := json.Marshal(tt.block)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to marshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]any
|
||||||
|
if err := json.Unmarshal(data, &result); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := result["input"]; ok {
|
||||||
|
t.Fatalf("unexpected input field in non-tool block JSON: %s", string(data))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||||
t.Run("text block start includes empty text", func(t *testing.T) {
|
t.Run("text block start includes empty text", func(t *testing.T) {
|
||||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||||
@@ -986,7 +1143,9 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
|||||||
// Marshal and verify the text field is present
|
// Marshal and verify the text field is present
|
||||||
data, _ := json.Marshal(start)
|
data, _ := json.Marshal(start)
|
||||||
var result map[string]any
|
var result map[string]any
|
||||||
json.Unmarshal(data, &result)
|
if err := json.Unmarshal(data, &result); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal content_block_start JSON: %v", err)
|
||||||
|
}
|
||||||
cb := result["content_block"].(map[string]any)
|
cb := result["content_block"].(map[string]any)
|
||||||
if _, ok := cb["text"]; !ok {
|
if _, ok := cb["text"]; !ok {
|
||||||
t.Error("content_block_start for text should include 'text' field")
|
t.Error("content_block_start for text should include 'text' field")
|
||||||
@@ -1033,13 +1192,71 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
|||||||
t.Error("expected thinking content_block_start event")
|
t.Error("expected thinking content_block_start event")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("tool_use block start includes empty input object", func(t *testing.T) {
|
||||||
|
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||||
|
|
||||||
|
resp := api.ChatResponse{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_123",
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: makeArgs("location", "Paris"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
events := conv.Process(resp)
|
||||||
|
|
||||||
|
var foundToolStart bool
|
||||||
|
for _, e := range events {
|
||||||
|
if e.Event == "content_block_start" {
|
||||||
|
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||||
|
if start.ContentBlock.Type == "tool_use" {
|
||||||
|
foundToolStart = true
|
||||||
|
if start.ContentBlock.Input.Len() != 0 {
|
||||||
|
t.Errorf("expected empty input object, got len=%d", start.ContentBlock.Input.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := json.Marshal(start)
|
||||||
|
var result map[string]any
|
||||||
|
json.Unmarshal(data, &result)
|
||||||
|
cb := result["content_block"].(map[string]any)
|
||||||
|
input, ok := cb["input"]
|
||||||
|
if !ok {
|
||||||
|
t.Error("content_block_start for tool_use should include 'input' field")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
inputMap, ok := input.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("input field should be an object, got %T", input)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(inputMap) != 0 {
|
||||||
|
t.Errorf("expected empty input object in content_block_start, got %v", inputMap)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundToolStart {
|
||||||
|
t.Error("expected tool_use content_block_start event")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEstimateTokens_SimpleMessage(t *testing.T) {
|
func TestEstimateTokens_SimpleMessage(t *testing.T) {
|
||||||
req := CountTokensRequest{
|
req := CountTokensRequest{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
Messages: []MessageParam{
|
Messages: []MessageParam{
|
||||||
{Role: "user", Content: "Hello, world!"},
|
{Role: "user", Content: textContent("Hello, world!")},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1060,7 +1277,7 @@ func TestEstimateTokens_WithSystemPrompt(t *testing.T) {
|
|||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
System: "You are a helpful assistant.",
|
System: "You are a helpful assistant.",
|
||||||
Messages: []MessageParam{
|
Messages: []MessageParam{
|
||||||
{Role: "user", Content: "Hello"},
|
{Role: "user", Content: textContent("Hello")},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1076,7 +1293,7 @@ func TestEstimateTokens_WithTools(t *testing.T) {
|
|||||||
req := CountTokensRequest{
|
req := CountTokensRequest{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
Messages: []MessageParam{
|
Messages: []MessageParam{
|
||||||
{Role: "user", Content: "What's the weather?"},
|
{Role: "user", Content: textContent("What's the weather?")},
|
||||||
},
|
},
|
||||||
Tools: []Tool{
|
Tools: []Tool{
|
||||||
{
|
{
|
||||||
@@ -1099,17 +1316,17 @@ func TestEstimateTokens_WithThinking(t *testing.T) {
|
|||||||
req := CountTokensRequest{
|
req := CountTokensRequest{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
Messages: []MessageParam{
|
Messages: []MessageParam{
|
||||||
{Role: "user", Content: "Hello"},
|
{Role: "user", Content: textContent("Hello")},
|
||||||
{
|
{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: []any{
|
Content: []ContentBlock{
|
||||||
map[string]any{
|
{
|
||||||
"type": "thinking",
|
Type: "thinking",
|
||||||
"thinking": "Let me think about this carefully...",
|
Thinking: ptr("Let me think about this carefully..."),
|
||||||
},
|
},
|
||||||
map[string]any{
|
{
|
||||||
"type": "text",
|
Type: "text",
|
||||||
"text": "Here is my response.",
|
Text: ptr("Here is my response."),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -1207,12 +1424,12 @@ func TestConvertTool_RegularTool(t *testing.T) {
|
|||||||
func TestConvertMessage_ServerToolUse(t *testing.T) {
|
func TestConvertMessage_ServerToolUse(t *testing.T) {
|
||||||
msg := MessageParam{
|
msg := MessageParam{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: []any{
|
Content: []ContentBlock{
|
||||||
map[string]any{
|
{
|
||||||
"type": "server_tool_use",
|
Type: "server_tool_use",
|
||||||
"id": "srvtoolu_123",
|
ID: "srvtoolu_123",
|
||||||
"name": "web_search",
|
Name: "web_search",
|
||||||
"input": map[string]any{"query": "test query"},
|
Input: makeArgs("query", "test query"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -1243,11 +1460,11 @@ func TestConvertMessage_ServerToolUse(t *testing.T) {
|
|||||||
func TestConvertMessage_WebSearchToolResult(t *testing.T) {
|
func TestConvertMessage_WebSearchToolResult(t *testing.T) {
|
||||||
msg := MessageParam{
|
msg := MessageParam{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: []any{
|
Content: []ContentBlock{
|
||||||
map[string]any{
|
{
|
||||||
"type": "web_search_tool_result",
|
Type: "web_search_tool_result",
|
||||||
"tool_use_id": "srvtoolu_123",
|
ToolUseID: "srvtoolu_123",
|
||||||
"content": []any{
|
Content: []any{
|
||||||
map[string]any{
|
map[string]any{
|
||||||
"type": "web_search_result",
|
"type": "web_search_result",
|
||||||
"title": "Test Result",
|
"title": "Test Result",
|
||||||
@@ -1284,11 +1501,11 @@ func TestConvertMessage_WebSearchToolResult(t *testing.T) {
|
|||||||
func TestConvertMessage_WebSearchToolResultEmptyStillCreatesToolMessage(t *testing.T) {
|
func TestConvertMessage_WebSearchToolResultEmptyStillCreatesToolMessage(t *testing.T) {
|
||||||
msg := MessageParam{
|
msg := MessageParam{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: []any{
|
Content: []ContentBlock{
|
||||||
map[string]any{
|
{
|
||||||
"type": "web_search_tool_result",
|
Type: "web_search_tool_result",
|
||||||
"tool_use_id": "srvtoolu_empty",
|
ToolUseID: "srvtoolu_empty",
|
||||||
"content": []any{},
|
Content: []any{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -1315,11 +1532,11 @@ func TestConvertMessage_WebSearchToolResultEmptyStillCreatesToolMessage(t *testi
|
|||||||
func TestConvertMessage_WebSearchToolResultErrorStillCreatesToolMessage(t *testing.T) {
|
func TestConvertMessage_WebSearchToolResultErrorStillCreatesToolMessage(t *testing.T) {
|
||||||
msg := MessageParam{
|
msg := MessageParam{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: []any{
|
Content: []ContentBlock{
|
||||||
map[string]any{
|
{
|
||||||
"type": "web_search_tool_result",
|
Type: "web_search_tool_result",
|
||||||
"tool_use_id": "srvtoolu_error",
|
ToolUseID: "srvtoolu_error",
|
||||||
"content": map[string]any{
|
Content: map[string]any{
|
||||||
"type": "web_search_tool_result_error",
|
"type": "web_search_tool_result_error",
|
||||||
"error_code": "max_uses_exceeded",
|
"error_code": "max_uses_exceeded",
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -476,25 +476,3 @@ func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
|
|||||||
}
|
}
|
||||||
return &resp, nil
|
return &resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AliasRequest is the request body for creating or updating a model alias.
|
|
||||||
type AliasRequest struct {
|
|
||||||
Alias string `json:"alias"`
|
|
||||||
Target string `json:"target"`
|
|
||||||
PrefixMatching bool `json:"prefix_matching,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetAliasExperimental creates or updates a model alias via the experimental aliases API.
|
|
||||||
func (c *Client) SetAliasExperimental(ctx context.Context, req *AliasRequest) error {
|
|
||||||
return c.do(ctx, http.MethodPost, "/api/experimental/aliases", req, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AliasDeleteRequest is the request body for deleting a model alias.
|
|
||||||
type AliasDeleteRequest struct {
|
|
||||||
Alias string `json:"alias"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteAliasExperimental deletes a model alias via the experimental aliases API.
|
|
||||||
func (c *Client) DeleteAliasExperimental(ctx context.Context, req *AliasDeleteRequest) error {
|
|
||||||
return c.do(ctx, http.MethodDelete, "/api/experimental/aliases", req, nil)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -436,6 +436,7 @@ type ToolProperty struct {
|
|||||||
Description string `json:"description,omitempty"`
|
Description string `json:"description,omitempty"`
|
||||||
Enum []any `json:"enum,omitempty"`
|
Enum []any `json:"enum,omitempty"`
|
||||||
Properties *ToolPropertiesMap `json:"properties,omitempty"`
|
Properties *ToolPropertiesMap `json:"properties,omitempty"`
|
||||||
|
Required []string `json:"required,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToTypeScriptType converts a ToolProperty to a TypeScript type string
|
// ToTypeScriptType converts a ToolProperty to a TypeScript type string
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
|
|
||||||
// currentSchemaVersion defines the current database schema version.
|
// currentSchemaVersion defines the current database schema version.
|
||||||
// Increment this when making schema changes that require migrations.
|
// Increment this when making schema changes that require migrations.
|
||||||
const currentSchemaVersion = 15
|
const currentSchemaVersion = 16
|
||||||
|
|
||||||
// database wraps the SQLite connection.
|
// database wraps the SQLite connection.
|
||||||
// SQLite handles its own locking for concurrent access:
|
// SQLite handles its own locking for concurrent access:
|
||||||
@@ -82,6 +82,7 @@ func (db *database) init() error {
|
|||||||
websearch_enabled BOOLEAN NOT NULL DEFAULT 0,
|
websearch_enabled BOOLEAN NOT NULL DEFAULT 0,
|
||||||
selected_model TEXT NOT NULL DEFAULT '',
|
selected_model TEXT NOT NULL DEFAULT '',
|
||||||
sidebar_open BOOLEAN NOT NULL DEFAULT 0,
|
sidebar_open BOOLEAN NOT NULL DEFAULT 0,
|
||||||
|
last_home_view TEXT NOT NULL DEFAULT 'launch',
|
||||||
think_enabled BOOLEAN NOT NULL DEFAULT 0,
|
think_enabled BOOLEAN NOT NULL DEFAULT 0,
|
||||||
think_level TEXT NOT NULL DEFAULT '',
|
think_level TEXT NOT NULL DEFAULT '',
|
||||||
cloud_setting_migrated BOOLEAN NOT NULL DEFAULT 0,
|
cloud_setting_migrated BOOLEAN NOT NULL DEFAULT 0,
|
||||||
@@ -264,6 +265,12 @@ func (db *database) migrate() error {
|
|||||||
return fmt.Errorf("migrate v14 to v15: %w", err)
|
return fmt.Errorf("migrate v14 to v15: %w", err)
|
||||||
}
|
}
|
||||||
version = 15
|
version = 15
|
||||||
|
case 15:
|
||||||
|
// add last_home_view column to settings table
|
||||||
|
if err := db.migrateV15ToV16(); err != nil {
|
||||||
|
return fmt.Errorf("migrate v15 to v16: %w", err)
|
||||||
|
}
|
||||||
|
version = 16
|
||||||
default:
|
default:
|
||||||
// If we have a version we don't recognize, just set it to current
|
// If we have a version we don't recognize, just set it to current
|
||||||
// This might happen during development
|
// This might happen during development
|
||||||
@@ -518,6 +525,21 @@ func (db *database) migrateV14ToV15() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// migrateV15ToV16 adds the last_home_view column to the settings table
|
||||||
|
func (db *database) migrateV15ToV16() error {
|
||||||
|
_, err := db.conn.Exec(`ALTER TABLE settings ADD COLUMN last_home_view TEXT NOT NULL DEFAULT 'launch'`)
|
||||||
|
if err != nil && !duplicateColumnError(err) {
|
||||||
|
return fmt.Errorf("add last_home_view column: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 16`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update schema version: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
|
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
|
||||||
func (db *database) cleanupOrphanedData() error {
|
func (db *database) cleanupOrphanedData() error {
|
||||||
_, err := db.conn.Exec(`
|
_, err := db.conn.Exec(`
|
||||||
@@ -1166,9 +1188,9 @@ func (db *database) getSettings() (Settings, error) {
|
|||||||
var s Settings
|
var s Settings
|
||||||
|
|
||||||
err := db.conn.QueryRow(`
|
err := db.conn.QueryRow(`
|
||||||
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level, auto_update_enabled
|
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, turbo_enabled, websearch_enabled, selected_model, sidebar_open, last_home_view, think_enabled, think_level, auto_update_enabled
|
||||||
FROM settings
|
FROM settings
|
||||||
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel, &s.AutoUpdateEnabled)
|
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.LastHomeView, &s.ThinkEnabled, &s.ThinkLevel, &s.AutoUpdateEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Settings{}, fmt.Errorf("get settings: %w", err)
|
return Settings{}, fmt.Errorf("get settings: %w", err)
|
||||||
}
|
}
|
||||||
@@ -1177,10 +1199,26 @@ func (db *database) getSettings() (Settings, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (db *database) setSettings(s Settings) error {
|
func (db *database) setSettings(s Settings) error {
|
||||||
|
lastHomeView := strings.ToLower(strings.TrimSpace(s.LastHomeView))
|
||||||
|
validLaunchView := map[string]struct{}{
|
||||||
|
"launch": {},
|
||||||
|
"openclaw": {},
|
||||||
|
"claude": {},
|
||||||
|
"codex": {},
|
||||||
|
"opencode": {},
|
||||||
|
"droid": {},
|
||||||
|
"pi": {},
|
||||||
|
}
|
||||||
|
if lastHomeView != "chat" {
|
||||||
|
if _, ok := validLaunchView[lastHomeView]; !ok {
|
||||||
|
lastHomeView = "launch"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
_, err := db.conn.Exec(`
|
_, err := db.conn.Exec(`
|
||||||
UPDATE settings
|
UPDATE settings
|
||||||
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?, auto_update_enabled = ?
|
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, last_home_view = ?, think_enabled = ?, think_level = ?, auto_update_enabled = ?
|
||||||
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel, s.AutoUpdateEnabled)
|
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, lastHomeView, s.ThinkEnabled, s.ThinkLevel, s.AutoUpdateEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("set settings: %w", err)
|
return fmt.Errorf("set settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -135,6 +135,45 @@ func TestMigrationV13ToV14ContextLength(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMigrationV15ToV16LastHomeViewDefaultsToLaunch(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
dbPath := filepath.Join(tmpDir, "test.db")
|
||||||
|
|
||||||
|
db, err := newDatabase(dbPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create database: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
if _, err := db.conn.Exec(`
|
||||||
|
ALTER TABLE settings DROP COLUMN last_home_view;
|
||||||
|
UPDATE settings SET schema_version = 15;
|
||||||
|
`); err != nil {
|
||||||
|
t.Fatalf("failed to seed v15 settings row: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.migrate(); err != nil {
|
||||||
|
t.Fatalf("migration from v15 to v16 failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var lastHomeView string
|
||||||
|
if err := db.conn.QueryRow("SELECT last_home_view FROM settings").Scan(&lastHomeView); err != nil {
|
||||||
|
t.Fatalf("failed to read last_home_view: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastHomeView != "launch" {
|
||||||
|
t.Fatalf("expected last_home_view to default to launch after migration, got %q", lastHomeView)
|
||||||
|
}
|
||||||
|
|
||||||
|
version, err := db.getSchemaVersion()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get schema version: %v", err)
|
||||||
|
}
|
||||||
|
if version != currentSchemaVersion {
|
||||||
|
t.Fatalf("expected schema version %d, got %d", currentSchemaVersion, version)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestChatDeletionWithCascade(t *testing.T) {
|
func TestChatDeletionWithCascade(t *testing.T) {
|
||||||
t.Run("chat deletion cascades to related messages", func(t *testing.T) {
|
t.Run("chat deletion cascades to related messages", func(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
|
|||||||
@@ -167,6 +167,9 @@ type Settings struct {
|
|||||||
// SidebarOpen indicates if the chat sidebar is open
|
// SidebarOpen indicates if the chat sidebar is open
|
||||||
SidebarOpen bool
|
SidebarOpen bool
|
||||||
|
|
||||||
|
// LastHomeView stores the preferred home route target ("chat" or integration name)
|
||||||
|
LastHomeView string
|
||||||
|
|
||||||
// AutoUpdateEnabled indicates if automatic updates should be downloaded
|
// AutoUpdateEnabled indicates if automatic updates should be downloaded
|
||||||
AutoUpdateEnabled bool
|
AutoUpdateEnabled bool
|
||||||
}
|
}
|
||||||
@@ -389,6 +392,10 @@ func (s *Store) Settings() (Settings, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if settings.LastHomeView == "" {
|
||||||
|
settings.LastHomeView = "launch"
|
||||||
|
}
|
||||||
|
|
||||||
return settings, nil
|
return settings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -81,6 +81,32 @@ func TestStore(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("settings default home view is launch", func(t *testing.T) {
|
||||||
|
loaded, err := s.Settings()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if loaded.LastHomeView != "launch" {
|
||||||
|
t.Fatalf("expected default LastHomeView to be launch, got %q", loaded.LastHomeView)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("settings empty home view falls back to launch", func(t *testing.T) {
|
||||||
|
if err := s.SetSettings(Settings{LastHomeView: ""}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded, err := s.Settings()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if loaded.LastHomeView != "launch" {
|
||||||
|
t.Fatalf("expected empty LastHomeView to fall back to launch, got %q", loaded.LastHomeView)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("window size", func(t *testing.T) {
|
t.Run("window size", func(t *testing.T) {
|
||||||
if err := s.SetWindowSize(1024, 768); err != nil {
|
if err := s.SetWindowSize(1024, 768); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
|||||||
@@ -414,6 +414,7 @@ export class Settings {
|
|||||||
ThinkLevel: string;
|
ThinkLevel: string;
|
||||||
SelectedModel: string;
|
SelectedModel: string;
|
||||||
SidebarOpen: boolean;
|
SidebarOpen: boolean;
|
||||||
|
LastHomeView: string;
|
||||||
AutoUpdateEnabled: boolean;
|
AutoUpdateEnabled: boolean;
|
||||||
|
|
||||||
constructor(source: any = {}) {
|
constructor(source: any = {}) {
|
||||||
@@ -432,6 +433,7 @@ export class Settings {
|
|||||||
this.ThinkLevel = source["ThinkLevel"];
|
this.ThinkLevel = source["ThinkLevel"];
|
||||||
this.SelectedModel = source["SelectedModel"];
|
this.SelectedModel = source["SelectedModel"];
|
||||||
this.SidebarOpen = source["SidebarOpen"];
|
this.SidebarOpen = source["SidebarOpen"];
|
||||||
|
this.LastHomeView = source["LastHomeView"];
|
||||||
this.AutoUpdateEnabled = source["AutoUpdateEnabled"];
|
this.AutoUpdateEnabled = source["AutoUpdateEnabled"];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -550,14 +552,12 @@ export class Error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
export class ModelUpstreamResponse {
|
export class ModelUpstreamResponse {
|
||||||
digest?: string;
|
stale: boolean;
|
||||||
pushTime: number;
|
|
||||||
error?: string;
|
error?: string;
|
||||||
|
|
||||||
constructor(source: any = {}) {
|
constructor(source: any = {}) {
|
||||||
if ('string' === typeof source) source = JSON.parse(source);
|
if ('string' === typeof source) source = JSON.parse(source);
|
||||||
this.digest = source["digest"];
|
this.stale = source["stale"];
|
||||||
this.pushTime = source["pushTime"];
|
|
||||||
this.error = source["error"];
|
this.error = source["error"];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
7
app/ui/app/public/launch-icons/claude.svg
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<!-- Generated by Pixelmator Pro 3.6.17 -->
|
||||||
|
<svg width="1200" height="1200" viewBox="0 0 1200 1200" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<g id="g314">
|
||||||
|
<path id="path147" fill="#d97757" stroke="none" d="M 233.959793 800.214905 L 468.644287 668.536987 L 472.590637 657.100647 L 468.644287 650.738403 L 457.208069 650.738403 L 417.986633 648.322144 L 283.892639 644.69812 L 167.597321 639.865845 L 54.926208 633.825623 L 26.577238 627.785339 L 3.3e-05 592.751709 L 2.73832 575.27533 L 26.577238 559.248352 L 60.724873 562.228149 L 136.187973 567.382629 L 249.422867 575.194763 L 331.570496 580.026978 L 453.261841 592.671082 L 472.590637 592.671082 L 475.328857 584.859009 L 468.724915 580.026978 L 463.570557 575.194763 L 346.389313 495.785217 L 219.543671 411.865906 L 153.100723 363.543762 L 117.181267 339.060425 L 99.060455 316.107361 L 91.248367 266.01355 L 123.865784 230.093994 L 167.677887 233.073853 L 178.872513 236.053772 L 223.248367 270.201477 L 318.040283 343.570496 L 441.825592 434.738342 L 459.946411 449.798706 L 467.194672 444.64447 L 468.080597 441.020203 L 459.946411 427.409485 L 392.617493 305.718323 L 320.778564 181.932983 L 288.80542 130.630859 L 280.348999 99.865845 C 277.369171 87.221436 275.194641 76.590698 275.194641 63.624268 L 312.322174 13.20813 L 332.8591 6.604126 L 382.389313 13.20813 L 403.248352 31.328979 L 434.013519 101.71814 L 483.865753 212.537048 L 561.181274 363.221497 L 583.812134 407.919434 L 595.892639 449.315491 L 600.40271 461.959839 L 608.214783 461.959839 L 608.214783 454.711609 L 614.577271 369.825623 L 626.335632 265.61084 L 637.771851 131.516846 L 641.718201 93.745117 L 660.402832 48.483276 L 697.530334 24.000122 L 726.52356 37.852417 L 750.362549 72 L 747.060486 94.067139 L 732.886047 186.201416 L 705.100708 330.52356 L 686.979919 427.167847 L 697.530334 427.167847 L 709.61084 415.087341 L 758.496704 350.174561 L 840.644348 247.490051 L 876.885925 206.738342 L 919.167847 161.71814 L 946.308838 140.29541 L 997.61084 140.29541 L 1035.38269 196.429626 L 1018.469849 254.416199 L 965.637634 321.422852 L 921.825562 378.201538 L 859.006714 462.765259 L 819.785278 530.41626 L 823.409424 535.812073 L 832.75177 534.92627 L 974.657776 504.724915 L 1051.328979 490.872559 L 1142.818848 475.167786 L 1184.214844 494.496582 L 1188.724854 514.147644 L 1172.456421 554.335693 L 1074.604126 578.496765 L 959.838989 601.449829 L 788.939636 641.879272 L 786.845764 643.409485 L 789.261841 646.389343 L 866.255127 653.637634 L 899.194702 655.409424 L 979.812134 655.409424 L 1129.932861 666.604187 L 1169.154419 692.537109 L 1192.671265 724.268677 L 1188.724854 748.429688 L 1128.322144 779.194641 L 1046.818848 759.865845 L 856.590759 714.604126 L 791.355774 698.335754 L 782.335693 698.335754 L 782.335693 703.731567 L 836.69812 756.885986 L 936.322205 846.845581 L 1061.073975 962.81897 L 1067.436279 991.490112 L 1051.409424 1014.120911 L 1034.496704 1011.704712 L 924.885986 929.234924 L 882.604126 892.107544 L 786.845764 811.48999 L 780.483276 811.48999 L 780.483276 819.946289 L 802.550415 852.241699 L 919.087341 1027.409424 L 925.127625 1081.127686 L 916.671204 1098.604126 L 886.469849 1109.154419 L 853.288696 1103.114136 L 785.073914 1007.355835 L 714.684631 899.516785 L 657.906067 802.872498 L 650.979858 806.81897 L 617.476624 1167.704834 L 601.771851 1186.147705 L 565.530212 1200 L 535.328857 1177.046997 L 519.302124 1139.919556 L 535.328857 1066.550537 L 554.657776 970.792053 L 570.362488 894.68457 L 584.536926 800.134277 L 592.993347 768.724976 L 592.429626 766.630859 L 585.503479 767.516968 L 514.22821 865.369263 L 405.825531 1011.865906 L 320.053711 1103.677979 L 299.516815 1111.812256 L 263.919525 1093.369263 L 267.221497 1060.429688 L 287.114136 1031.114136 L 405.825531 880.107361 L 477.422913 786.52356 L 523.651062 732.483276 L 523.328918 724.671265 L 520.590698 724.671265 L 205.288605 929.395935 L 149.154434 936.644409 L 124.993355 914.01355 L 127.973183 876.885986 L 139.409409 864.80542 L 234.201385 799.570435 L 233.879227 799.8927 Z"/>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 4.0 KiB |
1
app/ui/app/public/launch-icons/codex-dark.svg
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 320 320"><path fill="#fff" d="m297.06 130.97c7.26-21.79 4.76-45.66-6.85-65.48-17.46-30.4-52.56-46.04-86.84-38.68-15.25-17.18-37.16-26.95-60.13-26.81-35.04-.08-66.13 22.48-76.91 55.82-22.51 4.61-41.94 18.7-53.31 38.67-17.59 30.32-13.58 68.54 9.92 94.54-7.26 21.79-4.76 45.66 6.85 65.48 17.46 30.4 52.56 46.04 86.84 38.68 15.24 17.18 37.16 26.95 60.13 26.8 35.06.09 66.16-22.49 76.94-55.86 22.51-4.61 41.94-18.7 53.31-38.67 17.57-30.32 13.55-68.51-9.94-94.51zm-120.28 168.11c-14.03.02-27.62-4.89-38.39-13.88.49-.26 1.34-.73 1.89-1.07l63.72-36.8c3.26-1.85 5.26-5.32 5.24-9.07v-89.83l26.93 15.55c.29.14.48.42.52.74v74.39c-.04 33.08-26.83 59.9-59.91 59.97zm-128.84-55.03c-7.03-12.14-9.56-26.37-7.15-40.18.47.28 1.3.79 1.89 1.13l63.72 36.8c3.23 1.89 7.23 1.89 10.47 0l77.79-44.92v31.1c.02.32-.13.63-.38.83l-64.41 37.19c-28.69 16.52-65.33 6.7-81.92-21.95zm-16.77-139.09c7-12.16 18.05-21.46 31.21-26.29 0 .55-.03 1.52-.03 2.2v73.61c-.02 3.74 1.98 7.21 5.23 9.06l77.79 44.91-26.93 15.55c-.27.18-.61.21-.91.08l-64.42-37.22c-28.63-16.58-38.45-53.21-21.95-81.89zm221.26 51.49-77.79-44.92 26.93-15.54c.27-.18.61-.21.91-.08l64.42 37.19c28.68 16.57 38.51 53.26 21.94 81.94-7.01 12.14-18.05 21.44-31.2 26.28v-75.81c.03-3.74-1.96-7.2-5.2-9.06zm26.8-40.34c-.47-.29-1.3-.79-1.89-1.13l-63.72-36.8c-3.23-1.89-7.23-1.89-10.47 0l-77.79 44.92v-31.1c-.02-.32.13-.63.38-.83l64.41-37.16c28.69-16.55 65.37-6.7 81.91 22 6.99 12.12 9.52 26.31 7.15 40.1zm-168.51 55.43-26.94-15.55c-.29-.14-.48-.42-.52-.74v-74.39c.02-33.12 26.89-59.96 60.01-59.94 14.01 0 27.57 4.92 38.34 13.88-.49.26-1.33.73-1.89 1.07l-63.72 36.8c-3.26 1.85-5.26 5.31-5.24 9.06l-.04 89.79zm14.63-31.54 34.65-20.01 34.65 20v40.01l-34.65 20-34.65-20z"/></svg>
|
||||||
|
After Width: | Height: | Size: 1.7 KiB |
1
app/ui/app/public/launch-icons/codex.svg
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 320 320"><path d="m297.06 130.97c7.26-21.79 4.76-45.66-6.85-65.48-17.46-30.4-52.56-46.04-86.84-38.68-15.25-17.18-37.16-26.95-60.13-26.81-35.04-.08-66.13 22.48-76.91 55.82-22.51 4.61-41.94 18.7-53.31 38.67-17.59 30.32-13.58 68.54 9.92 94.54-7.26 21.79-4.76 45.66 6.85 65.48 17.46 30.4 52.56 46.04 86.84 38.68 15.24 17.18 37.16 26.95 60.13 26.8 35.06.09 66.16-22.49 76.94-55.86 22.51-4.61 41.94-18.7 53.31-38.67 17.57-30.32 13.55-68.51-9.94-94.51zm-120.28 168.11c-14.03.02-27.62-4.89-38.39-13.88.49-.26 1.34-.73 1.89-1.07l63.72-36.8c3.26-1.85 5.26-5.32 5.24-9.07v-89.83l26.93 15.55c.29.14.48.42.52.74v74.39c-.04 33.08-26.83 59.9-59.91 59.97zm-128.84-55.03c-7.03-12.14-9.56-26.37-7.15-40.18.47.28 1.3.79 1.89 1.13l63.72 36.8c3.23 1.89 7.23 1.89 10.47 0l77.79-44.92v31.1c.02.32-.13.63-.38.83l-64.41 37.19c-28.69 16.52-65.33 6.7-81.92-21.95zm-16.77-139.09c7-12.16 18.05-21.46 31.21-26.29 0 .55-.03 1.52-.03 2.2v73.61c-.02 3.74 1.98 7.21 5.23 9.06l77.79 44.91-26.93 15.55c-.27.18-.61.21-.91.08l-64.42-37.22c-28.63-16.58-38.45-53.21-21.95-81.89zm221.26 51.49-77.79-44.92 26.93-15.54c.27-.18.61-.21.91-.08l64.42 37.19c28.68 16.57 38.51 53.26 21.94 81.94-7.01 12.14-18.05 21.44-31.2 26.28v-75.81c.03-3.74-1.96-7.2-5.2-9.06zm26.8-40.34c-.47-.29-1.3-.79-1.89-1.13l-63.72-36.8c-3.23-1.89-7.23-1.89-10.47 0l-77.79 44.92v-31.1c-.02-.32.13-.63.38-.83l64.41-37.16c28.69-16.55 65.37-6.7 81.91 22 6.99 12.12 9.52 26.31 7.15 40.1zm-168.51 55.43-26.94-15.55c-.29-.14-.48-.42-.52-.74v-74.39c.02-33.12 26.89-59.96 60.01-59.94 14.01 0 27.57 4.92 38.34 13.88-.49.26-1.33.73-1.89 1.07l-63.72 36.8c-3.26 1.85-5.26 5.31-5.24 9.06l-.04 89.79zm14.63-31.54 34.65-20.01 34.65 20v40.01l-34.65 20-34.65-20z"/></svg>
|
||||||
|
After Width: | Height: | Size: 1.7 KiB |
8
app/ui/app/public/launch-icons/droid.svg
Normal file
|
After Width: | Height: | Size: 6.2 KiB |
242
app/ui/app/public/launch-icons/openclaw.svg
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
<svg version="1.2" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 500 500" width="500" height="500">
|
||||||
|
<style>
|
||||||
|
.s0 { fill: #f6f4f4 }
|
||||||
|
.s1 { fill: #0b0303 }
|
||||||
|
.s2 { fill: #ef0011 }
|
||||||
|
.s3 { fill: #f3e2e2 }
|
||||||
|
.s4 { fill: #f00212 }
|
||||||
|
.s5 { fill: #ba000d }
|
||||||
|
.s6 { fill: #faf1f1 }
|
||||||
|
.s7 { fill: #0b0100 }
|
||||||
|
.s8 { fill: #fbedee }
|
||||||
|
.s9 { fill: #faeaea }
|
||||||
|
.s10 { fill: #ab797d }
|
||||||
|
.s11 { fill: #f8eaea }
|
||||||
|
.s12 { fill: #902021 }
|
||||||
|
.s13 { fill: #f9eeee }
|
||||||
|
.s14 { fill: #f6ecec }
|
||||||
|
.s15 { fill: #080201 }
|
||||||
|
.s16 { fill: #150100 }
|
||||||
|
.s17 { fill: #f2e7e7 }
|
||||||
|
.s18 { fill: #fbe7e8 }
|
||||||
|
.s19 { fill: #060101 }
|
||||||
|
.s20 { fill: #f5e7e7 }
|
||||||
|
.s21 { fill: #fa999e }
|
||||||
|
.s22 { fill: #c46064 }
|
||||||
|
.s23 { fill: #180300 }
|
||||||
|
.s24 { fill: #f6dcdd }
|
||||||
|
.s25 { fill: #f2e6e6 }
|
||||||
|
.s26 { fill: #110200 }
|
||||||
|
.s27 { fill: #eb0011 }
|
||||||
|
.s28 { fill: #e20010 }
|
||||||
|
.s29 { fill: #ea0011 }
|
||||||
|
.s30 { fill: #760007 }
|
||||||
|
.s31 { fill: #f00514 }
|
||||||
|
.s32 { fill: #fcebeb }
|
||||||
|
.s33 { fill: #ecd6d6 }
|
||||||
|
.s34 { fill: #f5e3e3 }
|
||||||
|
.s35 { fill: #f5e4e4 }
|
||||||
|
.s36 { fill: #faf6f6 }
|
||||||
|
.s37 { fill: #e50010 }
|
||||||
|
.s38 { fill: #d5000f }
|
||||||
|
.s39 { fill: #f2e2e3 }
|
||||||
|
.s40 { fill: #ef1018 }
|
||||||
|
.s41 { fill: #f4e8e9 }
|
||||||
|
.s42 { fill: #ef0513 }
|
||||||
|
.s43 { fill: #f5e5e5 }
|
||||||
|
.s44 { fill: #f00413 }
|
||||||
|
.s45 { fill: #f4e9ea }
|
||||||
|
.s46 { fill: #ed0011 }
|
||||||
|
.s47 { fill: #e80011 }
|
||||||
|
.s48 { fill: #e60613 }
|
||||||
|
.s49 { fill: #f0d6d6 }
|
||||||
|
.s50 { fill: #fca9ac }
|
||||||
|
.s51 { fill: #9c000c }
|
||||||
|
.s52 { fill: #73393b }
|
||||||
|
</style>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s0" d="m166.5 52.5q3.5 0 7 0 2.75 2.99 1.5 7-21.27 45.61-20.5 96 39.99 2.76 72 26.5 7.87 6.86 13.5 15.5 42.88-56.39 103.5-92.5 47.35-25.46 101-25 14.52 0.38 23.5 11.5 3.19 7.74 2 16-1.81 7.18-4.5 14-1 0-1 1-5.04 6.05-9 13-1 0-1 1 0 0.5 0 1-12.42 12.15-28.5 19-6.02 36.27-41.5 45-0.83 2.75 0 5 19.02-12.85 41.5-9 10.85-8.09 23.5-13 15.01-6.37 31-2.5 14.09 7.43 14 23.5-2.83 23.25-15.5 43-6.42 9.92-14 19-10.04 8.8-19.5 18-72.02 48.88-156.5 27-19.63 9.6-41.5 10.5-4.59 1.27-9 3 2 1 4 2 20.09-1.11 35 12 25.46 6.95 37.5 30.5 1.26 5.69-1 11-3.38 3.79-7.5 6.5 5.74 10.07 1.5 20.5-7.55 7.47-17.5 3.5-11.01-5.34-22.5-9.5-18.26 10-38.5 13-15.5 0-31 0-26.62-4.54-51-17-4.17 1.33-8 3.5-7.23 5.87-15 11-8.62 2.58-13.5-4.5-1.82 2.32-4.5 3.5-6.06 2.24-12 3.5-7.5 0-15 0-27.42-2.56-50-18.5-18-17.25-23-41.5 0-11.5 0-23 4.12-22.7 25-33 6.95-16.67 22-26.5-20.39-20.8-14.5-49.5 7.01-26.98 28.5-44.5 7.56-5.27 15-10.5-13.09-30.88-7.5-64 3.16-15.57 14.5-26.5 6.85-2.48 8 4.5-6.59 39.53 11 75.5 7.99-0.49 16-2 2.42-34.57 14.5-67.5 8.51-22.23 27.5-36z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s1" d="m113.5 401.5q0.48-5.1-1-10-0.91 0.19-1 1-2.46 1.74-5 3.5 5.65 9.54-5 13-32.21 5.55-61-10-32.89-23.11-29.5-63.5 2.96-22.67 23.5-32 7.99-19.75 27-29.5-27.65-23.7-15.5-58.5 7.33-16.82 20.5-29.5 10.79-8.14 22-15.5-16.49-37.08-5.5-76 3.19-6.13 7.5-11.5 1.48-0.89 2 1-5.69 41.09 12.5 78.5 1 1 2 2 9.97-3.24 20.5-4 2 0 4 0 0-7.5 0-15 0.99-42.22 24.5-77 6.12-7.12 14-12-4.65 13.43-10 27-11.93 37.6-9.5 77 49.38 0.7 83.5 36 2.75 4.5 5.5 9 38.99-52.24 93-88.5 45.84-29.03 100-32.5 15.69-1.56 29 6.5 5.68 7.29 3.5 16.5-10.38 33.62-43.5 45-4.39 37.33-41 45-0.79 8.63-6 15.5 1.91 1.83 4.5 2.5 22.27-17.25 50.5-14.5 12.93-9.41 28-15 36.22-8.28 31.5 28.5-15.19 51.69-62.5 77.5-65.92 35.87-138 15.5-19.67 10.42-42 10.5-8.39 2.88-17 5 3.58 6.08 10 9 20.92-1.14 36 13 22.67 5.23 34.5 25.5 3.33 7.13-3.5 11.5-3.88 1.8-8 3 7.36 8.45 6.5 19.5-4.43 5.66-11.5 3.5-12.84-5.67-26-10.5-39.4 21.02-83 10.5-18.85-5.78-36.5-14.5-13.65 4.14-23.5 14.5-9.51 3.74-11-6.5z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s2" d="m153.5 173.5q24.62 1.46 46 13.5 12.11 8.1 17.5 21.5 0.74 2.45 0.5 5 0.09 0.81 1 1 1.48-4.9 1-10 5.04 10.48 1.5 22-9.81 27.86-35.5 42.5-26.17 14.97-56 19.5-2.77-0.4-2 1 2.86 1.27 6 1 25.64 1.53 48.5-10 0.34 10.08 2 20 1.08 5.76 5 10 1 1.5 0 3-31.11 20.84-68.5 17.5-23.7-5.7-32.5-28.5-4.39-9.18-3.5-19 15.41 6.23 32 4.5-20.68-6.39-39-18-34.81-27.22-12.5-65.5 11.84-14.83 29-23 4.21 7.66 11.5 12.5 3 1 6 0-26.04-34.62-29-78-0.13-8.46 2-16.5 1 6.5 2 13 3.43 39.53 24.5 73 2.03 2.28 4.5 4 0.5-1.25 1-2.5-1.27-6.54-5-12 0.5-0.75 1-1.5 9.72-3.43 20-4 0.55 10.34 8 17.5 1.94 0.74 4 0.5-17.8-64.6 16.5-122 0.98-1.79 1.5 0-28.21 56.64-13.5 118 1.08 1.43 2.5 0.5 2.21-4.98 2-10.5z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s3" d="m454.5 97.5q-18.37-2.97-37-1.5-16.14 2.08-32 5.5 32.38-14.09 67-7.5 1.98 1.22 2 3.5z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s4" d="m454.5 97.5q-1.33 11.18-8.5 20-21.81 26.28-55.5 32-1.11-0.2-2 0.5 2.31 2.82 5.5 4.5 1 2 0 4-9.56 11.3-19.5 20 19.71-8.72 31-27 2.68-0.43 5 1-14.24 30.97-48 36.5-9.93 1.71-20 1.5-6.8-0.48-13 1 5.81 6.92 14 11-10.78 16.03-27 26.5 27.16-7.4 38-33.5 4.34 1.35 9 1-9.08 23.84-33 33.5-18.45 6.41-38 7 22.59 8.92 45-1 12.05-5.52 24-11 9.01-1.79 17 2.5 5.28-4.38 11-8 12.8-6.07 27-5 0 0.5 0 1-19.34 2.69-34 15.5 0.5 0.25 1 0.5 17.79-8.09 36-15 2.71-0.79 5-2 2.5-1 5-2 5.53-4.04 11-8 11.7-4.18 24-6.5 7.78-1.36 15 1.5-2.97 18.45-13.5 34-34.92 49.37-94.5 62.5-59.27 12.45-108-23-15.53-12.52-21.5-31.5-2.47-14.26 4-27-3.15 24.41 14 42-4.92-10.28-7-22-1.97-17.63 7-33 47.28-69.5 125.5-100 15.86-3.42 32-5.5 18.63-1.47 37 1.5z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s5" d="m86.5 112.5q-1-6.5-2-13 0.7-5.34 3.5-10-1.8 11.32-1.5 23z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s6" d="m433.5 97.5q2.22-0.39 4 1-10 13.75-27 14-0.24-2.06 0.5-4 10.3-7.78 22.5-11z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s7" d="m407.5 101.5q2.55-0.24 5 0.5-52.87 18.31-84.5 64.5-6.94 7.95-17 11-9.38-2.38-5-11 40.38-48.62 101.5-65z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s8" d="m402.5 112.5q3 0 6 0-2.56 8.8-12 7-0.22-1.58 0.5-3 2.72-2.22 5.5-4z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s9" d="m390.5 149.5q7.77 0.52 15 2-11.29 18.28-31 27 9.94-8.7 19.5-20 1-2 0-4-3.19-1.68-5.5-4.5 0.89-0.7 2-0.5z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s10" d="m131.5 145.5q0 7.5 0 15-2 0-4 0 1.06-1.36 3-1-0.48-7.29 1-14z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s11" d="m219.5 204.5q-1 4.5-2 9 0.24-2.55-0.5-5-5.39-13.4-17.5-21.5-21.38-12.04-46-13.5 0-2 0-4 36.7-0.86 61.5 26 3.06 4.11 4.5 9z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s12" d="m329.5 191.5q6.2-1.48 13-1-3.5 1-7 2-2.9-0.97-6-1z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s13" d="m329.5 191.5q3.1 0.03 6 1 9.55 1.31 19 3-10.84 26.1-38 33.5 16.22-10.47 27-26.5-8.19-4.08-14-11z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s14" d="m479.5 199.5q-7.22-2.86-15-1.5-12.3 2.32-24 6.5 15.6-13.11 36-11.5 3.63 2.26 3 6.5z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s15" d="m193.5 216.5q-12.01 1.52-22 8-2.83 1.29-5.5 3-4.79-4.57-6.5-11-5.04 2.2-9.5-1-3.47-6.4 3.5-3 4.4 0.05 8-2.5 9.22-9.73 21-16 6.3-3.24 12 1-2.9 1.22-6 1.5 2.61 5.74 4.5 12 0.75 3.97 0.5 8z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s16" d="m458.5 200.5q3.04-0.24 6 0.5-18.02 7.05-33 19-1 1-2 0 11.53-14.3 29-19.5z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s17" d="m178.5 202.5q6.85-0.63 4.5 6-7.6 5.09-6-4 1.08-0.82 1.5-2z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s18" d="m469.5 201.5q-2.26 13.65-14.5 22-0.47-2.11 1-4 7.08-8.82 13.5-18z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s19" d="m74.5 208.5q8.22-0.2 16 2.5 11.8 4.26 23.5 8.5 5.65-0.63 8-6 2.41 11.83-9.5 13 0.55 3.61 2 7-0.5 1-1 2-4.67-0.94-9.5-1-9.96 0.44-19.5 2.5-5.05-3.55-6.5-9.5-0.75-7.48-0.5-15-6.47 0.15-3-4z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s20" d="m429.5 212.5q-2.5 1-5 2-4 0-8 0-14.2-1.07-27 5 15.27-12.44 35-9.5 2.72 1.14 5 2.5z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s21" d="m219.5 204.5q0.48 5.1-1 10-0.91-0.19-1-1 1-4.5 2-9z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s22" d="m416.5 215.5q0-0.5 0-1 4 0 8 0-2.29 1.21-5 2-1.06-1.36-3-1z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s23" d="m416.5 215.5q1.94-0.36 3 1-18.21 6.91-36 15-0.5-0.25-1-0.5 14.66-12.81 34-15.5z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s24" d="m193.5 216.5q4.39 1.3 9 3-0.79 1.04-2 1.5-14.77-0.13-29 3.5 9.99-6.48 22-8z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s25" d="m98.5 219.5q6.09-0.98 6 5-3.04 0.24-6-0.5-1.84-2.24 0-4.5z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s26" d="m176.5 229.5q8.85-1.14 16 4-4.98 1.75-10 0-13.56 14.3-33 19.5-28.06 8.2-55 1 3.32-6.4 10-5.5-0.71 1.47-2 2.5 36.58 4.24 69-14 4.68-2.13 1-5 2.35-0.91 4-2.5z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s27" d="m231.5 238.5q1.31-0.2 2 1-3.13 28.62 15 51-16.25 6.75-27-7.5-1-1-2 0 14.73 29.34 46 18.5 1.79 0.52 0 1.5-37.63 16.82-50.5-22.5-5.1-26.48 16.5-42z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s28" d="m243.5 259.5q5.88 3.62 10.5 9 12.96 18.46 32.5 29.5-31.51-7.75-43-38.5z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s29" d="m203.5 266.5q1.31-0.2 2 1-2.48 22.08 12 39-6.99 1.35-14 0.5 4.59 4.08 10 7-8.71 0.28-14.5-6.5-16.98-22.76 4.5-41z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s27" d="m58.5 284.5q9.6-2.17 14.5 6 5.15 14.18-1 28-11.05-13.14-27.5-17.5 5.15-9.9 14-16.5z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s30" d="m129.5 288.5q2 1 4 2-3.14 0.27-6-1-0.77-1.4 2-1z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s31" d="m56.5 313.5q3.43 5.43 8 10-4.88 0.44-8 4-1.11-0.2-2 0.5 28.91 1.65 38 28.5 0.45 3.16-1 6-11.02-7.01-23-12.5-4.75-3.75-9.5-7.5 1.47 7.42 7 13 8.34 27.18 32 43 0.99 2.41-1.5 3.5-40.25 5.58-66.5-25.5-15.67-22.01-8-48 10.46-23.87 34.5-15z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s32" d="m45.5 317.5q4.03-0.25 8 0.5 2.46 4.16-2 6-6.04 2.01-9-3.5 1.26-1.85 3-3z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s33" d="m56.5 313.5q4.91 3.14 9.5 7 0.88 2.25-1.5 3-4.57-4.57-8-10z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s34" d="m198.5 319.5q-11.1 11.56-27 15.5-15.75 4.88-32 2.5 28.81-3.69 54-18.5 2.65-0.96 5 0.5z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s4" d="m198.5 319.5q1.44 0.68 2.5 2 2.41 8.23 6 16 1.2 2.64-0.5 5-30.65 21.41-68 18.5-25.16-6.17-32.5-30.5 6.96 4.99 15.5 6.5 8.99 0.75 18 0.5 16.25 2.38 32-2.5 15.9-3.94 27-15.5z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s35" d="m92.5 356.5q-9.09-26.85-38-28.5 0.89-0.7 2-0.5 25.47-4.89 35.5 19 0.75 4.98 0.5 10z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s36" d="m72.5 335.5q3.62-0.38 5 3-4.22 1.83-5-3z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s37" d="m223.5 336.5q5.59-0.48 11 1-4.04 4.16-8.5 8-5.99-3.8-2.5-9z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s38" d="m90.5 334.5q0.59-1.54 2-0.5 3.94 5.45 9 10 7 6 14 12-6.91-1.7-13-6-6.21-7.72-12-15.5z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s39" d="m261.5 346.5q-3.54-2.44-8-3.5-6.98-0.75-14-0.5 0.63-1.08 2-1.5 13.82-2.52 26 4-2.63 1.98-6 1.5z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s40" d="m239.5 342.5q7.02-0.25 14 0.5 4.46 1.06 8 3.5-5.2 2.35-10 5.5-3.88 4.65-9 7.5-9.89-3.09-9.5-13 2.36-3.63 6.5-4z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s41" d="m214.5 349.5q-21.43 15.48-48 16 22.82-5.9 43-18.5 3.64-1.12 5 2.5z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s42" d="m214.5 349.5q5.96 7.2 13.5 13 1 1 0 2-28.58 23.34-65.5 20.5-18.15-4.24-27.5-19.5 1.13 0.94 2.5 1.5 14.7 1.42 29-1.5 26.57-0.52 48-16z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s43" d="m302.5 373.5q-14.74-16.73-37-19-4.55 0.25-9 1 25.3-10.24 43.5 11 2.85 2.91 2.5 7z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s44" d="m302.5 373.5q0.21 2.44-2 3.5-28.69 7.6-50.5-12.5-0.06-6.71 6.5-9 4.45-0.75 9-1 22.26 2.27 37 19z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s45" d="m100.5 356.5q5.42 2.71 11 5.5-13.04 7.54-18.5 21.5-7.57-7.14-10.5-17 5.58 1.54 10 5.5 4.2 0.84 5.5-3.5 1.41-5.99 2.5-12z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s8" d="m83.5 394.5q-18.9-10.15-29.5-29-1.54-3.52-2-7 5.79 2.39 10 7 7.82 16.63 21.5 29z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s46" d="m232.5 365.5q17.6 6.19 10.5 23-10.6 10.42-25.5 11.5-25.94 3.21-49-9 36.75-1.65 64-25.5z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s47" d="m113.5 367.5q7.7-0.01 9.5 7-9.69 7.19-18.5 15.5-7.23 5.76-5.5-3.5 3.12-12.84 14.5-19z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s29" d="m126.5 380.5q7.88-0.4 12 6.5-8.5 7.25-17 14.5-5.62-12.55 5-21z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s48" d="m283.5 385.5q3.22 2.95 7 5.5 2.8 4.03 6 7.5 0.42 2.77-2 4-15.5-9.75-31-19.5-1.79-0.98 0-1.5 9.96 2.49 20 4z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s49" d="m283.5 385.5q8.71-1.27 11.5 7 1.22 2.9 1.5 6-3.2-3.47-6-7.5-3.78-2.55-7-5.5z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s50" d="m83.5 394.5q1.88-0.06 3 1.5-2.25 0.88-3-1.5z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s51" d="m258.5 392.5q3.51 0.41 0 2.5-2.33 1.93-5 2 2.61-2.28 5-4.5z"/>
|
||||||
|
</g>
|
||||||
|
<g>
|
||||||
|
<path fill-rule="evenodd" class="s52" d="m111.5 392.5q0.09-0.81 1-1 1.48 4.9 1 10-1-4.5-2-9z"/>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 13 KiB |
7
app/ui/app/public/launch-icons/opencode.svg
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" xmlns:xlink="http://www.w3.org/1999/xlink" width="512" height="512"><svg width="512" height="512" viewBox="0 0 512 512" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<rect width="512" height="512" fill="#131010"></rect>
|
||||||
|
<path d="M320 224V352H192V224H320Z" fill="#5A5858"></path>
|
||||||
|
<path fill-rule="evenodd" clip-rule="evenodd" d="M384 416H128V96H384V416ZM320 160H192V352H320V160Z" fill="white"></path>
|
||||||
|
</svg><style>@media (prefers-color-scheme: light) { :root { filter: none; } }
|
||||||
|
@media (prefers-color-scheme: dark) { :root { filter: none; } }
|
||||||
|
</style></svg>
|
||||||
|
After Width: | Height: | Size: 612 B |
9
app/ui/app/public/launch-icons/pi-dark.svg
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 800 800">
|
||||||
|
<rect width="800" height="800" rx="160" fill="#fff"/>
|
||||||
|
<path fill="#000" fill-rule="evenodd" d="
|
||||||
|
M165.29 165.29 H517.36 V400 H400 V517.36 H282.65 V634.72 H165.29 Z
|
||||||
|
M282.65 282.65 V400 H400 V282.65 Z
|
||||||
|
"/>
|
||||||
|
<path fill="#000" d="M517.36 400 H634.72 V634.72 H517.36 Z"/>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 389 B |
9
app/ui/app/public/launch-icons/pi.svg
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 800 800">
|
||||||
|
<rect width="800" height="800" rx="160" fill="#000"/>
|
||||||
|
<path fill="#fff" fill-rule="evenodd" d="
|
||||||
|
M165.29 165.29 H517.36 V400 H400 V517.36 H282.65 V634.72 H165.29 Z
|
||||||
|
M282.65 282.65 V400 H400 V282.65 Z
|
||||||
|
"/>
|
||||||
|
<path fill="#fff" d="M517.36 400 H634.72 V634.72 H517.36 Z"/>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 389 B |
@@ -161,7 +161,7 @@ export async function getModels(query?: string): Promise<Model[]> {
|
|||||||
// Add query if it's in the registry and not already in the list
|
// Add query if it's in the registry and not already in the list
|
||||||
if (!exactMatch) {
|
if (!exactMatch) {
|
||||||
const result = await getModelUpstreamInfo(new Model({ model: query }));
|
const result = await getModelUpstreamInfo(new Model({ model: query }));
|
||||||
const existsUpstream = !!result.digest && !result.error;
|
const existsUpstream = result.exists;
|
||||||
if (existsUpstream) {
|
if (existsUpstream) {
|
||||||
filteredModels.push(new Model({ model: query }));
|
filteredModels.push(new Model({ model: query }));
|
||||||
}
|
}
|
||||||
@@ -339,7 +339,7 @@ export async function deleteChat(chatId: string): Promise<void> {
|
|||||||
// Get upstream information for model staleness checking
|
// Get upstream information for model staleness checking
|
||||||
export async function getModelUpstreamInfo(
|
export async function getModelUpstreamInfo(
|
||||||
model: Model,
|
model: Model,
|
||||||
): Promise<{ digest?: string; pushTime: number; error?: string }> {
|
): Promise<{ stale: boolean; exists: boolean; error?: string }> {
|
||||||
try {
|
try {
|
||||||
const response = await fetch(`${API_BASE}/api/v1/model/upstream`, {
|
const response = await fetch(`${API_BASE}/api/v1/model/upstream`, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
@@ -353,22 +353,22 @@ export async function getModelUpstreamInfo(
|
|||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
console.warn(
|
console.warn(
|
||||||
`Failed to check upstream digest for ${model.model}: ${response.status}`,
|
`Failed to check upstream for ${model.model}: ${response.status}`,
|
||||||
);
|
);
|
||||||
return { pushTime: 0 };
|
return { stale: false, exists: false };
|
||||||
}
|
}
|
||||||
|
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
|
|
||||||
if (data.error) {
|
if (data.error) {
|
||||||
console.warn(`Upstream digest check: ${data.error}`);
|
console.warn(`Upstream check: ${data.error}`);
|
||||||
return { error: data.error, pushTime: 0 };
|
return { stale: false, exists: false, error: data.error };
|
||||||
}
|
}
|
||||||
|
|
||||||
return { digest: data.digest, pushTime: data.pushTime || 0 };
|
return { stale: !!data.stale, exists: true };
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.warn(`Error checking model staleness:`, error);
|
console.warn(`Error checking model staleness:`, error);
|
||||||
return { pushTime: 0 };
|
return { stale: false, exists: false };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -480,13 +480,15 @@ function ChatForm({
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare attachments for submission
|
// Prepare attachments for submission, excluding unsupported images
|
||||||
const attachmentsToSend: FileAttachment[] = message.attachments.map(
|
const attachmentsToSend: FileAttachment[] = message.attachments
|
||||||
(att) => ({
|
.filter(
|
||||||
|
(att) => hasVisionCapability || !isImageFile(att.filename),
|
||||||
|
)
|
||||||
|
.map((att) => ({
|
||||||
filename: att.filename,
|
filename: att.filename,
|
||||||
data: att.data || new Uint8Array(0), // Empty data for existing files
|
data: att.data || new Uint8Array(0), // Empty data for existing files
|
||||||
}),
|
}));
|
||||||
);
|
|
||||||
|
|
||||||
const useWebSearch =
|
const useWebSearch =
|
||||||
supportsWebSearch && webSearchEnabled && !cloudDisabled;
|
supportsWebSearch && webSearchEnabled && !cloudDisabled;
|
||||||
@@ -736,10 +738,17 @@ function ChatForm({
|
|||||||
)}
|
)}
|
||||||
{(message.attachments.length > 0 || message.fileErrors.length > 0) && (
|
{(message.attachments.length > 0 || message.fileErrors.length > 0) && (
|
||||||
<div className="flex gap-2 overflow-x-auto px-3 pt pb-3 w-full scrollbar-hide">
|
<div className="flex gap-2 overflow-x-auto px-3 pt pb-3 w-full scrollbar-hide">
|
||||||
{message.attachments.map((attachment, index) => (
|
{message.attachments.map((attachment, index) => {
|
||||||
|
const isUnsupportedImage =
|
||||||
|
!hasVisionCapability && isImageFile(attachment.filename);
|
||||||
|
return (
|
||||||
<div
|
<div
|
||||||
key={attachment.id}
|
key={attachment.id}
|
||||||
className="group flex items-center gap-2 py-2 px-3 rounded-lg bg-neutral-50 dark:bg-neutral-700/50 hover:bg-neutral-100 dark:hover:bg-neutral-700 transition-colors flex-shrink-0"
|
className={`group flex items-center gap-2 py-2 px-3 rounded-lg transition-colors flex-shrink-0 ${
|
||||||
|
isUnsupportedImage
|
||||||
|
? "bg-red-50 dark:bg-red-900/20 border border-red-200 dark:border-red-800"
|
||||||
|
: "bg-neutral-50 dark:bg-neutral-700/50 hover:bg-neutral-100 dark:hover:bg-neutral-700"
|
||||||
|
}`}
|
||||||
>
|
>
|
||||||
{isImageFile(attachment.filename) ? (
|
{isImageFile(attachment.filename) ? (
|
||||||
<ImageThumbnail
|
<ImageThumbnail
|
||||||
@@ -764,9 +773,16 @@ function ChatForm({
|
|||||||
/>
|
/>
|
||||||
</svg>
|
</svg>
|
||||||
)}
|
)}
|
||||||
<span className="text-sm text-neutral-700 dark:text-neutral-300 max-w-[150px] truncate">
|
<div className="flex flex-col min-w-0">
|
||||||
{attachment.filename}
|
<span className={`text-sm max-w-36 truncate ${isUnsupportedImage ? "text-red-700 dark:text-red-300" : "text-neutral-700 dark:text-neutral-300"}`}>
|
||||||
</span>
|
{attachment.filename}
|
||||||
|
</span>
|
||||||
|
{isUnsupportedImage && (
|
||||||
|
<span className="text-xs text-red-600 dark:text-red-400 opacity-75">
|
||||||
|
This model does not support images
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
<button
|
<button
|
||||||
type="button"
|
type="button"
|
||||||
onClick={() => removeFile(index)}
|
onClick={() => removeFile(index)}
|
||||||
@@ -788,7 +804,8 @@ function ChatForm({
|
|||||||
</svg>
|
</svg>
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
))}
|
);
|
||||||
|
})}
|
||||||
{message.fileErrors.map((fileError, index) => (
|
{message.fileErrors.map((fileError, index) => (
|
||||||
<div
|
<div
|
||||||
key={`error-${index}`}
|
key={`error-${index}`}
|
||||||
|
|||||||
@@ -6,12 +6,13 @@ import { getChat } from "@/api";
|
|||||||
import { Link } from "@/components/ui/link";
|
import { Link } from "@/components/ui/link";
|
||||||
import { useState, useRef, useEffect, useCallback, useMemo } from "react";
|
import { useState, useRef, useEffect, useCallback, useMemo } from "react";
|
||||||
import { ChatsResponse } from "@/gotypes";
|
import { ChatsResponse } from "@/gotypes";
|
||||||
import { CogIcon } from "@heroicons/react/24/outline";
|
import { CogIcon, RocketLaunchIcon } from "@heroicons/react/24/outline";
|
||||||
|
|
||||||
// there's a hidden debug feature to copy a chat's data to the clipboard by
|
// there's a hidden debug feature to copy a chat's data to the clipboard by
|
||||||
// holding shift and clicking this many times within this many seconds
|
// holding shift and clicking this many times within this many seconds
|
||||||
const DEBUG_SHIFT_CLICKS_REQUIRED = 5;
|
const DEBUG_SHIFT_CLICKS_REQUIRED = 5;
|
||||||
const DEBUG_SHIFT_CLICK_WINDOW_MS = 7000; // 7 seconds
|
const DEBUG_SHIFT_CLICK_WINDOW_MS = 7000; // 7 seconds
|
||||||
|
const launchSidebarRequestedKey = "ollama.launchSidebarRequested";
|
||||||
|
|
||||||
interface ChatSidebarProps {
|
interface ChatSidebarProps {
|
||||||
currentChatId?: string;
|
currentChatId?: string;
|
||||||
@@ -267,9 +268,8 @@ export function ChatSidebar({ currentChatId }: ChatSidebarProps) {
|
|||||||
<Link
|
<Link
|
||||||
href="/c/new"
|
href="/c/new"
|
||||||
mask={{ to: "/" }}
|
mask={{ to: "/" }}
|
||||||
className={`flex w-full items-center gap-3 rounded-lg px-2 py-2 text-left text-sm text-neutral-700 hover:bg-neutral-100 dark:hover:bg-neutral-800 dark:text-neutral-100 ${
|
className={`flex w-full items-center gap-3 rounded-lg px-2 py-2 text-left text-sm text-neutral-700 hover:bg-neutral-100 dark:hover:bg-neutral-800 dark:text-neutral-100 ${currentChatId === "new" ? "bg-neutral-100 dark:bg-neutral-800" : ""
|
||||||
currentChatId === "new" ? "bg-neutral-100 dark:bg-neutral-800" : ""
|
}`}
|
||||||
}`}
|
|
||||||
draggable={false}
|
draggable={false}
|
||||||
>
|
>
|
||||||
<svg
|
<svg
|
||||||
@@ -283,6 +283,23 @@ export function ChatSidebar({ currentChatId }: ChatSidebarProps) {
|
|||||||
</svg>
|
</svg>
|
||||||
<span className="truncate">New Chat</span>
|
<span className="truncate">New Chat</span>
|
||||||
</Link>
|
</Link>
|
||||||
|
<Link
|
||||||
|
to="/c/$chatId"
|
||||||
|
params={{ chatId: "launch" }}
|
||||||
|
onClick={() => {
|
||||||
|
if (currentChatId !== "launch") {
|
||||||
|
sessionStorage.setItem(launchSidebarRequestedKey, "1");
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
className={`flex w-full items-center gap-3 rounded-lg px-2 py-2 text-left text-sm text-neutral-700 hover:bg-neutral-100 dark:hover:bg-neutral-800 dark:text-neutral-100 cursor-pointer ${currentChatId === "launch"
|
||||||
|
? "bg-neutral-100 dark:bg-neutral-800"
|
||||||
|
: ""
|
||||||
|
}`}
|
||||||
|
draggable={false}
|
||||||
|
>
|
||||||
|
<RocketLaunchIcon className="h-5 w-5 stroke-current" />
|
||||||
|
<span className="truncate">Launch</span>
|
||||||
|
</Link>
|
||||||
{isWindows && (
|
{isWindows && (
|
||||||
<Link
|
<Link
|
||||||
href="/settings"
|
href="/settings"
|
||||||
@@ -304,19 +321,18 @@ export function ChatSidebar({ currentChatId }: ChatSidebarProps) {
|
|||||||
{group.chats.map((chat) => (
|
{group.chats.map((chat) => (
|
||||||
<div
|
<div
|
||||||
key={chat.id}
|
key={chat.id}
|
||||||
className={`allow-context-menu flex items-center relative text-sm text-neutral-800 dark:text-neutral-400 rounded-lg hover:bg-neutral-100 dark:hover:bg-neutral-800 ${
|
className={`allow-context-menu flex items-center relative text-sm text-neutral-800 dark:text-neutral-400 rounded-lg hover:bg-neutral-100 dark:hover:bg-neutral-800 ${chat.id === currentChatId
|
||||||
chat.id === currentChatId
|
? "bg-neutral-100 text-black dark:bg-neutral-800"
|
||||||
? "bg-neutral-100 text-black dark:bg-neutral-800"
|
: ""
|
||||||
: ""
|
}`}
|
||||||
}`}
|
|
||||||
onMouseEnter={() => handleMouseEnter(chat.id)}
|
onMouseEnter={() => handleMouseEnter(chat.id)}
|
||||||
onContextMenu={(e) =>
|
onContextMenu={(e) =>
|
||||||
handleContextMenu(
|
handleContextMenu(
|
||||||
e,
|
e,
|
||||||
chat.id,
|
chat.id,
|
||||||
chat.title ||
|
chat.title ||
|
||||||
chat.userExcerpt ||
|
chat.userExcerpt ||
|
||||||
chat.createdAt.toLocaleString(),
|
chat.createdAt.toLocaleString(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
>
|
>
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ interface CopyButtonProps {
|
|||||||
showLabels?: boolean;
|
showLabels?: boolean;
|
||||||
className?: string;
|
className?: string;
|
||||||
title?: string;
|
title?: string;
|
||||||
|
onCopy?: () => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
const CopyButton: React.FC<CopyButtonProps> = ({
|
const CopyButton: React.FC<CopyButtonProps> = ({
|
||||||
@@ -20,6 +21,7 @@ const CopyButton: React.FC<CopyButtonProps> = ({
|
|||||||
showLabels = false,
|
showLabels = false,
|
||||||
className = "",
|
className = "",
|
||||||
title = "",
|
title = "",
|
||||||
|
onCopy,
|
||||||
}) => {
|
}) => {
|
||||||
const [isCopied, setIsCopied] = useState(false);
|
const [isCopied, setIsCopied] = useState(false);
|
||||||
|
|
||||||
@@ -48,12 +50,14 @@ const CopyButton: React.FC<CopyButtonProps> = ({
|
|||||||
}
|
}
|
||||||
|
|
||||||
setIsCopied(true);
|
setIsCopied(true);
|
||||||
|
onCopy?.();
|
||||||
setTimeout(() => setIsCopied(false), 2000);
|
setTimeout(() => setIsCopied(false), 2000);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Clipboard API failed, falling back to plain text", error);
|
console.error("Clipboard API failed, falling back to plain text", error);
|
||||||
try {
|
try {
|
||||||
await navigator.clipboard.writeText(content);
|
await navigator.clipboard.writeText(content);
|
||||||
setIsCopied(true);
|
setIsCopied(true);
|
||||||
|
onCopy?.();
|
||||||
setTimeout(() => setIsCopied(false), 2000);
|
setTimeout(() => setIsCopied(false), 2000);
|
||||||
} catch (fallbackError) {
|
} catch (fallbackError) {
|
||||||
console.error("Fallback copy also failed:", fallbackError);
|
console.error("Fallback copy also failed:", fallbackError);
|
||||||
|
|||||||
133
app/ui/app/src/components/LaunchCommands.tsx
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
import { useSettings } from "@/hooks/useSettings";
|
||||||
|
import CopyButton from "@/components/CopyButton";
|
||||||
|
|
||||||
|
interface LaunchCommand {
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
command: string;
|
||||||
|
description: string;
|
||||||
|
icon: string;
|
||||||
|
darkIcon?: string;
|
||||||
|
iconClassName?: string;
|
||||||
|
borderless?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
const LAUNCH_COMMANDS: LaunchCommand[] = [
|
||||||
|
{
|
||||||
|
id: "openclaw",
|
||||||
|
name: "OpenClaw",
|
||||||
|
command: "ollama launch openclaw",
|
||||||
|
description: "Personal AI with 100+ skills",
|
||||||
|
icon: "/launch-icons/openclaw.svg",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: "claude",
|
||||||
|
name: "Claude",
|
||||||
|
command: "ollama launch claude",
|
||||||
|
description: "Anthropic's coding tool with subagents",
|
||||||
|
icon: "/launch-icons/claude.svg",
|
||||||
|
iconClassName: "h-7 w-7",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: "codex",
|
||||||
|
name: "Codex",
|
||||||
|
command: "ollama launch codex",
|
||||||
|
description: "OpenAI's open-source coding agent",
|
||||||
|
icon: "/launch-icons/codex.svg",
|
||||||
|
darkIcon: "/launch-icons/codex-dark.svg",
|
||||||
|
iconClassName: "h-7 w-7",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: "opencode",
|
||||||
|
name: "OpenCode",
|
||||||
|
command: "ollama launch opencode",
|
||||||
|
description: "Anomaly's open-source coding agent",
|
||||||
|
icon: "/launch-icons/opencode.svg",
|
||||||
|
iconClassName: "h-7 w-7 rounded",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: "droid",
|
||||||
|
name: "Droid",
|
||||||
|
command: "ollama launch droid",
|
||||||
|
description: "Factory's coding agent across terminal and IDEs",
|
||||||
|
icon: "/launch-icons/droid.svg",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: "pi",
|
||||||
|
name: "Pi",
|
||||||
|
command: "ollama launch pi",
|
||||||
|
description: "Minimal AI agent toolkit with plugin support",
|
||||||
|
icon: "/launch-icons/pi.svg",
|
||||||
|
darkIcon: "/launch-icons/pi-dark.svg",
|
||||||
|
iconClassName: "h-7 w-7",
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
export default function LaunchCommands() {
|
||||||
|
const isWindows = navigator.platform.toLowerCase().includes("win");
|
||||||
|
const { setSettings } = useSettings();
|
||||||
|
|
||||||
|
const renderCommandCard = (item: LaunchCommand) => (
|
||||||
|
<div key={item.command} className="w-full text-left">
|
||||||
|
<div className="flex items-start gap-4 sm:gap-5">
|
||||||
|
<div
|
||||||
|
aria-hidden="true"
|
||||||
|
className={`flex h-10 w-10 shrink-0 items-center justify-center rounded-lg overflow-hidden ${item.borderless ? "" : "border border-neutral-200 bg-white dark:border-neutral-700 dark:bg-neutral-900"}`}
|
||||||
|
>
|
||||||
|
{item.darkIcon ? (
|
||||||
|
<picture>
|
||||||
|
<source srcSet={item.darkIcon} media="(prefers-color-scheme: dark)" />
|
||||||
|
<img src={item.icon} alt="" className={`${item.iconClassName ?? "h-8 w-8"} rounded-sm`} />
|
||||||
|
</picture>
|
||||||
|
) : (
|
||||||
|
<img src={item.icon} alt="" className={item.borderless ? "h-full w-full rounded-xl" : `${item.iconClassName ?? "h-8 w-8"} rounded-sm`} />
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="min-w-0 flex-1">
|
||||||
|
<span className="text-sm font-medium text-neutral-900 dark:text-neutral-100">
|
||||||
|
{item.name}
|
||||||
|
</span>
|
||||||
|
<p className="mt-0.5 text-xs text-neutral-500 dark:text-neutral-400">
|
||||||
|
{item.description}
|
||||||
|
</p>
|
||||||
|
<div className="mt-2 flex items-center gap-2 rounded-xl border-neutral-200 dark:border-neutral-700 bg-neutral-50 dark:bg-neutral-800 px-3 py-2">
|
||||||
|
<code className="min-w-0 flex-1 truncate text-xs text-neutral-600 dark:text-neutral-300">
|
||||||
|
{item.command}
|
||||||
|
</code>
|
||||||
|
<CopyButton
|
||||||
|
content={item.command}
|
||||||
|
size="md"
|
||||||
|
title="Copy command to clipboard"
|
||||||
|
className="text-neutral-500 dark:text-neutral-400 hover:text-neutral-700 dark:hover:text-neutral-200 hover:bg-neutral-200/60 dark:hover:bg-neutral-700/70"
|
||||||
|
onCopy={() => {
|
||||||
|
setSettings({ LastHomeView: item.id }).catch(() => { });
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<main className="flex h-screen w-full flex-col relative">
|
||||||
|
<section
|
||||||
|
className={`flex-1 overflow-y-auto overscroll-contain relative min-h-0 ${isWindows ? "xl:pt-4" : "xl:pt-8"}`}
|
||||||
|
>
|
||||||
|
<div className="max-w-[730px] mx-auto w-full px-4 pt-4 pb-20 sm:px-6 sm:pt-6 sm:pb-24 lg:px-8 lg:pt-8 lg:pb-28">
|
||||||
|
<h1 className="text-xl font-semibold text-neutral-900 dark:text-neutral-100">
|
||||||
|
Launch
|
||||||
|
</h1>
|
||||||
|
<p className="mt-1 text-sm text-neutral-500 dark:text-neutral-400">
|
||||||
|
Copy a command and run it in your terminal.
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<div className="mt-6 grid gap-7">
|
||||||
|
{LAUNCH_COMMANDS.map(renderCommandCard)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
</main>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -536,7 +536,7 @@ function ToolCallDisplay({
|
|||||||
let args: Record<string, unknown> | null = null;
|
let args: Record<string, unknown> | null = null;
|
||||||
try {
|
try {
|
||||||
args = JSON.parse(toolCall.function.arguments) as Record<string, unknown>;
|
args = JSON.parse(toolCall.function.arguments) as Record<string, unknown>;
|
||||||
} catch (e) {
|
} catch {
|
||||||
args = null;
|
args = null;
|
||||||
}
|
}
|
||||||
const query = args && typeof args.query === "string" ? args.query : "";
|
const query = args && typeof args.query === "string" ? args.query : "";
|
||||||
@@ -562,7 +562,7 @@ function ToolCallDisplay({
|
|||||||
let args: Record<string, unknown> | null = null;
|
let args: Record<string, unknown> | null = null;
|
||||||
try {
|
try {
|
||||||
args = JSON.parse(toolCall.function.arguments) as Record<string, unknown>;
|
args = JSON.parse(toolCall.function.arguments) as Record<string, unknown>;
|
||||||
} catch (e) {
|
} catch {
|
||||||
args = null;
|
args = null;
|
||||||
}
|
}
|
||||||
const url = args && typeof args.url === "string" ? args.url : "";
|
const url = args && typeof args.url === "string" ? args.url : "";
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ export default function MessageList({
|
|||||||
? String(args.url).trim()
|
? String(args.url).trim()
|
||||||
: "";
|
: "";
|
||||||
if (candidate) lastQuery = candidate;
|
if (candidate) lastQuery = candidate;
|
||||||
} catch {}
|
} catch { /* ignored */ }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -61,24 +61,7 @@ export const ModelPicker = forwardRef<
|
|||||||
try {
|
try {
|
||||||
const upstreamInfo = await getModelUpstreamInfo(model);
|
const upstreamInfo = await getModelUpstreamInfo(model);
|
||||||
|
|
||||||
// Compare local digest with upstream digest
|
if (upstreamInfo.stale) {
|
||||||
let isStale =
|
|
||||||
model.digest &&
|
|
||||||
upstreamInfo.digest &&
|
|
||||||
model.digest !== upstreamInfo.digest;
|
|
||||||
|
|
||||||
// If the model has a modified time and upstream has a push time,
|
|
||||||
// check if the model was modified after the push time - if so, it's not stale
|
|
||||||
if (isStale && model.modified_at && upstreamInfo.pushTime > 0) {
|
|
||||||
const modifiedAtTime =
|
|
||||||
new Date(model.modified_at as string | number | Date).getTime() /
|
|
||||||
1000;
|
|
||||||
if (modifiedAtTime > upstreamInfo.pushTime) {
|
|
||||||
isStale = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isStale) {
|
|
||||||
const currentStaleModels =
|
const currentStaleModels =
|
||||||
queryClient.getQueryData<Map<string, boolean>>(["staleModels"]) ||
|
queryClient.getQueryData<Map<string, boolean>>(["staleModels"]) ||
|
||||||
new Map();
|
new Map();
|
||||||
|
|||||||
@@ -214,6 +214,7 @@ export default function Settings() {
|
|||||||
Agent: false,
|
Agent: false,
|
||||||
Tools: false,
|
Tools: false,
|
||||||
ContextLength: 0,
|
ContextLength: 0,
|
||||||
|
AutoUpdateEnabled: true,
|
||||||
});
|
});
|
||||||
updateSettingsMutation.mutate(defaultSettings);
|
updateSettingsMutation.mutate(defaultSettings);
|
||||||
}
|
}
|
||||||
@@ -272,6 +273,10 @@ export default function Settings() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const isWindows = navigator.platform.toLowerCase().includes("win");
|
const isWindows = navigator.platform.toLowerCase().includes("win");
|
||||||
|
const handleCloseSettings = () => {
|
||||||
|
const chatId = settings.LastHomeView === "chat" ? "new" : "launch";
|
||||||
|
navigate({ to: "/c/$chatId", params: { chatId } });
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<main className="flex h-screen w-full flex-col select-none dark:bg-neutral-900">
|
<main className="flex h-screen w-full flex-col select-none dark:bg-neutral-900">
|
||||||
@@ -285,7 +290,7 @@ export default function Settings() {
|
|||||||
>
|
>
|
||||||
{isWindows && (
|
{isWindows && (
|
||||||
<button
|
<button
|
||||||
onClick={() => navigate({ to: "/" })}
|
onClick={handleCloseSettings}
|
||||||
className="hover:bg-neutral-100 mr-3 dark:hover:bg-neutral-800 rounded-full p-1.5"
|
className="hover:bg-neutral-100 mr-3 dark:hover:bg-neutral-800 rounded-full p-1.5"
|
||||||
>
|
>
|
||||||
<ArrowLeftIcon className="w-5 h-5 dark:text-white" />
|
<ArrowLeftIcon className="w-5 h-5 dark:text-white" />
|
||||||
@@ -295,7 +300,7 @@ export default function Settings() {
|
|||||||
</h1>
|
</h1>
|
||||||
{!isWindows && (
|
{!isWindows && (
|
||||||
<button
|
<button
|
||||||
onClick={() => navigate({ to: "/" })}
|
onClick={handleCloseSettings}
|
||||||
className="p-1 hover:bg-neutral-100 mr-3 dark:hover:bg-neutral-800 rounded-full"
|
className="p-1 hover:bg-neutral-100 mr-3 dark:hover:bg-neutral-800 rounded-full"
|
||||||
>
|
>
|
||||||
<XMarkIcon className="w-6 h-6 dark:text-white" />
|
<XMarkIcon className="w-6 h-6 dark:text-white" />
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ export const BadgeButton = forwardRef(function BadgeButton(
|
|||||||
),
|
),
|
||||||
ref: React.ForwardedRef<HTMLElement>,
|
ref: React.ForwardedRef<HTMLElement>,
|
||||||
) {
|
) {
|
||||||
let classes = clsx(
|
const classes = clsx(
|
||||||
className,
|
className,
|
||||||
"group relative inline-flex rounded-md focus:not-data-focus:outline-hidden data-focus:outline-2 data-focus:outline-offset-2 data-focus:outline-blue-500",
|
"group relative inline-flex rounded-md focus:not-data-focus:outline-hidden data-focus:outline-2 data-focus:outline-offset-2 data-focus:outline-blue-500",
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -171,7 +171,7 @@ export const Button = forwardRef(function Button(
|
|||||||
{ color, outline, plain, className, children, ...props }: ButtonProps,
|
{ color, outline, plain, className, children, ...props }: ButtonProps,
|
||||||
ref: React.ForwardedRef<HTMLElement>,
|
ref: React.ForwardedRef<HTMLElement>,
|
||||||
) {
|
) {
|
||||||
let classes = clsx(
|
const classes = clsx(
|
||||||
className,
|
className,
|
||||||
styles.base,
|
styles.base,
|
||||||
outline
|
outline
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ interface SettingsState {
|
|||||||
webSearchEnabled: boolean;
|
webSearchEnabled: boolean;
|
||||||
selectedModel: string;
|
selectedModel: string;
|
||||||
sidebarOpen: boolean;
|
sidebarOpen: boolean;
|
||||||
|
lastHomeView: string;
|
||||||
thinkEnabled: boolean;
|
thinkEnabled: boolean;
|
||||||
thinkLevel: string;
|
thinkLevel: string;
|
||||||
}
|
}
|
||||||
@@ -21,6 +22,7 @@ type SettingsUpdate = Partial<{
|
|||||||
ThinkLevel: string;
|
ThinkLevel: string;
|
||||||
SelectedModel: string;
|
SelectedModel: string;
|
||||||
SidebarOpen: boolean;
|
SidebarOpen: boolean;
|
||||||
|
LastHomeView: string;
|
||||||
}>;
|
}>;
|
||||||
|
|
||||||
export function useSettings() {
|
export function useSettings() {
|
||||||
@@ -50,6 +52,7 @@ export function useSettings() {
|
|||||||
thinkLevel: settingsData?.settings?.ThinkLevel ?? "none",
|
thinkLevel: settingsData?.settings?.ThinkLevel ?? "none",
|
||||||
selectedModel: settingsData?.settings?.SelectedModel ?? "",
|
selectedModel: settingsData?.settings?.SelectedModel ?? "",
|
||||||
sidebarOpen: settingsData?.settings?.SidebarOpen ?? false,
|
sidebarOpen: settingsData?.settings?.SidebarOpen ?? false,
|
||||||
|
lastHomeView: settingsData?.settings?.LastHomeView ?? "launch",
|
||||||
}),
|
}),
|
||||||
[settingsData?.settings],
|
[settingsData?.settings],
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -4,12 +4,37 @@ import Chat from "@/components/Chat";
|
|||||||
import { getChat } from "@/api";
|
import { getChat } from "@/api";
|
||||||
import { SidebarLayout } from "@/components/layout/layout";
|
import { SidebarLayout } from "@/components/layout/layout";
|
||||||
import { ChatSidebar } from "@/components/ChatSidebar";
|
import { ChatSidebar } from "@/components/ChatSidebar";
|
||||||
|
import LaunchCommands from "@/components/LaunchCommands";
|
||||||
|
import { useEffect, useRef } from "react";
|
||||||
|
import { useSettings } from "@/hooks/useSettings";
|
||||||
|
|
||||||
|
const launchSidebarRequestedKey = "ollama.launchSidebarRequested";
|
||||||
|
const launchSidebarSeenKey = "ollama.launchSidebarSeen";
|
||||||
|
const fallbackSessionState = new Map<string, string>();
|
||||||
|
|
||||||
|
function getSessionState() {
|
||||||
|
if (typeof sessionStorage !== "undefined") {
|
||||||
|
return sessionStorage;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
getItem(key: string) {
|
||||||
|
return fallbackSessionState.get(key) ?? null;
|
||||||
|
},
|
||||||
|
setItem(key: string, value: string) {
|
||||||
|
fallbackSessionState.set(key, value);
|
||||||
|
},
|
||||||
|
removeItem(key: string) {
|
||||||
|
fallbackSessionState.delete(key);
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
export const Route = createFileRoute("/c/$chatId")({
|
export const Route = createFileRoute("/c/$chatId")({
|
||||||
component: RouteComponent,
|
component: RouteComponent,
|
||||||
loader: async ({ context, params }) => {
|
loader: async ({ context, params }) => {
|
||||||
// Skip loading for "new" chat
|
// Skip loading for special non-chat views
|
||||||
if (params.chatId !== "new") {
|
if (params.chatId !== "new" && params.chatId !== "launch") {
|
||||||
context.queryClient.ensureQueryData({
|
context.queryClient.ensureQueryData({
|
||||||
queryKey: ["chat", params.chatId],
|
queryKey: ["chat", params.chatId],
|
||||||
queryFn: () => getChat(params.chatId),
|
queryFn: () => getChat(params.chatId),
|
||||||
@@ -21,13 +46,70 @@ export const Route = createFileRoute("/c/$chatId")({
|
|||||||
|
|
||||||
function RouteComponent() {
|
function RouteComponent() {
|
||||||
const { chatId } = Route.useParams();
|
const { chatId } = Route.useParams();
|
||||||
|
const { settingsData, setSettings } = useSettings();
|
||||||
|
const previousChatIdRef = useRef<string | null>(null);
|
||||||
|
|
||||||
// Always call hooks at the top level - use a flag to skip data when chatId is "new"
|
// Always call hooks at the top level - use a flag to skip data when chatId is a special view
|
||||||
const {
|
const {
|
||||||
data: chatData,
|
data: chatData,
|
||||||
isLoading: chatLoading,
|
isLoading: chatLoading,
|
||||||
error: chatError,
|
error: chatError,
|
||||||
} = useChat(chatId === "new" ? "" : chatId);
|
} = useChat(chatId === "new" || chatId === "launch" ? "" : chatId);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!settingsData) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const previousChatId = previousChatIdRef.current;
|
||||||
|
previousChatIdRef.current = chatId;
|
||||||
|
|
||||||
|
if (chatId === "launch") {
|
||||||
|
const sessionState = getSessionState();
|
||||||
|
const shouldOpenSidebar =
|
||||||
|
previousChatId !== "launch" &&
|
||||||
|
(() => {
|
||||||
|
if (sessionState.getItem(launchSidebarRequestedKey) === "1") {
|
||||||
|
sessionState.removeItem(launchSidebarRequestedKey);
|
||||||
|
sessionState.setItem(launchSidebarSeenKey, "1");
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sessionState.getItem(launchSidebarSeenKey) !== "1") {
|
||||||
|
sessionState.setItem(launchSidebarSeenKey, "1");
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
})();
|
||||||
|
const updates: { LastHomeView?: string; SidebarOpen?: boolean } = {};
|
||||||
|
|
||||||
|
if (settingsData.LastHomeView !== "launch") {
|
||||||
|
updates.LastHomeView = "launch";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (shouldOpenSidebar && !settingsData.SidebarOpen) {
|
||||||
|
updates.SidebarOpen = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Object.keys(updates).length === 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
setSettings(updates).catch(() => {
|
||||||
|
// Best effort persistence for home view preference.
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (settingsData.LastHomeView === "chat") {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
setSettings({ LastHomeView: "chat" }).catch(() => {
|
||||||
|
// Best effort persistence for home view preference.
|
||||||
|
});
|
||||||
|
}, [chatId, settingsData, setSettings]);
|
||||||
|
|
||||||
// Handle "new" chat case - just use Chat component which handles everything
|
// Handle "new" chat case - just use Chat component which handles everything
|
||||||
if (chatId === "new") {
|
if (chatId === "new") {
|
||||||
@@ -38,6 +120,14 @@ function RouteComponent() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (chatId === "launch") {
|
||||||
|
return (
|
||||||
|
<SidebarLayout sidebar={<ChatSidebar currentChatId={chatId} />}>
|
||||||
|
<LaunchCommands />
|
||||||
|
</SidebarLayout>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// Handle existing chat case
|
// Handle existing chat case
|
||||||
if (chatLoading) {
|
if (chatLoading) {
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -1,10 +1,18 @@
|
|||||||
import { createFileRoute, redirect } from "@tanstack/react-router";
|
import { createFileRoute, redirect } from "@tanstack/react-router";
|
||||||
|
import { getSettings } from "@/api";
|
||||||
|
|
||||||
export const Route = createFileRoute("/")({
|
export const Route = createFileRoute("/")({
|
||||||
beforeLoad: () => {
|
beforeLoad: async ({ context }) => {
|
||||||
|
const settingsData = await context.queryClient.ensureQueryData({
|
||||||
|
queryKey: ["settings"],
|
||||||
|
queryFn: getSettings,
|
||||||
|
});
|
||||||
|
const chatId =
|
||||||
|
settingsData?.settings?.LastHomeView === "chat" ? "new" : "launch";
|
||||||
|
|
||||||
throw redirect({
|
throw redirect({
|
||||||
to: "/c/$chatId",
|
to: "/c/$chatId",
|
||||||
params: { chatId: "new" },
|
params: { chatId },
|
||||||
mask: {
|
mask: {
|
||||||
to: "/",
|
to: "/",
|
||||||
},
|
},
|
||||||
|
|||||||
57
app/ui/app/src/utils/clipboard.test.ts
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
import { describe, expect, it, vi, beforeEach } from "vitest";
|
||||||
|
import { copyTextToClipboard } from "./clipboard";
|
||||||
|
|
||||||
|
describe("copyTextToClipboard", () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.restoreAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("copies via Clipboard API when available", async () => {
|
||||||
|
const writeText = vi.fn().mockResolvedValue(undefined);
|
||||||
|
vi.stubGlobal("navigator", {
|
||||||
|
clipboard: {
|
||||||
|
writeText,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const copied = await copyTextToClipboard("ollama launch claude");
|
||||||
|
|
||||||
|
expect(copied).toBe(true);
|
||||||
|
expect(writeText).toHaveBeenCalledWith("ollama launch claude");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("falls back to execCommand when Clipboard API fails", async () => {
|
||||||
|
const writeText = vi.fn().mockRejectedValue(new Error("not allowed"));
|
||||||
|
vi.stubGlobal("navigator", {
|
||||||
|
clipboard: {
|
||||||
|
writeText,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const textarea = {
|
||||||
|
value: "",
|
||||||
|
setAttribute: vi.fn(),
|
||||||
|
style: {} as Record<string, string>,
|
||||||
|
focus: vi.fn(),
|
||||||
|
select: vi.fn(),
|
||||||
|
};
|
||||||
|
const appendChild = vi.fn();
|
||||||
|
const removeChild = vi.fn();
|
||||||
|
const execCommand = vi.fn().mockReturnValue(true);
|
||||||
|
vi.stubGlobal("document", {
|
||||||
|
createElement: vi.fn().mockReturnValue(textarea),
|
||||||
|
body: {
|
||||||
|
appendChild,
|
||||||
|
removeChild,
|
||||||
|
},
|
||||||
|
execCommand,
|
||||||
|
});
|
||||||
|
|
||||||
|
const copied = await copyTextToClipboard("ollama launch openclaw");
|
||||||
|
|
||||||
|
expect(copied).toBe(true);
|
||||||
|
expect(execCommand).toHaveBeenCalledWith("copy");
|
||||||
|
expect(appendChild).toHaveBeenCalled();
|
||||||
|
expect(removeChild).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
30
app/ui/app/src/utils/clipboard.ts
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
export async function copyTextToClipboard(text: string): Promise<boolean> {
|
||||||
|
try {
|
||||||
|
await navigator.clipboard.writeText(text);
|
||||||
|
return true;
|
||||||
|
} catch (clipboardError) {
|
||||||
|
console.error(
|
||||||
|
"Clipboard API failed, falling back to execCommand",
|
||||||
|
clipboardError,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const textarea = document.createElement("textarea");
|
||||||
|
textarea.value = text;
|
||||||
|
textarea.setAttribute("readonly", "true");
|
||||||
|
textarea.style.position = "fixed";
|
||||||
|
textarea.style.left = "-9999px";
|
||||||
|
textarea.style.opacity = "0";
|
||||||
|
document.body.appendChild(textarea);
|
||||||
|
textarea.focus();
|
||||||
|
textarea.select();
|
||||||
|
|
||||||
|
const copied = document.execCommand("copy");
|
||||||
|
document.body.removeChild(textarea);
|
||||||
|
return copied;
|
||||||
|
} catch (fallbackError) {
|
||||||
|
console.error("Fallback copy failed", fallbackError);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -29,13 +29,15 @@ describe("fileValidation", () => {
|
|||||||
expect(result.valid).toBe(true);
|
expect(result.valid).toBe(true);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should reject WebP images when vision capability is disabled", () => {
|
it("should accept images regardless of vision capability", () => {
|
||||||
|
// Vision capability check is handled at the UI layer (ChatForm),
|
||||||
|
// not at validation time, so users can switch models without
|
||||||
|
// needing to re-upload files.
|
||||||
const file = createMockFile("test.webp", 1024, "image/webp");
|
const file = createMockFile("test.webp", 1024, "image/webp");
|
||||||
const result = validateFile(file, {
|
const result = validateFile(file, {
|
||||||
hasVisionCapability: false,
|
hasVisionCapability: false,
|
||||||
});
|
});
|
||||||
expect(result.valid).toBe(false);
|
expect(result.valid).toBe(true);
|
||||||
expect(result.error).toBe("This model does not support images");
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should accept PNG images when vision capability is enabled", () => {
|
it("should accept PNG images when vision capability is enabled", () => {
|
||||||
|
|||||||
@@ -63,7 +63,6 @@ export function validateFile(
|
|||||||
const {
|
const {
|
||||||
maxFileSize = 10,
|
maxFileSize = 10,
|
||||||
allowedExtensions = [...TEXT_FILE_EXTENSIONS, ...IMAGE_EXTENSIONS],
|
allowedExtensions = [...TEXT_FILE_EXTENSIONS, ...IMAGE_EXTENSIONS],
|
||||||
hasVisionCapability = false,
|
|
||||||
customValidator,
|
customValidator,
|
||||||
} = options;
|
} = options;
|
||||||
|
|
||||||
@@ -83,10 +82,6 @@ export function validateFile(
|
|||||||
return { valid: false, error: "File type not supported" };
|
return { valid: false, error: "File type not supported" };
|
||||||
}
|
}
|
||||||
|
|
||||||
if (IMAGE_EXTENSIONS.includes(fileExtension) && !hasVisionCapability) {
|
|
||||||
return { valid: false, error: "This model does not support images" };
|
|
||||||
}
|
|
||||||
|
|
||||||
// File size validation
|
// File size validation
|
||||||
if (file.size > MAX_FILE_SIZE) {
|
if (file.size > MAX_FILE_SIZE) {
|
||||||
return { valid: false, error: "File too large" };
|
return { valid: false, error: "File too large" };
|
||||||
|
|||||||
@@ -2,27 +2,28 @@ import { Model } from "@/gotypes";
|
|||||||
|
|
||||||
// Featured models list (in priority order)
|
// Featured models list (in priority order)
|
||||||
export const FEATURED_MODELS = [
|
export const FEATURED_MODELS = [
|
||||||
|
"kimi-k2.5:cloud",
|
||||||
|
"glm-5:cloud",
|
||||||
|
"minimax-m2.7:cloud",
|
||||||
|
"gemma4:31b-cloud",
|
||||||
|
"qwen3.5:397b-cloud",
|
||||||
"gpt-oss:120b-cloud",
|
"gpt-oss:120b-cloud",
|
||||||
"gpt-oss:20b-cloud",
|
"gpt-oss:20b-cloud",
|
||||||
"deepseek-v3.1:671b-cloud",
|
"deepseek-v3.1:671b-cloud",
|
||||||
"qwen3-coder:480b-cloud",
|
|
||||||
"qwen3-vl:235b-cloud",
|
|
||||||
"minimax-m2:cloud",
|
|
||||||
"glm-4.6:cloud",
|
|
||||||
"gpt-oss:120b",
|
"gpt-oss:120b",
|
||||||
"gpt-oss:20b",
|
"gpt-oss:20b",
|
||||||
"gemma3:27b",
|
"gemma4:31b",
|
||||||
"gemma3:12b",
|
"gemma4:26b",
|
||||||
"gemma3:4b",
|
"gemma4:e4b",
|
||||||
"gemma3:1b",
|
"gemma4:e2b",
|
||||||
"deepseek-r1:8b",
|
"deepseek-r1:8b",
|
||||||
"qwen3-coder:30b",
|
"qwen3-coder:30b",
|
||||||
"qwen3-vl:30b",
|
"qwen3-vl:30b",
|
||||||
"qwen3-vl:8b",
|
"qwen3-vl:8b",
|
||||||
"qwen3-vl:4b",
|
"qwen3-vl:4b",
|
||||||
"qwen3:30b",
|
"qwen3.5:27b",
|
||||||
"qwen3:8b",
|
"qwen3.5:9b",
|
||||||
"qwen3:4b",
|
"qwen3.5:4b",
|
||||||
];
|
];
|
||||||
|
|
||||||
function alphabeticalSort(a: Model, b: Model): number {
|
function alphabeticalSort(a: Model, b: Model): number {
|
||||||
|
|||||||
@@ -133,9 +133,8 @@ type Error struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ModelUpstreamResponse struct {
|
type ModelUpstreamResponse struct {
|
||||||
Digest string `json:"digest,omitempty"`
|
Stale bool `json:"stale"`
|
||||||
PushTime int64 `json:"pushTime"`
|
Error string `json:"error,omitempty"`
|
||||||
Error string `json:"error,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serializable data for the browser state
|
// Serializable data for the browser state
|
||||||
|
|||||||
39
app/ui/ui.go
@@ -32,6 +32,7 @@ import (
|
|||||||
"github.com/ollama/ollama/app/version"
|
"github.com/ollama/ollama/app/version"
|
||||||
ollamaAuth "github.com/ollama/ollama/auth"
|
ollamaAuth "github.com/ollama/ollama/auth"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/manifest"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
_ "github.com/tkrajina/typescriptify-golang-structs/typescriptify"
|
_ "github.com/tkrajina/typescriptify-golang-structs/typescriptify"
|
||||||
)
|
)
|
||||||
@@ -155,7 +156,7 @@ func (s *Server) ollamaProxy() http.Handler {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
target := envconfig.Host()
|
target := envconfig.ConnectableHost()
|
||||||
s.log().Info("configuring ollama proxy", "target", target.String())
|
s.log().Info("configuring ollama proxy", "target", target.String())
|
||||||
|
|
||||||
newProxy := httputil.NewSingleHostReverseProxy(target)
|
newProxy := httputil.NewSingleHostReverseProxy(target)
|
||||||
@@ -193,7 +194,7 @@ func (s *Server) Handler() http.Handler {
|
|||||||
if CORS() {
|
if CORS() {
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, User-Agent, Accept, X-Requested-With")
|
||||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||||
|
|
||||||
// Handle preflight requests
|
// Handle preflight requests
|
||||||
@@ -318,7 +319,7 @@ func (s *Server) handleError(w http.ResponseWriter, e error) {
|
|||||||
if CORS() {
|
if CORS() {
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, User-Agent, Accept, X-Requested-With")
|
||||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -341,8 +342,18 @@ func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error
|
|||||||
|
|
||||||
// httpClient returns an HTTP client that automatically adds the User-Agent header
|
// httpClient returns an HTTP client that automatically adds the User-Agent header
|
||||||
func (s *Server) httpClient() *http.Client {
|
func (s *Server) httpClient() *http.Client {
|
||||||
|
return userAgentHTTPClient(10 * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
// inferenceClient uses almost the same HTTP client, but without a timeout so
|
||||||
|
// long requests aren't truncated
|
||||||
|
func (s *Server) inferenceClient() *api.Client {
|
||||||
|
return api.NewClient(envconfig.Host(), userAgentHTTPClient(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func userAgentHTTPClient(timeout time.Duration) *http.Client {
|
||||||
return &http.Client{
|
return &http.Client{
|
||||||
Timeout: 10 * time.Second,
|
Timeout: timeout,
|
||||||
Transport: &userAgentTransport{
|
Transport: &userAgentTransport{
|
||||||
base: http.DefaultTransport,
|
base: http.DefaultTransport,
|
||||||
},
|
},
|
||||||
@@ -720,11 +731,7 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
|
|||||||
_, cancelLoading := context.WithCancel(ctx)
|
_, cancelLoading := context.WithCancel(ctx)
|
||||||
loading := false
|
loading := false
|
||||||
|
|
||||||
c, err := api.ClientFromEnvironment()
|
c := s.inferenceClient()
|
||||||
if err != nil {
|
|
||||||
cancelLoading()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the model exists locally by trying to show it
|
// Check if the model exists locally by trying to show it
|
||||||
// TODO (jmorganca): skip this round trip and instead just act
|
// TODO (jmorganca): skip this round trip and instead just act
|
||||||
@@ -1572,9 +1579,18 @@ func (s *Server) modelUpstream(w http.ResponseWriter, r *http.Request) error {
|
|||||||
return json.NewEncoder(w).Encode(response)
|
return json.NewEncoder(w).Encode(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
n := model.ParseName(req.Model)
|
||||||
|
stale := true
|
||||||
|
if m, err := manifest.ParseNamedManifest(n); err == nil {
|
||||||
|
if m.Digest() == digest {
|
||||||
|
stale = false
|
||||||
|
} else if pushTime > 0 && m.FileInfo().ModTime().Unix() >= pushTime {
|
||||||
|
stale = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
response := responses.ModelUpstreamResponse{
|
response := responses.ModelUpstreamResponse{
|
||||||
Digest: digest,
|
Stale: stale,
|
||||||
PushTime: pushTime,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
@@ -1672,7 +1688,6 @@ func supportsBrowserTools(model string) bool {
|
|||||||
return strings.HasPrefix(strings.ToLower(model), "gpt-oss")
|
return strings.HasPrefix(strings.ToLower(model), "gpt-oss")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// buildChatRequest converts store.Chat to api.ChatRequest
|
// buildChatRequest converts store.Chat to api.ChatRequest
|
||||||
func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, availableTools []map[string]any) (*api.ChatRequest, error) {
|
func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, availableTools []map[string]any) (*api.ChatRequest, error) {
|
||||||
var msgs []api.Message
|
var msgs []api.Message
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/app/store"
|
"github.com/ollama/ollama/app/store"
|
||||||
"github.com/ollama/ollama/app/updater"
|
"github.com/ollama/ollama/app/updater"
|
||||||
)
|
)
|
||||||
@@ -526,6 +527,33 @@ func TestUserAgentTransport(t *testing.T) {
|
|||||||
t.Logf("User-Agent transport successfully set: %s", receivedUA)
|
t.Logf("User-Agent transport successfully set: %s", receivedUA)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestInferenceClientUsesUserAgent(t *testing.T) {
|
||||||
|
var gotUserAgent atomic.Value
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotUserAgent.Store(r.Header.Get("User-Agent"))
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Write([]byte(`{}`))
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", ts.URL)
|
||||||
|
|
||||||
|
server := &Server{}
|
||||||
|
client := server.inferenceClient()
|
||||||
|
|
||||||
|
_, err := client.Show(context.Background(), &api.ShowRequest{Model: "test"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("show request failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
receivedUA, _ := gotUserAgent.Load().(string)
|
||||||
|
expectedUA := userAgent()
|
||||||
|
|
||||||
|
if receivedUA != expectedUA {
|
||||||
|
t.Errorf("User-Agent mismatch\nExpected: %s\nReceived: %s", expectedUA, receivedUA)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSupportsBrowserTools(t *testing.T) {
|
func TestSupportsBrowserTools(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
model string
|
model string
|
||||||
|
|||||||
@@ -1,27 +1,31 @@
|
|||||||
Ollama Benchmark Tool
|
Ollama Benchmark Tool
|
||||||
---------------------
|
---------------------
|
||||||
|
|
||||||
A Go-based command-line tool for benchmarking Ollama models with configurable parameters and multiple output formats.
|
A Go-based command-line tool for benchmarking Ollama models with configurable parameters, warmup phases, TTFT tracking, VRAM monitoring, and benchstat/CSV output.
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
* Benchmark multiple models in a single run
|
* Benchmark multiple models in a single run
|
||||||
* Support for both text and image prompts
|
* Support for both text and image prompts
|
||||||
* Configurable generation parameters (temperature, max tokens, seed, etc.)
|
* Configurable generation parameters (temperature, max tokens, seed, etc.)
|
||||||
* Supports benchstat and CSV output formats
|
* Warmup phase before timed epochs to stabilize measurements
|
||||||
* Detailed performance metrics (prefill, generate, load, total durations)
|
* Time-to-first-token (TTFT) tracking per epoch
|
||||||
|
* Model metadata display (parameter size, quantization level, family)
|
||||||
|
* VRAM and CPU memory usage tracking via running process info
|
||||||
|
* Controlled prompt token length for reproducible benchmarks
|
||||||
|
* Benchstat and CSV output formats
|
||||||
|
|
||||||
## Building from Source
|
## Building from Source
|
||||||
|
|
||||||
```
|
```
|
||||||
go build -o ollama-bench bench.go
|
go build -o ollama-bench ./cmd/bench
|
||||||
./ollama-bench -model gpt-oss:20b -epochs 6 -format csv
|
./ollama-bench -model gemma3 -epochs 6 -format csv
|
||||||
```
|
```
|
||||||
|
|
||||||
Using Go Run (without building)
|
Using Go Run (without building)
|
||||||
|
|
||||||
```
|
```
|
||||||
go run bench.go -model gpt-oss:20b -epochs 3
|
go run ./cmd/bench -model gemma3 -epochs 3
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
@@ -45,10 +49,16 @@ benchstat -col /name gemma.bench
|
|||||||
./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
|
./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Controlled Prompt Length
|
||||||
|
|
||||||
|
```
|
||||||
|
./ollama-bench -model gemma3 -epochs 6 -prompt-tokens 512
|
||||||
|
```
|
||||||
|
|
||||||
### Advanced Example
|
### Advanced Example
|
||||||
|
|
||||||
```
|
```
|
||||||
./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv
|
./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -warmup 2 -format csv -output results.csv
|
||||||
```
|
```
|
||||||
|
|
||||||
## Command Line Options
|
## Command Line Options
|
||||||
@@ -56,41 +66,48 @@ benchstat -col /name gemma.bench
|
|||||||
| Option | Description | Default |
|
| Option | Description | Default |
|
||||||
|----------|-------------|---------|
|
|----------|-------------|---------|
|
||||||
| -model | Comma-separated list of models to benchmark | (required) |
|
| -model | Comma-separated list of models to benchmark | (required) |
|
||||||
| -epochs | Number of iterations per model | 1 |
|
| -epochs | Number of iterations per model | 6 |
|
||||||
| -max-tokens | Maximum tokens for model response | 0 (unlimited) |
|
| -max-tokens | Maximum tokens for model response | 200 |
|
||||||
| -temperature | Temperature parameter | 0.0 |
|
| -temperature | Temperature parameter | 0.0 |
|
||||||
| -seed | Random seed | 0 (random) |
|
| -seed | Random seed | 0 (random) |
|
||||||
| -timeout | Timeout in seconds | 300 |
|
| -timeout | Timeout in seconds | 300 |
|
||||||
| -p | Prompt text | "Write a long story." |
|
| -p | Prompt text | (default story prompt) |
|
||||||
| -image | Image file to include in prompt | |
|
| -image | Image file to include in prompt | |
|
||||||
| -k | Keep-alive duration in seconds | 0 |
|
| -k | Keep-alive duration in seconds | 0 |
|
||||||
| -format | Output format (benchstat, csv) | benchstat |
|
| -format | Output format (benchstat, csv) | benchstat |
|
||||||
| -output | Output file for results | "" (stdout) |
|
| -output | Output file for results | "" (stdout) |
|
||||||
|
| -warmup | Number of warmup requests before timing | 1 |
|
||||||
|
| -prompt-tokens | Generate prompt targeting ~N tokens (0 = use -p) | 0 |
|
||||||
| -v | Verbose mode | false |
|
| -v | Verbose mode | false |
|
||||||
| -debug | Show debug information | false |
|
| -debug | Show debug information | false |
|
||||||
|
|
||||||
## Output Formats
|
## Output Formats
|
||||||
|
|
||||||
### Markdown Format
|
### Benchstat Format (default)
|
||||||
|
|
||||||
The default markdown format is suitable for copying and pasting into a GitHub issue and will look like:
|
Compatible with Go's benchstat tool for statistical analysis. Uses one value/unit pair per line, standard `ns/op` for timing metrics, and `ns/token` for throughput. Each epoch produces one set of lines -- benchstat aggregates across repeated runs to compute statistics.
|
||||||
```
|
|
||||||
Model | Step | Count | Duration | nsPerToken | tokensPerSec |
|
|
||||||
|-------|------|-------|----------|------------|--------------|
|
|
||||||
| gpt-oss:20b | prefill | 124 | 30.006458ms | 241987.56 | 4132.44 |
|
|
||||||
| gpt-oss:20b | generate | 200 | 2.646843954s | 13234219.77 | 75.56 |
|
|
||||||
| gpt-oss:20b | load | 1 | 121.674208ms | - | - |
|
|
||||||
| gpt-oss:20b | total | 1 | 2.861047625s | - | - |
|
|
||||||
```
|
|
||||||
|
|
||||||
### Benchstat Format
|
|
||||||
|
|
||||||
Compatible with Go's benchstat tool for statistical analysis:
|
|
||||||
|
|
||||||
```
|
```
|
||||||
BenchmarkModel/name=gpt-oss:20b/step=prefill 128 78125.00 ns/token 12800.00 token/sec
|
# Model: gemma3 | Params: 4.3B | Quant: Q4_K_M | Family: gemma3 | Size: 4080218931 | VRAM: 4080218931
|
||||||
BenchmarkModel/name=gpt-oss:20b/step=generate 512 19531.25 ns/token 51200.00 token/sec
|
BenchmarkModel/name=gemma3/step=prefill 1 78125.00 ns/token 12800.00 token/sec
|
||||||
BenchmarkModel/name=gpt-oss:20b/step=load 1 1500000000 ns/request
|
BenchmarkModel/name=gemma3/step=generate 1 19531.25 ns/token 51200.00 token/sec
|
||||||
|
BenchmarkModel/name=gemma3/step=ttft 1 45123000 ns/op
|
||||||
|
BenchmarkModel/name=gemma3/step=load 1 1500000000 ns/op
|
||||||
|
BenchmarkModel/name=gemma3/step=total 1 2861047625 ns/op
|
||||||
|
```
|
||||||
|
|
||||||
|
Use with benchstat:
|
||||||
|
```
|
||||||
|
./ollama-bench -model gemma3 -epochs 6 > gemma3.bench
|
||||||
|
benchstat -col /step gemma3.bench
|
||||||
|
```
|
||||||
|
|
||||||
|
Compare two runs:
|
||||||
|
```
|
||||||
|
./ollama-bench -model gemma3 -epochs 6 > before.bench
|
||||||
|
# ... make changes ...
|
||||||
|
./ollama-bench -model gemma3 -epochs 6 > after.bench
|
||||||
|
benchstat before.bench after.bench
|
||||||
```
|
```
|
||||||
|
|
||||||
### CSV Format
|
### CSV Format
|
||||||
@@ -99,17 +116,28 @@ Machine-readable comma-separated values:
|
|||||||
|
|
||||||
```
|
```
|
||||||
NAME,STEP,COUNT,NS_PER_COUNT,TOKEN_PER_SEC
|
NAME,STEP,COUNT,NS_PER_COUNT,TOKEN_PER_SEC
|
||||||
gpt-oss:20b,prefill,128,78125.00,12800.00
|
# Model: gemma3 | Params: 4.3B | Quant: Q4_K_M | Family: gemma3 | Size: 4080218931 | VRAM: 4080218931
|
||||||
gpt-oss:20b,generate,512,19531.25,51200.00
|
gemma3,prefill,128,78125.00,12800.00
|
||||||
gpt-oss:20b,load,1,1500000000,0
|
gemma3,generate,512,19531.25,51200.00
|
||||||
|
gemma3,ttft,1,45123000,0
|
||||||
|
gemma3,load,1,1500000000,0
|
||||||
|
gemma3,total,1,2861047625,0
|
||||||
```
|
```
|
||||||
|
|
||||||
## Metrics Explained
|
## Metrics Explained
|
||||||
|
|
||||||
The tool reports four types of metrics for each model:
|
The tool reports the following metrics for each epoch:
|
||||||
|
|
||||||
* prefill: Time spent processing the prompt
|
* **prefill**: Time spent processing the prompt (ns/token)
|
||||||
* generate: Time spent generating the response
|
* **generate**: Time spent generating the response (ns/token)
|
||||||
* load: Model loading time (one-time cost)
|
* **ttft**: Time to first token -- latency from request start to first response content
|
||||||
* total: Total request duration
|
* **load**: Model loading time (one-time cost)
|
||||||
|
* **total**: Total request duration
|
||||||
|
|
||||||
|
Additionally, the model info comment line (displayed once per model before epochs) includes:
|
||||||
|
|
||||||
|
* **Params**: Model parameter count (e.g., 4.3B)
|
||||||
|
* **Quant**: Quantization level (e.g., Q4_K_M)
|
||||||
|
* **Family**: Model family (e.g., gemma3)
|
||||||
|
* **Size**: Total model memory in bytes
|
||||||
|
* **VRAM**: GPU memory used by the loaded model (when Size > VRAM, the difference is CPU spill)
|
||||||
|
|||||||
@@ -17,19 +17,22 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type flagOptions struct {
|
type flagOptions struct {
|
||||||
models *string
|
models *string
|
||||||
epochs *int
|
epochs *int
|
||||||
maxTokens *int
|
maxTokens *int
|
||||||
temperature *float64
|
temperature *float64
|
||||||
seed *int
|
seed *int
|
||||||
timeout *int
|
timeout *int
|
||||||
prompt *string
|
prompt *string
|
||||||
imageFile *string
|
imageFile *string
|
||||||
keepAlive *float64
|
keepAlive *float64
|
||||||
format *string
|
format *string
|
||||||
outputFile *string
|
outputFile *string
|
||||||
debug *bool
|
debug *bool
|
||||||
verbose *bool
|
verbose *bool
|
||||||
|
warmup *int
|
||||||
|
promptTokens *int
|
||||||
|
numCtx *int
|
||||||
}
|
}
|
||||||
|
|
||||||
type Metrics struct {
|
type Metrics struct {
|
||||||
@@ -39,48 +42,203 @@ type Metrics struct {
|
|||||||
Duration time.Duration
|
Duration time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
var once sync.Once
|
type ModelInfo struct {
|
||||||
|
Name string
|
||||||
|
ParameterSize string
|
||||||
|
QuantizationLevel string
|
||||||
|
Family string
|
||||||
|
SizeBytes int64
|
||||||
|
VRAMBytes int64
|
||||||
|
NumCtx int64
|
||||||
|
}
|
||||||
|
|
||||||
const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.`
|
const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.`
|
||||||
|
|
||||||
|
// Word list for generating prompts targeting a specific token count.
|
||||||
|
var promptWordList = []string{
|
||||||
|
"the", "quick", "brown", "fox", "jumps", "over", "lazy", "dog",
|
||||||
|
"a", "bright", "sunny", "day", "in", "the", "meadow", "where",
|
||||||
|
"flowers", "bloom", "and", "birds", "sing", "their", "morning",
|
||||||
|
"songs", "while", "gentle", "breeze", "carries", "sweet", "scent",
|
||||||
|
"of", "pine", "trees", "across", "rolling", "hills", "toward",
|
||||||
|
"distant", "mountains", "covered", "with", "fresh", "snow",
|
||||||
|
"beneath", "clear", "blue", "sky", "children", "play", "near",
|
||||||
|
"old", "stone", "bridge", "that", "crosses", "winding", "river",
|
||||||
|
}
|
||||||
|
|
||||||
|
// tokensPerWord is the calibrated ratio of tokens to words for the current model.
|
||||||
|
// Initialized with a heuristic, then updated during warmup based on actual tokenization.
|
||||||
|
var tokensPerWord = 1.3
|
||||||
|
|
||||||
|
func generatePromptForTokenCount(targetTokens int, epoch int) string {
|
||||||
|
targetWords := int(float64(targetTokens) / tokensPerWord)
|
||||||
|
if targetWords < 1 {
|
||||||
|
targetWords = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vary the starting offset by epoch to defeat KV cache prefix matching
|
||||||
|
offset := epoch * 7 // stride by a prime to get good distribution
|
||||||
|
n := len(promptWordList)
|
||||||
|
words := make([]string, targetWords)
|
||||||
|
for i := range words {
|
||||||
|
words[i] = promptWordList[((i+offset)%n+n)%n]
|
||||||
|
}
|
||||||
|
return strings.Join(words, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// calibratePromptTokens adjusts tokensPerWord based on actual tokenization from a warmup run.
|
||||||
|
func calibratePromptTokens(targetTokens, actualTokens, wordCount int) {
|
||||||
|
if actualTokens <= 0 || wordCount <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tokensPerWord = float64(actualTokens) / float64(wordCount)
|
||||||
|
newWords := int(float64(targetTokens) / tokensPerWord)
|
||||||
|
fmt.Fprintf(os.Stderr, "bench: calibrated %.2f tokens/word (target=%d, got=%d, words=%d → %d)\n",
|
||||||
|
tokensPerWord, targetTokens, actualTokens, wordCount, newWords)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildGenerateRequest(model string, fOpt flagOptions, imgData api.ImageData, epoch int) *api.GenerateRequest {
|
||||||
|
options := make(map[string]interface{})
|
||||||
|
if *fOpt.maxTokens > 0 {
|
||||||
|
options["num_predict"] = *fOpt.maxTokens
|
||||||
|
}
|
||||||
|
options["temperature"] = *fOpt.temperature
|
||||||
|
if fOpt.seed != nil && *fOpt.seed > 0 {
|
||||||
|
options["seed"] = *fOpt.seed
|
||||||
|
}
|
||||||
|
if fOpt.numCtx != nil && *fOpt.numCtx > 0 {
|
||||||
|
options["num_ctx"] = *fOpt.numCtx
|
||||||
|
}
|
||||||
|
|
||||||
|
var keepAliveDuration *api.Duration
|
||||||
|
if *fOpt.keepAlive > 0 {
|
||||||
|
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
|
||||||
|
keepAliveDuration = &duration
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt := *fOpt.prompt
|
||||||
|
if *fOpt.promptTokens > 0 {
|
||||||
|
prompt = generatePromptForTokenCount(*fOpt.promptTokens, epoch)
|
||||||
|
} else {
|
||||||
|
// Vary the prompt per epoch to defeat KV cache prefix matching
|
||||||
|
prompt = fmt.Sprintf("[%d] %s", epoch, prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: model,
|
||||||
|
Prompt: prompt,
|
||||||
|
Raw: true,
|
||||||
|
Options: options,
|
||||||
|
KeepAlive: keepAliveDuration,
|
||||||
|
}
|
||||||
|
|
||||||
|
if imgData != nil {
|
||||||
|
req.Images = []api.ImageData{imgData}
|
||||||
|
}
|
||||||
|
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchModelInfo(ctx context.Context, client *api.Client, model string) ModelInfo {
|
||||||
|
info := ModelInfo{Name: model}
|
||||||
|
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "WARNING: Could not fetch model info for '%s': %v\n", model, err)
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
info.ParameterSize = resp.Details.ParameterSize
|
||||||
|
info.QuantizationLevel = resp.Details.QuantizationLevel
|
||||||
|
info.Family = resp.Details.Family
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchMemoryUsage(ctx context.Context, client *api.Client, model string) (size, vram int64) {
|
||||||
|
resp, err := client.ListRunning(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
||||||
|
fmt.Fprintf(os.Stderr, "WARNING: Could not fetch memory usage: %v\n", err)
|
||||||
|
}
|
||||||
|
return 0, 0
|
||||||
|
}
|
||||||
|
for _, m := range resp.Models {
|
||||||
|
if m.Name == model || m.Model == model {
|
||||||
|
return m.Size, m.SizeVRAM
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, m := range resp.Models {
|
||||||
|
if strings.HasPrefix(m.Name, model) || strings.HasPrefix(m.Model, model) {
|
||||||
|
return m.Size, m.SizeVRAM
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchContextLength(ctx context.Context, client *api.Client, model string) int64 {
|
||||||
|
resp, err := client.ListRunning(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
for _, m := range resp.Models {
|
||||||
|
if m.Name == model || m.Model == model || strings.HasPrefix(m.Name, model) || strings.HasPrefix(m.Model, model) {
|
||||||
|
return int64(m.ContextLength)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func outputFormatHeader(w io.Writer, format string, verbose bool) {
|
||||||
|
switch format {
|
||||||
|
case "benchstat":
|
||||||
|
if verbose {
|
||||||
|
fmt.Fprintf(w, "goos: %s\n", runtime.GOOS)
|
||||||
|
fmt.Fprintf(w, "goarch: %s\n", runtime.GOARCH)
|
||||||
|
}
|
||||||
|
case "csv":
|
||||||
|
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
|
||||||
|
fmt.Fprintln(w, strings.Join(headings, ","))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func outputModelInfo(w io.Writer, format string, info ModelInfo) {
|
||||||
|
params := cmp.Or(info.ParameterSize, "unknown")
|
||||||
|
quant := cmp.Or(info.QuantizationLevel, "unknown")
|
||||||
|
family := cmp.Or(info.Family, "unknown")
|
||||||
|
|
||||||
|
memStr := ""
|
||||||
|
if info.SizeBytes > 0 {
|
||||||
|
memStr = fmt.Sprintf(" | Size: %d | VRAM: %d", info.SizeBytes, info.VRAMBytes)
|
||||||
|
}
|
||||||
|
ctxStr := ""
|
||||||
|
if info.NumCtx > 0 {
|
||||||
|
ctxStr = fmt.Sprintf(" | NumCtx: %d", info.NumCtx)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(w, "# Model: %s | Params: %s | Quant: %s | Family: %s%s%s\n",
|
||||||
|
info.Name, params, quant, family, memStr, ctxStr)
|
||||||
|
}
|
||||||
|
|
||||||
func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) {
|
func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) {
|
||||||
switch format {
|
switch format {
|
||||||
case "benchstat":
|
case "benchstat":
|
||||||
if verbose {
|
|
||||||
printHeader := func() {
|
|
||||||
fmt.Fprintf(w, "sysname: %s\n", runtime.GOOS)
|
|
||||||
fmt.Fprintf(w, "machine: %s\n", runtime.GOARCH)
|
|
||||||
}
|
|
||||||
once.Do(printHeader)
|
|
||||||
}
|
|
||||||
for _, m := range metrics {
|
for _, m := range metrics {
|
||||||
if m.Step == "generate" || m.Step == "prefill" {
|
if m.Step == "generate" || m.Step == "prefill" {
|
||||||
if m.Count > 0 {
|
if m.Count > 0 {
|
||||||
nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
||||||
tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
||||||
|
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 %.2f ns/token %.2f token/sec\n",
|
||||||
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d %.2f ns/token %.2f token/sec\n",
|
m.Model, m.Step, nsPerToken, tokensPerSec)
|
||||||
m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
|
|
||||||
} else {
|
} else {
|
||||||
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d 0 ns/token 0 token/sec\n",
|
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 0 ns/token 0 token/sec\n",
|
||||||
m.Model, m.Step, m.Count)
|
m.Model, m.Step)
|
||||||
}
|
}
|
||||||
|
} else if m.Step == "ttft" {
|
||||||
|
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=ttft 1 %d ns/op\n",
|
||||||
|
m.Model, m.Duration.Nanoseconds())
|
||||||
} else {
|
} else {
|
||||||
var suffix string
|
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 %d ns/op\n",
|
||||||
if m.Step == "load" {
|
m.Model, m.Step, m.Duration.Nanoseconds())
|
||||||
suffix = "/step=load"
|
|
||||||
}
|
|
||||||
fmt.Fprintf(w, "BenchmarkModel/name=%s%s 1 %d ns/request\n",
|
|
||||||
m.Model, suffix, m.Duration.Nanoseconds())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "csv":
|
case "csv":
|
||||||
printHeader := func() {
|
|
||||||
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
|
|
||||||
fmt.Fprintln(w, strings.Join(headings, ","))
|
|
||||||
}
|
|
||||||
once.Do(printHeader)
|
|
||||||
|
|
||||||
for _, m := range metrics {
|
for _, m := range metrics {
|
||||||
if m.Step == "generate" || m.Step == "prefill" {
|
if m.Step == "generate" || m.Step == "prefill" {
|
||||||
var nsPerToken float64
|
var nsPerToken float64
|
||||||
@@ -94,39 +252,14 @@ func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool)
|
|||||||
fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds())
|
fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "markdown":
|
|
||||||
printHeader := func() {
|
|
||||||
fmt.Fprintln(w, "| Model | Step | Count | Duration | nsPerToken | tokensPerSec |")
|
|
||||||
fmt.Fprintln(w, "|-------|------|-------|----------|------------|--------------|")
|
|
||||||
}
|
|
||||||
once.Do(printHeader)
|
|
||||||
|
|
||||||
for _, m := range metrics {
|
|
||||||
var nsPerToken, tokensPerSec float64
|
|
||||||
var nsPerTokenStr, tokensPerSecStr string
|
|
||||||
|
|
||||||
if m.Step == "generate" || m.Step == "prefill" {
|
|
||||||
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
|
||||||
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
|
||||||
nsPerTokenStr = fmt.Sprintf("%.2f", nsPerToken)
|
|
||||||
tokensPerSecStr = fmt.Sprintf("%.2f", tokensPerSec)
|
|
||||||
} else {
|
|
||||||
nsPerTokenStr = "-"
|
|
||||||
tokensPerSecStr = "-"
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Fprintf(w, "| %s | %s | %d | %v | %s | %s |\n",
|
|
||||||
m.Model, m.Step, m.Count, m.Duration, nsPerTokenStr, tokensPerSecStr)
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format)
|
fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkChat(fOpt flagOptions) error {
|
func BenchmarkModel(fOpt flagOptions) error {
|
||||||
models := strings.Split(*fOpt.models, ",")
|
models := strings.Split(*fOpt.models, ",")
|
||||||
|
|
||||||
// todo - add multi-image support
|
|
||||||
var imgData api.ImageData
|
var imgData api.ImageData
|
||||||
var err error
|
var err error
|
||||||
if *fOpt.imageFile != "" {
|
if *fOpt.imageFile != "" {
|
||||||
@@ -158,71 +291,141 @@ func BenchmarkChat(fOpt flagOptions) error {
|
|||||||
out = f
|
out = f
|
||||||
}
|
}
|
||||||
|
|
||||||
|
outputFormatHeader(out, *fOpt.format, *fOpt.verbose)
|
||||||
|
|
||||||
|
// Log prompt-tokens info in debug mode
|
||||||
|
if *fOpt.debug && *fOpt.promptTokens > 0 {
|
||||||
|
prompt := generatePromptForTokenCount(*fOpt.promptTokens, 0)
|
||||||
|
wordCount := len(strings.Fields(prompt))
|
||||||
|
fmt.Fprintf(os.Stderr, "Generated prompt targeting ~%d tokens (%d words, varied per epoch)\n", *fOpt.promptTokens, wordCount)
|
||||||
|
}
|
||||||
|
|
||||||
for _, model := range models {
|
for _, model := range models {
|
||||||
for range *fOpt.epochs {
|
// Fetch model info
|
||||||
options := make(map[string]interface{})
|
infoCtx, infoCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
if *fOpt.maxTokens > 0 {
|
info := fetchModelInfo(infoCtx, client, model)
|
||||||
options["num_predict"] = *fOpt.maxTokens
|
infoCancel()
|
||||||
}
|
|
||||||
options["temperature"] = *fOpt.temperature
|
|
||||||
if fOpt.seed != nil && *fOpt.seed > 0 {
|
|
||||||
options["seed"] = *fOpt.seed
|
|
||||||
}
|
|
||||||
|
|
||||||
var keepAliveDuration *api.Duration
|
|
||||||
if *fOpt.keepAlive > 0 {
|
|
||||||
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
|
|
||||||
keepAliveDuration = &duration
|
|
||||||
}
|
|
||||||
|
|
||||||
req := &api.ChatRequest{
|
|
||||||
Model: model,
|
|
||||||
Messages: []api.Message{
|
|
||||||
{
|
|
||||||
Role: "user",
|
|
||||||
Content: *fOpt.prompt,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Options: options,
|
|
||||||
KeepAlive: keepAliveDuration,
|
|
||||||
}
|
|
||||||
|
|
||||||
if imgData != nil {
|
|
||||||
req.Messages[0].Images = []api.ImageData{imgData}
|
|
||||||
}
|
|
||||||
|
|
||||||
var responseMetrics *api.Metrics
|
|
||||||
|
|
||||||
|
// Warmup phase (uses negative epoch numbers to avoid colliding with timed epochs)
|
||||||
|
for i := range *fOpt.warmup {
|
||||||
|
req := buildGenerateRequest(model, fOpt, imgData, -(i + 1))
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
err = client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
|
||||||
if *fOpt.debug {
|
|
||||||
fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Message.Thinking, resp.Message.Content))
|
|
||||||
}
|
|
||||||
|
|
||||||
|
var warmupMetrics *api.Metrics
|
||||||
|
err = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||||
if resp.Done {
|
if resp.Done {
|
||||||
responseMetrics = &resp.Metrics
|
warmupMetrics = &resp.Metrics
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
cancel()
|
||||||
if *fOpt.debug {
|
|
||||||
fmt.Fprintln(os.Stderr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if ctx.Err() == context.DeadlineExceeded {
|
fmt.Fprintf(os.Stderr, "WARNING: Warmup %d/%d for %s failed: %v\n", i+1, *fOpt.warmup, model, err)
|
||||||
fmt.Fprintf(os.Stderr, "ERROR: Chat request timed out with model '%s' after %vs\n", model, 1)
|
} else {
|
||||||
continue
|
if *fOpt.debug {
|
||||||
|
fmt.Fprintf(os.Stderr, "Warmup %d/%d for %s complete\n", i+1, *fOpt.warmup, model)
|
||||||
}
|
}
|
||||||
fmt.Fprintf(os.Stderr, "ERROR: Couldn't chat with model '%s': %v\n", model, err)
|
// Calibrate prompt token count on last warmup run
|
||||||
|
if i == *fOpt.warmup-1 && *fOpt.promptTokens > 0 && warmupMetrics != nil {
|
||||||
|
prompt := generatePromptForTokenCount(*fOpt.promptTokens, -(i + 1))
|
||||||
|
wordCount := len(strings.Fields(prompt))
|
||||||
|
calibratePromptTokens(*fOpt.promptTokens, warmupMetrics.PromptEvalCount, wordCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch memory/context info once after warmup (model is loaded and stable)
|
||||||
|
memCtx, memCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
info.SizeBytes, info.VRAMBytes = fetchMemoryUsage(memCtx, client, model)
|
||||||
|
if fOpt.numCtx != nil && *fOpt.numCtx > 0 {
|
||||||
|
info.NumCtx = int64(*fOpt.numCtx)
|
||||||
|
} else {
|
||||||
|
info.NumCtx = fetchContextLength(memCtx, client, model)
|
||||||
|
}
|
||||||
|
memCancel()
|
||||||
|
|
||||||
|
outputModelInfo(out, *fOpt.format, info)
|
||||||
|
|
||||||
|
// Timed epoch loop
|
||||||
|
shortCount := 0
|
||||||
|
for epoch := range *fOpt.epochs {
|
||||||
|
var responseMetrics *api.Metrics
|
||||||
|
var ttft time.Duration
|
||||||
|
short := false
|
||||||
|
|
||||||
|
// Retry loop: if the model hits a stop token before max-tokens,
|
||||||
|
// retry with a different prompt (up to maxRetries times).
|
||||||
|
const maxRetries = 3
|
||||||
|
for attempt := range maxRetries + 1 {
|
||||||
|
responseMetrics = nil
|
||||||
|
ttft = 0
|
||||||
|
var ttftOnce sync.Once
|
||||||
|
|
||||||
|
req := buildGenerateRequest(model, fOpt, imgData, epoch+attempt*1000)
|
||||||
|
requestStart := time.Now()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
|
||||||
|
|
||||||
|
err = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||||
|
if *fOpt.debug {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Thinking, resp.Response))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capture TTFT on first content
|
||||||
|
ttftOnce.Do(func() {
|
||||||
|
if resp.Response != "" || resp.Thinking != "" {
|
||||||
|
ttft = time.Since(requestStart)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if resp.Done {
|
||||||
|
responseMetrics = &resp.Metrics
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if *fOpt.debug {
|
||||||
|
fmt.Fprintln(os.Stderr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() == context.DeadlineExceeded {
|
||||||
|
fmt.Fprintf(os.Stderr, "ERROR: Request timed out with model '%s' after %vs\n", model, *fOpt.timeout)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "ERROR: Couldn't generate with model '%s': %v\n", model, err)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if responseMetrics == nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the response was shorter than requested
|
||||||
|
short = *fOpt.maxTokens > 0 && responseMetrics.EvalCount < *fOpt.maxTokens
|
||||||
|
if !short || attempt == maxRetries {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if *fOpt.debug {
|
||||||
|
fmt.Fprintf(os.Stderr, "Short response (%d/%d tokens), retrying with different prompt (attempt %d/%d)\n",
|
||||||
|
responseMetrics.EvalCount, *fOpt.maxTokens, attempt+1, maxRetries)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil || responseMetrics == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if responseMetrics == nil {
|
if short {
|
||||||
fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model)
|
shortCount++
|
||||||
continue
|
if *fOpt.debug {
|
||||||
|
fmt.Fprintf(os.Stderr, "WARNING: Short response (%d/%d tokens) after %d retries for epoch %d\n",
|
||||||
|
responseMetrics.EvalCount, *fOpt.maxTokens, maxRetries, epoch+1)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
metrics := []Metrics{
|
metrics := []Metrics{
|
||||||
@@ -238,6 +441,12 @@ func BenchmarkChat(fOpt flagOptions) error {
|
|||||||
Count: responseMetrics.EvalCount,
|
Count: responseMetrics.EvalCount,
|
||||||
Duration: responseMetrics.EvalDuration,
|
Duration: responseMetrics.EvalDuration,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Model: model,
|
||||||
|
Step: "ttft",
|
||||||
|
Count: 1,
|
||||||
|
Duration: ttft,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Model: model,
|
Model: model,
|
||||||
Step: "load",
|
Step: "load",
|
||||||
@@ -254,15 +463,42 @@ func BenchmarkChat(fOpt flagOptions) error {
|
|||||||
|
|
||||||
OutputMetrics(out, *fOpt.format, metrics, *fOpt.verbose)
|
OutputMetrics(out, *fOpt.format, metrics, *fOpt.verbose)
|
||||||
|
|
||||||
|
if *fOpt.debug && *fOpt.promptTokens > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "Generated prompt targeting ~%d tokens (actual: %d)\n",
|
||||||
|
*fOpt.promptTokens, responseMetrics.PromptEvalCount)
|
||||||
|
}
|
||||||
|
|
||||||
if *fOpt.keepAlive > 0 {
|
if *fOpt.keepAlive > 0 {
|
||||||
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
|
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if shortCount > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "WARNING: %d/%d epochs for '%s' had short responses (<%d tokens). Generation metrics may be unreliable.\n",
|
||||||
|
shortCount, *fOpt.epochs, model, *fOpt.maxTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unload model before moving to the next one
|
||||||
|
unloadModel(client, model, *fOpt.timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func unloadModel(client *api.Client, model string, timeout int) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
zero := api.Duration{Duration: 0}
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: model,
|
||||||
|
KeepAlive: &zero,
|
||||||
|
}
|
||||||
|
_ = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func readImage(filePath string) (api.ImageData, error) {
|
func readImage(filePath string) (api.ImageData, error) {
|
||||||
file, err := os.Open(filePath)
|
file, err := os.Open(filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -280,19 +516,22 @@ func readImage(filePath string) (api.ImageData, error) {
|
|||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
fOpt := flagOptions{
|
fOpt := flagOptions{
|
||||||
models: flag.String("model", "", "Model to benchmark"),
|
models: flag.String("model", "", "Model to benchmark"),
|
||||||
epochs: flag.Int("epochs", 6, "Number of epochs (iterations) per model"),
|
epochs: flag.Int("epochs", 6, "Number of epochs (iterations) per model"),
|
||||||
maxTokens: flag.Int("max-tokens", 200, "Maximum tokens for model response"),
|
maxTokens: flag.Int("max-tokens", 200, "Maximum tokens for model response"),
|
||||||
temperature: flag.Float64("temperature", 0, "Temperature parameter"),
|
temperature: flag.Float64("temperature", 0, "Temperature parameter"),
|
||||||
seed: flag.Int("seed", 0, "Random seed"),
|
seed: flag.Int("seed", 0, "Random seed"),
|
||||||
timeout: flag.Int("timeout", 60*5, "Timeout in seconds (default 300s)"),
|
timeout: flag.Int("timeout", 60*5, "Timeout in seconds (default 300s)"),
|
||||||
prompt: flag.String("p", DefaultPrompt, "Prompt to use"),
|
prompt: flag.String("p", DefaultPrompt, "Prompt to use"),
|
||||||
imageFile: flag.String("image", "", "Filename for an image to include"),
|
imageFile: flag.String("image", "", "Filename for an image to include"),
|
||||||
keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"),
|
keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"),
|
||||||
format: flag.String("format", "markdown", "Output format [benchstat|csv] (default benchstat)"),
|
format: flag.String("format", "benchstat", "Output format [benchstat|csv]"),
|
||||||
outputFile: flag.String("output", "", "Output file for results (stdout if empty)"),
|
outputFile: flag.String("output", "", "Output file for results (stdout if empty)"),
|
||||||
verbose: flag.Bool("v", false, "Show system information"),
|
verbose: flag.Bool("v", false, "Show system information"),
|
||||||
debug: flag.Bool("debug", false, "Show debug information"),
|
debug: flag.Bool("debug", false, "Show debug information"),
|
||||||
|
warmup: flag.Int("warmup", 1, "Number of warmup requests before timing"),
|
||||||
|
promptTokens: flag.Int("prompt-tokens", 0, "Generate prompt targeting ~N tokens (0 = use -p prompt)"),
|
||||||
|
numCtx: flag.Int("num-ctx", 0, "Context size (0 = server default)"),
|
||||||
}
|
}
|
||||||
|
|
||||||
flag.Usage = func() {
|
flag.Usage = func() {
|
||||||
@@ -302,11 +541,12 @@ func main() {
|
|||||||
fmt.Fprintf(os.Stderr, "Options:\n")
|
fmt.Fprintf(os.Stderr, "Options:\n")
|
||||||
flag.PrintDefaults()
|
flag.PrintDefaults()
|
||||||
fmt.Fprintf(os.Stderr, "\nExamples:\n")
|
fmt.Fprintf(os.Stderr, "\nExamples:\n")
|
||||||
fmt.Fprintf(os.Stderr, " bench -model gpt-oss:20b -epochs 3 -temperature 0.7\n")
|
fmt.Fprintf(os.Stderr, " bench -model gemma3,llama3 -epochs 6\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " bench -model gemma3 -epochs 6 -prompt-tokens 512 -format csv\n")
|
||||||
}
|
}
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
if !slices.Contains([]string{"markdown", "benchstat", "csv"}, *fOpt.format) {
|
if !slices.Contains([]string{"benchstat", "csv"}, *fOpt.format) {
|
||||||
fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format)
|
fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
@@ -317,5 +557,5 @@ func main() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
BenchmarkChat(fOpt)
|
BenchmarkModel(fOpt)
|
||||||
}
|
}
|
||||||
|
|||||||
506
cmd/cmd.go
@@ -11,6 +11,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -38,9 +39,12 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/cmd/config"
|
"github.com/ollama/ollama/cmd/config"
|
||||||
|
"github.com/ollama/ollama/cmd/launch"
|
||||||
"github.com/ollama/ollama/cmd/tui"
|
"github.com/ollama/ollama/cmd/tui"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
|
"github.com/ollama/ollama/internal/modelref"
|
||||||
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/progress"
|
"github.com/ollama/ollama/progress"
|
||||||
"github.com/ollama/ollama/readline"
|
"github.com/ollama/ollama/readline"
|
||||||
@@ -50,46 +54,45 @@ import (
|
|||||||
"github.com/ollama/ollama/types/syncmap"
|
"github.com/ollama/ollama/types/syncmap"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
xcmd "github.com/ollama/ollama/x/cmd"
|
xcmd "github.com/ollama/ollama/x/cmd"
|
||||||
"github.com/ollama/ollama/x/create"
|
|
||||||
xcreateclient "github.com/ollama/ollama/x/create/client"
|
xcreateclient "github.com/ollama/ollama/x/create/client"
|
||||||
"github.com/ollama/ollama/x/imagegen"
|
"github.com/ollama/ollama/x/imagegen"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
|
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
|
||||||
config.DefaultSingleSelector = func(title string, items []config.ModelItem, current string) (string, error) {
|
launch.DefaultSingleSelector = func(title string, items []launch.ModelItem, current string) (string, error) {
|
||||||
|
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
|
||||||
|
return "", fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode")
|
||||||
|
}
|
||||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||||
result, err := tui.SelectSingle(title, tuiItems, current)
|
result, err := tui.SelectSingle(title, tuiItems, current)
|
||||||
if errors.Is(err, tui.ErrCancelled) {
|
if errors.Is(err, tui.ErrCancelled) {
|
||||||
return "", config.ErrCancelled
|
return "", launch.ErrCancelled
|
||||||
}
|
}
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
config.DefaultMultiSelector = func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
|
launch.DefaultMultiSelector = func(title string, items []launch.ModelItem, preChecked []string) ([]string, error) {
|
||||||
|
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
|
||||||
|
return nil, fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode")
|
||||||
|
}
|
||||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||||
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
|
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
|
||||||
if errors.Is(err, tui.ErrCancelled) {
|
if errors.Is(err, tui.ErrCancelled) {
|
||||||
return nil, config.ErrCancelled
|
return nil, launch.ErrCancelled
|
||||||
}
|
}
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
config.DefaultSignIn = func(modelName, signInURL string) (string, error) {
|
launch.DefaultSignIn = func(modelName, signInURL string) (string, error) {
|
||||||
userName, err := tui.RunSignIn(modelName, signInURL)
|
userName, err := tui.RunSignIn(modelName, signInURL)
|
||||||
if errors.Is(err, tui.ErrCancelled) {
|
if errors.Is(err, tui.ErrCancelled) {
|
||||||
return "", config.ErrCancelled
|
return "", launch.ErrCancelled
|
||||||
}
|
}
|
||||||
return userName, err
|
return userName, err
|
||||||
}
|
}
|
||||||
|
|
||||||
config.DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
launch.DefaultConfirmPrompt = tui.RunConfirmWithOptions
|
||||||
ok, err := tui.RunConfirm(prompt)
|
|
||||||
if errors.Is(err, tui.ErrCancelled) {
|
|
||||||
return false, config.ErrCancelled
|
|
||||||
}
|
|
||||||
return ok, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const ConnectInstructions = "If your browser did not open, navigate to:\n %s\n\n"
|
const ConnectInstructions = "If your browser did not open, navigate to:\n %s\n\n"
|
||||||
@@ -131,6 +134,17 @@ func getModelfileName(cmd *cobra.Command) (string, error) {
|
|||||||
return absName, nil
|
return absName, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isLocalhost returns true if the configured Ollama host is a loopback or unspecified address.
|
||||||
|
func isLocalhost() bool {
|
||||||
|
host := envconfig.Host()
|
||||||
|
h, _, _ := net.SplitHostPort(host.Host)
|
||||||
|
if h == "localhost" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
ip := net.ParseIP(h)
|
||||||
|
return ip != nil && (ip.IsLoopback() || ip.IsUnspecified())
|
||||||
|
}
|
||||||
|
|
||||||
func CreateHandler(cmd *cobra.Command, args []string) error {
|
func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||||
p := progress.NewProgress(os.Stderr)
|
p := progress.NewProgress(os.Stderr)
|
||||||
defer p.Stop()
|
defer p.Stop()
|
||||||
@@ -143,8 +157,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check for --experimental flag for safetensors model creation
|
// Check for --experimental flag for safetensors model creation
|
||||||
|
// This gates both safetensors LLM and imagegen model creation
|
||||||
experimental, _ := cmd.Flags().GetBool("experimental")
|
experimental, _ := cmd.Flags().GetBool("experimental")
|
||||||
if experimental {
|
if experimental {
|
||||||
|
if !isLocalhost() {
|
||||||
|
return errors.New("remote safetensor model creation not yet supported")
|
||||||
|
}
|
||||||
|
|
||||||
// Get Modelfile content - either from -f flag or default to "FROM ."
|
// Get Modelfile content - either from -f flag or default to "FROM ."
|
||||||
var reader io.Reader
|
var reader io.Reader
|
||||||
filename, err := getModelfileName(cmd)
|
filename, err := getModelfileName(cmd)
|
||||||
@@ -168,29 +187,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return fmt.Errorf("failed to parse Modelfile: %w", err)
|
return fmt.Errorf("failed to parse Modelfile: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract FROM path and configuration
|
modelDir, mfConfig, err := xcreateclient.ConfigFromModelfile(modelfile)
|
||||||
var modelDir string
|
if err != nil {
|
||||||
mfConfig := &xcreateclient.ModelfileConfig{}
|
return err
|
||||||
|
|
||||||
for _, cmd := range modelfile.Commands {
|
|
||||||
switch cmd.Name {
|
|
||||||
case "model":
|
|
||||||
modelDir = cmd.Args
|
|
||||||
case "template":
|
|
||||||
mfConfig.Template = cmd.Args
|
|
||||||
case "system":
|
|
||||||
mfConfig.System = cmd.Args
|
|
||||||
case "license":
|
|
||||||
mfConfig.License = cmd.Args
|
|
||||||
case "parser":
|
|
||||||
mfConfig.Parser = cmd.Args
|
|
||||||
case "renderer":
|
|
||||||
mfConfig.Renderer = cmd.Args
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if modelDir == "" {
|
|
||||||
modelDir = "."
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve relative paths based on Modelfile location
|
// Resolve relative paths based on Modelfile location
|
||||||
@@ -207,20 +206,12 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
}, p)
|
}, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Standard Modelfile + API path
|
||||||
var reader io.Reader
|
var reader io.Reader
|
||||||
|
|
||||||
filename, err := getModelfileName(cmd)
|
filename, err := getModelfileName(cmd)
|
||||||
if os.IsNotExist(err) {
|
if os.IsNotExist(err) {
|
||||||
if filename == "" {
|
if filename == "" {
|
||||||
// No Modelfile found - check if current directory is an image gen model
|
|
||||||
if create.IsTensorModelDir(".") {
|
|
||||||
quantize, _ := cmd.Flags().GetString("quantize")
|
|
||||||
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
|
||||||
ModelName: modelName,
|
|
||||||
ModelDir: ".",
|
|
||||||
Quantize: quantize,
|
|
||||||
}, p)
|
|
||||||
}
|
|
||||||
reader = strings.NewReader("FROM .\n")
|
reader = strings.NewReader("FROM .\n")
|
||||||
} else {
|
} else {
|
||||||
return errModelfileNotFound
|
return errModelfileNotFound
|
||||||
@@ -406,12 +397,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
requestedCloud := modelref.HasExplicitCloudSource(opts.Model)
|
||||||
|
|
||||||
if info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model}); err != nil {
|
if info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model}); err != nil {
|
||||||
return err
|
return err
|
||||||
} else if info.RemoteHost != "" {
|
} else if info.RemoteHost != "" || requestedCloud {
|
||||||
// Cloud model, no need to load/unload
|
// Cloud model, no need to load/unload
|
||||||
|
|
||||||
isCloud := strings.HasPrefix(info.RemoteHost, "https://ollama.com")
|
isCloud := requestedCloud || strings.HasPrefix(info.RemoteHost, "https://ollama.com")
|
||||||
|
|
||||||
// Check if user is signed in for ollama.com cloud models
|
// Check if user is signed in for ollama.com cloud models
|
||||||
if isCloud {
|
if isCloud {
|
||||||
@@ -422,10 +415,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
|||||||
|
|
||||||
if opts.ShowConnect {
|
if opts.ShowConnect {
|
||||||
p.StopAndClear()
|
p.StopAndClear()
|
||||||
|
remoteModel := info.RemoteModel
|
||||||
|
if remoteModel == "" {
|
||||||
|
remoteModel = opts.Model
|
||||||
|
}
|
||||||
if isCloud {
|
if isCloud {
|
||||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
|
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", remoteModel)
|
||||||
} else {
|
} else {
|
||||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
|
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", remoteModel, info.RemoteHost)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -497,6 +494,64 @@ func generateEmbedding(cmd *cobra.Command, modelName, input string, keepAlive *a
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(parthsareen): consolidate with TUI signin flow
|
||||||
|
func handleCloudAuthorizationError(err error) bool {
|
||||||
|
var authErr api.AuthorizationError
|
||||||
|
if errors.As(err, &authErr) && authErr.StatusCode == http.StatusUnauthorized {
|
||||||
|
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
|
||||||
|
if authErr.SigninURL != "" {
|
||||||
|
fmt.Printf(ConnectInstructions, authErr.SigninURL)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TEMP(drifkin): To match legacy `ollama run some-model:cloud` behavior, we
|
||||||
|
// best-effort pull cloud stub files for any explicit cloud source models.
|
||||||
|
// Remove this once `/api/tags` is cloud-aware.
|
||||||
|
func ensureCloudStub(ctx context.Context, client *api.Client, modelName string) {
|
||||||
|
if !modelref.HasExplicitCloudSource(modelName) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizedName, _, err := modelref.NormalizePullName(modelName)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to normalize pull name", "model", modelName, "error", err, "normalizedName", normalizedName)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
listResp, err := client.List(ctx)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to list models", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasListedModelName(listResp.Models, modelName) || hasListedModelName(listResp.Models, normalizedName) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logutil.Trace("pulling cloud stub", "model", modelName, "normalizedName", normalizedName)
|
||||||
|
err = client.Pull(ctx, &api.PullRequest{
|
||||||
|
Model: normalizedName,
|
||||||
|
}, func(api.ProgressResponse) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to pull cloud stub", "model", modelName, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasListedModelName(models []api.ListModelResponse, name string) bool {
|
||||||
|
for _, m := range models {
|
||||||
|
if strings.EqualFold(m.Name, name) || strings.EqualFold(m.Model, name) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||||
interactive := true
|
interactive := true
|
||||||
|
|
||||||
@@ -585,17 +640,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
opts.WordWrap = !nowrap
|
opts.WordWrap = !nowrap
|
||||||
|
|
||||||
useImagegen := false
|
|
||||||
if cmd.Flags().Lookup("imagegen") != nil {
|
|
||||||
useImagegen, err = cmd.Flags().GetBool("imagegen")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if useImagegen {
|
|
||||||
opts.Options["use_imagegen_runner"] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fill out the rest of the options based on information about the
|
// Fill out the rest of the options based on information about the
|
||||||
// model.
|
// model.
|
||||||
client, err := api.ClientFromEnvironment()
|
client, err := api.ClientFromEnvironment()
|
||||||
@@ -604,12 +648,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
name := args[0]
|
name := args[0]
|
||||||
|
requestedCloud := modelref.HasExplicitCloudSource(name)
|
||||||
|
|
||||||
info, err := func() (*api.ShowResponse, error) {
|
info, err := func() (*api.ShowResponse, error) {
|
||||||
showReq := &api.ShowRequest{Name: name}
|
showReq := &api.ShowRequest{Name: name}
|
||||||
info, err := client.Show(cmd.Context(), showReq)
|
info, err := client.Show(cmd.Context(), showReq)
|
||||||
var se api.StatusError
|
var se api.StatusError
|
||||||
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
|
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
|
||||||
|
if requestedCloud {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if err := PullHandler(cmd, []string{name}); err != nil {
|
if err := PullHandler(cmd, []string{name}); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -618,15 +666,21 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return info, err
|
return info, err
|
||||||
}()
|
}()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if handleCloudAuthorizationError(err) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ensureCloudStub(cmd.Context(), client, name)
|
||||||
|
|
||||||
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkFlag.Changed)
|
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkFlag.Changed)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
|
audioCapable := slices.Contains(info.Capabilities, model.CapabilityAudio)
|
||||||
|
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision) || audioCapable
|
||||||
|
|
||||||
// TODO: remove the projector info and vision info checks below,
|
// TODO: remove the projector info and vision info checks below,
|
||||||
// these are left in for backwards compatibility with older servers
|
// these are left in for backwards compatibility with older servers
|
||||||
@@ -641,7 +695,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
opts.ParentModel = info.Details.ParentModel
|
applyShowResponseToRunOptions(&opts, info)
|
||||||
|
|
||||||
// Check if this is an embedding model
|
// Check if this is an embedding model
|
||||||
isEmbeddingModel := slices.Contains(info.Capabilities, model.CapabilityEmbedding)
|
isEmbeddingModel := slices.Contains(info.Capabilities, model.CapabilityEmbedding)
|
||||||
@@ -712,7 +766,13 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
return generateInteractive(cmd, opts)
|
return generateInteractive(cmd, opts)
|
||||||
}
|
}
|
||||||
return generate(cmd, opts)
|
if err := generate(cmd, opts); err != nil {
|
||||||
|
if handleCloudAuthorizationError(err) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SigninHandler(cmd *cobra.Command, args []string) error {
|
func SigninHandler(cmd *cobra.Command, args []string) error {
|
||||||
@@ -1351,23 +1411,30 @@ func PullHandler(cmd *cobra.Command, args []string) error {
|
|||||||
type generateContextKey string
|
type generateContextKey string
|
||||||
|
|
||||||
type runOptions struct {
|
type runOptions struct {
|
||||||
Model string
|
Model string
|
||||||
ParentModel string
|
ParentModel string
|
||||||
Prompt string
|
LoadedMessages []api.Message
|
||||||
Messages []api.Message
|
Prompt string
|
||||||
WordWrap bool
|
Messages []api.Message
|
||||||
Format string
|
WordWrap bool
|
||||||
System string
|
Format string
|
||||||
Images []api.ImageData
|
System string
|
||||||
Options map[string]any
|
Images []api.ImageData
|
||||||
MultiModal bool
|
Options map[string]any
|
||||||
KeepAlive *api.Duration
|
MultiModal bool
|
||||||
Think *api.ThinkValue
|
KeepAlive *api.Duration
|
||||||
HideThinking bool
|
Think *api.ThinkValue
|
||||||
ShowConnect bool
|
HideThinking bool
|
||||||
|
ShowConnect bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r runOptions) Copy() runOptions {
|
func (r runOptions) Copy() runOptions {
|
||||||
|
var loadedMessages []api.Message
|
||||||
|
if r.LoadedMessages != nil {
|
||||||
|
loadedMessages = make([]api.Message, len(r.LoadedMessages))
|
||||||
|
copy(loadedMessages, r.LoadedMessages)
|
||||||
|
}
|
||||||
|
|
||||||
var messages []api.Message
|
var messages []api.Message
|
||||||
if r.Messages != nil {
|
if r.Messages != nil {
|
||||||
messages = make([]api.Message, len(r.Messages))
|
messages = make([]api.Message, len(r.Messages))
|
||||||
@@ -1395,23 +1462,29 @@ func (r runOptions) Copy() runOptions {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return runOptions{
|
return runOptions{
|
||||||
Model: r.Model,
|
Model: r.Model,
|
||||||
ParentModel: r.ParentModel,
|
ParentModel: r.ParentModel,
|
||||||
Prompt: r.Prompt,
|
LoadedMessages: loadedMessages,
|
||||||
Messages: messages,
|
Prompt: r.Prompt,
|
||||||
WordWrap: r.WordWrap,
|
Messages: messages,
|
||||||
Format: r.Format,
|
WordWrap: r.WordWrap,
|
||||||
System: r.System,
|
Format: r.Format,
|
||||||
Images: images,
|
System: r.System,
|
||||||
Options: opts,
|
Images: images,
|
||||||
MultiModal: r.MultiModal,
|
Options: opts,
|
||||||
KeepAlive: r.KeepAlive,
|
MultiModal: r.MultiModal,
|
||||||
Think: think,
|
KeepAlive: r.KeepAlive,
|
||||||
HideThinking: r.HideThinking,
|
Think: think,
|
||||||
ShowConnect: r.ShowConnect,
|
HideThinking: r.HideThinking,
|
||||||
|
ShowConnect: r.ShowConnect,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func applyShowResponseToRunOptions(opts *runOptions, info *api.ShowResponse) {
|
||||||
|
opts.ParentModel = info.Details.ParentModel
|
||||||
|
opts.LoadedMessages = slices.Clone(info.Messages)
|
||||||
|
}
|
||||||
|
|
||||||
type displayResponseState struct {
|
type displayResponseState struct {
|
||||||
lineLength int
|
lineLength int
|
||||||
wordBuffer string
|
wordBuffer string
|
||||||
@@ -1419,6 +1492,9 @@ type displayResponseState struct {
|
|||||||
|
|
||||||
func displayResponse(content string, wordWrap bool, state *displayResponseState) {
|
func displayResponse(content string, wordWrap bool, state *displayResponseState) {
|
||||||
termWidth, _, _ := term.GetSize(int(os.Stdout.Fd()))
|
termWidth, _, _ := term.GetSize(int(os.Stdout.Fd()))
|
||||||
|
if termWidth == 0 {
|
||||||
|
termWidth = 80
|
||||||
|
}
|
||||||
if wordWrap && termWidth >= 10 {
|
if wordWrap && termWidth >= 10 {
|
||||||
for _, ch := range content {
|
for _, ch := range content {
|
||||||
if state.lineLength+1 > termWidth-5 {
|
if state.lineLength+1 > termWidth-5 {
|
||||||
@@ -1892,6 +1968,24 @@ func ensureServerRunning(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func launchInteractiveModel(cmd *cobra.Command, modelName string) error {
|
||||||
|
opts := runOptions{
|
||||||
|
Model: modelName,
|
||||||
|
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||||
|
Options: map[string]any{},
|
||||||
|
ShowConnect: true,
|
||||||
|
}
|
||||||
|
// loadOrUnloadModel is cloud-safe here: remote/cloud models skip local preload
|
||||||
|
// and only validate auth/connectivity before interactive chat starts.
|
||||||
|
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||||
|
return fmt.Errorf("error loading model: %w", err)
|
||||||
|
}
|
||||||
|
if err := generateInteractive(cmd, opts); err != nil {
|
||||||
|
return fmt.Errorf("error running model: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// runInteractiveTUI runs the main interactive TUI menu.
|
// runInteractiveTUI runs the main interactive TUI menu.
|
||||||
func runInteractiveTUI(cmd *cobra.Command) {
|
func runInteractiveTUI(cmd *cobra.Command) {
|
||||||
// Ensure the server is running before showing the TUI
|
// Ensure the server is running before showing the TUI
|
||||||
@@ -1900,175 +1994,85 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Selector adapters for tui
|
deps := launcherDeps{
|
||||||
singleSelector := func(title string, items []config.ModelItem, current string) (string, error) {
|
buildState: launch.BuildLauncherState,
|
||||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
runMenu: tui.RunMenu,
|
||||||
result, err := tui.SelectSingle(title, tuiItems, current)
|
resolveRunModel: launch.ResolveRunModel,
|
||||||
if errors.Is(err, tui.ErrCancelled) {
|
launchIntegration: launch.LaunchIntegration,
|
||||||
return "", config.ErrCancelled
|
runModel: launchInteractiveModel,
|
||||||
}
|
|
||||||
return result, err
|
|
||||||
}
|
|
||||||
|
|
||||||
multiSelector := func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
|
|
||||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
|
||||||
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
|
|
||||||
if errors.Is(err, tui.ErrCancelled) {
|
|
||||||
return nil, config.ErrCancelled
|
|
||||||
}
|
|
||||||
return result, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
result, err := tui.Run()
|
continueLoop, err := runInteractiveTUIStep(cmd, deps)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||||
|
}
|
||||||
|
if !continueLoop {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
runModel := func(modelName string) {
|
type launcherDeps struct {
|
||||||
client, err := api.ClientFromEnvironment()
|
buildState func(context.Context) (*launch.LauncherState, error)
|
||||||
if err != nil {
|
runMenu func(*launch.LauncherState) (tui.TUIAction, error)
|
||||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
resolveRunModel func(context.Context, launch.RunModelRequest) (string, error)
|
||||||
return
|
launchIntegration func(context.Context, launch.IntegrationLaunchRequest) error
|
||||||
}
|
runModel func(*cobra.Command, string) error
|
||||||
if err := config.ShowOrPull(cmd.Context(), client, modelName); err != nil {
|
}
|
||||||
if errors.Is(err, config.ErrCancelled) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_ = config.SetLastModel(modelName)
|
|
||||||
opts := runOptions{
|
|
||||||
Model: modelName,
|
|
||||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
|
||||||
Options: map[string]any{},
|
|
||||||
ShowConnect: true,
|
|
||||||
}
|
|
||||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error loading model: %v\n", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := generateInteractive(cmd, opts); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error running model: %v\n", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
launchIntegration := func(name string) bool {
|
func runInteractiveTUIStep(cmd *cobra.Command, deps launcherDeps) (bool, error) {
|
||||||
if err := config.EnsureInstalled(name); err != nil {
|
state, err := deps.buildState(cmd.Context())
|
||||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
if err != nil {
|
||||||
return true
|
return false, fmt.Errorf("build launcher state: %w", err)
|
||||||
}
|
}
|
||||||
// If not configured or model no longer exists, prompt for model selection
|
|
||||||
configuredModel := config.IntegrationModel(name)
|
|
||||||
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) || config.IsCloudModelDisabled(cmd.Context(), configuredModel) {
|
|
||||||
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), name, singleSelector, multiSelector)
|
|
||||||
if errors.Is(err, config.ErrCancelled) {
|
|
||||||
return false // Return to main menu
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", name, err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := config.LaunchIntegration(name); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", name, err)
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
switch result.Selection {
|
action, err := deps.runMenu(state)
|
||||||
case tui.SelectionNone:
|
if err != nil {
|
||||||
// User quit
|
return false, fmt.Errorf("run launcher menu: %w", err)
|
||||||
return
|
}
|
||||||
case tui.SelectionRunModel:
|
|
||||||
_ = config.SetLastSelection("run")
|
return runLauncherAction(cmd, action, deps)
|
||||||
if modelName := config.LastModel(); modelName != "" && !config.IsCloudModelDisabled(cmd.Context(), modelName) {
|
}
|
||||||
runModel(modelName)
|
|
||||||
} else {
|
func saveLauncherSelection(action tui.TUIAction) {
|
||||||
modelName, err := config.SelectModelWithSelector(cmd.Context(), singleSelector)
|
// Best effort only: this affects menu recall, not launch correctness.
|
||||||
if errors.Is(err, config.ErrCancelled) {
|
_ = config.SetLastSelection(action.LastSelection())
|
||||||
continue // Return to main menu
|
}
|
||||||
}
|
|
||||||
if err != nil {
|
func runLauncherAction(cmd *cobra.Command, action tui.TUIAction, deps launcherDeps) (bool, error) {
|
||||||
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
|
switch action.Kind {
|
||||||
continue
|
case tui.TUIActionNone:
|
||||||
}
|
return false, nil
|
||||||
runModel(modelName)
|
case tui.TUIActionRunModel:
|
||||||
}
|
saveLauncherSelection(action)
|
||||||
case tui.SelectionChangeRunModel:
|
modelName, err := deps.resolveRunModel(cmd.Context(), action.RunModelRequest())
|
||||||
_ = config.SetLastSelection("run")
|
if errors.Is(err, launch.ErrCancelled) {
|
||||||
// Use model from modal if selected, otherwise show picker
|
return true, nil
|
||||||
modelName := result.Model
|
|
||||||
if modelName == "" {
|
|
||||||
var err error
|
|
||||||
modelName, err = config.SelectModelWithSelector(cmd.Context(), singleSelector)
|
|
||||||
if errors.Is(err, config.ErrCancelled) {
|
|
||||||
continue // Return to main menu
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if config.IsCloudModelDisabled(cmd.Context(), modelName) {
|
|
||||||
continue // Return to main menu
|
|
||||||
}
|
|
||||||
runModel(modelName)
|
|
||||||
case tui.SelectionIntegration:
|
|
||||||
_ = config.SetLastSelection(result.Integration)
|
|
||||||
if !launchIntegration(result.Integration) {
|
|
||||||
continue // Return to main menu
|
|
||||||
}
|
|
||||||
case tui.SelectionChangeIntegration:
|
|
||||||
_ = config.SetLastSelection(result.Integration)
|
|
||||||
if len(result.Models) > 0 {
|
|
||||||
// Filter out cloud-disabled models
|
|
||||||
var filtered []string
|
|
||||||
for _, m := range result.Models {
|
|
||||||
if !config.IsCloudModelDisabled(cmd.Context(), m) {
|
|
||||||
filtered = append(filtered, m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(filtered) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
result.Models = filtered
|
|
||||||
// Multi-select from modal (Editor integrations)
|
|
||||||
if err := config.SaveAndEditIntegration(result.Integration, result.Models); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := config.LaunchIntegrationWithModel(result.Integration, result.Models[0]); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
|
||||||
}
|
|
||||||
} else if result.Model != "" {
|
|
||||||
if config.IsCloudModelDisabled(cmd.Context(), result.Model) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Single-select from modal - save and launch
|
|
||||||
if err := config.SaveIntegration(result.Integration, []string{result.Model}); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error saving config: %v\n", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := config.LaunchIntegrationWithModel(result.Integration, result.Model); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), result.Integration, singleSelector, multiSelector)
|
|
||||||
if errors.Is(err, config.ErrCancelled) {
|
|
||||||
continue // Return to main menu
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := config.LaunchIntegration(result.Integration); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
|
return true, fmt.Errorf("selecting model: %w", err)
|
||||||
|
}
|
||||||
|
if err := deps.runModel(cmd, modelName); err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
case tui.TUIActionLaunchIntegration:
|
||||||
|
saveLauncherSelection(action)
|
||||||
|
err := deps.launchIntegration(cmd.Context(), action.IntegrationLaunchRequest())
|
||||||
|
if errors.Is(err, launch.ErrCancelled) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return true, fmt.Errorf("launching %s: %w", action.Integration, err)
|
||||||
|
}
|
||||||
|
// VS Code is a GUI app — exit the TUI loop after launching
|
||||||
|
if action.Integration == "vscode" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
default:
|
||||||
|
return false, fmt.Errorf("unknown launcher action: %d", action.Kind)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2338,7 +2342,7 @@ func NewCLI() *cobra.Command {
|
|||||||
copyCmd,
|
copyCmd,
|
||||||
deleteCmd,
|
deleteCmd,
|
||||||
runnerCmd,
|
runnerCmd,
|
||||||
config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
|
launch.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
|
||||||
)
|
)
|
||||||
|
|
||||||
return rootCmd
|
return rootCmd
|
||||||
|
|||||||
270
cmd/cmd_launcher_test.go
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/cmd/config"
|
||||||
|
"github.com/ollama/ollama/cmd/launch"
|
||||||
|
"github.com/ollama/ollama/cmd/tui"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setCmdTestHome(t *testing.T, dir string) {
|
||||||
|
t.Helper()
|
||||||
|
t.Setenv("HOME", dir)
|
||||||
|
t.Setenv("USERPROFILE", dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
func unexpectedRunModelResolution(t *testing.T) func(context.Context, launch.RunModelRequest) (string, error) {
|
||||||
|
t.Helper()
|
||||||
|
return func(ctx context.Context, req launch.RunModelRequest) (string, error) {
|
||||||
|
t.Fatalf("did not expect run-model resolution: %+v", req)
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func unexpectedIntegrationLaunch(t *testing.T) func(context.Context, launch.IntegrationLaunchRequest) error {
|
||||||
|
t.Helper()
|
||||||
|
return func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||||
|
t.Fatalf("did not expect integration launch: %+v", req)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func unexpectedModelLaunch(t *testing.T) func(*cobra.Command, string) error {
|
||||||
|
t.Helper()
|
||||||
|
return func(cmd *cobra.Command, model string) error {
|
||||||
|
t.Fatalf("did not expect chat launch: %s", model)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunInteractiveTUI_RunModelActionsUseResolveRunModel(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
action tui.TUIAction
|
||||||
|
wantForce bool
|
||||||
|
wantModel string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "enter uses saved model flow",
|
||||||
|
action: tui.TUIAction{Kind: tui.TUIActionRunModel},
|
||||||
|
wantModel: "qwen3:8b",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "right forces picker",
|
||||||
|
action: tui.TUIAction{Kind: tui.TUIActionRunModel, ForceConfigure: true},
|
||||||
|
wantForce: true,
|
||||||
|
wantModel: "glm-5:cloud",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
setCmdTestHome(t, t.TempDir())
|
||||||
|
|
||||||
|
var menuCalls int
|
||||||
|
runMenu := func(state *launch.LauncherState) (tui.TUIAction, error) {
|
||||||
|
menuCalls++
|
||||||
|
if menuCalls == 1 {
|
||||||
|
return tt.action, nil
|
||||||
|
}
|
||||||
|
return tui.TUIAction{Kind: tui.TUIActionNone}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var gotReq launch.RunModelRequest
|
||||||
|
var launched string
|
||||||
|
deps := launcherDeps{
|
||||||
|
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
|
||||||
|
return &launch.LauncherState{}, nil
|
||||||
|
},
|
||||||
|
runMenu: runMenu,
|
||||||
|
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
|
||||||
|
gotReq = req
|
||||||
|
return tt.wantModel, nil
|
||||||
|
},
|
||||||
|
launchIntegration: unexpectedIntegrationLaunch(t),
|
||||||
|
runModel: func(cmd *cobra.Command, model string) error {
|
||||||
|
launched = model
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(context.Background())
|
||||||
|
for {
|
||||||
|
continueLoop, err := runInteractiveTUIStep(cmd, deps)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected step error: %v", err)
|
||||||
|
}
|
||||||
|
if !continueLoop {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotReq.ForcePicker != tt.wantForce {
|
||||||
|
t.Fatalf("expected ForcePicker=%v, got %v", tt.wantForce, gotReq.ForcePicker)
|
||||||
|
}
|
||||||
|
if launched != tt.wantModel {
|
||||||
|
t.Fatalf("expected interactive launcher to run %q, got %q", tt.wantModel, launched)
|
||||||
|
}
|
||||||
|
if got := config.LastSelection(); got != "run" {
|
||||||
|
t.Fatalf("expected last selection to be run, got %q", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunInteractiveTUI_IntegrationActionsUseLaunchIntegration(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
action tui.TUIAction
|
||||||
|
wantForce bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "enter launches integration",
|
||||||
|
action: tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "right forces configure",
|
||||||
|
action: tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude", ForceConfigure: true},
|
||||||
|
wantForce: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
setCmdTestHome(t, t.TempDir())
|
||||||
|
|
||||||
|
var menuCalls int
|
||||||
|
runMenu := func(state *launch.LauncherState) (tui.TUIAction, error) {
|
||||||
|
menuCalls++
|
||||||
|
if menuCalls == 1 {
|
||||||
|
return tt.action, nil
|
||||||
|
}
|
||||||
|
return tui.TUIAction{Kind: tui.TUIActionNone}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var gotReq launch.IntegrationLaunchRequest
|
||||||
|
deps := launcherDeps{
|
||||||
|
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
|
||||||
|
return &launch.LauncherState{}, nil
|
||||||
|
},
|
||||||
|
runMenu: runMenu,
|
||||||
|
resolveRunModel: unexpectedRunModelResolution(t),
|
||||||
|
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||||
|
gotReq = req
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
runModel: unexpectedModelLaunch(t),
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(context.Background())
|
||||||
|
for {
|
||||||
|
continueLoop, err := runInteractiveTUIStep(cmd, deps)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected step error: %v", err)
|
||||||
|
}
|
||||||
|
if !continueLoop {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotReq.Name != "claude" {
|
||||||
|
t.Fatalf("expected integration name to be passed through, got %q", gotReq.Name)
|
||||||
|
}
|
||||||
|
if gotReq.ForceConfigure != tt.wantForce {
|
||||||
|
t.Fatalf("expected ForceConfigure=%v, got %v", tt.wantForce, gotReq.ForceConfigure)
|
||||||
|
}
|
||||||
|
if got := config.LastSelection(); got != "claude" {
|
||||||
|
t.Fatalf("expected last selection to be claude, got %q", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunLauncherAction_RunModelContinuesAfterCancellation(t *testing.T) {
|
||||||
|
setCmdTestHome(t, t.TempDir())
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(context.Background())
|
||||||
|
|
||||||
|
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionRunModel}, launcherDeps{
|
||||||
|
buildState: nil,
|
||||||
|
runMenu: nil,
|
||||||
|
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
|
||||||
|
return "", launch.ErrCancelled
|
||||||
|
},
|
||||||
|
launchIntegration: unexpectedIntegrationLaunch(t),
|
||||||
|
runModel: unexpectedModelLaunch(t),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected nil error on cancellation, got %v", err)
|
||||||
|
}
|
||||||
|
if !continueLoop {
|
||||||
|
t.Fatal("expected cancellation to continue the menu loop")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunLauncherAction_VSCodeExitsTUILoop(t *testing.T) {
|
||||||
|
setCmdTestHome(t, t.TempDir())
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(context.Background())
|
||||||
|
|
||||||
|
// VS Code should exit the TUI loop (return false) after a successful launch.
|
||||||
|
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "vscode"}, launcherDeps{
|
||||||
|
resolveRunModel: unexpectedRunModelResolution(t),
|
||||||
|
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
runModel: unexpectedModelLaunch(t),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected nil error, got %v", err)
|
||||||
|
}
|
||||||
|
if continueLoop {
|
||||||
|
t.Fatal("expected vscode launch to exit the TUI loop (return false)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Other integrations should continue the TUI loop (return true).
|
||||||
|
continueLoop, err = runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"}, launcherDeps{
|
||||||
|
resolveRunModel: unexpectedRunModelResolution(t),
|
||||||
|
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
runModel: unexpectedModelLaunch(t),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected nil error, got %v", err)
|
||||||
|
}
|
||||||
|
if !continueLoop {
|
||||||
|
t.Fatal("expected non-vscode integration to continue the TUI loop (return true)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunLauncherAction_IntegrationContinuesAfterCancellation(t *testing.T) {
|
||||||
|
setCmdTestHome(t, t.TempDir())
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(context.Background())
|
||||||
|
|
||||||
|
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"}, launcherDeps{
|
||||||
|
buildState: nil,
|
||||||
|
runMenu: nil,
|
||||||
|
resolveRunModel: unexpectedRunModelResolution(t),
|
||||||
|
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||||
|
return launch.ErrCancelled
|
||||||
|
},
|
||||||
|
runModel: unexpectedModelLaunch(t),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected nil error on cancellation, got %v", err)
|
||||||
|
}
|
||||||
|
if !continueLoop {
|
||||||
|
t.Fatal("expected cancellation to continue the menu loop")
|
||||||
|
}
|
||||||
|
}
|
||||||
585
cmd/cmd_test.go
@@ -301,7 +301,7 @@ Weigh anchor!
|
|||||||
ParameterSize: "7B",
|
ParameterSize: "7B",
|
||||||
QuantizationLevel: "FP16",
|
QuantizationLevel: "FP16",
|
||||||
},
|
},
|
||||||
Requires: "0.14.0",
|
Requires: "0.19.0",
|
||||||
}, false, &b); err != nil {
|
}, false, &b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -310,10 +310,17 @@ Weigh anchor!
|
|||||||
architecture test
|
architecture test
|
||||||
parameters 7B
|
parameters 7B
|
||||||
quantization FP16
|
quantization FP16
|
||||||
requires 0.14.0
|
requires 0.19.0
|
||||||
|
|
||||||
`
|
`
|
||||||
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
trimLinePadding := func(s string) string {
|
||||||
|
lines := strings.Split(s, "\n")
|
||||||
|
for i, line := range lines {
|
||||||
|
lines[i] = strings.TrimRight(line, " \t\r")
|
||||||
|
}
|
||||||
|
return strings.Join(lines, "\n")
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(trimLinePadding(expect), trimLinePadding(b.String())); diff != "" {
|
||||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -705,6 +712,347 @@ func TestRunEmbeddingModelNoInput(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRunHandler_CloudAuthErrorOnShow_PrintsSigninMessage(t *testing.T) {
|
||||||
|
var generateCalled bool
|
||||||
|
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
if err := json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"error": "unauthorized",
|
||||||
|
"signin_url": "https://ollama.com/signin",
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||||
|
generateCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
t.Cleanup(mockServer.Close)
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(t.Context())
|
||||||
|
cmd.Flags().String("keepalive", "", "")
|
||||||
|
cmd.Flags().Bool("truncate", false, "")
|
||||||
|
cmd.Flags().Int("dimensions", 0, "")
|
||||||
|
cmd.Flags().Bool("verbose", false, "")
|
||||||
|
cmd.Flags().Bool("insecure", false, "")
|
||||||
|
cmd.Flags().Bool("nowordwrap", false, "")
|
||||||
|
cmd.Flags().String("format", "", "")
|
||||||
|
cmd.Flags().String("think", "", "")
|
||||||
|
cmd.Flags().Bool("hidethinking", false, "")
|
||||||
|
|
||||||
|
oldStdout := os.Stdout
|
||||||
|
readOut, writeOut, _ := os.Pipe()
|
||||||
|
os.Stdout = writeOut
|
||||||
|
t.Cleanup(func() { os.Stdout = oldStdout })
|
||||||
|
|
||||||
|
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||||
|
|
||||||
|
_ = writeOut.Close()
|
||||||
|
var out bytes.Buffer
|
||||||
|
_, _ = io.Copy(&out, readOut)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RunHandler returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if generateCalled {
|
||||||
|
t.Fatal("expected run to stop before /api/generate after unauthorized /api/show")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
|
||||||
|
t.Fatalf("expected sign-in guidance message, got %q", out.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(out.String(), "https://ollama.com/signin") {
|
||||||
|
t.Fatalf("expected signin_url in output, got %q", out.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunHandler_CloudAuthErrorOnGenerate_PrintsSigninMessage(t *testing.T) {
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||||
|
Capabilities: []model.Capability{model.CapabilityCompletion},
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
if err := json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"error": "unauthorized",
|
||||||
|
"signin_url": "https://ollama.com/signin",
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
t.Cleanup(mockServer.Close)
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(t.Context())
|
||||||
|
cmd.Flags().String("keepalive", "", "")
|
||||||
|
cmd.Flags().Bool("truncate", false, "")
|
||||||
|
cmd.Flags().Int("dimensions", 0, "")
|
||||||
|
cmd.Flags().Bool("verbose", false, "")
|
||||||
|
cmd.Flags().Bool("insecure", false, "")
|
||||||
|
cmd.Flags().Bool("nowordwrap", false, "")
|
||||||
|
cmd.Flags().String("format", "", "")
|
||||||
|
cmd.Flags().String("think", "", "")
|
||||||
|
cmd.Flags().Bool("hidethinking", false, "")
|
||||||
|
|
||||||
|
oldStdout := os.Stdout
|
||||||
|
readOut, writeOut, _ := os.Pipe()
|
||||||
|
os.Stdout = writeOut
|
||||||
|
t.Cleanup(func() { os.Stdout = oldStdout })
|
||||||
|
|
||||||
|
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||||
|
|
||||||
|
_ = writeOut.Close()
|
||||||
|
var out bytes.Buffer
|
||||||
|
_, _ = io.Copy(&out, readOut)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RunHandler returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
|
||||||
|
t.Fatalf("expected sign-in guidance message, got %q", out.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(out.String(), "https://ollama.com/signin") {
|
||||||
|
t.Fatalf("expected signin_url in output, got %q", out.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunHandler_ExplicitCloudStubMissing_PullsNormalizedNameTEMP(t *testing.T) {
|
||||||
|
var pulledModel string
|
||||||
|
var generateCalled bool
|
||||||
|
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||||
|
Capabilities: []model.Capability{model.CapabilityCompletion},
|
||||||
|
RemoteModel: "gpt-oss:20b",
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/tags" && r.Method == http.MethodGet:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ListResponse{Models: nil}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/pull" && r.Method == http.MethodPost:
|
||||||
|
var req api.PullRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pulledModel = req.Model
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ProgressResponse{Status: "success"}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||||
|
generateCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
t.Cleanup(mockServer.Close)
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(t.Context())
|
||||||
|
cmd.Flags().String("keepalive", "", "")
|
||||||
|
cmd.Flags().Bool("truncate", false, "")
|
||||||
|
cmd.Flags().Int("dimensions", 0, "")
|
||||||
|
cmd.Flags().Bool("verbose", false, "")
|
||||||
|
cmd.Flags().Bool("insecure", false, "")
|
||||||
|
cmd.Flags().Bool("nowordwrap", false, "")
|
||||||
|
cmd.Flags().String("format", "", "")
|
||||||
|
cmd.Flags().String("think", "", "")
|
||||||
|
cmd.Flags().Bool("hidethinking", false, "")
|
||||||
|
|
||||||
|
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RunHandler returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pulledModel != "gpt-oss:20b-cloud" {
|
||||||
|
t.Fatalf("expected normalized pull model %q, got %q", "gpt-oss:20b-cloud", pulledModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !generateCalled {
|
||||||
|
t.Fatal("expected /api/generate to be called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunHandler_ExplicitCloudStubPresent_SkipsPullTEMP(t *testing.T) {
|
||||||
|
var pullCalled bool
|
||||||
|
var generateCalled bool
|
||||||
|
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||||
|
Capabilities: []model.Capability{model.CapabilityCompletion},
|
||||||
|
RemoteModel: "gpt-oss:20b",
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/tags" && r.Method == http.MethodGet:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ListResponse{
|
||||||
|
Models: []api.ListModelResponse{{Name: "gpt-oss:20b-cloud"}},
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/pull" && r.Method == http.MethodPost:
|
||||||
|
pullCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ProgressResponse{Status: "success"}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||||
|
generateCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
t.Cleanup(mockServer.Close)
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(t.Context())
|
||||||
|
cmd.Flags().String("keepalive", "", "")
|
||||||
|
cmd.Flags().Bool("truncate", false, "")
|
||||||
|
cmd.Flags().Int("dimensions", 0, "")
|
||||||
|
cmd.Flags().Bool("verbose", false, "")
|
||||||
|
cmd.Flags().Bool("insecure", false, "")
|
||||||
|
cmd.Flags().Bool("nowordwrap", false, "")
|
||||||
|
cmd.Flags().String("format", "", "")
|
||||||
|
cmd.Flags().String("think", "", "")
|
||||||
|
cmd.Flags().Bool("hidethinking", false, "")
|
||||||
|
|
||||||
|
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RunHandler returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pullCalled {
|
||||||
|
t.Fatal("expected /api/pull not to be called when cloud stub already exists")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !generateCalled {
|
||||||
|
t.Fatal("expected /api/generate to be called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunHandler_ExplicitCloudStubPullFailure_IsBestEffortTEMP(t *testing.T) {
|
||||||
|
var generateCalled bool
|
||||||
|
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||||
|
Capabilities: []model.Capability{model.CapabilityCompletion},
|
||||||
|
RemoteModel: "gpt-oss:20b",
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/tags" && r.Method == http.MethodGet:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ListResponse{Models: nil}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/pull" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
if err := json.NewEncoder(w).Encode(map[string]string{"error": "pull failed"}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||||
|
generateCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
t.Cleanup(mockServer.Close)
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(t.Context())
|
||||||
|
cmd.Flags().String("keepalive", "", "")
|
||||||
|
cmd.Flags().Bool("truncate", false, "")
|
||||||
|
cmd.Flags().Int("dimensions", 0, "")
|
||||||
|
cmd.Flags().Bool("verbose", false, "")
|
||||||
|
cmd.Flags().Bool("insecure", false, "")
|
||||||
|
cmd.Flags().Bool("nowordwrap", false, "")
|
||||||
|
cmd.Flags().String("format", "", "")
|
||||||
|
cmd.Flags().String("think", "", "")
|
||||||
|
cmd.Flags().Bool("hidethinking", false, "")
|
||||||
|
|
||||||
|
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RunHandler returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !generateCalled {
|
||||||
|
t.Fatal("expected /api/generate to be called despite pull failure")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetModelfileName(t *testing.T) {
|
func TestGetModelfileName(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -1212,6 +1560,20 @@ func TestNewCreateRequest(t *testing.T) {
|
|||||||
Model: "newmodel",
|
Model: "newmodel",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"explicit cloud model preserves source when parent lacks it",
|
||||||
|
"newmodel",
|
||||||
|
runOptions{
|
||||||
|
Model: "qwen3.5:cloud",
|
||||||
|
ParentModel: "qwen3.5",
|
||||||
|
Messages: []api.Message{},
|
||||||
|
WordWrap: true,
|
||||||
|
},
|
||||||
|
&api.CreateRequest{
|
||||||
|
From: "qwen3.5:cloud",
|
||||||
|
Model: "newmodel",
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"parent model as filepath test",
|
"parent model as filepath test",
|
||||||
"newmodel",
|
"newmodel",
|
||||||
@@ -1293,6 +1655,24 @@ func TestNewCreateRequest(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"loaded messages are preserved when saving",
|
||||||
|
"newmodel",
|
||||||
|
runOptions{
|
||||||
|
Model: "mymodel",
|
||||||
|
ParentModel: "parentmodel",
|
||||||
|
LoadedMessages: []api.Message{{Role: "assistant", Content: "loaded"}},
|
||||||
|
Messages: []api.Message{{Role: "user", Content: "new"}},
|
||||||
|
},
|
||||||
|
&api.CreateRequest{
|
||||||
|
From: "parentmodel",
|
||||||
|
Model: "newmodel",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "assistant", Content: "loaded"},
|
||||||
|
{Role: "user", Content: "new"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -1305,15 +1685,43 @@ func TestNewCreateRequest(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApplyShowResponseToRunOptions(t *testing.T) {
|
||||||
|
opts := runOptions{}
|
||||||
|
info := &api.ShowResponse{
|
||||||
|
Details: api.ModelDetails{
|
||||||
|
ParentModel: "parentmodel",
|
||||||
|
},
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "assistant", Content: "loaded"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
applyShowResponseToRunOptions(&opts, info)
|
||||||
|
|
||||||
|
if opts.ParentModel != "parentmodel" {
|
||||||
|
t.Fatalf("ParentModel = %q, want %q", opts.ParentModel, "parentmodel")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cmp.Equal(opts.LoadedMessages, info.Messages) {
|
||||||
|
t.Fatalf("LoadedMessages = %#v, want %#v", opts.LoadedMessages, info.Messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
info.Messages[0].Content = "modified"
|
||||||
|
if opts.LoadedMessages[0].Content == "modified" {
|
||||||
|
t.Fatal("LoadedMessages should be copied independently from ShowResponse")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRunOptions_Copy(t *testing.T) {
|
func TestRunOptions_Copy(t *testing.T) {
|
||||||
// Setup test data
|
// Setup test data
|
||||||
originalKeepAlive := &api.Duration{Duration: 5 * time.Minute}
|
originalKeepAlive := &api.Duration{Duration: 5 * time.Minute}
|
||||||
originalThink := &api.ThinkValue{Value: "test reasoning"}
|
originalThink := &api.ThinkValue{Value: "test reasoning"}
|
||||||
|
|
||||||
original := runOptions{
|
original := runOptions{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
ParentModel: "parent-model",
|
ParentModel: "parent-model",
|
||||||
Prompt: "test prompt",
|
LoadedMessages: []api.Message{{Role: "assistant", Content: "loaded hello"}},
|
||||||
|
Prompt: "test prompt",
|
||||||
Messages: []api.Message{
|
Messages: []api.Message{
|
||||||
{Role: "user", Content: "hello"},
|
{Role: "user", Content: "hello"},
|
||||||
{Role: "assistant", Content: "hi there"},
|
{Role: "assistant", Content: "hi there"},
|
||||||
@@ -1353,6 +1761,7 @@ func TestRunOptions_Copy(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{"Model", copied.Model, original.Model},
|
{"Model", copied.Model, original.Model},
|
||||||
{"ParentModel", copied.ParentModel, original.ParentModel},
|
{"ParentModel", copied.ParentModel, original.ParentModel},
|
||||||
|
{"LoadedMessages", copied.LoadedMessages, original.LoadedMessages},
|
||||||
{"Prompt", copied.Prompt, original.Prompt},
|
{"Prompt", copied.Prompt, original.Prompt},
|
||||||
{"WordWrap", copied.WordWrap, original.WordWrap},
|
{"WordWrap", copied.WordWrap, original.WordWrap},
|
||||||
{"Format", copied.Format, original.Format},
|
{"Format", copied.Format, original.Format},
|
||||||
@@ -1457,13 +1866,18 @@ func TestRunOptions_Copy(t *testing.T) {
|
|||||||
func TestRunOptions_Copy_EmptySlicesAndMaps(t *testing.T) {
|
func TestRunOptions_Copy_EmptySlicesAndMaps(t *testing.T) {
|
||||||
// Test with empty slices and maps
|
// Test with empty slices and maps
|
||||||
original := runOptions{
|
original := runOptions{
|
||||||
Messages: []api.Message{},
|
LoadedMessages: []api.Message{},
|
||||||
Images: []api.ImageData{},
|
Messages: []api.Message{},
|
||||||
Options: map[string]any{},
|
Images: []api.ImageData{},
|
||||||
|
Options: map[string]any{},
|
||||||
}
|
}
|
||||||
|
|
||||||
copied := original.Copy()
|
copied := original.Copy()
|
||||||
|
|
||||||
|
if copied.LoadedMessages == nil {
|
||||||
|
t.Error("Empty LoadedMessages slice should remain empty, not nil")
|
||||||
|
}
|
||||||
|
|
||||||
if copied.Messages == nil {
|
if copied.Messages == nil {
|
||||||
t.Error("Empty Messages slice should remain empty, not nil")
|
t.Error("Empty Messages slice should remain empty, not nil")
|
||||||
}
|
}
|
||||||
@@ -1480,6 +1894,10 @@ func TestRunOptions_Copy_EmptySlicesAndMaps(t *testing.T) {
|
|||||||
t.Error("Empty Messages slice should remain empty")
|
t.Error("Empty Messages slice should remain empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(copied.LoadedMessages) != 0 {
|
||||||
|
t.Error("Empty LoadedMessages slice should remain empty")
|
||||||
|
}
|
||||||
|
|
||||||
if len(copied.Images) != 0 {
|
if len(copied.Images) != 0 {
|
||||||
t.Error("Empty Images slice should remain empty")
|
t.Error("Empty Images slice should remain empty")
|
||||||
}
|
}
|
||||||
@@ -1557,7 +1975,7 @@ func TestShowInfoImageGen(t *testing.T) {
|
|||||||
QuantizationLevel: "Q8",
|
QuantizationLevel: "Q8",
|
||||||
},
|
},
|
||||||
Capabilities: []model.Capability{model.CapabilityImage},
|
Capabilities: []model.Capability{model.CapabilityImage},
|
||||||
Requires: "0.14.0",
|
Requires: "0.19.0",
|
||||||
}, false, &b)
|
}, false, &b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -1567,7 +1985,7 @@ func TestShowInfoImageGen(t *testing.T) {
|
|||||||
" architecture ZImagePipeline \n" +
|
" architecture ZImagePipeline \n" +
|
||||||
" parameters 10.3B \n" +
|
" parameters 10.3B \n" +
|
||||||
" quantization Q8 \n" +
|
" quantization Q8 \n" +
|
||||||
" requires 0.14.0 \n" +
|
" requires 0.19.0 \n" +
|
||||||
"\n" +
|
"\n" +
|
||||||
" Capabilities\n" +
|
" Capabilities\n" +
|
||||||
" image \n" +
|
" image \n" +
|
||||||
@@ -1625,16 +2043,20 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
|
|||||||
// Test that modifications to original don't affect copy
|
// Test that modifications to original don't affect copy
|
||||||
originalThink := &api.ThinkValue{Value: "original"}
|
originalThink := &api.ThinkValue{Value: "original"}
|
||||||
original := runOptions{
|
original := runOptions{
|
||||||
Model: "original-model",
|
Model: "original-model",
|
||||||
Messages: []api.Message{{Role: "user", Content: "original"}},
|
LoadedMessages: []api.Message{{Role: "assistant", Content: "loaded"}},
|
||||||
Options: map[string]any{"key": "value"},
|
Messages: []api.Message{{Role: "user", Content: "original"}},
|
||||||
Think: originalThink,
|
Options: map[string]any{"key": "value"},
|
||||||
|
Think: originalThink,
|
||||||
}
|
}
|
||||||
|
|
||||||
copied := original.Copy()
|
copied := original.Copy()
|
||||||
|
|
||||||
// Modify original
|
// Modify original
|
||||||
original.Model = "modified-model"
|
original.Model = "modified-model"
|
||||||
|
if len(original.LoadedMessages) > 0 {
|
||||||
|
original.LoadedMessages[0].Content = "modified loaded"
|
||||||
|
}
|
||||||
if len(original.Messages) > 0 {
|
if len(original.Messages) > 0 {
|
||||||
original.Messages[0].Content = "modified"
|
original.Messages[0].Content = "modified"
|
||||||
}
|
}
|
||||||
@@ -1648,6 +2070,10 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
|
|||||||
t.Error("Copy Model should not be affected by original modification")
|
t.Error("Copy Model should not be affected by original modification")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(copied.LoadedMessages) > 0 && copied.LoadedMessages[0].Content == "modified loaded" {
|
||||||
|
t.Error("Copy LoadedMessages should not be affected by original modification")
|
||||||
|
}
|
||||||
|
|
||||||
if len(copied.Messages) > 0 && copied.Messages[0].Content == "modified" {
|
if len(copied.Messages) > 0 && copied.Messages[0].Content == "modified" {
|
||||||
t.Error("Copy Messages should not be affected by original modification")
|
t.Error("Copy Messages should not be affected by original modification")
|
||||||
}
|
}
|
||||||
@@ -1663,31 +2089,81 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
|
|||||||
|
|
||||||
func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
remoteHost string
|
model string
|
||||||
whoamiStatus int
|
showStatus int
|
||||||
whoamiResp any
|
remoteHost string
|
||||||
expectedError string
|
remoteModel string
|
||||||
|
whoamiStatus int
|
||||||
|
whoamiResp any
|
||||||
|
expectWhoami bool
|
||||||
|
expectedError string
|
||||||
|
expectAuthError bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "ollama.com cloud model - user signed in",
|
name: "ollama.com cloud model - user signed in",
|
||||||
|
model: "test-cloud-model",
|
||||||
remoteHost: "https://ollama.com",
|
remoteHost: "https://ollama.com",
|
||||||
|
remoteModel: "test-model",
|
||||||
whoamiStatus: http.StatusOK,
|
whoamiStatus: http.StatusOK,
|
||||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||||
|
expectWhoami: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "ollama.com cloud model - user not signed in",
|
name: "ollama.com cloud model - user not signed in",
|
||||||
|
model: "test-cloud-model",
|
||||||
remoteHost: "https://ollama.com",
|
remoteHost: "https://ollama.com",
|
||||||
|
remoteModel: "test-model",
|
||||||
whoamiStatus: http.StatusUnauthorized,
|
whoamiStatus: http.StatusUnauthorized,
|
||||||
whoamiResp: map[string]string{
|
whoamiResp: map[string]string{
|
||||||
"error": "unauthorized",
|
"error": "unauthorized",
|
||||||
"signin_url": "https://ollama.com/signin",
|
"signin_url": "https://ollama.com/signin",
|
||||||
},
|
},
|
||||||
expectedError: "unauthorized",
|
expectWhoami: true,
|
||||||
|
expectedError: "unauthorized",
|
||||||
|
expectAuthError: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "non-ollama.com remote - no auth check",
|
name: "non-ollama.com remote - no auth check",
|
||||||
|
model: "test-cloud-model",
|
||||||
remoteHost: "https://other-remote.com",
|
remoteHost: "https://other-remote.com",
|
||||||
|
remoteModel: "test-model",
|
||||||
|
whoamiStatus: http.StatusUnauthorized, // should not be called
|
||||||
|
whoamiResp: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit :cloud model - auth check without remote metadata",
|
||||||
|
model: "kimi-k2.5:cloud",
|
||||||
|
remoteHost: "",
|
||||||
|
remoteModel: "",
|
||||||
|
whoamiStatus: http.StatusOK,
|
||||||
|
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||||
|
expectWhoami: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit :cloud model without local stub returns not found by default",
|
||||||
|
model: "minimax-m2.7:cloud",
|
||||||
|
showStatus: http.StatusNotFound,
|
||||||
|
whoamiStatus: http.StatusOK,
|
||||||
|
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||||
|
expectedError: "not found",
|
||||||
|
expectWhoami: false,
|
||||||
|
expectAuthError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit -cloud model - auth check without remote metadata",
|
||||||
|
model: "kimi-k2.5:latest-cloud",
|
||||||
|
remoteHost: "",
|
||||||
|
remoteModel: "",
|
||||||
|
whoamiStatus: http.StatusOK,
|
||||||
|
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||||
|
expectWhoami: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dash cloud-like name without explicit source does not require auth",
|
||||||
|
model: "test-cloud-model",
|
||||||
|
remoteHost: "",
|
||||||
|
remoteModel: "",
|
||||||
whoamiStatus: http.StatusUnauthorized, // should not be called
|
whoamiStatus: http.StatusUnauthorized, // should not be called
|
||||||
whoamiResp: nil,
|
whoamiResp: nil,
|
||||||
},
|
},
|
||||||
@@ -1699,10 +2175,15 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
|||||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
switch r.URL.Path {
|
switch r.URL.Path {
|
||||||
case "/api/show":
|
case "/api/show":
|
||||||
|
if tt.showStatus != 0 && tt.showStatus != http.StatusOK {
|
||||||
|
w.WriteHeader(tt.showStatus)
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]string{"error": "not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||||
RemoteHost: tt.remoteHost,
|
RemoteHost: tt.remoteHost,
|
||||||
RemoteModel: "test-model",
|
RemoteModel: tt.remoteModel,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
@@ -1715,6 +2196,8 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
|||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case "/api/generate":
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
default:
|
default:
|
||||||
http.NotFound(w, r)
|
http.NotFound(w, r)
|
||||||
}
|
}
|
||||||
@@ -1727,29 +2210,28 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
|||||||
cmd.SetContext(t.Context())
|
cmd.SetContext(t.Context())
|
||||||
|
|
||||||
opts := &runOptions{
|
opts := &runOptions{
|
||||||
Model: "test-cloud-model",
|
Model: tt.model,
|
||||||
ShowConnect: false,
|
ShowConnect: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := loadOrUnloadModel(cmd, opts)
|
err := loadOrUnloadModel(cmd, opts)
|
||||||
|
|
||||||
if strings.HasPrefix(tt.remoteHost, "https://ollama.com") {
|
if whoamiCalled != tt.expectWhoami {
|
||||||
if !whoamiCalled {
|
t.Errorf("whoami called = %v, want %v", whoamiCalled, tt.expectWhoami)
|
||||||
t.Error("expected whoami to be called for ollama.com cloud model")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if whoamiCalled {
|
|
||||||
t.Error("whoami should not be called for non-ollama.com remote")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if tt.expectedError != "" {
|
if tt.expectedError != "" {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("expected error containing %q, got nil", tt.expectedError)
|
t.Errorf("expected error containing %q, got nil", tt.expectedError)
|
||||||
} else {
|
} else {
|
||||||
var authErr api.AuthorizationError
|
if !tt.expectAuthError && !strings.Contains(strings.ToLower(err.Error()), strings.ToLower(tt.expectedError)) {
|
||||||
if !errors.As(err, &authErr) {
|
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
|
||||||
t.Errorf("expected AuthorizationError, got %T: %v", err, err)
|
}
|
||||||
|
if tt.expectAuthError {
|
||||||
|
var authErr api.AuthorizationError
|
||||||
|
if !errors.As(err, &authErr) {
|
||||||
|
t.Errorf("expected AuthorizationError, got %T: %v", err, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -1760,3 +2242,38 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsLocalhost(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
host string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"default empty", "", true},
|
||||||
|
{"localhost no port", "localhost", true},
|
||||||
|
{"localhost with port", "localhost:11435", true},
|
||||||
|
{"127.0.0.1 no port", "127.0.0.1", true},
|
||||||
|
{"127.0.0.1 with port", "127.0.0.1:11434", true},
|
||||||
|
{"0.0.0.0 no port", "0.0.0.0", true},
|
||||||
|
{"0.0.0.0 with port", "0.0.0.0:11434", true},
|
||||||
|
{"::1 no port", "::1", true},
|
||||||
|
{"[::1] with port", "[::1]:11434", true},
|
||||||
|
{"loopback with scheme", "http://localhost:11434", true},
|
||||||
|
{"remote hostname", "example.com", false},
|
||||||
|
{"remote hostname with port", "example.com:11434", false},
|
||||||
|
{"remote IP", "192.168.1.1", false},
|
||||||
|
{"remote IP with port", "192.168.1.1:11434", false},
|
||||||
|
{"remote with scheme", "http://example.com:11434", false},
|
||||||
|
{"https remote", "https://example.com:443", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_HOST", tt.host)
|
||||||
|
got := isLocalhost()
|
||||||
|
if got != tt.expected {
|
||||||
|
t.Errorf("isLocalhost() with OLLAMA_HOST=%q = %v, want %v", tt.host, got, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,192 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Claude implements Runner and AliasConfigurer for Claude Code integration
|
|
||||||
type Claude struct{}
|
|
||||||
|
|
||||||
// Compile-time check that Claude implements AliasConfigurer
|
|
||||||
var _ AliasConfigurer = (*Claude)(nil)
|
|
||||||
|
|
||||||
func (c *Claude) String() string { return "Claude Code" }
|
|
||||||
|
|
||||||
func (c *Claude) args(model string, extra []string) []string {
|
|
||||||
var args []string
|
|
||||||
if model != "" {
|
|
||||||
args = append(args, "--model", model)
|
|
||||||
}
|
|
||||||
args = append(args, extra...)
|
|
||||||
return args
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Claude) findPath() (string, error) {
|
|
||||||
if p, err := exec.LookPath("claude"); err == nil {
|
|
||||||
return p, nil
|
|
||||||
}
|
|
||||||
home, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
name := "claude"
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
name = "claude.exe"
|
|
||||||
}
|
|
||||||
fallback := filepath.Join(home, ".claude", "local", name)
|
|
||||||
if _, err := os.Stat(fallback); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return fallback, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Claude) Run(model string, args []string) error {
|
|
||||||
claudePath, err := c.findPath()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command(claudePath, c.args(model, args)...)
|
|
||||||
cmd.Stdin = os.Stdin
|
|
||||||
cmd.Stdout = os.Stdout
|
|
||||||
cmd.Stderr = os.Stderr
|
|
||||||
|
|
||||||
env := append(os.Environ(),
|
|
||||||
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
|
|
||||||
"ANTHROPIC_API_KEY=",
|
|
||||||
"ANTHROPIC_AUTH_TOKEN=ollama",
|
|
||||||
)
|
|
||||||
|
|
||||||
env = append(env, c.modelEnvVars(model)...)
|
|
||||||
|
|
||||||
cmd.Env = env
|
|
||||||
return cmd.Run()
|
|
||||||
}
|
|
||||||
|
|
||||||
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
|
|
||||||
func (c *Claude) modelEnvVars(model string) []string {
|
|
||||||
primary := model
|
|
||||||
fast := model
|
|
||||||
if cfg, err := loadIntegration("claude"); err == nil && cfg.Aliases != nil {
|
|
||||||
if p := cfg.Aliases["primary"]; p != "" {
|
|
||||||
primary = p
|
|
||||||
}
|
|
||||||
if f := cfg.Aliases["fast"]; f != "" {
|
|
||||||
fast = f
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return []string{
|
|
||||||
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + primary,
|
|
||||||
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + primary,
|
|
||||||
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + fast,
|
|
||||||
"CLAUDE_CODE_SUBAGENT_MODEL=" + primary,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConfigureAliases sets up model aliases for Claude Code.
|
|
||||||
// model: the model to use (if empty, user will be prompted to select)
|
|
||||||
// aliases: existing alias configuration to preserve/update
|
|
||||||
// Cloud-only: subagent routing (fast model) is gated to cloud models only until
|
|
||||||
// there is a better strategy for prompt caching on local models.
|
|
||||||
func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAliases map[string]string, force bool) (map[string]string, bool, error) {
|
|
||||||
aliases := make(map[string]string)
|
|
||||||
for k, v := range existingAliases {
|
|
||||||
aliases[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
if model != "" {
|
|
||||||
aliases["primary"] = model
|
|
||||||
}
|
|
||||||
|
|
||||||
if !force && aliases["primary"] != "" {
|
|
||||||
client, _ := api.ClientFromEnvironment()
|
|
||||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
|
||||||
if isCloudModel(ctx, client, aliases["fast"]) {
|
|
||||||
return aliases, false, nil
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
delete(aliases, "fast")
|
|
||||||
return aliases, false, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
items, existingModels, cloudModels, client, err := listModels(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
|
|
||||||
|
|
||||||
if aliases["primary"] == "" || force {
|
|
||||||
primary, err := DefaultSingleSelector("Select model:", items, aliases["primary"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, false, err
|
|
||||||
}
|
|
||||||
if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil {
|
|
||||||
return nil, false, err
|
|
||||||
}
|
|
||||||
if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil {
|
|
||||||
return nil, false, err
|
|
||||||
}
|
|
||||||
aliases["primary"] = primary
|
|
||||||
}
|
|
||||||
|
|
||||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
|
||||||
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
|
|
||||||
aliases["fast"] = aliases["primary"]
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
delete(aliases, "fast")
|
|
||||||
}
|
|
||||||
|
|
||||||
return aliases, true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetAliases syncs the configured aliases to the Ollama server using prefix matching.
|
|
||||||
// Cloud-only: for local models (fast is empty), we delete any existing aliases to
|
|
||||||
// prevent stale routing to a previous cloud model.
|
|
||||||
func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error {
|
|
||||||
client, err := api.ClientFromEnvironment()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
prefixes := []string{"claude-sonnet-", "claude-haiku-"}
|
|
||||||
|
|
||||||
if aliases["fast"] == "" {
|
|
||||||
for _, prefix := range prefixes {
|
|
||||||
_ = client.DeleteAliasExperimental(ctx, &api.AliasDeleteRequest{Alias: prefix})
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
prefixAliases := map[string]string{
|
|
||||||
"claude-sonnet-": aliases["primary"],
|
|
||||||
"claude-haiku-": aliases["fast"],
|
|
||||||
}
|
|
||||||
|
|
||||||
var errs []string
|
|
||||||
for prefix, target := range prefixAliases {
|
|
||||||
req := &api.AliasRequest{
|
|
||||||
Alias: prefix,
|
|
||||||
Target: target,
|
|
||||||
PrefixMatching: true,
|
|
||||||
}
|
|
||||||
if err := client.SetAliasExperimental(ctx, req); err != nil {
|
|
||||||
errs = append(errs, prefix)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(errs) > 0 {
|
|
||||||
return fmt.Errorf("failed to set aliases: %v", errs)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
|
||||||
"golang.org/x/mod/semver"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Codex implements Runner for Codex integration
|
|
||||||
type Codex struct{}
|
|
||||||
|
|
||||||
func (c *Codex) String() string { return "Codex" }
|
|
||||||
|
|
||||||
func (c *Codex) args(model string, extra []string) []string {
|
|
||||||
args := []string{"--oss"}
|
|
||||||
if model != "" {
|
|
||||||
args = append(args, "-m", model)
|
|
||||||
}
|
|
||||||
args = append(args, extra...)
|
|
||||||
return args
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Codex) Run(model string, args []string) error {
|
|
||||||
if err := checkCodexVersion(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command("codex", c.args(model, args)...)
|
|
||||||
cmd.Stdin = os.Stdin
|
|
||||||
cmd.Stdout = os.Stdout
|
|
||||||
cmd.Stderr = os.Stderr
|
|
||||||
cmd.Env = append(os.Environ(),
|
|
||||||
"OPENAI_BASE_URL="+envconfig.Host().String()+"/v1/",
|
|
||||||
"OPENAI_API_KEY=ollama",
|
|
||||||
)
|
|
||||||
return cmd.Run()
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkCodexVersion() error {
|
|
||||||
if _, err := exec.LookPath("codex"); err != nil {
|
|
||||||
return fmt.Errorf("codex is not installed, install with: npm install -g @openai/codex")
|
|
||||||
}
|
|
||||||
|
|
||||||
out, err := exec.Command("codex", "--version").Output()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get codex version: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse output like "codex-cli 0.87.0"
|
|
||||||
fields := strings.Fields(strings.TrimSpace(string(out)))
|
|
||||||
if len(fields) < 2 {
|
|
||||||
return fmt.Errorf("unexpected codex version output: %s", string(out))
|
|
||||||
}
|
|
||||||
|
|
||||||
version := "v" + fields[len(fields)-1]
|
|
||||||
minVersion := "v0.81.0"
|
|
||||||
|
|
||||||
if semver.Compare(version, minVersion) < 0 {
|
|
||||||
return fmt.Errorf("codex version %s is too old, minimum required is %s, update with: npm update -g @openai/codex", fields[len(fields)-1], "0.81.0")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"slices"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCodexArgs(t *testing.T) {
|
|
||||||
c := &Codex{}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
model string
|
|
||||||
args []string
|
|
||||||
want []string
|
|
||||||
}{
|
|
||||||
{"with model", "llama3.2", nil, []string{"--oss", "-m", "llama3.2"}},
|
|
||||||
{"empty model", "", nil, []string{"--oss"}},
|
|
||||||
{"with model and profile", "qwen3-coder", []string{"-p", "myprofile"}, []string{"--oss", "-m", "qwen3-coder", "-p", "myprofile"}},
|
|
||||||
{"with sandbox flag", "llama3.2", []string{"--sandbox", "workspace-write"}, []string{"--oss", "-m", "llama3.2", "--sandbox", "workspace-write"}},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got := c.args(tt.model, tt.args)
|
|
||||||
if !slices.Equal(got, tt.want) {
|
|
||||||
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -3,7 +3,6 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -11,7 +10,7 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
type integration struct {
|
type integration struct {
|
||||||
@@ -20,6 +19,9 @@ type integration struct {
|
|||||||
Onboarded bool `json:"onboarded,omitempty"`
|
Onboarded bool `json:"onboarded,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IntegrationConfig is the persisted config for one integration.
|
||||||
|
type IntegrationConfig = integration
|
||||||
|
|
||||||
type config struct {
|
type config struct {
|
||||||
Integrations map[string]*integration `json:"integrations"`
|
Integrations map[string]*integration `json:"integrations"`
|
||||||
LastModel string `json:"last_model,omitempty"`
|
LastModel string `json:"last_model,omitempty"`
|
||||||
@@ -124,7 +126,7 @@ func save(cfg *config) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return writeWithBackup(path, data)
|
return fileutil.WriteWithBackup(path, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func SaveIntegration(appName string, models []string) error {
|
func SaveIntegration(appName string, models []string) error {
|
||||||
@@ -155,8 +157,8 @@ func SaveIntegration(appName string, models []string) error {
|
|||||||
return save(cfg)
|
return save(cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// integrationOnboarded marks an integration as onboarded in ollama's config.
|
// MarkIntegrationOnboarded marks an integration as onboarded in Ollama's config.
|
||||||
func integrationOnboarded(appName string) error {
|
func MarkIntegrationOnboarded(appName string) error {
|
||||||
cfg, err := load()
|
cfg, err := load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -174,7 +176,7 @@ func integrationOnboarded(appName string) error {
|
|||||||
|
|
||||||
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
|
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
|
||||||
func IntegrationModel(appName string) string {
|
func IntegrationModel(appName string) string {
|
||||||
integrationConfig, err := loadIntegration(appName)
|
integrationConfig, err := LoadIntegration(appName)
|
||||||
if err != nil || len(integrationConfig.Models) == 0 {
|
if err != nil || len(integrationConfig.Models) == 0 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@@ -183,7 +185,7 @@ func IntegrationModel(appName string) string {
|
|||||||
|
|
||||||
// IntegrationModels returns all configured models for an integration, or nil.
|
// IntegrationModels returns all configured models for an integration, or nil.
|
||||||
func IntegrationModels(appName string) []string {
|
func IntegrationModels(appName string) []string {
|
||||||
integrationConfig, err := loadIntegration(appName)
|
integrationConfig, err := LoadIntegration(appName)
|
||||||
if err != nil || len(integrationConfig.Models) == 0 {
|
if err != nil || len(integrationConfig.Models) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -228,28 +230,8 @@ func SetLastSelection(selection string) error {
|
|||||||
return save(cfg)
|
return save(cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModelExists checks if a model exists on the Ollama server.
|
// LoadIntegration returns the saved config for one integration.
|
||||||
func ModelExists(ctx context.Context, name string) bool {
|
func LoadIntegration(appName string) (*integration, error) {
|
||||||
if name == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
client, err := api.ClientFromEnvironment()
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
models, err := client.List(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for _, m := range models.Models {
|
|
||||||
if m.Name == name || strings.HasPrefix(m.Name, name+":") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadIntegration(appName string) (*integration, error) {
|
|
||||||
cfg, err := load()
|
cfg, err := load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -263,7 +245,8 @@ func loadIntegration(appName string) (*integration, error) {
|
|||||||
return integrationConfig, nil
|
return integrationConfig, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func saveAliases(appName string, aliases map[string]string) error {
|
// SaveAliases replaces the saved aliases for one integration.
|
||||||
|
func SaveAliases(appName string, aliases map[string]string) error {
|
||||||
if appName == "" {
|
if appName == "" {
|
||||||
return errors.New("app name cannot be empty")
|
return errors.New("app name cannot be empty")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -45,12 +44,12 @@ func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
|
|||||||
"primary": "cloud-model",
|
"primary": "cloud-model",
|
||||||
"fast": "cloud-model",
|
"fast": "cloud-model",
|
||||||
}
|
}
|
||||||
if err := saveAliases("claude", initial); err != nil {
|
if err := SaveAliases("claude", initial); err != nil {
|
||||||
t.Fatalf("failed to save initial aliases: %v", err)
|
t.Fatalf("failed to save initial aliases: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify both are saved
|
// Verify both are saved
|
||||||
loaded, err := loadIntegration("claude")
|
loaded, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load: %v", err)
|
t.Fatalf("failed to load: %v", err)
|
||||||
}
|
}
|
||||||
@@ -63,12 +62,12 @@ func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
|
|||||||
"primary": "local-model",
|
"primary": "local-model",
|
||||||
// fast intentionally missing
|
// fast intentionally missing
|
||||||
}
|
}
|
||||||
if err := saveAliases("claude", updated); err != nil {
|
if err := SaveAliases("claude", updated); err != nil {
|
||||||
t.Fatalf("failed to save updated aliases: %v", err)
|
t.Fatalf("failed to save updated aliases: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify fast is GONE (not merged/preserved)
|
// Verify fast is GONE (not merged/preserved)
|
||||||
loaded, err = loadIntegration("claude")
|
loaded, err = LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load after update: %v", err)
|
t.Fatalf("failed to load after update: %v", err)
|
||||||
}
|
}
|
||||||
@@ -91,12 +90,12 @@ func TestSaveAliases_PreservesModels(t *testing.T) {
|
|||||||
|
|
||||||
// Then update aliases
|
// Then update aliases
|
||||||
aliases := map[string]string{"primary": "new-model"}
|
aliases := map[string]string{"primary": "new-model"}
|
||||||
if err := saveAliases("claude", aliases); err != nil {
|
if err := SaveAliases("claude", aliases); err != nil {
|
||||||
t.Fatalf("failed to save aliases: %v", err)
|
t.Fatalf("failed to save aliases: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify models are preserved
|
// Verify models are preserved
|
||||||
loaded, err := loadIntegration("claude")
|
loaded, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load: %v", err)
|
t.Fatalf("failed to load: %v", err)
|
||||||
}
|
}
|
||||||
@@ -111,16 +110,16 @@ func TestSaveAliases_EmptyMap(t *testing.T) {
|
|||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
// Save with aliases
|
// Save with aliases
|
||||||
if err := saveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
|
if err := SaveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
|
||||||
t.Fatalf("failed to save: %v", err)
|
t.Fatalf("failed to save: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save empty map
|
// Save empty map
|
||||||
if err := saveAliases("claude", map[string]string{}); err != nil {
|
if err := SaveAliases("claude", map[string]string{}); err != nil {
|
||||||
t.Fatalf("failed to save empty: %v", err)
|
t.Fatalf("failed to save empty: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, err := loadIntegration("claude")
|
loaded, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load: %v", err)
|
t.Fatalf("failed to load: %v", err)
|
||||||
}
|
}
|
||||||
@@ -135,16 +134,16 @@ func TestSaveAliases_NilMap(t *testing.T) {
|
|||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
// Save with aliases first
|
// Save with aliases first
|
||||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||||
t.Fatalf("failed to save: %v", err)
|
t.Fatalf("failed to save: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save nil map - should clear aliases
|
// Save nil map - should clear aliases
|
||||||
if err := saveAliases("claude", nil); err != nil {
|
if err := SaveAliases("claude", nil); err != nil {
|
||||||
t.Fatalf("failed to save nil: %v", err)
|
t.Fatalf("failed to save nil: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, err := loadIntegration("claude")
|
loaded, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load: %v", err)
|
t.Fatalf("failed to load: %v", err)
|
||||||
}
|
}
|
||||||
@@ -155,7 +154,7 @@ func TestSaveAliases_NilMap(t *testing.T) {
|
|||||||
|
|
||||||
// TestSaveAliases_EmptyAppName returns error
|
// TestSaveAliases_EmptyAppName returns error
|
||||||
func TestSaveAliases_EmptyAppName(t *testing.T) {
|
func TestSaveAliases_EmptyAppName(t *testing.T) {
|
||||||
err := saveAliases("", map[string]string{"primary": "model"})
|
err := SaveAliases("", map[string]string{"primary": "model"})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error for empty app name")
|
t.Error("expected error for empty app name")
|
||||||
}
|
}
|
||||||
@@ -165,12 +164,12 @@ func TestSaveAliases_CaseInsensitive(t *testing.T) {
|
|||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
if err := saveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
|
if err := SaveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
|
||||||
t.Fatalf("failed to save: %v", err)
|
t.Fatalf("failed to save: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load with different case
|
// Load with different case
|
||||||
loaded, err := loadIntegration("claude")
|
loaded, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load: %v", err)
|
t.Fatalf("failed to load: %v", err)
|
||||||
}
|
}
|
||||||
@@ -179,11 +178,11 @@ func TestSaveAliases_CaseInsensitive(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update with different case
|
// Update with different case
|
||||||
if err := saveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
|
if err := SaveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
|
||||||
t.Fatalf("failed to update: %v", err)
|
t.Fatalf("failed to update: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, err = loadIntegration("claude")
|
loaded, err = LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load after update: %v", err)
|
t.Fatalf("failed to load after update: %v", err)
|
||||||
}
|
}
|
||||||
@@ -198,11 +197,11 @@ func TestSaveAliases_CreatesIntegration(t *testing.T) {
|
|||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
// Save aliases for non-existent integration
|
// Save aliases for non-existent integration
|
||||||
if err := saveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
|
if err := SaveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
|
||||||
t.Fatalf("failed to save: %v", err)
|
t.Fatalf("failed to save: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, err := loadIntegration("newintegration")
|
loaded, err := LoadIntegration("newintegration")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load: %v", err)
|
t.Fatalf("failed to load: %v", err)
|
||||||
}
|
}
|
||||||
@@ -371,12 +370,12 @@ func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) {
|
|||||||
t.Fatal("server should succeed")
|
t.Fatal("server should succeed")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||||
t.Fatalf("saveAliases failed: %v", err)
|
t.Fatalf("saveAliases failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify it was actually saved
|
// Verify it was actually saved
|
||||||
loaded, err := loadIntegration("claude")
|
loaded, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load: %v", err)
|
t.Fatalf("failed to load: %v", err)
|
||||||
}
|
}
|
||||||
@@ -408,7 +407,7 @@ func TestConfigFile_PreservesUnknownFields(t *testing.T) {
|
|||||||
os.WriteFile(configPath, []byte(initialConfig), 0o644)
|
os.WriteFile(configPath, []byte(initialConfig), 0o644)
|
||||||
|
|
||||||
// Update aliases
|
// Update aliases
|
||||||
if err := saveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
|
if err := SaveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
|
||||||
t.Fatalf("failed to save: %v", err)
|
t.Fatalf("failed to save: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -440,11 +439,6 @@ func containsHelper(s, substr string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClaudeImplementsAliasConfigurer(t *testing.T) {
|
|
||||||
c := &Claude{}
|
|
||||||
var _ AliasConfigurer = c // Compile-time check
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestModelNameEdgeCases(t *testing.T) {
|
func TestModelNameEdgeCases(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -464,11 +458,11 @@ func TestModelNameEdgeCases(t *testing.T) {
|
|||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
aliases := map[string]string{"primary": tc.model}
|
aliases := map[string]string{"primary": tc.model}
|
||||||
if err := saveAliases("claude", aliases); err != nil {
|
if err := SaveAliases("claude", aliases); err != nil {
|
||||||
t.Fatalf("failed to save model %q: %v", tc.model, err)
|
t.Fatalf("failed to save model %q: %v", tc.model, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, err := loadIntegration("claude")
|
loaded, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load: %v", err)
|
t.Fatalf("failed to load: %v", err)
|
||||||
}
|
}
|
||||||
@@ -485,7 +479,7 @@ func TestSwitchingScenarios(t *testing.T) {
|
|||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
// Initial cloud config
|
// Initial cloud config
|
||||||
if err := saveAliases("claude", map[string]string{
|
if err := SaveAliases("claude", map[string]string{
|
||||||
"primary": "cloud-model",
|
"primary": "cloud-model",
|
||||||
"fast": "cloud-model",
|
"fast": "cloud-model",
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
@@ -493,13 +487,13 @@ func TestSwitchingScenarios(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Switch to local (no fast)
|
// Switch to local (no fast)
|
||||||
if err := saveAliases("claude", map[string]string{
|
if err := SaveAliases("claude", map[string]string{
|
||||||
"primary": "local-model",
|
"primary": "local-model",
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, _ := loadIntegration("claude")
|
loaded, _ := LoadIntegration("claude")
|
||||||
if loaded.Aliases["fast"] != "" {
|
if loaded.Aliases["fast"] != "" {
|
||||||
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
|
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
|
||||||
}
|
}
|
||||||
@@ -513,21 +507,21 @@ func TestSwitchingScenarios(t *testing.T) {
|
|||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
// Initial local config
|
// Initial local config
|
||||||
if err := saveAliases("claude", map[string]string{
|
if err := SaveAliases("claude", map[string]string{
|
||||||
"primary": "local-model",
|
"primary": "local-model",
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Switch to cloud (with fast)
|
// Switch to cloud (with fast)
|
||||||
if err := saveAliases("claude", map[string]string{
|
if err := SaveAliases("claude", map[string]string{
|
||||||
"primary": "cloud-model",
|
"primary": "cloud-model",
|
||||||
"fast": "cloud-model",
|
"fast": "cloud-model",
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, _ := loadIntegration("claude")
|
loaded, _ := LoadIntegration("claude")
|
||||||
if loaded.Aliases["fast"] != "cloud-model" {
|
if loaded.Aliases["fast"] != "cloud-model" {
|
||||||
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
|
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
|
||||||
}
|
}
|
||||||
@@ -538,7 +532,7 @@ func TestSwitchingScenarios(t *testing.T) {
|
|||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
// Initial cloud config
|
// Initial cloud config
|
||||||
if err := saveAliases("claude", map[string]string{
|
if err := SaveAliases("claude", map[string]string{
|
||||||
"primary": "cloud-model-1",
|
"primary": "cloud-model-1",
|
||||||
"fast": "cloud-model-1",
|
"fast": "cloud-model-1",
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
@@ -546,14 +540,14 @@ func TestSwitchingScenarios(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Switch to different cloud
|
// Switch to different cloud
|
||||||
if err := saveAliases("claude", map[string]string{
|
if err := SaveAliases("claude", map[string]string{
|
||||||
"primary": "cloud-model-2",
|
"primary": "cloud-model-2",
|
||||||
"fast": "cloud-model-2",
|
"fast": "cloud-model-2",
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, _ := loadIntegration("claude")
|
loaded, _ := LoadIntegration("claude")
|
||||||
if loaded.Aliases["primary"] != "cloud-model-2" {
|
if loaded.Aliases["primary"] != "cloud-model-2" {
|
||||||
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
|
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
|
||||||
}
|
}
|
||||||
@@ -563,43 +557,13 @@ func TestSwitchingScenarios(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestToolCapabilityFiltering(t *testing.T) {
|
|
||||||
t.Run("all models checked for tool capability", func(t *testing.T) {
|
|
||||||
// Both cloud and local models are checked for tool capability via Show API
|
|
||||||
// Only models with "tools" in capabilities are included
|
|
||||||
m := modelInfo{Name: "tool-model", Remote: false, ToolCapable: true}
|
|
||||||
if !m.ToolCapable {
|
|
||||||
t.Error("tool capable model should be marked as such")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("modelInfo includes ToolCapable field", func(t *testing.T) {
|
|
||||||
m := modelInfo{Name: "test", Remote: true, ToolCapable: true}
|
|
||||||
if !m.ToolCapable {
|
|
||||||
t.Error("ToolCapable field should be accessible")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsCloudModel_RequiresClient(t *testing.T) {
|
|
||||||
t.Run("nil client always returns false", func(t *testing.T) {
|
|
||||||
// isCloudModel now only uses Show API, no suffix detection
|
|
||||||
if isCloudModel(context.Background(), nil, "model:cloud") {
|
|
||||||
t.Error("nil client should return false regardless of suffix")
|
|
||||||
}
|
|
||||||
if isCloudModel(context.Background(), nil, "local-model") {
|
|
||||||
t.Error("nil client should return false")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||||
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
|
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
// Save aliases with one model
|
// Save aliases with one model
|
||||||
if err := saveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
|
if err := SaveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -608,7 +572,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, _ := loadIntegration("claude")
|
loaded, _ := LoadIntegration("claude")
|
||||||
if loaded.Aliases["primary"] != loaded.Models[0] {
|
if loaded.Aliases["primary"] != loaded.Models[0] {
|
||||||
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
|
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
|
||||||
}
|
}
|
||||||
@@ -622,11 +586,11 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
|||||||
if err := SaveIntegration("claude", []string{"old-model"}); err != nil {
|
if err := SaveIntegration("claude", []string{"old-model"}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
|
if err := SaveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, _ := loadIntegration("claude")
|
loaded, _ := LoadIntegration("claude")
|
||||||
|
|
||||||
// They should be different (this is the bug state)
|
// They should be different (this is the bug state)
|
||||||
if loaded.Models[0] == loaded.Aliases["primary"] {
|
if loaded.Models[0] == loaded.Aliases["primary"] {
|
||||||
@@ -638,7 +602,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, _ = loadIntegration("claude")
|
loaded, _ = LoadIntegration("claude")
|
||||||
if loaded.Models[0] != loaded.Aliases["primary"] {
|
if loaded.Models[0] != loaded.Aliases["primary"] {
|
||||||
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
|
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
|
||||||
loaded.Models[0], loaded.Aliases["primary"])
|
loaded.Models[0], loaded.Aliases["primary"])
|
||||||
@@ -653,20 +617,20 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
|||||||
if err := SaveIntegration("claude", []string{"initial-model"}); err != nil {
|
if err := SaveIntegration("claude", []string{"initial-model"}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
|
if err := SaveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update aliases AND models together
|
// Update aliases AND models together
|
||||||
newAliases := map[string]string{"primary": "updated-model"}
|
newAliases := map[string]string{"primary": "updated-model"}
|
||||||
if err := saveAliases("claude", newAliases); err != nil {
|
if err := SaveAliases("claude", newAliases); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := SaveIntegration("claude", []string{newAliases["primary"]}); err != nil {
|
if err := SaveIntegration("claude", []string{newAliases["primary"]}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, _ := loadIntegration("claude")
|
loaded, _ := LoadIntegration("claude")
|
||||||
if loaded.Models[0] != "updated-model" {
|
if loaded.Models[0] != "updated-model" {
|
||||||
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
|
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,17 +10,10 @@ import (
|
|||||||
// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests
|
// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests
|
||||||
func setTestHome(t *testing.T, dir string) {
|
func setTestHome(t *testing.T, dir string) {
|
||||||
t.Setenv("HOME", dir)
|
t.Setenv("HOME", dir)
|
||||||
|
t.Setenv("TMPDIR", dir)
|
||||||
t.Setenv("USERPROFILE", dir)
|
t.Setenv("USERPROFILE", dir)
|
||||||
}
|
}
|
||||||
|
|
||||||
// editorPaths is a test helper that safely calls Paths if the runner implements Editor
|
|
||||||
func editorPaths(r Runner) []string {
|
|
||||||
if editor, ok := r.(Editor); ok {
|
|
||||||
return editor.Paths()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntegrationConfig(t *testing.T) {
|
func TestIntegrationConfig(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
@@ -31,7 +24,7 @@ func TestIntegrationConfig(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := loadIntegration("claude")
|
config, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -55,11 +48,11 @@ func TestIntegrationConfig(t *testing.T) {
|
|||||||
"primary": "llama3.2:70b",
|
"primary": "llama3.2:70b",
|
||||||
"fast": "llama3.2:8b",
|
"fast": "llama3.2:8b",
|
||||||
}
|
}
|
||||||
if err := saveAliases("claude", aliases); err != nil {
|
if err := SaveAliases("claude", aliases); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := loadIntegration("claude")
|
config, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -77,14 +70,14 @@ func TestIntegrationConfig(t *testing.T) {
|
|||||||
if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
|
if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
|
if err := SaveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := SaveIntegration("claude", []string{"model-b"}); err != nil {
|
if err := SaveIntegration("claude", []string{"model-b"}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
config, err := loadIntegration("claude")
|
config, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -96,7 +89,7 @@ func TestIntegrationConfig(t *testing.T) {
|
|||||||
t.Run("defaultModel returns first model", func(t *testing.T) {
|
t.Run("defaultModel returns first model", func(t *testing.T) {
|
||||||
SaveIntegration("codex", []string{"model-a", "model-b"})
|
SaveIntegration("codex", []string{"model-a", "model-b"})
|
||||||
|
|
||||||
config, _ := loadIntegration("codex")
|
config, _ := LoadIntegration("codex")
|
||||||
defaultModel := ""
|
defaultModel := ""
|
||||||
if len(config.Models) > 0 {
|
if len(config.Models) > 0 {
|
||||||
defaultModel = config.Models[0]
|
defaultModel = config.Models[0]
|
||||||
@@ -120,7 +113,7 @@ func TestIntegrationConfig(t *testing.T) {
|
|||||||
t.Run("app name is case-insensitive", func(t *testing.T) {
|
t.Run("app name is case-insensitive", func(t *testing.T) {
|
||||||
SaveIntegration("Claude", []string{"model-x"})
|
SaveIntegration("Claude", []string{"model-x"})
|
||||||
|
|
||||||
config, err := loadIntegration("claude")
|
config, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -137,8 +130,8 @@ func TestIntegrationConfig(t *testing.T) {
|
|||||||
SaveIntegration("app1", []string{"model-1"})
|
SaveIntegration("app1", []string{"model-1"})
|
||||||
SaveIntegration("app2", []string{"model-2"})
|
SaveIntegration("app2", []string{"model-2"})
|
||||||
|
|
||||||
config1, _ := loadIntegration("app1")
|
config1, _ := LoadIntegration("app1")
|
||||||
config2, _ := loadIntegration("app2")
|
config2, _ := LoadIntegration("app2")
|
||||||
|
|
||||||
defaultModel1 := ""
|
defaultModel1 := ""
|
||||||
if len(config1.Models) > 0 {
|
if len(config1.Models) > 0 {
|
||||||
@@ -185,64 +178,6 @@ func TestListIntegrations(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEditorPaths(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
t.Run("returns empty for claude (no Editor)", func(t *testing.T) {
|
|
||||||
r := integrations["claude"]
|
|
||||||
paths := editorPaths(r)
|
|
||||||
if len(paths) != 0 {
|
|
||||||
t.Errorf("expected no paths for claude, got %v", paths)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns empty for codex (no Editor)", func(t *testing.T) {
|
|
||||||
r := integrations["codex"]
|
|
||||||
paths := editorPaths(r)
|
|
||||||
if len(paths) != 0 {
|
|
||||||
t.Errorf("expected no paths for codex, got %v", paths)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns empty for droid when no config exists", func(t *testing.T) {
|
|
||||||
r := integrations["droid"]
|
|
||||||
paths := editorPaths(r)
|
|
||||||
if len(paths) != 0 {
|
|
||||||
t.Errorf("expected no paths, got %v", paths)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns path for droid when config exists", func(t *testing.T) {
|
|
||||||
settingsDir, _ := os.UserHomeDir()
|
|
||||||
settingsDir = filepath.Join(settingsDir, ".factory")
|
|
||||||
os.MkdirAll(settingsDir, 0o755)
|
|
||||||
os.WriteFile(filepath.Join(settingsDir, "settings.json"), []byte(`{}`), 0o644)
|
|
||||||
|
|
||||||
r := integrations["droid"]
|
|
||||||
paths := editorPaths(r)
|
|
||||||
if len(paths) != 1 {
|
|
||||||
t.Errorf("expected 1 path, got %d", len(paths))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns paths for opencode when configs exist", func(t *testing.T) {
|
|
||||||
home, _ := os.UserHomeDir()
|
|
||||||
configDir := filepath.Join(home, ".config", "opencode")
|
|
||||||
stateDir := filepath.Join(home, ".local", "state", "opencode")
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.MkdirAll(stateDir, 0o755)
|
|
||||||
os.WriteFile(filepath.Join(configDir, "opencode.json"), []byte(`{}`), 0o644)
|
|
||||||
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644)
|
|
||||||
|
|
||||||
r := integrations["opencode"]
|
|
||||||
paths := editorPaths(r)
|
|
||||||
if len(paths) != 2 {
|
|
||||||
t.Errorf("expected 2 paths, got %d: %v", len(paths), paths)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
@@ -251,7 +186,7 @@ func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
|||||||
os.MkdirAll(dir, 0o755)
|
os.MkdirAll(dir, 0o755)
|
||||||
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
|
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
|
||||||
|
|
||||||
_, err := loadIntegration("test")
|
_, err := LoadIntegration("test")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error for nonexistent integration in corrupted file")
|
t.Error("expected error for nonexistent integration in corrupted file")
|
||||||
}
|
}
|
||||||
@@ -265,7 +200,7 @@ func TestSaveIntegration_NilModels(t *testing.T) {
|
|||||||
t.Fatalf("saveIntegration with nil models failed: %v", err)
|
t.Fatalf("saveIntegration with nil models failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := loadIntegration("test")
|
config, err := LoadIntegration("test")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("loadIntegration failed: %v", err)
|
t.Fatalf("loadIntegration failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -294,7 +229,7 @@ func TestLoadIntegration_NonexistentIntegration(t *testing.T) {
|
|||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
_, err := loadIntegration("nonexistent")
|
_, err := LoadIntegration("nonexistent")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error for nonexistent integration, got nil")
|
t.Error("expected error for nonexistent integration, got nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,279 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"maps"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"path/filepath"
|
|
||||||
"slices"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
|
||||||
)
|
|
||||||
|
|
||||||
// OpenCode implements Runner and Editor for OpenCode integration
|
|
||||||
type OpenCode struct{}
|
|
||||||
|
|
||||||
// cloudModelLimit holds context and output token limits for a cloud model.
|
|
||||||
type cloudModelLimit struct {
|
|
||||||
Context int
|
|
||||||
Output int
|
|
||||||
}
|
|
||||||
|
|
||||||
// lookupCloudModelLimit returns the token limits for a cloud model.
|
|
||||||
// It tries the exact name first, then strips the ":cloud" suffix.
|
|
||||||
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
|
||||||
if l, ok := cloudModelLimits[name]; ok {
|
|
||||||
return l, true
|
|
||||||
}
|
|
||||||
base := strings.TrimSuffix(name, ":cloud")
|
|
||||||
if base != name {
|
|
||||||
if l, ok := cloudModelLimits[base]; ok {
|
|
||||||
return l, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return cloudModelLimit{}, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *OpenCode) String() string { return "OpenCode" }
|
|
||||||
|
|
||||||
func (o *OpenCode) Run(model string, args []string) error {
|
|
||||||
if _, err := exec.LookPath("opencode"); err != nil {
|
|
||||||
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Call Edit() to ensure config is up-to-date before launch
|
|
||||||
models := []string{model}
|
|
||||||
if config, err := loadIntegration("opencode"); err == nil && len(config.Models) > 0 {
|
|
||||||
models = config.Models
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
models, err = resolveEditorModels("opencode", models, func() ([]string, error) {
|
|
||||||
return selectModels(context.Background(), "opencode", "")
|
|
||||||
})
|
|
||||||
if errors.Is(err, errCancelled) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := o.Edit(models); err != nil {
|
|
||||||
return fmt.Errorf("setup failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command("opencode", args...)
|
|
||||||
cmd.Stdin = os.Stdin
|
|
||||||
cmd.Stdout = os.Stdout
|
|
||||||
cmd.Stderr = os.Stderr
|
|
||||||
return cmd.Run()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *OpenCode) Paths() []string {
|
|
||||||
home, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var paths []string
|
|
||||||
p := filepath.Join(home, ".config", "opencode", "opencode.json")
|
|
||||||
if _, err := os.Stat(p); err == nil {
|
|
||||||
paths = append(paths, p)
|
|
||||||
}
|
|
||||||
sp := filepath.Join(home, ".local", "state", "opencode", "model.json")
|
|
||||||
if _, err := os.Stat(sp); err == nil {
|
|
||||||
paths = append(paths, sp)
|
|
||||||
}
|
|
||||||
return paths
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *OpenCode) Edit(modelList []string) error {
|
|
||||||
if len(modelList) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
home, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
configPath := filepath.Join(home, ".config", "opencode", "opencode.json")
|
|
||||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
config := make(map[string]any)
|
|
||||||
if data, err := os.ReadFile(configPath); err == nil {
|
|
||||||
_ = json.Unmarshal(data, &config) // Ignore parse errors; treat missing/corrupt files as empty
|
|
||||||
}
|
|
||||||
|
|
||||||
config["$schema"] = "https://opencode.ai/config.json"
|
|
||||||
|
|
||||||
provider, ok := config["provider"].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
provider = make(map[string]any)
|
|
||||||
}
|
|
||||||
|
|
||||||
ollama, ok := provider["ollama"].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
ollama = map[string]any{
|
|
||||||
"npm": "@ai-sdk/openai-compatible",
|
|
||||||
"name": "Ollama (local)",
|
|
||||||
"options": map[string]any{
|
|
||||||
"baseURL": envconfig.Host().String() + "/v1",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
models, ok := ollama["models"].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
models = make(map[string]any)
|
|
||||||
}
|
|
||||||
|
|
||||||
selectedSet := make(map[string]bool)
|
|
||||||
for _, m := range modelList {
|
|
||||||
selectedSet[m] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, cfg := range models {
|
|
||||||
if cfgMap, ok := cfg.(map[string]any); ok {
|
|
||||||
if isOllamaModel(cfgMap) && !selectedSet[name] {
|
|
||||||
delete(models, name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
client, _ := api.ClientFromEnvironment()
|
|
||||||
|
|
||||||
for _, model := range modelList {
|
|
||||||
if existing, ok := models[model].(map[string]any); ok {
|
|
||||||
// migrate existing models without _launch marker
|
|
||||||
if isOllamaModel(existing) {
|
|
||||||
existing["_launch"] = true
|
|
||||||
if name, ok := existing["name"].(string); ok {
|
|
||||||
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if isCloudModel(context.Background(), client, model) {
|
|
||||||
if l, ok := lookupCloudModelLimit(model); ok {
|
|
||||||
existing["limit"] = map[string]any{
|
|
||||||
"context": l.Context,
|
|
||||||
"output": l.Output,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
entry := map[string]any{
|
|
||||||
"name": model,
|
|
||||||
"_launch": true,
|
|
||||||
}
|
|
||||||
if isCloudModel(context.Background(), client, model) {
|
|
||||||
if l, ok := lookupCloudModelLimit(model); ok {
|
|
||||||
entry["limit"] = map[string]any{
|
|
||||||
"context": l.Context,
|
|
||||||
"output": l.Output,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
models[model] = entry
|
|
||||||
}
|
|
||||||
|
|
||||||
ollama["models"] = models
|
|
||||||
provider["ollama"] = ollama
|
|
||||||
config["provider"] = provider
|
|
||||||
|
|
||||||
configData, err := json.MarshalIndent(config, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := writeWithBackup(configPath, configData); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
statePath := filepath.Join(home, ".local", "state", "opencode", "model.json")
|
|
||||||
if err := os.MkdirAll(filepath.Dir(statePath), 0o755); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
state := map[string]any{
|
|
||||||
"recent": []any{},
|
|
||||||
"favorite": []any{},
|
|
||||||
"variant": map[string]any{},
|
|
||||||
}
|
|
||||||
if data, err := os.ReadFile(statePath); err == nil {
|
|
||||||
_ = json.Unmarshal(data, &state) // Ignore parse errors; use defaults
|
|
||||||
}
|
|
||||||
|
|
||||||
recent, _ := state["recent"].([]any)
|
|
||||||
|
|
||||||
modelSet := make(map[string]bool)
|
|
||||||
for _, m := range modelList {
|
|
||||||
modelSet[m] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Filter out existing Ollama models we're about to re-add
|
|
||||||
newRecent := slices.DeleteFunc(slices.Clone(recent), func(entry any) bool {
|
|
||||||
e, ok := entry.(map[string]any)
|
|
||||||
if !ok || e["providerID"] != "ollama" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
modelID, _ := e["modelID"].(string)
|
|
||||||
return modelSet[modelID]
|
|
||||||
})
|
|
||||||
|
|
||||||
// Prepend models in reverse order so first model ends up first
|
|
||||||
for _, model := range slices.Backward(modelList) {
|
|
||||||
newRecent = slices.Insert(newRecent, 0, any(map[string]any{
|
|
||||||
"providerID": "ollama",
|
|
||||||
"modelID": model,
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
const maxRecentModels = 10
|
|
||||||
newRecent = newRecent[:min(len(newRecent), maxRecentModels)]
|
|
||||||
|
|
||||||
state["recent"] = newRecent
|
|
||||||
|
|
||||||
stateData, err := json.MarshalIndent(state, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return writeWithBackup(statePath, stateData)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *OpenCode) Models() []string {
|
|
||||||
home, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
config, err := readJSONFile(filepath.Join(home, ".config", "opencode", "opencode.json"))
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
provider, _ := config["provider"].(map[string]any)
|
|
||||||
ollama, _ := provider["ollama"].(map[string]any)
|
|
||||||
models, _ := ollama["models"].(map[string]any)
|
|
||||||
if len(models) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
keys := slices.Collect(maps.Keys(models))
|
|
||||||
slices.Sort(keys)
|
|
||||||
return keys
|
|
||||||
}
|
|
||||||
|
|
||||||
// isOllamaModel reports whether a model config entry is managed by us
|
|
||||||
func isOllamaModel(cfg map[string]any) bool {
|
|
||||||
if v, ok := cfg["_launch"].(bool); ok && v {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// previously used [Ollama] as a suffix for the model managed by ollama launch
|
|
||||||
if name, ok := cfg["name"].(string); ok {
|
|
||||||
return strings.HasSuffix(name, "[Ollama]")
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
@@ -1,668 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestOpenCodeIntegration(t *testing.T) {
|
|
||||||
o := &OpenCode{}
|
|
||||||
|
|
||||||
t.Run("String", func(t *testing.T) {
|
|
||||||
if got := o.String(); got != "OpenCode" {
|
|
||||||
t.Errorf("String() = %q, want %q", got, "OpenCode")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("implements Runner", func(t *testing.T) {
|
|
||||||
var _ Runner = o
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("implements Editor", func(t *testing.T) {
|
|
||||||
var _ Editor = o
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenCodeEdit(t *testing.T) {
|
|
||||||
o := &OpenCode{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
|
||||||
configPath := filepath.Join(configDir, "opencode.json")
|
|
||||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
|
||||||
statePath := filepath.Join(stateDir, "model.json")
|
|
||||||
|
|
||||||
cleanup := func() {
|
|
||||||
os.RemoveAll(configDir)
|
|
||||||
os.RemoveAll(stateDir)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("fresh install", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
|
||||||
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("preserve other providers", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{"provider":{"anthropic":{"apiKey":"xxx"}}}`), 0o644)
|
|
||||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
provider := cfg["provider"].(map[string]any)
|
|
||||||
if provider["anthropic"] == nil {
|
|
||||||
t.Error("anthropic provider was removed")
|
|
||||||
}
|
|
||||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("preserve other models", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"mistral":{"name":"Mistral"}}}}}`), 0o644)
|
|
||||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
assertOpenCodeModelExists(t, configPath, "mistral")
|
|
||||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("update existing model", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
o.Edit([]string{"llama3.2"})
|
|
||||||
o.Edit([]string{"llama3.2"})
|
|
||||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("preserve top-level keys", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{"theme":"dark","keybindings":{}}`), 0o644)
|
|
||||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
if cfg["theme"] != "dark" {
|
|
||||||
t.Error("theme was removed")
|
|
||||||
}
|
|
||||||
if cfg["keybindings"] == nil {
|
|
||||||
t.Error("keybindings was removed")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("model state - insert at index 0", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(stateDir, 0o755)
|
|
||||||
os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644)
|
|
||||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
|
|
||||||
assertOpenCodeRecentModel(t, statePath, 1, "anthropic", "claude")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("model state - preserve favorites and variants", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(stateDir, 0o755)
|
|
||||||
os.WriteFile(statePath, []byte(`{"recent":[],"favorite":[{"providerID":"x","modelID":"y"}],"variant":{"a":"b"}}`), 0o644)
|
|
||||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
data, _ := os.ReadFile(statePath)
|
|
||||||
var state map[string]any
|
|
||||||
json.Unmarshal(data, &state)
|
|
||||||
if len(state["favorite"].([]any)) != 1 {
|
|
||||||
t.Error("favorite was modified")
|
|
||||||
}
|
|
||||||
if state["variant"].(map[string]any)["a"] != "b" {
|
|
||||||
t.Error("variant was modified")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("model state - deduplicate on re-add", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(stateDir, 0o755)
|
|
||||||
os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"ollama","modelID":"llama3.2"},{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644)
|
|
||||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
data, _ := os.ReadFile(statePath)
|
|
||||||
var state map[string]any
|
|
||||||
json.Unmarshal(data, &state)
|
|
||||||
recent := state["recent"].([]any)
|
|
||||||
if len(recent) != 2 {
|
|
||||||
t.Errorf("expected 2 recent entries, got %d", len(recent))
|
|
||||||
}
|
|
||||||
assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("remove model", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
// First add two models
|
|
||||||
o.Edit([]string{"llama3.2", "mistral"})
|
|
||||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
|
||||||
assertOpenCodeModelExists(t, configPath, "mistral")
|
|
||||||
|
|
||||||
// Then remove one by only selecting the other
|
|
||||||
o.Edit([]string{"llama3.2"})
|
|
||||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
|
||||||
assertOpenCodeModelNotExists(t, configPath, "mistral")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("preserve user customizations on managed models", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add custom fields to the model entry (simulating user edits)
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
provider := cfg["provider"].(map[string]any)
|
|
||||||
ollama := provider["ollama"].(map[string]any)
|
|
||||||
models := ollama["models"].(map[string]any)
|
|
||||||
entry := models["llama3.2"].(map[string]any)
|
|
||||||
entry["_myPref"] = "custom-value"
|
|
||||||
entry["_myNum"] = 42
|
|
||||||
configData, _ := json.MarshalIndent(cfg, "", " ")
|
|
||||||
os.WriteFile(configPath, configData, 0o644)
|
|
||||||
|
|
||||||
// Re-run Edit — should preserve custom fields
|
|
||||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
data, _ = os.ReadFile(configPath)
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
provider = cfg["provider"].(map[string]any)
|
|
||||||
ollama = provider["ollama"].(map[string]any)
|
|
||||||
models = ollama["models"].(map[string]any)
|
|
||||||
entry = models["llama3.2"].(map[string]any)
|
|
||||||
|
|
||||||
if entry["_myPref"] != "custom-value" {
|
|
||||||
t.Errorf("_myPref was lost: got %v", entry["_myPref"])
|
|
||||||
}
|
|
||||||
if entry["_myNum"] != float64(42) {
|
|
||||||
t.Errorf("_myNum was lost: got %v", entry["_myNum"])
|
|
||||||
}
|
|
||||||
if v, ok := entry["_launch"].(bool); !ok || !v {
|
|
||||||
t.Errorf("_launch marker missing or false: got %v", entry["_launch"])
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("migrate legacy [Ollama] suffix entries", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
// Write a config with a legacy entry (has [Ollama] suffix but no _launch marker)
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"llama3.2":{"name":"llama3.2 [Ollama]"}}}}}`), 0o644)
|
|
||||||
|
|
||||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
provider := cfg["provider"].(map[string]any)
|
|
||||||
ollama := provider["ollama"].(map[string]any)
|
|
||||||
models := ollama["models"].(map[string]any)
|
|
||||||
entry := models["llama3.2"].(map[string]any)
|
|
||||||
|
|
||||||
// _launch marker should be added
|
|
||||||
if v, ok := entry["_launch"].(bool); !ok || !v {
|
|
||||||
t.Errorf("_launch marker not added during migration: got %v", entry["_launch"])
|
|
||||||
}
|
|
||||||
// [Ollama] suffix should be stripped
|
|
||||||
if name, ok := entry["name"].(string); !ok || name != "llama3.2" {
|
|
||||||
t.Errorf("name suffix not stripped: got %q", entry["name"])
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
// Add a non-Ollama model manually
|
|
||||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"external":{"name":"External Model"}}}}}`), 0o644)
|
|
||||||
|
|
||||||
o.Edit([]string{"llama3.2"})
|
|
||||||
assertOpenCodeModelExists(t, configPath, "llama3.2")
|
|
||||||
assertOpenCodeModelExists(t, configPath, "external") // Should be preserved
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertOpenCodeModelExists(t *testing.T, path, model string) {
|
|
||||||
t.Helper()
|
|
||||||
data, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
var cfg map[string]any
|
|
||||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
provider, ok := cfg["provider"].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("provider not found")
|
|
||||||
}
|
|
||||||
ollama, ok := provider["ollama"].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("ollama provider not found")
|
|
||||||
}
|
|
||||||
models, ok := ollama["models"].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("models not found")
|
|
||||||
}
|
|
||||||
if models[model] == nil {
|
|
||||||
t.Errorf("model %s not found", model)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertOpenCodeModelNotExists(t *testing.T, path, model string) {
|
|
||||||
t.Helper()
|
|
||||||
data, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
var cfg map[string]any
|
|
||||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
provider, ok := cfg["provider"].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
return // No provider means no model
|
|
||||||
}
|
|
||||||
ollama, ok := provider["ollama"].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
return // No ollama means no model
|
|
||||||
}
|
|
||||||
models, ok := ollama["models"].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
return // No models means no model
|
|
||||||
}
|
|
||||||
if models[model] != nil {
|
|
||||||
t.Errorf("model %s should not exist but was found", model)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertOpenCodeRecentModel(t *testing.T, path string, index int, providerID, modelID string) {
|
|
||||||
t.Helper()
|
|
||||||
data, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
var state map[string]any
|
|
||||||
if err := json.Unmarshal(data, &state); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
recent, ok := state["recent"].([]any)
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("recent not found")
|
|
||||||
}
|
|
||||||
if index >= len(recent) {
|
|
||||||
t.Fatalf("index %d out of range (len=%d)", index, len(recent))
|
|
||||||
}
|
|
||||||
entry, ok := recent[index].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("entry is not a map")
|
|
||||||
}
|
|
||||||
if entry["providerID"] != providerID {
|
|
||||||
t.Errorf("expected providerID %s, got %s", providerID, entry["providerID"])
|
|
||||||
}
|
|
||||||
if entry["modelID"] != modelID {
|
|
||||||
t.Errorf("expected modelID %s, got %s", modelID, entry["modelID"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Edge case tests for opencode.go
|
|
||||||
|
|
||||||
func TestOpenCodeEdit_CorruptedConfigJSON(t *testing.T) {
|
|
||||||
o := &OpenCode{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
|
||||||
configPath := filepath.Join(configDir, "opencode.json")
|
|
||||||
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{corrupted json content`), 0o644)
|
|
||||||
|
|
||||||
// Should not panic - corrupted JSON should be treated as empty
|
|
||||||
err := o.Edit([]string{"llama3.2"})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Edit failed with corrupted config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify valid JSON was created
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
var cfg map[string]any
|
|
||||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
|
||||||
t.Errorf("resulting config is not valid JSON: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenCodeEdit_CorruptedStateJSON(t *testing.T) {
|
|
||||||
o := &OpenCode{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
|
||||||
statePath := filepath.Join(stateDir, "model.json")
|
|
||||||
|
|
||||||
os.MkdirAll(stateDir, 0o755)
|
|
||||||
os.WriteFile(statePath, []byte(`{corrupted state`), 0o644)
|
|
||||||
|
|
||||||
err := o.Edit([]string{"llama3.2"})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Edit failed with corrupted state: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify valid state was created
|
|
||||||
data, _ := os.ReadFile(statePath)
|
|
||||||
var state map[string]any
|
|
||||||
if err := json.Unmarshal(data, &state); err != nil {
|
|
||||||
t.Errorf("resulting state is not valid JSON: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenCodeEdit_WrongTypeProvider(t *testing.T) {
|
|
||||||
o := &OpenCode{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
|
||||||
configPath := filepath.Join(configDir, "opencode.json")
|
|
||||||
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{"provider": "not a map"}`), 0o644)
|
|
||||||
|
|
||||||
err := o.Edit([]string{"llama3.2"})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Edit with wrong type provider failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify provider is now correct type
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
|
|
||||||
provider, ok := cfg["provider"].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("provider should be map after setup, got %T", cfg["provider"])
|
|
||||||
}
|
|
||||||
if provider["ollama"] == nil {
|
|
||||||
t.Error("ollama provider should be created")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenCodeEdit_WrongTypeRecent(t *testing.T) {
|
|
||||||
o := &OpenCode{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
|
||||||
statePath := filepath.Join(stateDir, "model.json")
|
|
||||||
|
|
||||||
os.MkdirAll(stateDir, 0o755)
|
|
||||||
os.WriteFile(statePath, []byte(`{"recent": "not an array", "favorite": [], "variant": {}}`), 0o644)
|
|
||||||
|
|
||||||
err := o.Edit([]string{"llama3.2"})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Edit with wrong type recent failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The function should handle this gracefully
|
|
||||||
data, _ := os.ReadFile(statePath)
|
|
||||||
var state map[string]any
|
|
||||||
json.Unmarshal(data, &state)
|
|
||||||
|
|
||||||
// recent should be properly set after setup
|
|
||||||
recent, ok := state["recent"].([]any)
|
|
||||||
if !ok {
|
|
||||||
t.Logf("Note: recent type after setup is %T (documenting behavior)", state["recent"])
|
|
||||||
} else if len(recent) == 0 {
|
|
||||||
t.Logf("Note: recent is empty (documenting behavior)")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenCodeEdit_EmptyModels(t *testing.T) {
|
|
||||||
o := &OpenCode{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
|
||||||
configPath := filepath.Join(configDir, "opencode.json")
|
|
||||||
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
originalContent := `{"provider":{"ollama":{"models":{"existing":{}}}}}`
|
|
||||||
os.WriteFile(configPath, []byte(originalContent), 0o644)
|
|
||||||
|
|
||||||
// Empty models should be no-op
|
|
||||||
err := o.Edit([]string{})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Edit with empty models failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Original content should be preserved (file not modified)
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
if string(data) != originalContent {
|
|
||||||
t.Errorf("empty models should not modify file, but content changed")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
|
|
||||||
o := &OpenCode{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
// Model name with special characters (though unusual)
|
|
||||||
specialModel := `model-with-"quotes"`
|
|
||||||
|
|
||||||
err := o.Edit([]string{specialModel})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Edit with special chars failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify it was stored correctly
|
|
||||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
|
||||||
configPath := filepath.Join(configDir, "opencode.json")
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
|
|
||||||
var cfg map[string]any
|
|
||||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
|
||||||
t.Fatalf("resulting config is invalid JSON: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Model should be accessible
|
|
||||||
provider, _ := cfg["provider"].(map[string]any)
|
|
||||||
ollama, _ := provider["ollama"].(map[string]any)
|
|
||||||
models, _ := ollama["models"].(map[string]any)
|
|
||||||
|
|
||||||
if models[specialModel] == nil {
|
|
||||||
t.Errorf("model with special chars not found in config")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func readOpenCodeModel(t *testing.T, configPath, model string) map[string]any {
|
|
||||||
t.Helper()
|
|
||||||
data, err := os.ReadFile(configPath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
provider := cfg["provider"].(map[string]any)
|
|
||||||
ollama := provider["ollama"].(map[string]any)
|
|
||||||
models := ollama["models"].(map[string]any)
|
|
||||||
entry, ok := models[model].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("model %s not found in config", model)
|
|
||||||
}
|
|
||||||
return entry
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenCodeEdit_LocalModelNoLimit(t *testing.T) {
|
|
||||||
o := &OpenCode{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
configPath := filepath.Join(tmpDir, ".config", "opencode", "opencode.json")
|
|
||||||
|
|
||||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
entry := readOpenCodeModel(t, configPath, "llama3.2")
|
|
||||||
if entry["limit"] != nil {
|
|
||||||
t.Errorf("local model should not have limit set, got %v", entry["limit"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenCodeEdit_PreservesUserLimit(t *testing.T) {
|
|
||||||
o := &OpenCode{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
|
||||||
configPath := filepath.Join(configDir, "opencode.json")
|
|
||||||
|
|
||||||
// Set up a model with a user-configured limit
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{
|
|
||||||
"provider": {
|
|
||||||
"ollama": {
|
|
||||||
"models": {
|
|
||||||
"llama3.2": {
|
|
||||||
"name": "llama3.2",
|
|
||||||
"_launch": true,
|
|
||||||
"limit": {"context": 8192, "output": 4096}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}`), 0o644)
|
|
||||||
|
|
||||||
// Re-edit should preserve the user's limit (not delete it)
|
|
||||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
entry := readOpenCodeModel(t, configPath, "llama3.2")
|
|
||||||
limit, ok := entry["limit"].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("user-configured limit was removed")
|
|
||||||
}
|
|
||||||
if limit["context"] != float64(8192) {
|
|
||||||
t.Errorf("context limit changed: got %v, want 8192", limit["context"])
|
|
||||||
}
|
|
||||||
if limit["output"] != float64(4096) {
|
|
||||||
t.Errorf("output limit changed: got %v, want 4096", limit["output"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
|
|
||||||
// Verify that when a cloud model entry has limits set (as Edit would do),
|
|
||||||
// the structure matches what opencode expects and re-edit preserves them.
|
|
||||||
o := &OpenCode{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
|
||||||
configPath := filepath.Join(configDir, "opencode.json")
|
|
||||||
|
|
||||||
expected := cloudModelLimits["glm-4.7"]
|
|
||||||
|
|
||||||
// Simulate a cloud model that already has the limit set by a previous Edit
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(fmt.Sprintf(`{
|
|
||||||
"provider": {
|
|
||||||
"ollama": {
|
|
||||||
"models": {
|
|
||||||
"glm-4.7:cloud": {
|
|
||||||
"name": "glm-4.7:cloud",
|
|
||||||
"_launch": true,
|
|
||||||
"limit": {"context": %d, "output": %d}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}`, expected.Context, expected.Output)), 0o644)
|
|
||||||
|
|
||||||
// Re-edit should preserve the cloud model limit
|
|
||||||
if err := o.Edit([]string{"glm-4.7:cloud"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
entry := readOpenCodeModel(t, configPath, "glm-4.7:cloud")
|
|
||||||
limit, ok := entry["limit"].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("cloud model limit was removed on re-edit")
|
|
||||||
}
|
|
||||||
if limit["context"] != float64(expected.Context) {
|
|
||||||
t.Errorf("context = %v, want %d", limit["context"], expected.Context)
|
|
||||||
}
|
|
||||||
if limit["output"] != float64(expected.Output) {
|
|
||||||
t.Errorf("output = %v, want %d", limit["output"], expected.Output)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLookupCloudModelLimit(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
wantOK bool
|
|
||||||
wantContext int
|
|
||||||
wantOutput int
|
|
||||||
}{
|
|
||||||
{"glm-4.7", true, 202_752, 131_072},
|
|
||||||
{"glm-4.7:cloud", true, 202_752, 131_072},
|
|
||||||
{"kimi-k2.5", true, 262_144, 262_144},
|
|
||||||
{"kimi-k2.5:cloud", true, 262_144, 262_144},
|
|
||||||
{"deepseek-v3.2", true, 163_840, 65_536},
|
|
||||||
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
|
|
||||||
{"qwen3-coder:480b", true, 262_144, 65_536},
|
|
||||||
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
|
|
||||||
{"llama3.2", false, 0, 0},
|
|
||||||
{"unknown-model:cloud", false, 0, 0},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
l, ok := lookupCloudModelLimit(tt.name)
|
|
||||||
if ok != tt.wantOK {
|
|
||||||
t.Errorf("lookupCloudModelLimit(%q) ok = %v, want %v", tt.name, ok, tt.wantOK)
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
if l.Context != tt.wantContext {
|
|
||||||
t.Errorf("context = %d, want %d", l.Context, tt.wantContext)
|
|
||||||
}
|
|
||||||
if l.Output != tt.wantOutput {
|
|
||||||
t.Errorf("output = %d, want %d", l.Output, tt.wantOutput)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenCodeModels_NoConfig(t *testing.T) {
|
|
||||||
o := &OpenCode{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
models := o.Models()
|
|
||||||
if len(models) > 0 {
|
|
||||||
t.Errorf("expected nil/empty for missing config, got %v", models)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
237
cmd/config/pi.go
@@ -1,237 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"path/filepath"
|
|
||||||
"slices"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
|
||||||
"github.com/ollama/ollama/types/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Pi implements Runner and Editor for Pi (Pi Coding Agent) integration
|
|
||||||
type Pi struct{}
|
|
||||||
|
|
||||||
func (p *Pi) String() string { return "Pi" }
|
|
||||||
|
|
||||||
func (p *Pi) Run(model string, args []string) error {
|
|
||||||
if _, err := exec.LookPath("pi"); err != nil {
|
|
||||||
return fmt.Errorf("pi is not installed, install with: npm install -g @mariozechner/pi-coding-agent")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Call Edit() to ensure config is up-to-date before launch
|
|
||||||
models := []string{model}
|
|
||||||
if config, err := loadIntegration("pi"); err == nil && len(config.Models) > 0 {
|
|
||||||
models = config.Models
|
|
||||||
}
|
|
||||||
if err := p.Edit(models); err != nil {
|
|
||||||
return fmt.Errorf("setup failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command("pi", args...)
|
|
||||||
cmd.Stdin = os.Stdin
|
|
||||||
cmd.Stdout = os.Stdout
|
|
||||||
cmd.Stderr = os.Stderr
|
|
||||||
return cmd.Run()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Pi) Paths() []string {
|
|
||||||
home, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var paths []string
|
|
||||||
modelsPath := filepath.Join(home, ".pi", "agent", "models.json")
|
|
||||||
if _, err := os.Stat(modelsPath); err == nil {
|
|
||||||
paths = append(paths, modelsPath)
|
|
||||||
}
|
|
||||||
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
|
|
||||||
if _, err := os.Stat(settingsPath); err == nil {
|
|
||||||
paths = append(paths, settingsPath)
|
|
||||||
}
|
|
||||||
return paths
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Pi) Edit(models []string) error {
|
|
||||||
if len(models) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
home, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
configPath := filepath.Join(home, ".pi", "agent", "models.json")
|
|
||||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
config := make(map[string]any)
|
|
||||||
if data, err := os.ReadFile(configPath); err == nil {
|
|
||||||
_ = json.Unmarshal(data, &config)
|
|
||||||
}
|
|
||||||
|
|
||||||
providers, ok := config["providers"].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
providers = make(map[string]any)
|
|
||||||
}
|
|
||||||
|
|
||||||
ollama, ok := providers["ollama"].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
ollama = map[string]any{
|
|
||||||
"baseUrl": envconfig.Host().String() + "/v1",
|
|
||||||
"api": "openai-completions",
|
|
||||||
"apiKey": "ollama",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
existingModels, ok := ollama["models"].([]any)
|
|
||||||
if !ok {
|
|
||||||
existingModels = make([]any, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build set of selected models to track which need to be added
|
|
||||||
selectedSet := make(map[string]bool, len(models))
|
|
||||||
for _, m := range models {
|
|
||||||
selectedSet[m] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build new models list:
|
|
||||||
// 1. Keep user-managed models (no _launch marker) - untouched
|
|
||||||
// 2. Keep ollama-managed models (_launch marker) that are still selected
|
|
||||||
// 3. Add new ollama-managed models
|
|
||||||
var newModels []any
|
|
||||||
for _, m := range existingModels {
|
|
||||||
if modelObj, ok := m.(map[string]any); ok {
|
|
||||||
if id, ok := modelObj["id"].(string); ok {
|
|
||||||
// User-managed model (no _launch marker) - always preserve
|
|
||||||
if !isPiOllamaModel(modelObj) {
|
|
||||||
newModels = append(newModels, m)
|
|
||||||
} else if selectedSet[id] {
|
|
||||||
// Ollama-managed and still selected - keep it
|
|
||||||
newModels = append(newModels, m)
|
|
||||||
selectedSet[id] = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add newly selected models that weren't already in the list
|
|
||||||
client := api.NewClient(envconfig.Host(), http.DefaultClient)
|
|
||||||
ctx := context.Background()
|
|
||||||
for _, model := range models {
|
|
||||||
if selectedSet[model] {
|
|
||||||
newModels = append(newModels, createConfig(ctx, client, model))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ollama["models"] = newModels
|
|
||||||
providers["ollama"] = ollama
|
|
||||||
config["providers"] = providers
|
|
||||||
|
|
||||||
configData, err := json.MarshalIndent(config, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := writeWithBackup(configPath, configData); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update settings.json with default provider and model
|
|
||||||
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
|
|
||||||
settings := make(map[string]any)
|
|
||||||
if data, err := os.ReadFile(settingsPath); err == nil {
|
|
||||||
_ = json.Unmarshal(data, &settings)
|
|
||||||
}
|
|
||||||
|
|
||||||
settings["defaultProvider"] = "ollama"
|
|
||||||
settings["defaultModel"] = models[0]
|
|
||||||
|
|
||||||
settingsData, err := json.MarshalIndent(settings, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return writeWithBackup(settingsPath, settingsData)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Pi) Models() []string {
|
|
||||||
home, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
configPath := filepath.Join(home, ".pi", "agent", "models.json")
|
|
||||||
config, err := readJSONFile(configPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
providers, _ := config["providers"].(map[string]any)
|
|
||||||
ollama, _ := providers["ollama"].(map[string]any)
|
|
||||||
models, _ := ollama["models"].([]any)
|
|
||||||
|
|
||||||
var result []string
|
|
||||||
for _, m := range models {
|
|
||||||
if modelObj, ok := m.(map[string]any); ok {
|
|
||||||
if id, ok := modelObj["id"].(string); ok {
|
|
||||||
result = append(result, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
slices.Sort(result)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// isPiOllamaModel reports whether a model config entry is managed by ollama launch
|
|
||||||
func isPiOllamaModel(cfg map[string]any) bool {
|
|
||||||
if v, ok := cfg["_launch"].(bool); ok && v {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// createConfig builds Pi model config with capability detection
|
|
||||||
func createConfig(ctx context.Context, client *api.Client, modelID string) map[string]any {
|
|
||||||
cfg := map[string]any{
|
|
||||||
"id": modelID,
|
|
||||||
"_launch": true,
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID})
|
|
||||||
if err != nil {
|
|
||||||
return cfg
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set input types based on vision capability
|
|
||||||
if slices.Contains(resp.Capabilities, model.CapabilityVision) {
|
|
||||||
cfg["input"] = []string{"text", "image"}
|
|
||||||
} else {
|
|
||||||
cfg["input"] = []string{"text"}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set reasoning based on thinking capability
|
|
||||||
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
|
||||||
cfg["reasoning"] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract context window from ModelInfo
|
|
||||||
for key, val := range resp.ModelInfo {
|
|
||||||
if strings.HasSuffix(key, ".context_length") {
|
|
||||||
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
|
|
||||||
cfg["contextWindow"] = int(ctxLen)
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return cfg
|
|
||||||
}
|
|
||||||
@@ -1,59 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"golang.org/x/term"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ANSI escape sequences for terminal formatting.
|
|
||||||
const (
|
|
||||||
ansiBold = "\033[1m"
|
|
||||||
ansiReset = "\033[0m"
|
|
||||||
ansiGray = "\033[37m"
|
|
||||||
ansiGreen = "\033[32m"
|
|
||||||
ansiYellow = "\033[33m"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ErrCancelled is returned when the user cancels a selection.
|
|
||||||
var ErrCancelled = errors.New("cancelled")
|
|
||||||
|
|
||||||
// errCancelled is kept as an alias for backward compatibility within the package.
|
|
||||||
var errCancelled = ErrCancelled
|
|
||||||
|
|
||||||
// DefaultConfirmPrompt provides a TUI-based confirmation prompt.
|
|
||||||
// When set, confirmPrompt delegates to it instead of using raw terminal I/O.
|
|
||||||
var DefaultConfirmPrompt func(prompt string) (bool, error)
|
|
||||||
|
|
||||||
func confirmPrompt(prompt string) (bool, error) {
|
|
||||||
if DefaultConfirmPrompt != nil {
|
|
||||||
return DefaultConfirmPrompt(prompt)
|
|
||||||
}
|
|
||||||
|
|
||||||
fd := int(os.Stdin.Fd())
|
|
||||||
oldState, err := term.MakeRaw(fd)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
defer term.Restore(fd, oldState)
|
|
||||||
|
|
||||||
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
|
|
||||||
|
|
||||||
buf := make([]byte, 1)
|
|
||||||
for {
|
|
||||||
if _, err := os.Stdin.Read(buf); err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch buf[0] {
|
|
||||||
case 'Y', 'y', 13:
|
|
||||||
fmt.Fprintf(os.Stderr, "yes\r\n")
|
|
||||||
return true, nil
|
|
||||||
case 'N', 'n', 27, 3:
|
|
||||||
fmt.Fprintf(os.Stderr, "no\r\n")
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestErrCancelled(t *testing.T) {
|
|
||||||
t.Run("NotNil", func(t *testing.T) {
|
|
||||||
if errCancelled == nil {
|
|
||||||
t.Error("errCancelled should not be nil")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Message", func(t *testing.T) {
|
|
||||||
if errCancelled.Error() != "cancelled" {
|
|
||||||
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/internal/modelref"
|
||||||
"github.com/ollama/ollama/readline"
|
"github.com/ollama/ollama/readline"
|
||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
@@ -46,7 +47,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
|
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
|
||||||
|
|
||||||
if opts.MultiModal {
|
if opts.MultiModal {
|
||||||
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, or .webp images.\n", filepath.FromSlash("/path/to/file"))
|
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, .webp images, or .wav audio files.\n", filepath.FromSlash("/path/to/file"))
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintln(os.Stderr, "")
|
fmt.Fprintln(os.Stderr, "")
|
||||||
@@ -213,10 +214,17 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
}
|
}
|
||||||
origOpts := opts.Copy()
|
origOpts := opts.Copy()
|
||||||
|
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("error: couldn't connect to ollama server")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
opts.Model = args[1]
|
opts.Model = args[1]
|
||||||
opts.Messages = []api.Message{}
|
opts.Messages = []api.Message{}
|
||||||
|
opts.LoadedMessages = nil
|
||||||
fmt.Printf("Loading model '%s'\n", opts.Model)
|
fmt.Printf("Loading model '%s'\n", opts.Model)
|
||||||
opts.Think, err = inferThinkingOption(nil, &opts, thinkExplicitlySet)
|
info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.Contains(err.Error(), "not found") {
|
if strings.Contains(err.Error(), "not found") {
|
||||||
fmt.Printf("Couldn't find model '%s'\n", opts.Model)
|
fmt.Printf("Couldn't find model '%s'\n", opts.Model)
|
||||||
@@ -225,6 +233,11 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
applyShowResponseToRunOptions(&opts, info)
|
||||||
|
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkExplicitlySet)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||||
if strings.Contains(err.Error(), "not found") {
|
if strings.Contains(err.Error(), "not found") {
|
||||||
fmt.Printf("Couldn't find model '%s'\n", opts.Model)
|
fmt.Printf("Couldn't find model '%s'\n", opts.Model)
|
||||||
@@ -540,6 +553,13 @@ func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
|
|||||||
parentModel = ""
|
parentModel = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Preserve explicit cloud intent for sessions started with `:cloud`.
|
||||||
|
// Cloud model metadata can return a source-less parent_model (for example
|
||||||
|
// "qwen3.5"), which would otherwise make `/save` create a local derivative.
|
||||||
|
if modelref.HasExplicitCloudSource(opts.Model) && !modelref.HasExplicitCloudSource(parentModel) {
|
||||||
|
parentModel = ""
|
||||||
|
}
|
||||||
|
|
||||||
req := &api.CreateRequest{
|
req := &api.CreateRequest{
|
||||||
Model: name,
|
Model: name,
|
||||||
From: cmp.Or(parentModel, opts.Model),
|
From: cmp.Or(parentModel, opts.Model),
|
||||||
@@ -553,8 +573,10 @@ func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
|
|||||||
req.Parameters = opts.Options
|
req.Parameters = opts.Options
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(opts.Messages) > 0 {
|
messages := slices.Clone(opts.LoadedMessages)
|
||||||
req.Messages = opts.Messages
|
messages = append(messages, opts.Messages...)
|
||||||
|
if len(messages) > 0 {
|
||||||
|
req.Messages = messages
|
||||||
}
|
}
|
||||||
|
|
||||||
return req
|
return req
|
||||||
@@ -584,7 +606,7 @@ func extractFileNames(input string) []string {
|
|||||||
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
|
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
|
||||||
// and followed by more characters and a file extension
|
// and followed by more characters and a file extension
|
||||||
// This will capture non filename strings, but we'll check for file existence to remove mismatches
|
// This will capture non filename strings, but we'll check for file existence to remove mismatches
|
||||||
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp)\b`
|
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp|wav)\b`
|
||||||
re := regexp.MustCompile(regexPattern)
|
re := regexp.MustCompile(regexPattern)
|
||||||
|
|
||||||
return re.FindAllString(input, -1)
|
return re.FindAllString(input, -1)
|
||||||
@@ -600,10 +622,16 @@ func extractFileData(input string) (string, []api.ImageData, error) {
|
|||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
continue
|
continue
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "Couldn't process image: %q\n", err)
|
fmt.Fprintf(os.Stderr, "Couldn't process file: %q\n", err)
|
||||||
return "", imgs, err
|
return "", imgs, err
|
||||||
}
|
}
|
||||||
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
|
ext := strings.ToLower(filepath.Ext(nfp))
|
||||||
|
switch ext {
|
||||||
|
case ".wav":
|
||||||
|
fmt.Fprintf(os.Stderr, "Added audio '%s'\n", nfp)
|
||||||
|
default:
|
||||||
|
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
|
||||||
|
}
|
||||||
input = strings.ReplaceAll(input, "'"+nfp+"'", "")
|
input = strings.ReplaceAll(input, "'"+nfp+"'", "")
|
||||||
input = strings.ReplaceAll(input, "'"+fp+"'", "")
|
input = strings.ReplaceAll(input, "'"+fp+"'", "")
|
||||||
input = strings.ReplaceAll(input, fp, "")
|
input = strings.ReplaceAll(input, fp, "")
|
||||||
@@ -677,9 +705,9 @@ func getImageData(filePath string) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
contentType := http.DetectContentType(buf)
|
contentType := http.DetectContentType(buf)
|
||||||
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp"}
|
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp", "audio/wave"}
|
||||||
if !slices.Contains(allowedTypes, contentType) {
|
if !slices.Contains(allowedTypes, contentType) {
|
||||||
return nil, fmt.Errorf("invalid image type: %s", contentType)
|
return nil, fmt.Errorf("invalid file type: %s", contentType)
|
||||||
}
|
}
|
||||||
|
|
||||||
info, err := file.Stat()
|
info, err := file.Stat()
|
||||||
@@ -687,8 +715,7 @@ func getImageData(filePath string) ([]byte, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the file size exceeds 100MB
|
var maxSize int64 = 100 * 1024 * 1024 // 100MB
|
||||||
var maxSize int64 = 100 * 1024 * 1024 // 100MB in bytes
|
|
||||||
if info.Size() > maxSize {
|
if info.Size() > maxSize {
|
||||||
return nil, errors.New("file size exceeds maximum limit (100MB)")
|
return nil, errors.New("file size exceeds maximum limit (100MB)")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -84,3 +84,33 @@ func TestExtractFileDataRemovesQuotedFilepath(t *testing.T) {
|
|||||||
assert.Len(t, imgs, 1)
|
assert.Len(t, imgs, 1)
|
||||||
assert.Equal(t, cleaned, "before after")
|
assert.Equal(t, cleaned, "before after")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExtractFileDataWAV(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
fp := filepath.Join(dir, "sample.wav")
|
||||||
|
data := make([]byte, 600)
|
||||||
|
copy(data[:44], []byte{
|
||||||
|
'R', 'I', 'F', 'F',
|
||||||
|
0x58, 0x02, 0x00, 0x00, // file size - 8
|
||||||
|
'W', 'A', 'V', 'E',
|
||||||
|
'f', 'm', 't', ' ',
|
||||||
|
0x10, 0x00, 0x00, 0x00, // fmt chunk size
|
||||||
|
0x01, 0x00, // PCM
|
||||||
|
0x01, 0x00, // mono
|
||||||
|
0x80, 0x3e, 0x00, 0x00, // 16000 Hz
|
||||||
|
0x00, 0x7d, 0x00, 0x00, // byte rate
|
||||||
|
0x02, 0x00, // block align
|
||||||
|
0x10, 0x00, // 16-bit
|
||||||
|
'd', 'a', 't', 'a',
|
||||||
|
0x34, 0x02, 0x00, 0x00, // data size
|
||||||
|
})
|
||||||
|
if err := os.WriteFile(fp, data, 0o600); err != nil {
|
||||||
|
t.Fatalf("failed to write test audio: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
input := "before " + fp + " after"
|
||||||
|
cleaned, imgs, err := extractFileData(input)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, imgs, 1)
|
||||||
|
assert.Equal(t, "before after", cleaned)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
package config
|
// Package fileutil provides small shared helpers for reading JSON files
|
||||||
|
// and writing config files with backup-on-overwrite semantics.
|
||||||
|
package fileutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@@ -9,7 +11,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func readJSONFile(path string) (map[string]any, error) {
|
// ReadJSON reads a JSON object file into a generic map.
|
||||||
|
func ReadJSON(path string) (map[string]any, error) {
|
||||||
data, err := os.ReadFile(path)
|
data, err := os.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -33,12 +36,13 @@ func copyFile(src, dst string) error {
|
|||||||
return os.WriteFile(dst, data, info.Mode().Perm())
|
return os.WriteFile(dst, data, info.Mode().Perm())
|
||||||
}
|
}
|
||||||
|
|
||||||
func backupDir() string {
|
// BackupDir returns the shared backup directory used before overwriting files.
|
||||||
|
func BackupDir() string {
|
||||||
return filepath.Join(os.TempDir(), "ollama-backups")
|
return filepath.Join(os.TempDir(), "ollama-backups")
|
||||||
}
|
}
|
||||||
|
|
||||||
func backupToTmp(srcPath string) (string, error) {
|
func backupToTmp(srcPath string) (string, error) {
|
||||||
dir := backupDir()
|
dir := BackupDir()
|
||||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -50,8 +54,8 @@ func backupToTmp(srcPath string) (string, error) {
|
|||||||
return backupPath, nil
|
return backupPath, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeWithBackup writes data to path via temp file + rename, backing up any existing file first
|
// WriteWithBackup writes data to path via temp file + rename, backing up any existing file first.
|
||||||
func writeWithBackup(path string, data []byte) error {
|
func WriteWithBackup(path string, data []byte) error {
|
||||||
var backupPath string
|
var backupPath string
|
||||||
// backup must be created before any writes to the target file
|
// backup must be created before any writes to the target file
|
||||||
if existingContent, err := os.ReadFile(path); err == nil {
|
if existingContent, err := os.ReadFile(path); err == nil {
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package fileutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -9,6 +9,21 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
tmpRoot, err := os.MkdirTemp("", "fileutil-test-*")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.Setenv("TMPDIR", tmpRoot); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
code := m.Run()
|
||||||
|
_ = os.RemoveAll(tmpRoot)
|
||||||
|
os.Exit(code)
|
||||||
|
}
|
||||||
|
|
||||||
func mustMarshal(t *testing.T, v any) []byte {
|
func mustMarshal(t *testing.T, v any) []byte {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
data, err := json.MarshalIndent(v, "", " ")
|
data, err := json.MarshalIndent(v, "", " ")
|
||||||
@@ -18,14 +33,19 @@ func mustMarshal(t *testing.T, v any) []byte {
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isolatedTempDir(t *testing.T) string {
|
||||||
|
t.Helper()
|
||||||
|
return t.TempDir()
|
||||||
|
}
|
||||||
|
|
||||||
func TestWriteWithBackup(t *testing.T) {
|
func TestWriteWithBackup(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
|
|
||||||
t.Run("creates file", func(t *testing.T) {
|
t.Run("creates file", func(t *testing.T) {
|
||||||
path := filepath.Join(tmpDir, "new.json")
|
path := filepath.Join(tmpDir, "new.json")
|
||||||
data := mustMarshal(t, map[string]string{"key": "value"})
|
data := mustMarshal(t, map[string]string{"key": "value"})
|
||||||
|
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -43,17 +63,17 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("creates backup in /tmp/ollama-backups", func(t *testing.T) {
|
t.Run("creates backup in the temp backup directory", func(t *testing.T) {
|
||||||
path := filepath.Join(tmpDir, "backup.json")
|
path := filepath.Join(tmpDir, "backup.json")
|
||||||
|
|
||||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||||
|
|
||||||
data := mustMarshal(t, map[string]bool{"updated": true})
|
data := mustMarshal(t, map[string]bool{"updated": true})
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
entries, err := os.ReadDir(backupDir())
|
entries, err := os.ReadDir(BackupDir())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("backup directory not created")
|
t.Fatal("backup directory not created")
|
||||||
}
|
}
|
||||||
@@ -63,7 +83,7 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
if filepath.Ext(entry.Name()) != ".json" {
|
if filepath.Ext(entry.Name()) != ".json" {
|
||||||
name := entry.Name()
|
name := entry.Name()
|
||||||
if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." {
|
if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." {
|
||||||
backupPath := filepath.Join(backupDir(), name)
|
backupPath := filepath.Join(BackupDir(), name)
|
||||||
backup, err := os.ReadFile(backupPath)
|
backup, err := os.ReadFile(backupPath)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
var backupData map[string]bool
|
var backupData map[string]bool
|
||||||
@@ -79,7 +99,7 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !foundBackup {
|
if !foundBackup {
|
||||||
t.Error("backup file not created in /tmp/ollama-backups")
|
t.Error("backup file not created in backup directory")
|
||||||
}
|
}
|
||||||
|
|
||||||
current, _ := os.ReadFile(path)
|
current, _ := os.ReadFile(path)
|
||||||
@@ -94,11 +114,11 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
path := filepath.Join(tmpDir, "nobak.json")
|
path := filepath.Join(tmpDir, "nobak.json")
|
||||||
|
|
||||||
data := mustMarshal(t, map[string]string{"new": "file"})
|
data := mustMarshal(t, map[string]string{"new": "file"})
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
entries, _ := os.ReadDir(backupDir())
|
entries, _ := os.ReadDir(BackupDir())
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." {
|
if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." {
|
||||||
t.Error("backup should not exist for new file")
|
t.Error("backup should not exist for new file")
|
||||||
@@ -111,11 +131,11 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
|
|
||||||
data := mustMarshal(t, map[string]string{"key": "value"})
|
data := mustMarshal(t, map[string]string{"key": "value"})
|
||||||
|
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
entries1, _ := os.ReadDir(backupDir())
|
entries1, _ := os.ReadDir(BackupDir())
|
||||||
countBefore := 0
|
countBefore := 0
|
||||||
for _, e := range entries1 {
|
for _, e := range entries1 {
|
||||||
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
||||||
@@ -123,11 +143,11 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
entries2, _ := os.ReadDir(backupDir())
|
entries2, _ := os.ReadDir(BackupDir())
|
||||||
countAfter := 0
|
countAfter := 0
|
||||||
for _, e := range entries2 {
|
for _, e := range entries2 {
|
||||||
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
||||||
@@ -145,11 +165,11 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
|
|
||||||
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
|
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
|
||||||
data := mustMarshal(t, map[string]int{"v": 2})
|
data := mustMarshal(t, map[string]int{"v": 2})
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
entries, _ := os.ReadDir(backupDir())
|
entries, _ := os.ReadDir(BackupDir())
|
||||||
var found bool
|
var found bool
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
name := entry.Name()
|
name := entry.Name()
|
||||||
@@ -161,7 +181,7 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
found = true
|
found = true
|
||||||
os.Remove(filepath.Join(backupDir(), name))
|
os.Remove(filepath.Join(BackupDir(), name))
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -180,7 +200,7 @@ func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
|
|||||||
t.Skip("permission tests unreliable on Windows")
|
t.Skip("permission tests unreliable on Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
path := filepath.Join(tmpDir, "config.json")
|
path := filepath.Join(tmpDir, "config.json")
|
||||||
|
|
||||||
// Create original file
|
// Create original file
|
||||||
@@ -188,13 +208,13 @@ func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
|
|||||||
os.WriteFile(path, originalContent, 0o644)
|
os.WriteFile(path, originalContent, 0o644)
|
||||||
|
|
||||||
// Make backup directory read-only to force backup failure
|
// Make backup directory read-only to force backup failure
|
||||||
backupDir := backupDir()
|
backupDir := BackupDir()
|
||||||
os.MkdirAll(backupDir, 0o755)
|
os.MkdirAll(backupDir, 0o755)
|
||||||
os.Chmod(backupDir, 0o444) // Read-only
|
os.Chmod(backupDir, 0o444) // Read-only
|
||||||
defer os.Chmod(backupDir, 0o755)
|
defer os.Chmod(backupDir, 0o755)
|
||||||
|
|
||||||
newContent := []byte(`{"updated": true}`)
|
newContent := []byte(`{"updated": true}`)
|
||||||
err := writeWithBackup(path, newContent)
|
err := WriteWithBackup(path, newContent)
|
||||||
|
|
||||||
// Should fail because backup couldn't be created
|
// Should fail because backup couldn't be created
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -215,7 +235,7 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
|
|||||||
t.Skip("permission tests unreliable on Windows")
|
t.Skip("permission tests unreliable on Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
|
|
||||||
// Create a read-only directory
|
// Create a read-only directory
|
||||||
readOnlyDir := filepath.Join(tmpDir, "readonly")
|
readOnlyDir := filepath.Join(tmpDir, "readonly")
|
||||||
@@ -224,7 +244,7 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
|
|||||||
defer os.Chmod(readOnlyDir, 0o755)
|
defer os.Chmod(readOnlyDir, 0o755)
|
||||||
|
|
||||||
path := filepath.Join(readOnlyDir, "config.json")
|
path := filepath.Join(readOnlyDir, "config.json")
|
||||||
err := writeWithBackup(path, []byte(`{"test": true}`))
|
err := WriteWithBackup(path, []byte(`{"test": true}`))
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected permission error, got nil")
|
t.Error("expected permission error, got nil")
|
||||||
@@ -234,10 +254,10 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
|
|||||||
// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist.
|
// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist.
|
||||||
// writeWithBackup doesn't create directories - caller is responsible.
|
// writeWithBackup doesn't create directories - caller is responsible.
|
||||||
func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) {
|
func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json")
|
path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json")
|
||||||
|
|
||||||
err := writeWithBackup(path, []byte(`{"test": true}`))
|
err := WriteWithBackup(path, []byte(`{"test": true}`))
|
||||||
|
|
||||||
// Should fail because directory doesn't exist
|
// Should fail because directory doesn't exist
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -252,7 +272,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
|
|||||||
t.Skip("symlink tests may require admin on Windows")
|
t.Skip("symlink tests may require admin on Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
realFile := filepath.Join(tmpDir, "real.json")
|
realFile := filepath.Join(tmpDir, "real.json")
|
||||||
symlink := filepath.Join(tmpDir, "link.json")
|
symlink := filepath.Join(tmpDir, "link.json")
|
||||||
|
|
||||||
@@ -261,7 +281,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
|
|||||||
os.Symlink(realFile, symlink)
|
os.Symlink(realFile, symlink)
|
||||||
|
|
||||||
// Write through symlink
|
// Write through symlink
|
||||||
err := writeWithBackup(symlink, []byte(`{"v": 2}`))
|
err := WriteWithBackup(symlink, []byte(`{"v": 2}`))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("writeWithBackup through symlink failed: %v", err)
|
t.Fatalf("writeWithBackup through symlink failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -276,7 +296,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
|
|||||||
// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters.
|
// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters.
|
||||||
// User may have config files with unusual names.
|
// User may have config files with unusual names.
|
||||||
func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) {
|
func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
|
|
||||||
// File with spaces and special chars
|
// File with spaces and special chars
|
||||||
path := filepath.Join(tmpDir, "my config (backup).json")
|
path := filepath.Join(tmpDir, "my config (backup).json")
|
||||||
@@ -305,7 +325,7 @@ func TestCopyFile_PreservesPermissions(t *testing.T) {
|
|||||||
t.Skip("permission preservation tests unreliable on Windows")
|
t.Skip("permission preservation tests unreliable on Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
src := filepath.Join(tmpDir, "src.json")
|
src := filepath.Join(tmpDir, "src.json")
|
||||||
dst := filepath.Join(tmpDir, "dst.json")
|
dst := filepath.Join(tmpDir, "dst.json")
|
||||||
|
|
||||||
@@ -327,7 +347,7 @@ func TestCopyFile_PreservesPermissions(t *testing.T) {
|
|||||||
|
|
||||||
// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist.
|
// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist.
|
||||||
func TestCopyFile_SourceNotFound(t *testing.T) {
|
func TestCopyFile_SourceNotFound(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
src := filepath.Join(tmpDir, "nonexistent.json")
|
src := filepath.Join(tmpDir, "nonexistent.json")
|
||||||
dst := filepath.Join(tmpDir, "dst.json")
|
dst := filepath.Join(tmpDir, "dst.json")
|
||||||
|
|
||||||
@@ -339,11 +359,11 @@ func TestCopyFile_SourceNotFound(t *testing.T) {
|
|||||||
|
|
||||||
// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory.
|
// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory.
|
||||||
func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
|
func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
dirPath := filepath.Join(tmpDir, "actualdir")
|
dirPath := filepath.Join(tmpDir, "actualdir")
|
||||||
os.MkdirAll(dirPath, 0o755)
|
os.MkdirAll(dirPath, 0o755)
|
||||||
|
|
||||||
err := writeWithBackup(dirPath, []byte(`{"test": true}`))
|
err := WriteWithBackup(dirPath, []byte(`{"test": true}`))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error when target is a directory, got nil")
|
t.Error("expected error when target is a directory, got nil")
|
||||||
}
|
}
|
||||||
@@ -351,10 +371,10 @@ func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
|
|||||||
|
|
||||||
// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly.
|
// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly.
|
||||||
func TestWriteWithBackup_EmptyData(t *testing.T) {
|
func TestWriteWithBackup_EmptyData(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
path := filepath.Join(tmpDir, "empty.json")
|
path := filepath.Join(tmpDir, "empty.json")
|
||||||
|
|
||||||
err := writeWithBackup(path, []byte{})
|
err := WriteWithBackup(path, []byte{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("writeWithBackup with empty data failed: %v", err)
|
t.Fatalf("writeWithBackup with empty data failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -375,7 +395,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
|
|||||||
t.Skip("permission tests unreliable on Windows")
|
t.Skip("permission tests unreliable on Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
path := filepath.Join(tmpDir, "unreadable.json")
|
path := filepath.Join(tmpDir, "unreadable.json")
|
||||||
|
|
||||||
// Create file and make it unreadable
|
// Create file and make it unreadable
|
||||||
@@ -384,7 +404,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
|
|||||||
defer os.Chmod(path, 0o644)
|
defer os.Chmod(path, 0o644)
|
||||||
|
|
||||||
// Should fail because we can't read the file to compare/backup
|
// Should fail because we can't read the file to compare/backup
|
||||||
err := writeWithBackup(path, []byte(`{"updated": true}`))
|
err := WriteWithBackup(path, []byte(`{"updated": true}`))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error when file is unreadable, got nil")
|
t.Error("expected error when file is unreadable, got nil")
|
||||||
}
|
}
|
||||||
@@ -393,7 +413,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
|
|||||||
// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes
|
// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes
|
||||||
// within the same second (timestamp collision scenario).
|
// within the same second (timestamp collision scenario).
|
||||||
func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
path := filepath.Join(tmpDir, "rapid.json")
|
path := filepath.Join(tmpDir, "rapid.json")
|
||||||
|
|
||||||
// Create initial file
|
// Create initial file
|
||||||
@@ -402,7 +422,7 @@ func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
|||||||
// Rapid successive writes
|
// Rapid successive writes
|
||||||
for i := 1; i <= 3; i++ {
|
for i := 1; i <= 3; i++ {
|
||||||
data := []byte(fmt.Sprintf(`{"v": %d}`, i))
|
data := []byte(fmt.Sprintf(`{"v": %d}`, i))
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatalf("write %d failed: %v", i, err)
|
t.Fatalf("write %d failed: %v", i, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -414,7 +434,7 @@ func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Verify at least one backup exists
|
// Verify at least one backup exists
|
||||||
entries, _ := os.ReadDir(backupDir())
|
entries, _ := os.ReadDir(BackupDir())
|
||||||
var backupCount int
|
var backupCount int
|
||||||
for _, e := range entries {
|
for _, e := range entries {
|
||||||
if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." {
|
if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." {
|
||||||
@@ -432,8 +452,9 @@ func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
|
|||||||
t.Skip("test modifies system temp directory")
|
t.Skip("test modifies system temp directory")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tmpDir := isolatedTempDir(t)
|
||||||
// Create a file at the backup directory path
|
// Create a file at the backup directory path
|
||||||
backupPath := backupDir()
|
backupPath := BackupDir()
|
||||||
// Clean up any existing directory first
|
// Clean up any existing directory first
|
||||||
os.RemoveAll(backupPath)
|
os.RemoveAll(backupPath)
|
||||||
// Create a file instead of directory
|
// Create a file instead of directory
|
||||||
@@ -443,11 +464,10 @@ func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
|
|||||||
os.MkdirAll(backupPath, 0o755)
|
os.MkdirAll(backupPath, 0o755)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
path := filepath.Join(tmpDir, "test.json")
|
path := filepath.Join(tmpDir, "test.json")
|
||||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||||
|
|
||||||
err := writeWithBackup(path, []byte(`{"updated": true}`))
|
err := WriteWithBackup(path, []byte(`{"updated": true}`))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error when backup dir is a file, got nil")
|
t.Error("expected error when backup dir is a file, got nil")
|
||||||
}
|
}
|
||||||
@@ -459,7 +479,7 @@ func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
|
|||||||
t.Skip("permission tests unreliable on Windows")
|
t.Skip("permission tests unreliable on Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
|
|
||||||
// Count existing temp files
|
// Count existing temp files
|
||||||
countTempFiles := func() int {
|
countTempFiles := func() int {
|
||||||
@@ -493,7 +513,7 @@ func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
|
|||||||
badPath := filepath.Join(tmpDir, "isdir")
|
badPath := filepath.Join(tmpDir, "isdir")
|
||||||
os.MkdirAll(badPath, 0o755)
|
os.MkdirAll(badPath, 0o755)
|
||||||
|
|
||||||
_ = writeWithBackup(badPath, []byte(`{"test": true}`))
|
_ = WriteWithBackup(badPath, []byte(`{"test": true}`))
|
||||||
|
|
||||||
after := countTempFiles()
|
after := countTempFiles()
|
||||||
if after > before {
|
if after > before {
|
||||||
87
cmd/launch/claude.go
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Claude implements Runner for Claude Code integration.
|
||||||
|
type Claude struct{}
|
||||||
|
|
||||||
|
func (c *Claude) String() string { return "Claude Code" }
|
||||||
|
|
||||||
|
func (c *Claude) args(model string, extra []string) []string {
|
||||||
|
var args []string
|
||||||
|
if model != "" {
|
||||||
|
args = append(args, "--model", model)
|
||||||
|
}
|
||||||
|
args = append(args, extra...)
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Claude) findPath() (string, error) {
|
||||||
|
if p, err := exec.LookPath("claude"); err == nil {
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
name := "claude"
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
name = "claude.exe"
|
||||||
|
}
|
||||||
|
fallback := filepath.Join(home, ".claude", "local", name)
|
||||||
|
if _, err := os.Stat(fallback); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return fallback, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Claude) Run(model string, args []string) error {
|
||||||
|
claudePath, err := c.findPath()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command(claudePath, c.args(model, args)...)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
|
||||||
|
env := append(os.Environ(),
|
||||||
|
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
|
||||||
|
"ANTHROPIC_API_KEY=",
|
||||||
|
"ANTHROPIC_AUTH_TOKEN=ollama",
|
||||||
|
"CLAUDE_CODE_ATTRIBUTION_HEADER=0",
|
||||||
|
)
|
||||||
|
|
||||||
|
env = append(env, c.modelEnvVars(model)...)
|
||||||
|
|
||||||
|
cmd.Env = env
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
|
||||||
|
func (c *Claude) modelEnvVars(model string) []string {
|
||||||
|
env := []string{
|
||||||
|
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + model,
|
||||||
|
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + model,
|
||||||
|
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + model,
|
||||||
|
"CLAUDE_CODE_SUBAGENT_MODEL=" + model,
|
||||||
|
}
|
||||||
|
|
||||||
|
if isCloudModelName(model) {
|
||||||
|
if l, ok := lookupCloudModelLimit(model); ok {
|
||||||
|
env = append(env, "CLAUDE_CODE_AUTO_COMPACT_WINDOW="+strconv.Itoa(l.Context))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return env
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
@@ -117,10 +117,7 @@ func TestClaudeModelEnvVars(t *testing.T) {
|
|||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("falls back to model param when no aliases saved", func(t *testing.T) {
|
t.Run("maps all Claude model env vars to the provided model", func(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
got := envMap(c.modelEnvVars("llama3.2"))
|
got := envMap(c.modelEnvVars("llama3.2"))
|
||||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2" {
|
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2" {
|
||||||
t.Errorf("OPUS = %q, want llama3.2", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
t.Errorf("OPUS = %q, want llama3.2", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||||
@@ -134,65 +131,41 @@ func TestClaudeModelEnvVars(t *testing.T) {
|
|||||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2" {
|
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2" {
|
||||||
t.Errorf("SUBAGENT = %q, want llama3.2", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
t.Errorf("SUBAGENT = %q, want llama3.2", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||||
}
|
}
|
||||||
})
|
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
|
||||||
|
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty for local models", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
|
||||||
t.Run("uses primary alias for opus sonnet and subagent", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
SaveIntegration("claude", []string{"qwen3:8b"})
|
|
||||||
saveAliases("claude", map[string]string{"primary": "qwen3:8b"})
|
|
||||||
|
|
||||||
got := envMap(c.modelEnvVars("qwen3:8b"))
|
|
||||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "qwen3:8b" {
|
|
||||||
t.Errorf("OPUS = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
|
||||||
}
|
|
||||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "qwen3:8b" {
|
|
||||||
t.Errorf("SONNET = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
|
||||||
}
|
|
||||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "qwen3:8b" {
|
|
||||||
t.Errorf("HAIKU = %q, want qwen3:8b (no fast alias)", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
|
||||||
}
|
|
||||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "qwen3:8b" {
|
|
||||||
t.Errorf("SUBAGENT = %q, want qwen3:8b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("uses fast alias for haiku", func(t *testing.T) {
|
t.Run("supports empty model", func(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
got := envMap(c.modelEnvVars(""))
|
||||||
setTestHome(t, tmpDir)
|
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "" {
|
||||||
|
t.Errorf("OPUS = %q, want empty", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||||
SaveIntegration("claude", []string{"llama3.2:70b"})
|
|
||||||
saveAliases("claude", map[string]string{
|
|
||||||
"primary": "llama3.2:70b",
|
|
||||||
"fast": "llama3.2:8b",
|
|
||||||
})
|
|
||||||
|
|
||||||
got := envMap(c.modelEnvVars("llama3.2:70b"))
|
|
||||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2:70b" {
|
|
||||||
t.Errorf("OPUS = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
|
||||||
}
|
}
|
||||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2:70b" {
|
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "" {
|
||||||
t.Errorf("SONNET = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
t.Errorf("SONNET = %q, want empty", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||||
}
|
}
|
||||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2:8b" {
|
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "" {
|
||||||
t.Errorf("HAIKU = %q, want llama3.2:8b", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
t.Errorf("HAIKU = %q, want empty", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||||
}
|
}
|
||||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2:70b" {
|
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "" {
|
||||||
t.Errorf("SUBAGENT = %q, want llama3.2:70b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
t.Errorf("SUBAGENT = %q, want empty", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||||
|
}
|
||||||
|
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
|
||||||
|
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("alias primary overrides model param", func(t *testing.T) {
|
t.Run("sets auto compact window for known cloud models", func(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
got := envMap(c.modelEnvVars("glm-5:cloud"))
|
||||||
setTestHome(t, tmpDir)
|
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "202752" {
|
||||||
|
t.Errorf("AUTO_COMPACT_WINDOW = %q, want 202752", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
SaveIntegration("claude", []string{"saved-model"})
|
t.Run("does not set auto compact window for unknown cloud models", func(t *testing.T) {
|
||||||
saveAliases("claude", map[string]string{"primary": "saved-model"})
|
got := envMap(c.modelEnvVars("unknown-model:cloud"))
|
||||||
|
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
|
||||||
got := envMap(c.modelEnvVars("different-model"))
|
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
|
||||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "saved-model" {
|
|
||||||
t.Errorf("OPUS = %q, want saved-model", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -1,14 +1,13 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -22,24 +21,6 @@ func (c *Cline) Run(model string, args []string) error {
|
|||||||
return fmt.Errorf("cline is not installed, install with: npm install -g cline")
|
return fmt.Errorf("cline is not installed, install with: npm install -g cline")
|
||||||
}
|
}
|
||||||
|
|
||||||
models := []string{model}
|
|
||||||
if config, err := loadIntegration("cline"); err == nil && len(config.Models) > 0 {
|
|
||||||
models = config.Models
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
models, err = resolveEditorModels("cline", models, func() ([]string, error) {
|
|
||||||
return selectModels(context.Background(), "cline", "")
|
|
||||||
})
|
|
||||||
if errors.Is(err, errCancelled) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := c.Edit(models); err != nil {
|
|
||||||
return fmt.Errorf("setup failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command("cline", args...)
|
cmd := exec.Command("cline", args...)
|
||||||
cmd.Stdin = os.Stdin
|
cmd.Stdin = os.Stdin
|
||||||
cmd.Stdout = os.Stdout
|
cmd.Stdout = os.Stdout
|
||||||
@@ -97,7 +78,7 @@ func (c *Cline) Edit(models []string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return writeWithBackup(configPath, data)
|
return fileutil.WriteWithBackup(configPath, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Cline) Models() []string {
|
func (c *Cline) Models() []string {
|
||||||
@@ -106,7 +87,7 @@ func (c *Cline) Models() []string {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := readJSONFile(filepath.Join(home, ".cline", "data", "globalState.json"))
|
config, err := fileutil.ReadJSON(filepath.Join(home, ".cline", "data", "globalState.json"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
148
cmd/launch/codex.go
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"golang.org/x/mod/semver"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Codex implements Runner for Codex integration
|
||||||
|
type Codex struct{}
|
||||||
|
|
||||||
|
func (c *Codex) String() string { return "Codex" }
|
||||||
|
|
||||||
|
const codexProfileName = "ollama-launch"
|
||||||
|
|
||||||
|
func (c *Codex) args(model string, extra []string) []string {
|
||||||
|
args := []string{"--profile", codexProfileName}
|
||||||
|
if model != "" {
|
||||||
|
args = append(args, "-m", model)
|
||||||
|
}
|
||||||
|
args = append(args, extra...)
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Codex) Run(model string, args []string) error {
|
||||||
|
if err := checkCodexVersion(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ensureCodexConfig(); err != nil {
|
||||||
|
return fmt.Errorf("failed to configure codex: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command("codex", c.args(model, args)...)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
cmd.Env = append(os.Environ(),
|
||||||
|
"OPENAI_API_KEY=ollama",
|
||||||
|
)
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureCodexConfig writes a [profiles.ollama-launch] section to ~/.codex/config.toml
|
||||||
|
// with openai_base_url pointing to the local Ollama server.
|
||||||
|
func ensureCodexConfig() error {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
codexDir := filepath.Join(home, ".codex")
|
||||||
|
if err := os.MkdirAll(codexDir, 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
configPath := filepath.Join(codexDir, "config.toml")
|
||||||
|
return writeCodexProfile(configPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeCodexProfile ensures ~/.codex/config.toml has the ollama-launch profile
|
||||||
|
// and model provider sections with the correct base URL.
|
||||||
|
func writeCodexProfile(configPath string) error {
|
||||||
|
baseURL := envconfig.Host().String() + "/v1/"
|
||||||
|
|
||||||
|
sections := []struct {
|
||||||
|
header string
|
||||||
|
lines []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
header: fmt.Sprintf("[profiles.%s]", codexProfileName),
|
||||||
|
lines: []string{
|
||||||
|
fmt.Sprintf("openai_base_url = %q", baseURL),
|
||||||
|
`forced_login_method = "api"`,
|
||||||
|
fmt.Sprintf("model_provider = %q", codexProfileName),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
header: fmt.Sprintf("[model_providers.%s]", codexProfileName),
|
||||||
|
lines: []string{
|
||||||
|
`name = "Ollama"`,
|
||||||
|
fmt.Sprintf("base_url = %q", baseURL),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
content, readErr := os.ReadFile(configPath)
|
||||||
|
text := ""
|
||||||
|
if readErr == nil {
|
||||||
|
text = string(content)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, s := range sections {
|
||||||
|
block := strings.Join(append([]string{s.header}, s.lines...), "\n") + "\n"
|
||||||
|
|
||||||
|
if idx := strings.Index(text, s.header); idx >= 0 {
|
||||||
|
// Replace the existing section up to the next section header.
|
||||||
|
rest := text[idx+len(s.header):]
|
||||||
|
if endIdx := strings.Index(rest, "\n["); endIdx >= 0 {
|
||||||
|
text = text[:idx] + block + rest[endIdx+1:]
|
||||||
|
} else {
|
||||||
|
text = text[:idx] + block
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Append the section.
|
||||||
|
if text != "" && !strings.HasSuffix(text, "\n") {
|
||||||
|
text += "\n"
|
||||||
|
}
|
||||||
|
if text != "" {
|
||||||
|
text += "\n"
|
||||||
|
}
|
||||||
|
text += block
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return os.WriteFile(configPath, []byte(text), 0o644)
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkCodexVersion() error {
|
||||||
|
if _, err := exec.LookPath("codex"); err != nil {
|
||||||
|
return fmt.Errorf("codex is not installed, install with: npm install -g @openai/codex")
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := exec.Command("codex", "--version").Output()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get codex version: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse output like "codex-cli 0.87.0"
|
||||||
|
fields := strings.Fields(strings.TrimSpace(string(out)))
|
||||||
|
if len(fields) < 2 {
|
||||||
|
return fmt.Errorf("unexpected codex version output: %s", string(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
version := "v" + fields[len(fields)-1]
|
||||||
|
minVersion := "v0.81.0"
|
||||||
|
|
||||||
|
if semver.Compare(version, minVersion) < 0 {
|
||||||
|
return fmt.Errorf("codex version %s is too old, minimum required is %s, update with: npm update -g @openai/codex", fields[len(fields)-1], "0.81.0")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
229
cmd/launch/codex_test.go
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCodexArgs(t *testing.T) {
|
||||||
|
c := &Codex{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model string
|
||||||
|
args []string
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
{"with model", "llama3.2", nil, []string{"--profile", "ollama-launch", "-m", "llama3.2"}},
|
||||||
|
{"empty model", "", nil, []string{"--profile", "ollama-launch"}},
|
||||||
|
{"with model and extra args", "qwen3.5", []string{"-p", "myprofile"}, []string{"--profile", "ollama-launch", "-m", "qwen3.5", "-p", "myprofile"}},
|
||||||
|
{"with sandbox flag", "llama3.2", []string{"--sandbox", "workspace-write"}, []string{"--profile", "ollama-launch", "-m", "llama3.2", "--sandbox", "workspace-write"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := c.args(tt.model, tt.args)
|
||||||
|
if !slices.Equal(got, tt.want) {
|
||||||
|
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteCodexProfile(t *testing.T) {
|
||||||
|
t.Run("creates new file when none exists", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tmpDir, "config.toml")
|
||||||
|
|
||||||
|
if err := writeCodexProfile(configPath); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
content := string(data)
|
||||||
|
if !strings.Contains(content, "[profiles.ollama-launch]") {
|
||||||
|
t.Error("missing [profiles.ollama-launch] header")
|
||||||
|
}
|
||||||
|
if !strings.Contains(content, "openai_base_url") {
|
||||||
|
t.Error("missing openai_base_url key")
|
||||||
|
}
|
||||||
|
if !strings.Contains(content, "/v1/") {
|
||||||
|
t.Error("missing /v1/ suffix in base URL")
|
||||||
|
}
|
||||||
|
if !strings.Contains(content, `forced_login_method = "api"`) {
|
||||||
|
t.Error("missing forced_login_method key")
|
||||||
|
}
|
||||||
|
if !strings.Contains(content, `model_provider = "ollama-launch"`) {
|
||||||
|
t.Error("missing model_provider key")
|
||||||
|
}
|
||||||
|
if !strings.Contains(content, "[model_providers.ollama-launch]") {
|
||||||
|
t.Error("missing [model_providers.ollama-launch] section")
|
||||||
|
}
|
||||||
|
if !strings.Contains(content, `name = "Ollama"`) {
|
||||||
|
t.Error("missing model provider name")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("appends profile to existing file without profile", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tmpDir, "config.toml")
|
||||||
|
existing := "[some_other_section]\nkey = \"value\"\n"
|
||||||
|
os.WriteFile(configPath, []byte(existing), 0o644)
|
||||||
|
|
||||||
|
if err := writeCodexProfile(configPath); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := os.ReadFile(configPath)
|
||||||
|
content := string(data)
|
||||||
|
|
||||||
|
if !strings.Contains(content, "[some_other_section]") {
|
||||||
|
t.Error("existing section was removed")
|
||||||
|
}
|
||||||
|
if !strings.Contains(content, "[profiles.ollama-launch]") {
|
||||||
|
t.Error("missing [profiles.ollama-launch] header")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("replaces existing profile section", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tmpDir, "config.toml")
|
||||||
|
existing := "[profiles.ollama-launch]\nopenai_base_url = \"http://old:1234/v1/\"\n\n[model_providers.ollama-launch]\nname = \"Ollama\"\nbase_url = \"http://old:1234/v1/\"\n"
|
||||||
|
os.WriteFile(configPath, []byte(existing), 0o644)
|
||||||
|
|
||||||
|
if err := writeCodexProfile(configPath); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := os.ReadFile(configPath)
|
||||||
|
content := string(data)
|
||||||
|
|
||||||
|
if strings.Contains(content, "old:1234") {
|
||||||
|
t.Error("old URL was not replaced")
|
||||||
|
}
|
||||||
|
if strings.Count(content, "[profiles.ollama-launch]") != 1 {
|
||||||
|
t.Errorf("expected exactly one [profiles.ollama-launch] section, got %d", strings.Count(content, "[profiles.ollama-launch]"))
|
||||||
|
}
|
||||||
|
if strings.Count(content, "[model_providers.ollama-launch]") != 1 {
|
||||||
|
t.Errorf("expected exactly one [model_providers.ollama-launch] section, got %d", strings.Count(content, "[model_providers.ollama-launch]"))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("replaces profile while preserving following sections", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tmpDir, "config.toml")
|
||||||
|
existing := "[profiles.ollama-launch]\nopenai_base_url = \"http://old:1234/v1/\"\n[another_section]\nfoo = \"bar\"\n"
|
||||||
|
os.WriteFile(configPath, []byte(existing), 0o644)
|
||||||
|
|
||||||
|
if err := writeCodexProfile(configPath); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := os.ReadFile(configPath)
|
||||||
|
content := string(data)
|
||||||
|
|
||||||
|
if strings.Contains(content, "old:1234") {
|
||||||
|
t.Error("old URL was not replaced")
|
||||||
|
}
|
||||||
|
if !strings.Contains(content, "[another_section]") {
|
||||||
|
t.Error("following section was removed")
|
||||||
|
}
|
||||||
|
if !strings.Contains(content, "foo = \"bar\"") {
|
||||||
|
t.Error("following section content was removed")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("appends newline to file not ending with newline", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tmpDir, "config.toml")
|
||||||
|
existing := "[other]\nkey = \"val\""
|
||||||
|
os.WriteFile(configPath, []byte(existing), 0o644)
|
||||||
|
|
||||||
|
if err := writeCodexProfile(configPath); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := os.ReadFile(configPath)
|
||||||
|
content := string(data)
|
||||||
|
|
||||||
|
if !strings.Contains(content, "[profiles.ollama-launch]") {
|
||||||
|
t.Error("missing [profiles.ollama-launch] header")
|
||||||
|
}
|
||||||
|
// Should not have double blank lines from missing trailing newline
|
||||||
|
if strings.Contains(content, "\n\n\n") {
|
||||||
|
t.Error("unexpected triple newline in output")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("uses custom OLLAMA_HOST", func(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_HOST", "http://myhost:9999")
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tmpDir, "config.toml")
|
||||||
|
|
||||||
|
if err := writeCodexProfile(configPath); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := os.ReadFile(configPath)
|
||||||
|
content := string(data)
|
||||||
|
|
||||||
|
if !strings.Contains(content, "myhost:9999/v1/") {
|
||||||
|
t.Errorf("expected custom host in URL, got:\n%s", content)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureCodexConfig(t *testing.T) {
|
||||||
|
t.Run("creates .codex dir and config.toml", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
if err := ensureCodexConfig(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
configPath := filepath.Join(tmpDir, ".codex", "config.toml")
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("config.toml not created: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
content := string(data)
|
||||||
|
if !strings.Contains(content, "[profiles.ollama-launch]") {
|
||||||
|
t.Error("missing [profiles.ollama-launch] header")
|
||||||
|
}
|
||||||
|
if !strings.Contains(content, "openai_base_url") {
|
||||||
|
t.Error("missing openai_base_url key")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("is idempotent", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
if err := ensureCodexConfig(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := ensureCodexConfig(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
configPath := filepath.Join(tmpDir, ".codex", "config.toml")
|
||||||
|
data, _ := os.ReadFile(configPath)
|
||||||
|
content := string(data)
|
||||||
|
|
||||||
|
if strings.Count(content, "[profiles.ollama-launch]") != 1 {
|
||||||
|
t.Errorf("expected exactly one [profiles.ollama-launch] section after two calls, got %d", strings.Count(content, "[profiles.ollama-launch]"))
|
||||||
|
}
|
||||||
|
if strings.Count(content, "[model_providers.ollama-launch]") != 1 {
|
||||||
|
t.Errorf("expected exactly one [model_providers.ollama-launch] section after two calls, got %d", strings.Count(content, "[model_providers.ollama-launch]"))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
598
cmd/launch/command_test.go
Normal file
@@ -0,0 +1,598 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/ollama/ollama/cmd/config"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
func captureStderr(t *testing.T, fn func()) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
oldStderr := os.Stderr
|
||||||
|
r, w, err := os.Pipe()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create stderr pipe: %v", err)
|
||||||
|
}
|
||||||
|
os.Stderr = w
|
||||||
|
defer func() {
|
||||||
|
os.Stderr = oldStderr
|
||||||
|
}()
|
||||||
|
|
||||||
|
done := make(chan string, 1)
|
||||||
|
go func() {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
_, _ = io.Copy(&buf, r)
|
||||||
|
done <- buf.String()
|
||||||
|
}()
|
||||||
|
|
||||||
|
fn()
|
||||||
|
|
||||||
|
_ = w.Close()
|
||||||
|
return <-done
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmd(t *testing.T) {
|
||||||
|
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
mockTUI := func(cmd *cobra.Command) {}
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
|
||||||
|
t.Run("command structure", func(t *testing.T) {
|
||||||
|
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
|
||||||
|
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]")
|
||||||
|
}
|
||||||
|
if cmd.Short == "" {
|
||||||
|
t.Error("Short description should not be empty")
|
||||||
|
}
|
||||||
|
if cmd.Long == "" {
|
||||||
|
t.Error("Long description should not be empty")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("flags exist", func(t *testing.T) {
|
||||||
|
if cmd.Flags().Lookup("model") == nil {
|
||||||
|
t.Error("--model flag should exist")
|
||||||
|
}
|
||||||
|
if cmd.Flags().Lookup("config") == nil {
|
||||||
|
t.Error("--config flag should exist")
|
||||||
|
}
|
||||||
|
if cmd.Flags().Lookup("yes") == nil {
|
||||||
|
t.Error("--yes flag should exist")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("PreRunE is set", func(t *testing.T) {
|
||||||
|
if cmd.PreRunE == nil {
|
||||||
|
t.Error("PreRunE should be set to checkServerHeartbeat")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdTUICallback(t *testing.T) {
|
||||||
|
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("no args calls TUI", func(t *testing.T) {
|
||||||
|
tuiCalled := false
|
||||||
|
mockTUI := func(cmd *cobra.Command) {
|
||||||
|
tuiCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
cmd.SetArgs([]string{})
|
||||||
|
_ = cmd.Execute()
|
||||||
|
|
||||||
|
if !tuiCalled {
|
||||||
|
t.Error("TUI callback should be called when no args provided")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("integration arg bypasses TUI", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.NotFoundHandler())
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
tuiCalled := false
|
||||||
|
mockTUI := func(cmd *cobra.Command) {
|
||||||
|
tuiCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
cmd.SetArgs([]string{"claude"})
|
||||||
|
_ = cmd.Execute()
|
||||||
|
|
||||||
|
if tuiCalled {
|
||||||
|
t.Error("TUI callback should NOT be called when integration arg provided")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("--model flag without integration returns error", func(t *testing.T) {
|
||||||
|
tuiCalled := false
|
||||||
|
mockTUI := func(cmd *cobra.Command) {
|
||||||
|
tuiCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
cmd.SetArgs([]string{"--model", "test-model"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected --model without an integration to fail")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "require an integration name") {
|
||||||
|
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if tuiCalled {
|
||||||
|
t.Error("TUI callback should NOT be called when --model is provided without an integration")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("--config flag without integration returns error", func(t *testing.T) {
|
||||||
|
tuiCalled := false
|
||||||
|
mockTUI := func(cmd *cobra.Command) {
|
||||||
|
tuiCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
cmd.SetArgs([]string{"--config"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected --config without an integration to fail")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "require an integration name") {
|
||||||
|
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if tuiCalled {
|
||||||
|
t.Error("TUI callback should NOT be called when --config is provided without an integration")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("--yes flag without integration returns error", func(t *testing.T) {
|
||||||
|
tuiCalled := false
|
||||||
|
mockTUI := func(cmd *cobra.Command) {
|
||||||
|
tuiCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
cmd.SetArgs([]string{"--yes"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected --yes without an integration to fail")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "require an integration name") {
|
||||||
|
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if tuiCalled {
|
||||||
|
t.Error("TUI callback should NOT be called when --yes is provided without an integration")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("extra args without integration return error", func(t *testing.T) {
|
||||||
|
tuiCalled := false
|
||||||
|
mockTUI := func(cmd *cobra.Command) {
|
||||||
|
tuiCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
cmd.SetArgs([]string{"--model", "test-model", "--", "--sandbox", "workspace-write"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected flags and extra args without an integration to fail")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "require an integration name") {
|
||||||
|
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if tuiCalled {
|
||||||
|
t.Error("TUI callback should NOT be called when flags or extra args are provided without an integration")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdNilHeartbeat(t *testing.T) {
|
||||||
|
cmd := LaunchCmd(nil, nil)
|
||||||
|
if cmd == nil {
|
||||||
|
t.Fatal("LaunchCmd returned nil")
|
||||||
|
}
|
||||||
|
if cmd.PreRunE != nil {
|
||||||
|
t.Log("Note: PreRunE is set even when nil is passed (acceptable)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdModelFlagFiltersDisabledCloudFromSavedConfig(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
if err := config.SaveIntegration("stubeditor", []string{"glm-5:cloud"}); err != nil {
|
||||||
|
t.Fatalf("failed to seed saved config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/status":
|
||||||
|
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprintf(w, `{"model":"llama3.2"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherEditorRunner{}
|
||||||
|
restore := OverrideIntegration("stubeditor", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
|
||||||
|
if err := cmd.Execute(); err != nil {
|
||||||
|
t.Fatalf("launch command failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
saved, err := config.LoadIntegration("stubeditor")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to reload integration config: %v", err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
|
||||||
|
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
|
||||||
|
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "llama3.2" {
|
||||||
|
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdModelFlagClearsDisabledCloudOverride(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/status":
|
||||||
|
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||||
|
case "/api/tags":
|
||||||
|
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherSingleRunner{}
|
||||||
|
restore := OverrideIntegration("stubapp", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
oldSelector := DefaultSingleSelector
|
||||||
|
defer func() { DefaultSingleSelector = oldSelector }()
|
||||||
|
|
||||||
|
var selectorCalls int
|
||||||
|
var gotCurrent string
|
||||||
|
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||||
|
selectorCalls++
|
||||||
|
gotCurrent = current
|
||||||
|
return "llama3.2", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubapp", "--model", "glm-5:cloud"})
|
||||||
|
stderr := captureStderr(t, func() {
|
||||||
|
if err := cmd.Execute(); err != nil {
|
||||||
|
t.Fatalf("launch command failed: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if selectorCalls != 1 {
|
||||||
|
t.Fatalf("expected disabled cloud override to fall back to selector, got %d calls", selectorCalls)
|
||||||
|
}
|
||||||
|
if gotCurrent != "" {
|
||||||
|
t.Fatalf("expected disabled override to be cleared before selection, got current %q", gotCurrent)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "llama3.2" {
|
||||||
|
t.Fatalf("expected launch to run with replacement local model, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
if !strings.Contains(stderr, "Warning: ignoring --model glm-5:cloud because cloud is disabled") {
|
||||||
|
t.Fatalf("expected disabled-cloud warning, got stderr: %q", stderr)
|
||||||
|
}
|
||||||
|
|
||||||
|
saved, err := config.LoadIntegration("stubapp")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to reload integration config: %v", err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
|
||||||
|
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdYes_AutoConfirmsLaunchPromptPath(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
withInteractiveSession(t, false)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||||
|
case "/api/status":
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprint(w, `{"error":"not found"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherEditorRunner{paths: []string{"/tmp/stubeditor.json"}}
|
||||||
|
restore := OverrideIntegration("stubeditor", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||||
|
t.Fatalf("unexpected prompt with --yes: %q", prompt)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2", "--yes"})
|
||||||
|
if err := cmd.Execute(); err != nil {
|
||||||
|
t.Fatalf("launch command with --yes failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
|
||||||
|
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "llama3.2" {
|
||||||
|
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdHeadlessWithYes_AutoPullsMissingLocalModel(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
withInteractiveSession(t, false)
|
||||||
|
|
||||||
|
var pullCalled bool
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprint(w, `{"error":"model not found"}`)
|
||||||
|
case "/api/pull":
|
||||||
|
pullCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprint(w, `{"status":"success"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherSingleRunner{}
|
||||||
|
restore := OverrideIntegration("stubapp", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||||
|
t.Fatalf("unexpected prompt with --yes in headless autopull path: %q", prompt)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubapp", "--model", "missing-model", "--yes"})
|
||||||
|
if err := cmd.Execute(); err != nil {
|
||||||
|
t.Fatalf("launch command with --yes failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !pullCalled {
|
||||||
|
t.Fatal("expected missing local model to be auto-pulled with --yes in headless mode")
|
||||||
|
}
|
||||||
|
if stub.ranModel != "missing-model" {
|
||||||
|
t.Fatalf("expected launch to run with pulled model, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdHeadlessWithoutYes_ReturnsActionableConfirmError(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
withInteractiveSession(t, false)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||||
|
case "/api/status":
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprint(w, `{"error":"not found"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherEditorRunner{paths: []string{"/tmp/stubeditor.json"}}
|
||||||
|
restore := OverrideIntegration("stubeditor", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||||
|
t.Fatalf("unexpected prompt in headless non-yes mode: %q", prompt)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected launch command to fail without --yes in headless mode")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "re-run with --yes") {
|
||||||
|
t.Fatalf("expected actionable --yes guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if len(stub.edited) != 0 {
|
||||||
|
t.Fatalf("expected no editor writes when confirmation is blocked, got %v", stub.edited)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "" {
|
||||||
|
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdIntegrationArgPromptsForModelWithSavedSelection(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
if err := config.SaveIntegration("stubapp", []string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatalf("failed to seed saved config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/tags":
|
||||||
|
fmt.Fprint(w, `{"models":[{"name":"llama3.2"},{"name":"qwen3:8b"}]}`)
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model":"qwen3:8b"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherSingleRunner{}
|
||||||
|
restore := OverrideIntegration("stubapp", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
oldSelector := DefaultSingleSelector
|
||||||
|
defer func() { DefaultSingleSelector = oldSelector }()
|
||||||
|
|
||||||
|
var gotCurrent string
|
||||||
|
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||||
|
gotCurrent = current
|
||||||
|
return "qwen3:8b", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubapp"})
|
||||||
|
if err := cmd.Execute(); err != nil {
|
||||||
|
t.Fatalf("launch command failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotCurrent != "llama3.2" {
|
||||||
|
t.Fatalf("expected selector current model to be saved model llama3.2, got %q", gotCurrent)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "qwen3:8b" {
|
||||||
|
t.Fatalf("expected launch to run selected model qwen3:8b, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
saved, err := config.LoadIntegration("stubapp")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to reload integration config: %v", err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff([]string{"qwen3:8b"}, saved.Models); diff != "" {
|
||||||
|
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdHeadlessYes_IntegrationRequiresModelEvenWhenSaved(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
withInteractiveSession(t, false)
|
||||||
|
|
||||||
|
if err := config.SaveIntegration("stubapp", []string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatalf("failed to seed saved config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherSingleRunner{}
|
||||||
|
restore := OverrideIntegration("stubapp", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
oldSelector := DefaultSingleSelector
|
||||||
|
defer func() { DefaultSingleSelector = oldSelector }()
|
||||||
|
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||||
|
t.Fatal("selector should not be called for headless --yes saved-model launch")
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubapp", "--yes"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected launch command to fail when --yes is used headlessly without --model")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "requires --model <model>") {
|
||||||
|
t.Fatalf("expected actionable --model guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "" {
|
||||||
|
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdHeadlessYes_IntegrationWithoutSavedModelReturnsError(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
withInteractiveSession(t, false)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherSingleRunner{}
|
||||||
|
restore := OverrideIntegration("stubapp", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
oldSelector := DefaultSingleSelector
|
||||||
|
defer func() { DefaultSingleSelector = oldSelector }()
|
||||||
|
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||||
|
t.Fatal("selector should not be called for headless --yes without saved model")
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubapp", "--yes"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected launch command to fail when --yes is used headlessly without --model")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "requires --model <model>") {
|
||||||
|
t.Fatalf("expected actionable --model guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "" {
|
||||||
|
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,16 +1,14 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -47,25 +45,6 @@ func (d *Droid) Run(model string, args []string) error {
|
|||||||
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
|
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call Edit() to ensure config is up-to-date before launch
|
|
||||||
models := []string{model}
|
|
||||||
if config, err := loadIntegration("droid"); err == nil && len(config.Models) > 0 {
|
|
||||||
models = config.Models
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
models, err = resolveEditorModels("droid", models, func() ([]string, error) {
|
|
||||||
return selectModels(context.Background(), "droid", "")
|
|
||||||
})
|
|
||||||
if errors.Is(err, errCancelled) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := d.Edit(models); err != nil {
|
|
||||||
return fmt.Errorf("setup failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command("droid", args...)
|
cmd := exec.Command("droid", args...)
|
||||||
cmd.Stdin = os.Stdin
|
cmd.Stdin = os.Stdin
|
||||||
cmd.Stdout = os.Stdout
|
cmd.Stdout = os.Stdout
|
||||||
@@ -111,6 +90,16 @@ func (d *Droid) Edit(models []string) error {
|
|||||||
json.Unmarshal(data, &settings) // ignore error, zero values are fine
|
json.Unmarshal(data, &settings) // ignore error, zero values are fine
|
||||||
}
|
}
|
||||||
|
|
||||||
|
settingsMap = updateDroidSettings(settingsMap, settings, models)
|
||||||
|
|
||||||
|
data, err := json.MarshalIndent(settingsMap, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return fileutil.WriteWithBackup(settingsPath, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateDroidSettings(settingsMap map[string]any, settings droidSettings, models []string) map[string]any {
|
||||||
// Keep only non-Ollama models from the raw map (preserves extra fields)
|
// Keep only non-Ollama models from the raw map (preserves extra fields)
|
||||||
// Rebuild Ollama models
|
// Rebuild Ollama models
|
||||||
var nonOllamaModels []any
|
var nonOllamaModels []any
|
||||||
@@ -125,13 +114,12 @@ func (d *Droid) Edit(models []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
|
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
|
||||||
client, _ := api.ClientFromEnvironment()
|
|
||||||
|
|
||||||
var newModels []any
|
var newModels []any
|
||||||
var defaultModelID string
|
var defaultModelID string
|
||||||
for i, model := range models {
|
for i, model := range models {
|
||||||
maxOutput := 64000
|
maxOutput := 64000
|
||||||
if isCloudModel(context.Background(), client, model) {
|
if isCloudModelName(model) {
|
||||||
if l, ok := lookupCloudModelLimit(model); ok {
|
if l, ok := lookupCloudModelLimit(model); ok {
|
||||||
maxOutput = l.Output
|
maxOutput = l.Output
|
||||||
}
|
}
|
||||||
@@ -167,12 +155,7 @@ func (d *Droid) Edit(models []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
settingsMap["sessionDefaultSettings"] = sessionSettings
|
settingsMap["sessionDefaultSettings"] = sessionSettings
|
||||||
|
return settingsMap
|
||||||
data, err := json.MarshalIndent(settingsMap, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return writeWithBackup(settingsPath, data)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Droid) Models() []string {
|
func (d *Droid) Models() []string {
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -6,6 +6,8 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDroidIntegration(t *testing.T) {
|
func TestDroidIntegration(t *testing.T) {
|
||||||
@@ -362,7 +364,7 @@ func TestDroidEdit_DuplicateModels(t *testing.T) {
|
|||||||
t.Fatalf("Edit with duplicates failed: %v", err)
|
t.Fatalf("Edit with duplicates failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := readJSONFile(settingsPath)
|
settings, err := fileutil.ReadJSON(settingsPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("readJSONFile failed: %v", err)
|
t.Fatalf("readJSONFile failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -392,7 +394,7 @@ func TestDroidEdit_MalformedModelEntry(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Malformed entries (non-object) are dropped - only valid model objects are preserved
|
// Malformed entries (non-object) are dropped - only valid model objects are preserved
|
||||||
settings, _ := readJSONFile(settingsPath)
|
settings, _ := fileutil.ReadJSON(settingsPath)
|
||||||
customModels, _ := settings["customModels"].([]any)
|
customModels, _ := settings["customModels"].([]any)
|
||||||
|
|
||||||
// Should have: 1 new Ollama model only (malformed entries dropped)
|
// Should have: 1 new Ollama model only (malformed entries dropped)
|
||||||
@@ -419,7 +421,7 @@ func TestDroidEdit_WrongTypeSessionSettings(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Should create proper sessionDefaultSettings
|
// Should create proper sessionDefaultSettings
|
||||||
settings, _ := readJSONFile(settingsPath)
|
settings, _ := fileutil.ReadJSON(settingsPath)
|
||||||
session, ok := settings["sessionDefaultSettings"].(map[string]any)
|
session, ok := settings["sessionDefaultSettings"].(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("sessionDefaultSettings should be map after setup, got %T", settings["sessionDefaultSettings"])
|
t.Fatalf("sessionDefaultSettings should be map after setup, got %T", settings["sessionDefaultSettings"])
|
||||||
@@ -1008,34 +1010,34 @@ func TestDroidEdit_ModelNamesWithSpecialCharacters(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDroidEdit_MissingCustomModelsKey(t *testing.T) {
|
func TestDroidEdit_MissingCustomModelsKey(t *testing.T) {
|
||||||
d := &Droid{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
settingsDir := filepath.Join(tmpDir, ".factory")
|
|
||||||
settingsPath := filepath.Join(settingsDir, "settings.json")
|
|
||||||
|
|
||||||
os.MkdirAll(settingsDir, 0o755)
|
|
||||||
|
|
||||||
// No customModels key at all
|
// No customModels key at all
|
||||||
original := `{
|
original := `{
|
||||||
"diffMode": "github",
|
"diffMode": "github",
|
||||||
"sessionDefaultSettings": {"autonomyMode": "auto-high"}
|
"sessionDefaultSettings": {"autonomyMode": "auto-high"}
|
||||||
}`
|
}`
|
||||||
os.WriteFile(settingsPath, []byte(original), 0o644)
|
|
||||||
|
|
||||||
if err := d.Edit([]string{"model-a"}); err != nil {
|
var settingsStruct droidSettings
|
||||||
|
var settings map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(original), &settings); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(original), &settingsStruct); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
data, _ := os.ReadFile(settingsPath)
|
settings = updateDroidSettings(settings, settingsStruct, []string{"model-a"})
|
||||||
var settings map[string]any
|
|
||||||
json.Unmarshal(data, &settings)
|
|
||||||
|
|
||||||
// Original fields preserved
|
// Original fields preserved
|
||||||
if settings["diffMode"] != "github" {
|
if settings["diffMode"] != "github" {
|
||||||
t.Error("diffMode not preserved")
|
t.Error("diffMode not preserved")
|
||||||
}
|
}
|
||||||
|
session, ok := settings["sessionDefaultSettings"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("sessionDefaultSettings not preserved")
|
||||||
|
}
|
||||||
|
if session["autonomyMode"] != "auto-high" {
|
||||||
|
t.Error("sessionDefaultSettings.autonomyMode not preserved")
|
||||||
|
}
|
||||||
|
|
||||||
// customModels created
|
// customModels created
|
||||||
models, ok := settings["customModels"].([]any)
|
models, ok := settings["customModels"].([]any)
|
||||||
@@ -1276,25 +1278,17 @@ func TestDroidEdit_LocalModelDefaultMaxOutput(t *testing.T) {
|
|||||||
|
|
||||||
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
|
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
|
||||||
// Verify that every cloud model in cloudModelLimits has a valid output
|
// Verify that every cloud model in cloudModelLimits has a valid output
|
||||||
// value that would be used for maxOutputTokens when isCloudModel returns true.
|
// value that would be used for maxOutputTokens when the selected model uses
|
||||||
// :cloud suffix stripping must also work since that's how users specify them.
|
// the explicit :cloud source tag.
|
||||||
for name, expected := range cloudModelLimits {
|
for name, expected := range cloudModelLimits {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
l, ok := lookupCloudModelLimit(name)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("lookupCloudModelLimit(%q) returned false", name)
|
|
||||||
}
|
|
||||||
if l.Output != expected.Output {
|
|
||||||
t.Errorf("output = %d, want %d", l.Output, expected.Output)
|
|
||||||
}
|
|
||||||
// Also verify :cloud suffix lookup
|
|
||||||
cloudName := name + ":cloud"
|
cloudName := name + ":cloud"
|
||||||
l2, ok := lookupCloudModelLimit(cloudName)
|
l, ok := lookupCloudModelLimit(cloudName)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName)
|
t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName)
|
||||||
}
|
}
|
||||||
if l2.Output != expected.Output {
|
if l.Output != expected.Output {
|
||||||
t.Errorf(":cloud output = %d, want %d", l2.Output, expected.Output)
|
t.Errorf("output = %d, want %d", l.Output, expected.Output)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
879
cmd/launch/launch.go
Normal file
@@ -0,0 +1,879 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/cmd/config"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"golang.org/x/term"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LauncherState is the launch-owned snapshot used to render the root launcher menu.
|
||||||
|
type LauncherState struct {
|
||||||
|
LastSelection string
|
||||||
|
RunModel string
|
||||||
|
RunModelUsable bool
|
||||||
|
Integrations map[string]LauncherIntegrationState
|
||||||
|
}
|
||||||
|
|
||||||
|
// LauncherIntegrationState is the launch-owned status for one launcher integration.
|
||||||
|
type LauncherIntegrationState struct {
|
||||||
|
Name string
|
||||||
|
DisplayName string
|
||||||
|
Description string
|
||||||
|
Installed bool
|
||||||
|
AutoInstallable bool
|
||||||
|
Selectable bool
|
||||||
|
Changeable bool
|
||||||
|
CurrentModel string
|
||||||
|
ModelUsable bool
|
||||||
|
InstallHint string
|
||||||
|
Editor bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunModelRequest controls how the root launcher resolves the chat model.
|
||||||
|
type RunModelRequest struct {
|
||||||
|
ForcePicker bool
|
||||||
|
Policy *LaunchPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
// LaunchConfirmMode controls confirmation behavior across launch flows.
|
||||||
|
type LaunchConfirmMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// LaunchConfirmPrompt prompts the user for confirmation.
|
||||||
|
LaunchConfirmPrompt LaunchConfirmMode = iota
|
||||||
|
// LaunchConfirmAutoApprove skips prompts and treats confirmation as accepted.
|
||||||
|
LaunchConfirmAutoApprove
|
||||||
|
// LaunchConfirmRequireYes rejects confirmation requests with a --yes hint.
|
||||||
|
LaunchConfirmRequireYes
|
||||||
|
)
|
||||||
|
|
||||||
|
// LaunchMissingModelMode controls local missing-model handling in launch flows.
|
||||||
|
type LaunchMissingModelMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// LaunchMissingModelPromptToPull prompts to pull a missing local model.
|
||||||
|
LaunchMissingModelPromptToPull LaunchMissingModelMode = iota
|
||||||
|
// LaunchMissingModelAutoPull pulls a missing local model without prompting.
|
||||||
|
LaunchMissingModelAutoPull
|
||||||
|
// LaunchMissingModelFail fails immediately when a local model is missing.
|
||||||
|
LaunchMissingModelFail
|
||||||
|
)
|
||||||
|
|
||||||
|
// LaunchPolicy controls launch behavior that may vary by caller context.
|
||||||
|
type LaunchPolicy struct {
|
||||||
|
Confirm LaunchConfirmMode
|
||||||
|
MissingModel LaunchMissingModelMode
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultLaunchPolicy(interactive bool, yes bool) LaunchPolicy {
|
||||||
|
policy := LaunchPolicy{
|
||||||
|
Confirm: LaunchConfirmPrompt,
|
||||||
|
MissingModel: LaunchMissingModelPromptToPull,
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case yes:
|
||||||
|
// if yes flag is set, auto approve and auto pull
|
||||||
|
policy.Confirm = LaunchConfirmAutoApprove
|
||||||
|
policy.MissingModel = LaunchMissingModelAutoPull
|
||||||
|
case !interactive:
|
||||||
|
// otherwise make sure to stop when needed
|
||||||
|
policy.Confirm = LaunchConfirmRequireYes
|
||||||
|
policy.MissingModel = LaunchMissingModelFail
|
||||||
|
}
|
||||||
|
return policy
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p LaunchPolicy) confirmPolicy() launchConfirmPolicy {
|
||||||
|
switch p.Confirm {
|
||||||
|
case LaunchConfirmAutoApprove:
|
||||||
|
return launchConfirmPolicy{yes: true}
|
||||||
|
case LaunchConfirmRequireYes:
|
||||||
|
return launchConfirmPolicy{requireYesMessage: true}
|
||||||
|
default:
|
||||||
|
return launchConfirmPolicy{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p LaunchPolicy) missingModelPolicy() missingModelPolicy {
|
||||||
|
switch p.MissingModel {
|
||||||
|
case LaunchMissingModelAutoPull:
|
||||||
|
return missingModelAutoPull
|
||||||
|
case LaunchMissingModelFail:
|
||||||
|
return missingModelFail
|
||||||
|
default:
|
||||||
|
return missingModelPromptPull
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntegrationLaunchRequest controls the canonical integration launcher flow.
|
||||||
|
type IntegrationLaunchRequest struct {
|
||||||
|
Name string
|
||||||
|
ModelOverride string
|
||||||
|
ForceConfigure bool
|
||||||
|
ConfigureOnly bool
|
||||||
|
ExtraArgs []string
|
||||||
|
Policy *LaunchPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
var isInteractiveSession = func() bool {
|
||||||
|
return term.IsTerminal(int(os.Stdin.Fd())) && term.IsTerminal(int(os.Stdout.Fd()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Runner executes a model with an integration.
|
||||||
|
type Runner interface {
|
||||||
|
Run(model string, args []string) error
|
||||||
|
String() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Editor can edit config files for integrations that support model configuration.
|
||||||
|
type Editor interface {
|
||||||
|
Paths() []string
|
||||||
|
Edit(models []string) error
|
||||||
|
Models() []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type modelInfo struct {
|
||||||
|
Name string
|
||||||
|
Remote bool
|
||||||
|
ToolCapable bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelInfo re-exports launcher model inventory details for callers.
|
||||||
|
type ModelInfo = modelInfo
|
||||||
|
|
||||||
|
// ModelItem represents a model for selection UIs.
|
||||||
|
type ModelItem struct {
|
||||||
|
Name string
|
||||||
|
Description string
|
||||||
|
Recommended bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// LaunchCmd returns the cobra command for launching integrations.
|
||||||
|
// The runTUI callback is called when the root launcher UI should be shown.
|
||||||
|
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error, runTUI func(cmd *cobra.Command)) *cobra.Command {
|
||||||
|
var modelFlag string
|
||||||
|
var configFlag bool
|
||||||
|
var yesFlag bool
|
||||||
|
|
||||||
|
cmd := &cobra.Command{
|
||||||
|
Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
|
||||||
|
Short: "Launch the Ollama menu or an integration",
|
||||||
|
Long: `Launch the Ollama interactive menu, or directly launch a specific integration.
|
||||||
|
|
||||||
|
Without arguments, this is equivalent to running 'ollama' directly.
|
||||||
|
Flags and extra arguments require an integration name.
|
||||||
|
|
||||||
|
Supported integrations:
|
||||||
|
claude Claude Code
|
||||||
|
cline Cline
|
||||||
|
codex Codex
|
||||||
|
droid Droid
|
||||||
|
opencode OpenCode
|
||||||
|
openclaw OpenClaw (aliases: clawdbot, moltbot)
|
||||||
|
pi Pi
|
||||||
|
vscode VS Code (aliases: code)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
ollama launch
|
||||||
|
ollama launch claude
|
||||||
|
ollama launch claude --model <model>
|
||||||
|
ollama launch droid --config (does not auto-launch)
|
||||||
|
ollama launch codex -- -p myprofile (pass extra args to integration)
|
||||||
|
ollama launch codex -- --sandbox workspace-write`,
|
||||||
|
Args: cobra.ArbitraryArgs,
|
||||||
|
PreRunE: checkServerHeartbeat,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
policy := defaultLaunchPolicy(isInteractiveSession(), yesFlag)
|
||||||
|
// reset when done to make sure state doens't leak between launches
|
||||||
|
restoreConfirmPolicy := withLaunchConfirmPolicy(policy.confirmPolicy())
|
||||||
|
defer restoreConfirmPolicy()
|
||||||
|
|
||||||
|
var name string
|
||||||
|
var passArgs []string
|
||||||
|
dashIdx := cmd.ArgsLenAtDash()
|
||||||
|
|
||||||
|
if dashIdx == -1 {
|
||||||
|
if len(args) > 1 {
|
||||||
|
return fmt.Errorf("unexpected arguments: %v\nUse '--' to pass extra arguments to the integration", args[1:])
|
||||||
|
}
|
||||||
|
if len(args) == 1 {
|
||||||
|
name = args[0]
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if dashIdx > 1 {
|
||||||
|
return fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx)
|
||||||
|
}
|
||||||
|
if dashIdx == 1 {
|
||||||
|
name = args[0]
|
||||||
|
}
|
||||||
|
passArgs = args[dashIdx:]
|
||||||
|
}
|
||||||
|
|
||||||
|
if name == "" {
|
||||||
|
if cmd.Flags().Changed("model") || cmd.Flags().Changed("config") || cmd.Flags().Changed("yes") || len(passArgs) > 0 {
|
||||||
|
return fmt.Errorf("flags and extra args require an integration name, for example: 'ollama launch claude --model qwen3.5'")
|
||||||
|
}
|
||||||
|
runTUI(cmd)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelFlag != "" && isCloudModelName(modelFlag) {
|
||||||
|
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||||
|
if disabled, _ := cloudStatusDisabled(cmd.Context(), client); disabled {
|
||||||
|
fmt.Fprintf(os.Stderr, "Warning: ignoring --model %s because cloud is disabled\n", modelFlag)
|
||||||
|
modelFlag = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
headlessYes := yesFlag && !isInteractiveSession()
|
||||||
|
err := LaunchIntegration(cmd.Context(), IntegrationLaunchRequest{
|
||||||
|
Name: name,
|
||||||
|
ModelOverride: modelFlag,
|
||||||
|
ForceConfigure: configFlag || (modelFlag == "" && !headlessYes),
|
||||||
|
ConfigureOnly: configFlag,
|
||||||
|
ExtraArgs: passArgs,
|
||||||
|
Policy: &policy,
|
||||||
|
})
|
||||||
|
if errors.Is(err, ErrCancelled) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Flags().StringVar(&modelFlag, "model", "", "Model to use")
|
||||||
|
cmd.Flags().BoolVar(&configFlag, "config", false, "Configure without launching")
|
||||||
|
cmd.Flags().BoolVarP(&yesFlag, "yes", "y", false, "Automatically answer yes to confirmation prompts")
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
type launcherClient struct {
|
||||||
|
apiClient *api.Client
|
||||||
|
modelInventory []ModelInfo
|
||||||
|
inventoryLoaded bool
|
||||||
|
policy LaunchPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLauncherClient(policy LaunchPolicy) (*launcherClient, error) {
|
||||||
|
apiClient, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &launcherClient{
|
||||||
|
apiClient: apiClient,
|
||||||
|
policy: policy,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildLauncherState returns the launch-owned root launcher menu snapshot.
|
||||||
|
func BuildLauncherState(ctx context.Context) (*LauncherState, error) {
|
||||||
|
launchClient, err := newLauncherClient(defaultLaunchPolicy(isInteractiveSession(), false))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return launchClient.buildLauncherState(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveRunModel returns the model that should be used for interactive chat.
|
||||||
|
func ResolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
|
||||||
|
// Called by the launcher TUI "Run a model" action (cmd/runLauncherAction),
|
||||||
|
// which resolves models separately from LaunchIntegration. Callers can pass
|
||||||
|
// Policy directly; otherwise we fall back to ambient --yes/session defaults.
|
||||||
|
policy := defaultLaunchPolicy(isInteractiveSession(), currentLaunchConfirmPolicy.yes)
|
||||||
|
if req.Policy != nil {
|
||||||
|
policy = *req.Policy
|
||||||
|
}
|
||||||
|
|
||||||
|
launchClient, err := newLauncherClient(policy)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return launchClient.resolveRunModel(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LaunchIntegration runs the canonical launcher flow for one integration.
|
||||||
|
func LaunchIntegration(ctx context.Context, req IntegrationLaunchRequest) error {
|
||||||
|
name, runner, err := LookupIntegration(req.Name)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !req.ConfigureOnly {
|
||||||
|
if err := EnsureIntegrationInstalled(name, runner); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var policy LaunchPolicy
|
||||||
|
// TUI does not set a policy, whereas ollama launch <app> does as it can have flags which change the behavior
|
||||||
|
if req.Policy == nil {
|
||||||
|
policy = defaultLaunchPolicy(isInteractiveSession(), false)
|
||||||
|
} else {
|
||||||
|
policy = *req.Policy
|
||||||
|
}
|
||||||
|
|
||||||
|
launchClient, err := newLauncherClient(policy)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
saved, _ := loadStoredIntegrationConfig(name)
|
||||||
|
// In headless --yes mode we cannot prompt, so require an explicit --model.
|
||||||
|
if policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() && req.ModelOverride == "" {
|
||||||
|
return fmt.Errorf("headless --yes launch for %s requires --model <model>", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if editor, ok := runner.(Editor); ok {
|
||||||
|
return launchClient.launchEditorIntegration(ctx, name, runner, editor, saved, req)
|
||||||
|
}
|
||||||
|
return launchClient.launchSingleIntegration(ctx, name, runner, saved, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) buildLauncherState(ctx context.Context) (*LauncherState, error) {
|
||||||
|
_ = c.loadModelInventoryOnce(ctx)
|
||||||
|
|
||||||
|
state := &LauncherState{
|
||||||
|
LastSelection: config.LastSelection(),
|
||||||
|
RunModel: config.LastModel(),
|
||||||
|
Integrations: make(map[string]LauncherIntegrationState),
|
||||||
|
}
|
||||||
|
runModelUsable, err := c.savedModelUsable(ctx, state.RunModel)
|
||||||
|
if err != nil {
|
||||||
|
runModelUsable = false
|
||||||
|
}
|
||||||
|
state.RunModelUsable = runModelUsable
|
||||||
|
|
||||||
|
for _, info := range ListIntegrationInfos() {
|
||||||
|
integrationState, err := c.buildLauncherIntegrationState(ctx, info)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
state.Integrations[info.Name] = integrationState
|
||||||
|
}
|
||||||
|
|
||||||
|
return state, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) buildLauncherIntegrationState(ctx context.Context, info IntegrationInfo) (LauncherIntegrationState, error) {
|
||||||
|
integration, err := integrationFor(info.Name)
|
||||||
|
if err != nil {
|
||||||
|
return LauncherIntegrationState{}, err
|
||||||
|
}
|
||||||
|
currentModel, usable, err := c.launcherModelState(ctx, info.Name, integration.editor)
|
||||||
|
if err != nil {
|
||||||
|
return LauncherIntegrationState{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return LauncherIntegrationState{
|
||||||
|
Name: info.Name,
|
||||||
|
DisplayName: info.DisplayName,
|
||||||
|
Description: info.Description,
|
||||||
|
Installed: integration.installed,
|
||||||
|
AutoInstallable: integration.autoInstallable,
|
||||||
|
Selectable: integration.installed || integration.autoInstallable,
|
||||||
|
Changeable: integration.installed || integration.autoInstallable,
|
||||||
|
CurrentModel: currentModel,
|
||||||
|
ModelUsable: usable,
|
||||||
|
InstallHint: integration.installHint,
|
||||||
|
Editor: integration.editor,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) launcherModelState(ctx context.Context, name string, isEditor bool) (string, bool, error) {
|
||||||
|
cfg, loadErr := loadStoredIntegrationConfig(name)
|
||||||
|
hasModels := loadErr == nil && len(cfg.Models) > 0
|
||||||
|
if !hasModels {
|
||||||
|
return "", false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if isEditor {
|
||||||
|
filtered := c.filterDisabledCloudModels(ctx, cfg.Models)
|
||||||
|
if len(filtered) > 0 {
|
||||||
|
return filtered[0], true, nil
|
||||||
|
}
|
||||||
|
return cfg.Models[0], false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
model := cfg.Models[0]
|
||||||
|
usable, usableErr := c.savedModelUsable(ctx, model)
|
||||||
|
return model, usableErr == nil && usable, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
|
||||||
|
current := config.LastModel()
|
||||||
|
if !req.ForcePicker && current != "" && c.policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() {
|
||||||
|
if err := c.ensureModelsReady(ctx, []string{current}); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "Headless mode: auto-selected last used model %q\n", current)
|
||||||
|
return current, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !req.ForcePicker {
|
||||||
|
usable, err := c.savedModelUsable(ctx, current)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if usable {
|
||||||
|
if err := c.ensureModelsReady(ctx, []string{current}); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return current, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
model, err := c.selectSingleModelWithSelector(ctx, "Select model to run:", current, DefaultSingleSelector)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if model != current {
|
||||||
|
if err := config.SetLastModel(model); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return model, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) launchSingleIntegration(ctx context.Context, name string, runner Runner, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
|
||||||
|
current := primaryModelFromConfig(saved)
|
||||||
|
target := req.ModelOverride
|
||||||
|
needsConfigure := req.ForceConfigure
|
||||||
|
|
||||||
|
if target == "" {
|
||||||
|
target = current
|
||||||
|
usable, err := c.savedModelUsable(ctx, target)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !usable {
|
||||||
|
needsConfigure = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if needsConfigure {
|
||||||
|
selected, err := c.selectSingleModelWithSelector(ctx, fmt.Sprintf("Select model for %s:", runner), target, DefaultSingleSelector)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
target = selected
|
||||||
|
} else if err := c.ensureModelsReady(ctx, []string{target}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if target == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if target != current {
|
||||||
|
if err := config.SaveIntegration(name, []string{target}); err != nil {
|
||||||
|
return fmt.Errorf("failed to save: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return launchAfterConfiguration(name, runner, target, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) launchEditorIntegration(ctx context.Context, name string, runner Runner, editor Editor, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
|
||||||
|
models, needsConfigure := c.resolveEditorLaunchModels(ctx, saved, req)
|
||||||
|
|
||||||
|
if needsConfigure {
|
||||||
|
selected, err := c.selectMultiModelsForIntegration(ctx, runner, models)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
models = selected
|
||||||
|
} else if len(models) > 0 {
|
||||||
|
if err := c.ensureModelsReady(ctx, models[:1]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if (needsConfigure || req.ModelOverride != "") && !savedMatchesModels(saved, models) {
|
||||||
|
if err := prepareEditorIntegration(name, runner, editor, models); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return launchAfterConfiguration(name, runner, models[0], req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) selectSingleModelWithSelector(ctx context.Context, title, current string, selector SingleSelector) (string, error) {
|
||||||
|
if selector == nil {
|
||||||
|
return "", fmt.Errorf("no selector configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
items, _, err := c.loadSelectableModels(ctx, nil, current, "no models available, run 'ollama pull <model>' first")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
selected, err := selector(title, items, current)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := c.ensureModelsReady(ctx, []string{selected}); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return selected, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) selectMultiModelsForIntegration(ctx context.Context, runner Runner, preChecked []string) ([]string, error) {
|
||||||
|
if DefaultMultiSelector == nil {
|
||||||
|
return nil, fmt.Errorf("no selector configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
current := firstModel(preChecked)
|
||||||
|
|
||||||
|
items, orderedChecked, err := c.loadSelectableModels(ctx, preChecked, current, "no models available")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
selected, err := DefaultMultiSelector(fmt.Sprintf("Select models for %s:", runner), items, orderedChecked)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
accepted, skipped, err := c.selectReadyModelsForSave(ctx, selected)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, skip := range skipped {
|
||||||
|
fmt.Fprintf(os.Stderr, "Skipped %s: %s\n", skip.model, skip.reason)
|
||||||
|
}
|
||||||
|
return accepted, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) loadSelectableModels(ctx context.Context, preChecked []string, current, emptyMessage string) ([]ModelItem, []string, error) {
|
||||||
|
if err := c.loadModelInventoryOnce(ctx); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||||
|
items, orderedChecked, _, _ := buildModelList(c.modelInventory, preChecked, current)
|
||||||
|
if cloudDisabled {
|
||||||
|
items = filterCloudItems(items)
|
||||||
|
orderedChecked = c.filterDisabledCloudModels(ctx, orderedChecked)
|
||||||
|
}
|
||||||
|
if len(items) == 0 {
|
||||||
|
return nil, nil, errors.New(emptyMessage)
|
||||||
|
}
|
||||||
|
return items, orderedChecked, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) ensureModelsReady(ctx context.Context, models []string) error {
|
||||||
|
models = dedupeModelList(models)
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cloudModels := make(map[string]bool, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
isCloudModel := isCloudModelName(model)
|
||||||
|
if isCloudModel {
|
||||||
|
cloudModels[model] = true
|
||||||
|
}
|
||||||
|
if err := showOrPullWithPolicy(ctx, c.apiClient, model, c.policy.missingModelPolicy(), isCloudModel); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ensureAuth(ctx, c.apiClient, cloudModels, models)
|
||||||
|
}
|
||||||
|
|
||||||
|
func dedupeModelList(models []string) []string {
|
||||||
|
deduped := make([]string, 0, len(models))
|
||||||
|
seen := make(map[string]bool, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
if model == "" || seen[model] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[model] = true
|
||||||
|
deduped = append(deduped, model)
|
||||||
|
}
|
||||||
|
return deduped
|
||||||
|
}
|
||||||
|
|
||||||
|
type skippedModel struct {
|
||||||
|
model string
|
||||||
|
reason string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) selectReadyModelsForSave(ctx context.Context, selected []string) ([]string, []skippedModel, error) {
|
||||||
|
selected = dedupeModelList(selected)
|
||||||
|
accepted := make([]string, 0, len(selected))
|
||||||
|
skipped := make([]skippedModel, 0, len(selected))
|
||||||
|
|
||||||
|
for _, model := range selected {
|
||||||
|
if err := c.ensureModelsReady(ctx, []string{model}); err != nil {
|
||||||
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
skipped = append(skipped, skippedModel{
|
||||||
|
model: model,
|
||||||
|
reason: skippedModelReason(model, err),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
accepted = append(accepted, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
return accepted, skipped, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func skippedModelReason(model string, err error) string {
|
||||||
|
if errors.Is(err, ErrCancelled) {
|
||||||
|
if isCloudModelName(model) {
|
||||||
|
return "sign in was cancelled"
|
||||||
|
}
|
||||||
|
return "download was cancelled"
|
||||||
|
}
|
||||||
|
return err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) resolveEditorLaunchModels(ctx context.Context, saved *config.IntegrationConfig, req IntegrationLaunchRequest) ([]string, bool) {
|
||||||
|
if req.ForceConfigure {
|
||||||
|
return editorPreCheckedModels(saved, req.ModelOverride), true
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.ModelOverride != "" {
|
||||||
|
models := append([]string{req.ModelOverride}, additionalSavedModels(saved, req.ModelOverride)...)
|
||||||
|
models = c.filterDisabledCloudModels(ctx, models)
|
||||||
|
return models, len(models) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if saved == nil || len(saved.Models) == 0 {
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
|
||||||
|
models := c.filterDisabledCloudModels(ctx, saved.Models)
|
||||||
|
return models, len(models) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) filterDisabledCloudModels(ctx context.Context, models []string) []string {
|
||||||
|
// if connection cannot be established or there is a 404, cloud models will continue to be displayed
|
||||||
|
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||||
|
if !cloudDisabled {
|
||||||
|
return append([]string(nil), models...)
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := make([]string, 0, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
if !isCloudModelName(model) {
|
||||||
|
filtered = append(filtered, model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) savedModelUsable(ctx context.Context, name string) (bool, error) {
|
||||||
|
if err := c.loadModelInventoryOnce(ctx); err != nil {
|
||||||
|
return c.showBasedModelUsable(ctx, name)
|
||||||
|
}
|
||||||
|
return c.singleModelUsable(ctx, name), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) showBasedModelUsable(ctx context.Context, name string) (bool, error) {
|
||||||
|
if name == "" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := c.apiClient.Show(ctx, &api.ShowRequest{Model: name})
|
||||||
|
if err != nil {
|
||||||
|
var statusErr api.StatusError
|
||||||
|
if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if isCloudModelName(name) || info.RemoteModel != "" {
|
||||||
|
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||||
|
|
||||||
|
return !cloudDisabled, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) singleModelUsable(ctx context.Context, name string) bool {
|
||||||
|
if name == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if isCloudModelName(name) {
|
||||||
|
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||||
|
return !cloudDisabled
|
||||||
|
}
|
||||||
|
return c.hasLocalModel(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) hasLocalModel(name string) bool {
|
||||||
|
for _, model := range c.modelInventory {
|
||||||
|
if model.Remote {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if model.Name == name || strings.HasPrefix(model.Name, name+":") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) loadModelInventoryOnce(ctx context.Context) error {
|
||||||
|
if c.inventoryLoaded {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.apiClient.List(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.modelInventory = c.modelInventory[:0]
|
||||||
|
for _, model := range resp.Models {
|
||||||
|
c.modelInventory = append(c.modelInventory, ModelInfo{
|
||||||
|
Name: model.Name,
|
||||||
|
Remote: model.RemoteModel != "",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||||
|
if cloudDisabled {
|
||||||
|
c.modelInventory = filterCloudModels(c.modelInventory)
|
||||||
|
}
|
||||||
|
c.inventoryLoaded = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runIntegration(runner Runner, modelName string, args []string) error {
|
||||||
|
return runner.Run(modelName, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
func launchAfterConfiguration(name string, runner Runner, model string, req IntegrationLaunchRequest) error {
|
||||||
|
if req.ConfigureOnly {
|
||||||
|
launch, err := ConfirmPrompt(fmt.Sprintf("Launch %s now?", runner))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !launch {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := EnsureIntegrationInstalled(name, runner); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return runIntegration(runner, model, req.ExtraArgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadStoredIntegrationConfig(name string) (*config.IntegrationConfig, error) {
|
||||||
|
cfg, err := config.LoadIntegration(name)
|
||||||
|
if err == nil {
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
if !errors.Is(err, os.ErrNotExist) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
spec, specErr := LookupIntegrationSpec(name)
|
||||||
|
if specErr != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, alias := range spec.Aliases {
|
||||||
|
legacy, legacyErr := config.LoadIntegration(alias)
|
||||||
|
if legacyErr == nil {
|
||||||
|
migrateLegacyIntegrationConfig(spec.Name, legacy)
|
||||||
|
if migrated, migratedErr := config.LoadIntegration(spec.Name); migratedErr == nil {
|
||||||
|
return migrated, nil
|
||||||
|
}
|
||||||
|
return legacy, nil
|
||||||
|
}
|
||||||
|
if legacyErr != nil && !errors.Is(legacyErr, os.ErrNotExist) {
|
||||||
|
return nil, legacyErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func migrateLegacyIntegrationConfig(canonical string, legacy *config.IntegrationConfig) {
|
||||||
|
if legacy == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = config.SaveIntegration(canonical, append([]string(nil), legacy.Models...))
|
||||||
|
if len(legacy.Aliases) > 0 {
|
||||||
|
_ = config.SaveAliases(canonical, cloneAliases(legacy.Aliases))
|
||||||
|
}
|
||||||
|
if legacy.Onboarded {
|
||||||
|
_ = config.MarkIntegrationOnboarded(canonical)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func primaryModelFromConfig(cfg *config.IntegrationConfig) string {
|
||||||
|
if cfg == nil || len(cfg.Models) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return cfg.Models[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneAliases(aliases map[string]string) map[string]string {
|
||||||
|
if len(aliases) == 0 {
|
||||||
|
return make(map[string]string)
|
||||||
|
}
|
||||||
|
|
||||||
|
cloned := make(map[string]string, len(aliases))
|
||||||
|
for key, value := range aliases {
|
||||||
|
cloned[key] = value
|
||||||
|
}
|
||||||
|
return cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstModel(models []string) string {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return models[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func savedMatchesModels(saved *config.IntegrationConfig, models []string) bool {
|
||||||
|
if saved == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return slices.Equal(saved.Models, models)
|
||||||
|
}
|
||||||
|
|
||||||
|
func editorPreCheckedModels(saved *config.IntegrationConfig, override string) []string {
|
||||||
|
if override == "" {
|
||||||
|
if saved == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return append([]string(nil), saved.Models...)
|
||||||
|
}
|
||||||
|
return append([]string{override}, additionalSavedModels(saved, override)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func additionalSavedModels(saved *config.IntegrationConfig, exclude string) []string {
|
||||||
|
if saved == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var models []string
|
||||||
|
for _, model := range saved.Models {
|
||||||
|
if model != exclude {
|
||||||
|
models = append(models, model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return models
|
||||||
|
}
|
||||||
1990
cmd/launch/launch_test.go
Normal file
507
cmd/launch/models.go
Normal file
@@ -0,0 +1,507 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"runtime"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/cmd/config"
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
|
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||||
|
"github.com/ollama/ollama/internal/modelref"
|
||||||
|
"github.com/ollama/ollama/progress"
|
||||||
|
)
|
||||||
|
|
||||||
|
var recommendedModels = []ModelItem{
|
||||||
|
{Name: "kimi-k2.5:cloud", Description: "Multimodal reasoning with subagents", Recommended: true},
|
||||||
|
{Name: "qwen3.5:cloud", Description: "Reasoning, coding, and agentic tool use with vision", Recommended: true},
|
||||||
|
{Name: "glm-5.1:cloud", Description: "Reasoning and code generation", Recommended: true},
|
||||||
|
{Name: "minimax-m2.7:cloud", Description: "Fast, efficient coding and real-world productivity", Recommended: true},
|
||||||
|
{Name: "gemma4", Description: "Reasoning and code generation locally", Recommended: true},
|
||||||
|
{Name: "qwen3.5", Description: "Reasoning, coding, and visual understanding locally", Recommended: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
var recommendedVRAM = map[string]string{
|
||||||
|
"gemma4": "~16GB",
|
||||||
|
"qwen3.5": "~11GB",
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloudModelLimit holds context and output token limits for a cloud model.
|
||||||
|
type cloudModelLimit struct {
|
||||||
|
Context int
|
||||||
|
Output int
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloudModelLimits maps cloud model base names to their token limits.
|
||||||
|
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
|
||||||
|
var cloudModelLimits = map[string]cloudModelLimit{
|
||||||
|
"minimax-m2.7": {Context: 204_800, Output: 128_000},
|
||||||
|
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
|
||||||
|
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
|
||||||
|
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
|
||||||
|
"gemma4:31b": {Context: 262_144, Output: 131_072},
|
||||||
|
"glm-4.6": {Context: 202_752, Output: 131_072},
|
||||||
|
"glm-4.7": {Context: 202_752, Output: 131_072},
|
||||||
|
"glm-5": {Context: 202_752, Output: 131_072},
|
||||||
|
"glm-5.1": {Context: 202_752, Output: 131_072},
|
||||||
|
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
|
||||||
|
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
|
||||||
|
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
|
||||||
|
"kimi-k2.5": {Context: 262_144, Output: 262_144},
|
||||||
|
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
|
||||||
|
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
|
||||||
|
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
|
||||||
|
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
|
||||||
|
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
|
||||||
|
"qwen3.5": {Context: 262_144, Output: 32_768},
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupCloudModelLimit returns the token limits for a cloud model.
|
||||||
|
// It normalizes explicit cloud source suffixes before checking the shared limit map.
|
||||||
|
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
||||||
|
base, stripped := modelref.StripCloudSourceTag(name)
|
||||||
|
if stripped {
|
||||||
|
if l, ok := cloudModelLimits[base]; ok {
|
||||||
|
return l, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cloudModelLimit{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// missingModelPolicy controls how model-not-found errors should be handled.
|
||||||
|
type missingModelPolicy int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// missingModelPromptPull prompts the user to download missing local models.
|
||||||
|
missingModelPromptPull missingModelPolicy = iota
|
||||||
|
// missingModelAutoPull downloads missing local models without prompting.
|
||||||
|
missingModelAutoPull
|
||||||
|
// missingModelFail returns an error for missing local models without prompting.
|
||||||
|
missingModelFail
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenBrowser opens the URL in the user's browser.
|
||||||
|
func OpenBrowser(url string) {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
_ = exec.Command("open", url).Start()
|
||||||
|
case "linux":
|
||||||
|
// Skip on headless systems where no display server is available
|
||||||
|
if os.Getenv("DISPLAY") == "" && os.Getenv("WAYLAND_DISPLAY") == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = exec.Command("xdg-open", url).Start()
|
||||||
|
case "windows":
|
||||||
|
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureAuth ensures the user is signed in before cloud-backed models run.
|
||||||
|
func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error {
|
||||||
|
var selectedCloudModels []string
|
||||||
|
for _, m := range selected {
|
||||||
|
if cloudModels[m] {
|
||||||
|
selectedCloudModels = append(selectedCloudModels, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(selectedCloudModels) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
|
||||||
|
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := client.Whoami(ctx)
|
||||||
|
if err == nil && user != nil && user.Name != "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var aErr api.AuthorizationError
|
||||||
|
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
modelList := strings.Join(selectedCloudModels, ", ")
|
||||||
|
|
||||||
|
if DefaultSignIn != nil {
|
||||||
|
_, err := DefaultSignIn(modelList, aErr.SigninURL)
|
||||||
|
if errors.Is(err, ErrCancelled) {
|
||||||
|
return ErrCancelled
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%s requires sign in", modelList)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
yes, err := ConfirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||||
|
if errors.Is(err, ErrCancelled) {
|
||||||
|
return ErrCancelled
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !yes {
|
||||||
|
return ErrCancelled
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||||
|
OpenBrowser(aErr.SigninURL)
|
||||||
|
|
||||||
|
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||||
|
frame := 0
|
||||||
|
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||||
|
|
||||||
|
ticker := time.NewTicker(200 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||||
|
return ctx.Err()
|
||||||
|
case <-ticker.C:
|
||||||
|
frame++
|
||||||
|
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||||
|
|
||||||
|
if frame%10 == 0 {
|
||||||
|
u, err := client.Whoami(ctx)
|
||||||
|
if err == nil && u != nil && u.Name != "" {
|
||||||
|
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// showOrPullWithPolicy checks if a model exists and applies the provided missing-model policy.
|
||||||
|
func showOrPullWithPolicy(ctx context.Context, client *api.Client, model string, policy missingModelPolicy, isCloudModel bool) error {
|
||||||
|
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
var statusErr api.StatusError
|
||||||
|
if !errors.As(err, &statusErr) || statusErr.StatusCode != http.StatusNotFound {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isCloudModel {
|
||||||
|
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
|
||||||
|
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
|
||||||
|
}
|
||||||
|
return fmt.Errorf("model %q not found", model)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch policy {
|
||||||
|
case missingModelAutoPull:
|
||||||
|
return pullMissingModel(ctx, client, model)
|
||||||
|
case missingModelFail:
|
||||||
|
return fmt.Errorf("model %q not found; run 'ollama pull %s' first, or use --yes to auto-pull", model, model)
|
||||||
|
default:
|
||||||
|
return confirmAndPull(ctx, client, model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func confirmAndPull(ctx context.Context, client *api.Client, model string) error {
|
||||||
|
if ok, err := ConfirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
|
||||||
|
return err
|
||||||
|
} else if !ok {
|
||||||
|
return errCancelled
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "\n")
|
||||||
|
return pullMissingModel(ctx, client, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func pullMissingModel(ctx context.Context, client *api.Client, model string) error {
|
||||||
|
if err := pullModel(ctx, client, model, false); err != nil {
|
||||||
|
return fmt.Errorf("failed to pull %s: %w", model, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepareEditorIntegration persists models and applies editor-managed config files.
|
||||||
|
func prepareEditorIntegration(name string, runner Runner, editor Editor, models []string) error {
|
||||||
|
if ok, err := confirmEditorEdit(runner, editor); err != nil {
|
||||||
|
return err
|
||||||
|
} else if !ok {
|
||||||
|
return errCancelled
|
||||||
|
}
|
||||||
|
if err := editor.Edit(models); err != nil {
|
||||||
|
return fmt.Errorf("setup failed: %w", err)
|
||||||
|
}
|
||||||
|
if err := config.SaveIntegration(name, models); err != nil {
|
||||||
|
return fmt.Errorf("failed to save: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func confirmEditorEdit(runner Runner, editor Editor) (bool, error) {
|
||||||
|
paths := editor.Paths()
|
||||||
|
if len(paths) == 0 {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", runner)
|
||||||
|
for _, path := range paths {
|
||||||
|
fmt.Fprintf(os.Stderr, " %s\n", path)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", fileutil.BackupDir())
|
||||||
|
|
||||||
|
return ConfirmPrompt("Proceed?")
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildModelList merges existing models with recommendations for selection UIs.
|
||||||
|
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
|
||||||
|
existingModels = make(map[string]bool)
|
||||||
|
cloudModels = make(map[string]bool)
|
||||||
|
recommended := make(map[string]bool)
|
||||||
|
var hasLocalModel, hasCloudModel bool
|
||||||
|
|
||||||
|
recDesc := make(map[string]string)
|
||||||
|
for _, rec := range recommendedModels {
|
||||||
|
recommended[rec.Name] = true
|
||||||
|
recDesc[rec.Name] = rec.Description
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range existing {
|
||||||
|
existingModels[m.Name] = true
|
||||||
|
if m.Remote {
|
||||||
|
cloudModels[m.Name] = true
|
||||||
|
hasCloudModel = true
|
||||||
|
} else {
|
||||||
|
hasLocalModel = true
|
||||||
|
}
|
||||||
|
displayName := strings.TrimSuffix(m.Name, ":latest")
|
||||||
|
existingModels[displayName] = true
|
||||||
|
item := ModelItem{Name: displayName, Recommended: recommended[displayName], Description: recDesc[displayName]}
|
||||||
|
items = append(items, item)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rec := range recommendedModels {
|
||||||
|
if existingModels[rec.Name] || existingModels[rec.Name+":latest"] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
items = append(items, rec)
|
||||||
|
if isCloudModelName(rec.Name) {
|
||||||
|
cloudModels[rec.Name] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
checked := make(map[string]bool, len(preChecked))
|
||||||
|
for _, n := range preChecked {
|
||||||
|
checked[n] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if current != "" {
|
||||||
|
matchedCurrent := false
|
||||||
|
for _, item := range items {
|
||||||
|
if item.Name == current {
|
||||||
|
current = item.Name
|
||||||
|
matchedCurrent = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !matchedCurrent {
|
||||||
|
for _, item := range items {
|
||||||
|
if strings.HasPrefix(item.Name, current+":") {
|
||||||
|
current = item.Name
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if checked[current] {
|
||||||
|
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
|
||||||
|
}
|
||||||
|
|
||||||
|
notInstalled := make(map[string]bool)
|
||||||
|
for i := range items {
|
||||||
|
if !existingModels[items[i].Name] && !cloudModels[items[i].Name] {
|
||||||
|
notInstalled[items[i].Name] = true
|
||||||
|
var parts []string
|
||||||
|
if items[i].Description != "" {
|
||||||
|
parts = append(parts, items[i].Description)
|
||||||
|
}
|
||||||
|
if vram := recommendedVRAM[items[i].Name]; vram != "" {
|
||||||
|
parts = append(parts, vram)
|
||||||
|
}
|
||||||
|
parts = append(parts, "(not downloaded)")
|
||||||
|
items[i].Description = strings.Join(parts, ", ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
recRank := make(map[string]int)
|
||||||
|
for i, rec := range recommendedModels {
|
||||||
|
recRank[rec.Name] = i + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
onlyLocal := hasLocalModel && !hasCloudModel
|
||||||
|
|
||||||
|
if hasLocalModel || hasCloudModel {
|
||||||
|
slices.SortStableFunc(items, func(a, b ModelItem) int {
|
||||||
|
ac, bc := checked[a.Name], checked[b.Name]
|
||||||
|
aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
|
||||||
|
aRec, bRec := recRank[a.Name] > 0, recRank[b.Name] > 0
|
||||||
|
aCloud, bCloud := cloudModels[a.Name], cloudModels[b.Name]
|
||||||
|
|
||||||
|
if ac != bc {
|
||||||
|
if ac {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if aRec != bRec {
|
||||||
|
if aRec {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if aRec && bRec {
|
||||||
|
if aCloud != bCloud {
|
||||||
|
if onlyLocal {
|
||||||
|
if aCloud {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if aCloud {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return recRank[a.Name] - recRank[b.Name]
|
||||||
|
}
|
||||||
|
// Among checked non-recommended items - put the default first
|
||||||
|
if ac && !aRec && current != "" {
|
||||||
|
aCurrent := a.Name == current
|
||||||
|
bCurrent := b.Name == current
|
||||||
|
if aCurrent != bCurrent {
|
||||||
|
if aCurrent {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if aNew != bNew {
|
||||||
|
if aNew {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return items, preChecked, existingModels, cloudModels
|
||||||
|
}
|
||||||
|
|
||||||
|
// isCloudModelName reports whether the model name has an explicit cloud source.
|
||||||
|
func isCloudModelName(name string) bool {
|
||||||
|
return modelref.HasExplicitCloudSource(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterCloudModels drops remote-only models from the given inventory.
|
||||||
|
func filterCloudModels(existing []modelInfo) []modelInfo {
|
||||||
|
filtered := existing[:0]
|
||||||
|
for _, m := range existing {
|
||||||
|
if !m.Remote {
|
||||||
|
filtered = append(filtered, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterCloudItems removes cloud models from selection items.
|
||||||
|
func filterCloudItems(items []ModelItem) []ModelItem {
|
||||||
|
filtered := items[:0]
|
||||||
|
for _, item := range items {
|
||||||
|
if !isCloudModelName(item.Name) {
|
||||||
|
filtered = append(filtered, item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
|
||||||
|
if client == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return resp.RemoteModel != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloudStatusDisabled returns whether cloud usage is currently disabled.
|
||||||
|
func cloudStatusDisabled(ctx context.Context, client *api.Client) (disabled bool, known bool) {
|
||||||
|
status, err := client.CloudStatusExperimental(ctx)
|
||||||
|
if err != nil {
|
||||||
|
var statusErr api.StatusError
|
||||||
|
if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
return status.Cloud.Disabled, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(parthsareen): this duplicates the pull progress UI in cmd.PullHandler.
|
||||||
|
// Move the shared pull rendering to a small utility once the package boundary settles.
|
||||||
|
func pullModel(ctx context.Context, client *api.Client, model string, insecure bool) error {
|
||||||
|
p := progress.NewProgress(os.Stderr)
|
||||||
|
defer p.Stop()
|
||||||
|
|
||||||
|
bars := make(map[string]*progress.Bar)
|
||||||
|
var status string
|
||||||
|
var spinner *progress.Spinner
|
||||||
|
|
||||||
|
fn := func(resp api.ProgressResponse) error {
|
||||||
|
if resp.Digest != "" {
|
||||||
|
if resp.Completed == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if spinner != nil {
|
||||||
|
spinner.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
bar, ok := bars[resp.Digest]
|
||||||
|
if !ok {
|
||||||
|
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
if isDigest {
|
||||||
|
name = name[:min(12, len(name))]
|
||||||
|
}
|
||||||
|
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
|
||||||
|
bars[resp.Digest] = bar
|
||||||
|
p.Add(resp.Digest, bar)
|
||||||
|
}
|
||||||
|
|
||||||
|
bar.Set(resp.Completed)
|
||||||
|
} else if status != resp.Status {
|
||||||
|
if spinner != nil {
|
||||||
|
spinner.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
status = resp.Status
|
||||||
|
spinner = progress.NewSpinner(status)
|
||||||
|
p.Add(status, spinner)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
request := api.PullRequest{Name: model, Insecure: insecure}
|
||||||
|
return client.Pull(ctx, &request, fn)
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -14,7 +14,10 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/mod/semver"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
@@ -24,6 +27,9 @@ const defaultGatewayPort = 18789
|
|||||||
// Bound model capability probing so launch/config cannot hang on slow/unreachable API calls.
|
// Bound model capability probing so launch/config cannot hang on slow/unreachable API calls.
|
||||||
var openclawModelShowTimeout = 5 * time.Second
|
var openclawModelShowTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
// openclawFreshInstall is set to true when ensureOpenclawInstalled performs an install
|
||||||
|
var openclawFreshInstall bool
|
||||||
|
|
||||||
type Openclaw struct{}
|
type Openclaw struct{}
|
||||||
|
|
||||||
func (c *Openclaw) String() string { return "OpenClaw" }
|
func (c *Openclaw) String() string { return "OpenClaw" }
|
||||||
@@ -34,10 +40,7 @@ func (c *Openclaw) Run(model string, args []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
firstLaunch := true
|
firstLaunch := !c.onboarded()
|
||||||
if integrationConfig, err := loadIntegration("openclaw"); err == nil {
|
|
||||||
firstLaunch = !integrationConfig.Onboarded
|
|
||||||
}
|
|
||||||
|
|
||||||
if firstLaunch {
|
if firstLaunch {
|
||||||
fmt.Fprintf(os.Stderr, "\n%sSecurity%s\n\n", ansiBold, ansiReset)
|
fmt.Fprintf(os.Stderr, "\n%sSecurity%s\n\n", ansiBold, ansiReset)
|
||||||
@@ -45,28 +48,46 @@ func (c *Openclaw) Run(model string, args []string) error {
|
|||||||
fmt.Fprintf(os.Stderr, " A bad prompt can trick it into doing unsafe things.\n\n")
|
fmt.Fprintf(os.Stderr, " A bad prompt can trick it into doing unsafe things.\n\n")
|
||||||
fmt.Fprintf(os.Stderr, "%s Learn more: https://docs.openclaw.ai/gateway/security%s\n\n", ansiGray, ansiReset)
|
fmt.Fprintf(os.Stderr, "%s Learn more: https://docs.openclaw.ai/gateway/security%s\n\n", ansiGray, ansiReset)
|
||||||
|
|
||||||
ok, err := confirmPrompt("I understand the risks. Continue?")
|
ok, err := ConfirmPrompt("I understand the risks. Continue?")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if !c.onboarded() {
|
// Ensure the latest version is installed before onboarding so we get
|
||||||
|
// the newest wizard flags (e.g. --auth-choice ollama).
|
||||||
|
if !openclawFreshInstall {
|
||||||
|
update := exec.Command(bin, "update")
|
||||||
|
update.Stdout = os.Stdout
|
||||||
|
update.Stderr = os.Stderr
|
||||||
|
_ = update.Run() // best-effort; continue even if update fails
|
||||||
|
}
|
||||||
|
|
||||||
fmt.Fprintf(os.Stderr, "\n%sSetting up OpenClaw with Ollama...%s\n", ansiGreen, ansiReset)
|
fmt.Fprintf(os.Stderr, "\n%sSetting up OpenClaw with Ollama...%s\n", ansiGreen, ansiReset)
|
||||||
fmt.Fprintf(os.Stderr, "%s Model: %s%s\n\n", ansiGray, model, ansiReset)
|
fmt.Fprintf(os.Stderr, "%s Model: %s%s\n\n", ansiGray, model, ansiReset)
|
||||||
|
|
||||||
cmd := exec.Command(bin, "onboard",
|
onboardArgs := []string{
|
||||||
|
"onboard",
|
||||||
"--non-interactive",
|
"--non-interactive",
|
||||||
"--accept-risk",
|
"--accept-risk",
|
||||||
"--auth-choice", "skip",
|
"--auth-choice", "ollama",
|
||||||
"--gateway-token", "ollama",
|
"--custom-base-url", envconfig.Host().String(),
|
||||||
"--install-daemon",
|
"--custom-model-id", model,
|
||||||
"--skip-channels",
|
"--skip-channels",
|
||||||
"--skip-skills",
|
"--skip-skills",
|
||||||
)
|
}
|
||||||
|
if canInstallDaemon() {
|
||||||
|
onboardArgs = append(onboardArgs, "--install-daemon")
|
||||||
|
} else {
|
||||||
|
// When we can't install a daemon (e.g. no systemd, sudo dropped
|
||||||
|
// XDG_RUNTIME_DIR, or container environment), skip the gateway
|
||||||
|
// health check so non-interactive onboarding completes. The
|
||||||
|
// gateway is started as a foreground child process after onboarding.
|
||||||
|
onboardArgs = append(onboardArgs, "--skip-health")
|
||||||
|
}
|
||||||
|
cmd := exec.Command(bin, onboardArgs...)
|
||||||
cmd.Stdin = os.Stdin
|
cmd.Stdin = os.Stdin
|
||||||
cmd.Stdout = os.Stdout
|
cmd.Stdout = os.Stdout
|
||||||
cmd.Stderr = os.Stderr
|
cmd.Stderr = os.Stderr
|
||||||
@@ -75,24 +96,10 @@ func (c *Openclaw) Run(model string, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
patchDeviceScopes()
|
patchDeviceScopes()
|
||||||
|
|
||||||
// Onboarding overwrites openclaw.json, so re-apply the model config
|
|
||||||
// that Edit() wrote before Run() was called.
|
|
||||||
if err := c.Edit([]string{model}); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "%s Warning: could not re-apply model config: %v%s\n", ansiYellow, err, ansiReset)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasSuffix(model, ":cloud") || strings.HasSuffix(model, "-cloud") {
|
if ensureWebSearchPlugin() {
|
||||||
if ensureWebSearchPlugin() {
|
registerWebSearchPlugin()
|
||||||
registerWebSearchPlugin()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if firstLaunch {
|
|
||||||
fmt.Fprintf(os.Stderr, "\n%sPreparing your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
|
|
||||||
} else {
|
|
||||||
fmt.Fprintf(os.Stderr, "\n%sStarting your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// When extra args are passed through, run exactly what the user asked for
|
// When extra args are passed through, run exactly what the user asked for
|
||||||
@@ -106,19 +113,23 @@ func (c *Openclaw) Run(model string, args []string) error {
|
|||||||
if err := cmd.Run(); err != nil {
|
if err := cmd.Run(); err != nil {
|
||||||
return windowsHint(err)
|
return windowsHint(err)
|
||||||
}
|
}
|
||||||
if firstLaunch {
|
|
||||||
if err := integrationOnboarded("openclaw"); err != nil {
|
|
||||||
return fmt.Errorf("failed to save onboarding state: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := c.runChannelSetupPreflight(bin); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Keep local pairing scopes up to date before the gateway lifecycle
|
||||||
|
// (restart/start) regardless of channel preflight branch behavior.
|
||||||
|
patchDeviceScopes()
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "\n%sStarting your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
|
||||||
|
|
||||||
token, port := c.gatewayInfo()
|
token, port := c.gatewayInfo()
|
||||||
addr := fmt.Sprintf("localhost:%d", port)
|
addr := fmt.Sprintf("localhost:%d", port)
|
||||||
|
|
||||||
// If the gateway is already running (e.g. via the daemon), restart it
|
// If the gateway is already running (e.g. via the daemon), restart it
|
||||||
// so it picks up any config changes from Edit() above (model, provider, etc.).
|
// so it picks up any config changes (model, provider, etc.).
|
||||||
if portOpen(addr) {
|
if portOpen(addr) {
|
||||||
restart := exec.Command(bin, "daemon", "restart")
|
restart := exec.Command(bin, "daemon", "restart")
|
||||||
restart.Env = openclawEnv()
|
restart.Env = openclawEnv()
|
||||||
@@ -165,12 +176,95 @@ func (c *Openclaw) Run(model string, args []string) error {
|
|||||||
return windowsHint(err)
|
return windowsHint(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if firstLaunch {
|
return nil
|
||||||
if err := integrationOnboarded("openclaw"); err != nil {
|
}
|
||||||
return fmt.Errorf("failed to save onboarding state: %w", err)
|
|
||||||
|
// runChannelSetupPreflight prompts users to connect a messaging channel before
|
||||||
|
// starting the built-in gateway+TUI flow. In interactive sessions, it loops
|
||||||
|
// until a channel is configured, unless the user chooses "Set up later".
|
||||||
|
func (c *Openclaw) runChannelSetupPreflight(bin string) error {
|
||||||
|
if !isInteractiveSession() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// --yes is headless; channel setup spawns an interactive picker we can't
|
||||||
|
// auto-answer, so skip it. Users can run `openclaw channels add` later.
|
||||||
|
if currentLaunchConfirmPolicy.yes {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
if c.channelsConfigured() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "\nYour assistant can message you on WhatsApp, Telegram, Discord, and more.\n\n")
|
||||||
|
ok, err := ConfirmPromptWithOptions("Connect a channel (messaging app) now?", ConfirmOptions{
|
||||||
|
YesLabel: "Yes",
|
||||||
|
NoLabel: "Set up later",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command(bin, "channels", "add")
|
||||||
|
cmd.Env = openclawEnv()
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
return windowsHint(fmt.Errorf("openclaw channel setup failed: %w\n\nTry running: %s channels add", err, bin))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
}
|
||||||
|
|
||||||
|
// channelsConfigured reports whether local OpenClaw config contains at least
|
||||||
|
// one meaningfully configured channel entry.
|
||||||
|
func (c *Openclaw) channelsConfigured() bool {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, path := range []string{
|
||||||
|
filepath.Join(home, ".openclaw", "openclaw.json"),
|
||||||
|
filepath.Join(home, ".clawdbot", "clawdbot.json"),
|
||||||
|
} {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg map[string]any
|
||||||
|
if json.Unmarshal(data, &cfg) != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
channels, _ := cfg["channels"].(map[string]any)
|
||||||
|
if channels == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range channels {
|
||||||
|
if key == "defaults" || key == "modelByChannel" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
entry, ok := value.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for entryKey := range entry {
|
||||||
|
if entryKey != "enabled" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// gatewayInfo reads the gateway auth token and port from the OpenClaw config.
|
// gatewayInfo reads the gateway auth token and port from the OpenClaw config.
|
||||||
@@ -219,12 +313,9 @@ func printOpenclawReady(bin, token string, port int, firstLaunch bool) {
|
|||||||
if firstLaunch {
|
if firstLaunch {
|
||||||
fmt.Fprintf(os.Stderr, "%s Quick start:%s\n", ansiBold, ansiReset)
|
fmt.Fprintf(os.Stderr, "%s Quick start:%s\n", ansiBold, ansiReset)
|
||||||
fmt.Fprintf(os.Stderr, "%s /help see all commands%s\n", ansiGray, ansiReset)
|
fmt.Fprintf(os.Stderr, "%s /help see all commands%s\n", ansiGray, ansiReset)
|
||||||
fmt.Fprintf(os.Stderr, "%s %s configure --section channels connect WhatsApp, Telegram, etc.%s\n", ansiGray, bin, ansiReset)
|
|
||||||
fmt.Fprintf(os.Stderr, "%s %s skills browse and install skills%s\n\n", ansiGray, bin, ansiReset)
|
fmt.Fprintf(os.Stderr, "%s %s skills browse and install skills%s\n\n", ansiGray, bin, ansiReset)
|
||||||
fmt.Fprintf(os.Stderr, "%s The OpenClaw gateway is running in the background.%s\n", ansiYellow, ansiReset)
|
fmt.Fprintf(os.Stderr, "%s The OpenClaw gateway is running in the background.%s\n", ansiYellow, ansiReset)
|
||||||
fmt.Fprintf(os.Stderr, "%s Stop it with: %s gateway stop%s\n\n", ansiYellow, bin, ansiReset)
|
fmt.Fprintf(os.Stderr, "%s Stop it with: %s gateway stop%s\n\n", ansiYellow, bin, ansiReset)
|
||||||
} else {
|
|
||||||
fmt.Fprintf(os.Stderr, "%sTip: connect WhatsApp, Telegram, and more with: %s configure --section channels%s\n", ansiGray, bin, ansiReset)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -313,9 +404,10 @@ func (c *Openclaw) onboarded() bool {
|
|||||||
return lastRunAt != ""
|
return lastRunAt != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// patchDeviceScopes upgrades the local CLI device's paired scopes to include
|
// patchDeviceScopes upgrades the local CLI device's paired operator scopes so
|
||||||
// operator.admin. Only patches the local device, not remote ones.
|
// newer gateway auth baselines (approvedScopes) allow launch+TUI reconnects
|
||||||
// Best-effort: silently returns on any error.
|
// without forcing an interactive re-pair. Only patches the local device,
|
||||||
|
// not remote ones. Best-effort: silently returns on any error.
|
||||||
func patchDeviceScopes() {
|
func patchDeviceScopes() {
|
||||||
home, err := os.UserHomeDir()
|
home, err := os.UserHomeDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -351,9 +443,15 @@ func patchDeviceScopes() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
changed := patchScopes(dev, "scopes", required)
|
changed := patchScopes(dev, "scopes", required)
|
||||||
|
if patchScopes(dev, "approvedScopes", required) {
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
if tokens, ok := dev["tokens"].(map[string]any); ok {
|
if tokens, ok := dev["tokens"].(map[string]any); ok {
|
||||||
for _, tok := range tokens {
|
for role, tok := range tokens {
|
||||||
if tokenMap, ok := tok.(map[string]any); ok {
|
if tokenMap, ok := tok.(map[string]any); ok {
|
||||||
|
if !isOperatorToken(role, tokenMap) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if patchScopes(tokenMap, "scopes", required) {
|
if patchScopes(tokenMap, "scopes", required) {
|
||||||
changed = true
|
changed = true
|
||||||
}
|
}
|
||||||
@@ -409,6 +507,33 @@ func patchScopes(obj map[string]any, key string, required []string) bool {
|
|||||||
return added
|
return added
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isOperatorToken(tokenRole string, token map[string]any) bool {
|
||||||
|
if strings.EqualFold(strings.TrimSpace(tokenRole), "operator") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
role, _ := token["role"].(string)
|
||||||
|
return strings.EqualFold(strings.TrimSpace(role), "operator")
|
||||||
|
}
|
||||||
|
|
||||||
|
// canInstallDaemon reports whether the openclaw daemon can be installed as a
|
||||||
|
// background service. Returns false on Linux when systemd is absent (e.g.
|
||||||
|
// containers) so that --install-daemon is omitted and the gateway is started
|
||||||
|
// as a foreground child process instead. Returns true in all other cases.
|
||||||
|
func canInstallDaemon() bool {
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// /run/systemd/system exists as a directory when systemd is the init system.
|
||||||
|
// This is absent in most containers.
|
||||||
|
fi, err := os.Stat("/run/systemd/system")
|
||||||
|
if err != nil || !fi.IsDir() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Even when systemd is the init system, user services require a user
|
||||||
|
// manager instance. XDG_RUNTIME_DIR being set is a prerequisite.
|
||||||
|
return os.Getenv("XDG_RUNTIME_DIR") != ""
|
||||||
|
}
|
||||||
|
|
||||||
func ensureOpenclawInstalled() (string, error) {
|
func ensureOpenclawInstalled() (string, error) {
|
||||||
if _, err := exec.LookPath("openclaw"); err == nil {
|
if _, err := exec.LookPath("openclaw"); err == nil {
|
||||||
return "openclaw", nil
|
return "openclaw", nil
|
||||||
@@ -417,16 +542,20 @@ func ensureOpenclawInstalled() (string, error) {
|
|||||||
return "clawdbot", nil
|
return "clawdbot", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := exec.LookPath("npm"); err != nil {
|
_, npmErr := exec.LookPath("npm")
|
||||||
return "", fmt.Errorf("openclaw is not installed and npm was not found\n\n" +
|
_, gitErr := exec.LookPath("git")
|
||||||
"Install Node.js first:\n" +
|
if npmErr != nil || gitErr != nil {
|
||||||
" https://nodejs.org/\n\n" +
|
var missing []string
|
||||||
"Then rerun:\n" +
|
if npmErr != nil {
|
||||||
" ollama launch\n" +
|
missing = append(missing, "npm (Node.js): https://nodejs.org/")
|
||||||
"and select OpenClaw")
|
}
|
||||||
|
if gitErr != nil {
|
||||||
|
missing = append(missing, "git: https://git-scm.com/")
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("OpenClaw is not installed and required dependencies are missing\n\nInstall the following first:\n %s\n\nThen re-run:\n ollama launch openclaw", strings.Join(missing, "\n "))
|
||||||
}
|
}
|
||||||
|
|
||||||
ok, err := confirmPrompt("OpenClaw is not installed. Install with npm?")
|
ok, err := ConfirmPrompt("OpenClaw is not installed. Install with npm?")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -448,6 +577,7 @@ func ensureOpenclawInstalled() (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintf(os.Stderr, "%sOpenClaw installed successfully%s\n\n", ansiGreen, ansiReset)
|
fmt.Fprintf(os.Stderr, "%sOpenClaw installed successfully%s\n\n", ansiGreen, ansiReset)
|
||||||
|
openclawFreshInstall = true
|
||||||
return "openclaw", nil
|
return "openclaw", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -502,7 +632,7 @@ func (c *Openclaw) Edit(models []string) error {
|
|||||||
ollama = make(map[string]any)
|
ollama = make(map[string]any)
|
||||||
}
|
}
|
||||||
|
|
||||||
ollama["baseUrl"] = envconfig.Host().String() + "/v1"
|
ollama["baseUrl"] = envconfig.Host().String()
|
||||||
// needed to register provider
|
// needed to register provider
|
||||||
ollama["apiKey"] = "ollama-local"
|
ollama["apiKey"] = "ollama-local"
|
||||||
ollama["api"] = "ollama"
|
ollama["api"] = "ollama"
|
||||||
@@ -561,7 +691,7 @@ func (c *Openclaw) Edit(models []string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := writeWithBackup(configPath, data); err != nil {
|
if err := fileutil.WriteWithBackup(configPath, data); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -592,6 +722,8 @@ func clearSessionModelOverride(primary string) {
|
|||||||
if override, _ := sess["modelOverride"].(string); override != "" && override != primary {
|
if override, _ := sess["modelOverride"].(string); override != "" && override != primary {
|
||||||
delete(sess, "modelOverride")
|
delete(sess, "modelOverride")
|
||||||
delete(sess, "providerOverride")
|
delete(sess, "providerOverride")
|
||||||
|
}
|
||||||
|
if model, _ := sess["model"].(string); model != "" && model != primary {
|
||||||
sess["model"] = primary
|
sess["model"] = primary
|
||||||
changed = true
|
changed = true
|
||||||
}
|
}
|
||||||
@@ -606,11 +738,15 @@ func clearSessionModelOverride(primary string) {
|
|||||||
_ = os.WriteFile(path, out, 0o600)
|
_ = os.WriteFile(path, out, 0o600)
|
||||||
}
|
}
|
||||||
|
|
||||||
const webSearchNpmPackage = "@ollama/openclaw-web-search"
|
const (
|
||||||
|
webSearchNpmPackage = "@ollama/openclaw-web-search"
|
||||||
|
webSearchMinVersion = "0.2.1"
|
||||||
|
)
|
||||||
|
|
||||||
// ensureWebSearchPlugin installs the openclaw-web-search extension into the
|
// ensureWebSearchPlugin installs the openclaw-web-search extension into the
|
||||||
// user-level extensions directory (~/.openclaw/extensions/) if it isn't already
|
// user-level extensions directory (~/.openclaw/extensions/) if it isn't already
|
||||||
// present. Returns true if the extension is available.
|
// present, or re-installs if the installed version is older than webSearchMinVersion.
|
||||||
|
// Returns true if the extension is available.
|
||||||
func ensureWebSearchPlugin() bool {
|
func ensureWebSearchPlugin() bool {
|
||||||
home, err := os.UserHomeDir()
|
home, err := os.UserHomeDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -618,8 +754,8 @@ func ensureWebSearchPlugin() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pluginDir := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
|
pluginDir := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
|
||||||
if _, err := os.Stat(filepath.Join(pluginDir, "index.ts")); err == nil {
|
if webSearchPluginUpToDate(pluginDir) {
|
||||||
return true // already installed
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
npmBin, err := exec.LookPath("npm")
|
npmBin, err := exec.LookPath("npm")
|
||||||
@@ -649,10 +785,38 @@ func ensureWebSearchPlugin() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintf(os.Stderr, "%s ✓ Installed web search plugin%s\n", ansiGreen, ansiReset)
|
fmt.Fprintf(os.Stderr, "%s ✓ Installed Ollama web search %s\n", ansiGreen, ansiReset)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// webSearchPluginUpToDate returns true if the plugin is installed and its
|
||||||
|
// package.json version is >= webSearchMinVersion.
|
||||||
|
func webSearchPluginUpToDate(pluginDir string) bool {
|
||||||
|
data, err := os.ReadFile(filepath.Join(pluginDir, "package.json"))
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
var pkg struct {
|
||||||
|
Version string `json:"version"`
|
||||||
|
}
|
||||||
|
if json.Unmarshal(data, &pkg) != nil || pkg.Version == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return !versionLessThan(pkg.Version, webSearchMinVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
// versionLessThan compares two semver version strings (major.minor.patch).
|
||||||
|
// Inputs may omit the "v" prefix; it is added automatically for semver.Compare.
|
||||||
|
func versionLessThan(a, b string) bool {
|
||||||
|
if !strings.HasPrefix(a, "v") {
|
||||||
|
a = "v" + a
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(b, "v") {
|
||||||
|
b = "v" + b
|
||||||
|
}
|
||||||
|
return semver.Compare(a, b) < 0
|
||||||
|
}
|
||||||
|
|
||||||
// registerWebSearchPlugin adds plugins.entries.openclaw-web-search to the OpenClaw
|
// registerWebSearchPlugin adds plugins.entries.openclaw-web-search to the OpenClaw
|
||||||
// config so the gateway activates it on next start. Best-effort; silently returns
|
// config so the gateway activates it on next start. Best-effort; silently returns
|
||||||
// on any error.
|
// on any error.
|
||||||
@@ -679,23 +843,67 @@ func registerWebSearchPlugin() {
|
|||||||
if entries == nil {
|
if entries == nil {
|
||||||
entries = make(map[string]any)
|
entries = make(map[string]any)
|
||||||
}
|
}
|
||||||
if _, ok := entries["openclaw-web-search"]; ok {
|
|
||||||
return // already registered
|
|
||||||
}
|
|
||||||
entries["openclaw-web-search"] = map[string]any{"enabled": true}
|
entries["openclaw-web-search"] = map[string]any{"enabled": true}
|
||||||
plugins["entries"] = entries
|
plugins["entries"] = entries
|
||||||
|
|
||||||
|
// Pin trust so the gateway doesn't warn about untracked plugins.
|
||||||
|
allow, _ := plugins["allow"].([]any)
|
||||||
|
hasAllow := false
|
||||||
|
for _, v := range allow {
|
||||||
|
if s, ok := v.(string); ok && s == "openclaw-web-search" {
|
||||||
|
hasAllow = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasAllow {
|
||||||
|
allow = append(allow, "openclaw-web-search")
|
||||||
|
}
|
||||||
|
plugins["allow"] = allow
|
||||||
|
|
||||||
|
// Record install provenance so the loader can verify the plugin origin.
|
||||||
|
installs, _ := plugins["installs"].(map[string]any)
|
||||||
|
if installs == nil {
|
||||||
|
installs = make(map[string]any)
|
||||||
|
}
|
||||||
|
pluginDir := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
|
||||||
|
installs["openclaw-web-search"] = map[string]any{
|
||||||
|
"source": "npm",
|
||||||
|
"spec": webSearchNpmPackage,
|
||||||
|
"installPath": pluginDir,
|
||||||
|
}
|
||||||
|
plugins["installs"] = installs
|
||||||
|
|
||||||
config["plugins"] = plugins
|
config["plugins"] = plugins
|
||||||
|
|
||||||
// Disable the built-in web search since our plugin replaces it.
|
// Add plugin tools to tools.alsoAllow so they survive the coding profile's
|
||||||
|
// policy pipeline (which has an explicit allow list of core tools only).
|
||||||
tools, _ := config["tools"].(map[string]any)
|
tools, _ := config["tools"].(map[string]any)
|
||||||
if tools == nil {
|
if tools == nil {
|
||||||
tools = make(map[string]any)
|
tools = make(map[string]any)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
alsoAllow, _ := tools["alsoAllow"].([]any)
|
||||||
|
needed := []string{"ollama_web_search", "ollama_web_fetch"}
|
||||||
|
have := make(map[string]bool, len(alsoAllow))
|
||||||
|
for _, v := range alsoAllow {
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
have[s] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, name := range needed {
|
||||||
|
if !have[name] {
|
||||||
|
alsoAllow = append(alsoAllow, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tools["alsoAllow"] = alsoAllow
|
||||||
|
|
||||||
|
// Disable built-in web search/fetch since our plugin replaces them.
|
||||||
web, _ := tools["web"].(map[string]any)
|
web, _ := tools["web"].(map[string]any)
|
||||||
if web == nil {
|
if web == nil {
|
||||||
web = make(map[string]any)
|
web = make(map[string]any)
|
||||||
}
|
}
|
||||||
web["search"] = map[string]any{"enabled": false}
|
web["search"] = map[string]any{"enabled": false}
|
||||||
|
web["fetch"] = map[string]any{"enabled": false}
|
||||||
tools["web"] = web
|
tools["web"] = web
|
||||||
config["tools"] = tools
|
config["tools"] = tools
|
||||||
|
|
||||||
@@ -776,9 +984,9 @@ func (c *Openclaw) Models() []string {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := readJSONFile(filepath.Join(home, ".openclaw", "openclaw.json"))
|
config, err := fileutil.ReadJSON(filepath.Join(home, ".openclaw", "openclaw.json"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
config, err = readJSONFile(filepath.Join(home, ".clawdbot", "clawdbot.json"))
|
config, err = fileutil.ReadJSON(filepath.Join(home, ".clawdbot", "clawdbot.json"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
298
cmd/launch/opencode.go
Normal file
@@ -0,0 +1,298 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
modeltype "github.com/ollama/ollama/types/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenCode implements Runner and Editor for OpenCode integration.
|
||||||
|
// Config is passed via OPENCODE_CONFIG_CONTENT env var at launch time
|
||||||
|
// instead of writing to opencode's config files.
|
||||||
|
type OpenCode struct {
|
||||||
|
configContent string // JSON config built by Edit, passed to Run via env var
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OpenCode) String() string { return "OpenCode" }
|
||||||
|
|
||||||
|
// findOpenCode returns the opencode binary path, checking PATH first then the
|
||||||
|
// curl installer location (~/.opencode/bin) which may not be on PATH yet.
|
||||||
|
func findOpenCode() (string, bool) {
|
||||||
|
if p, err := exec.LookPath("opencode"); err == nil {
|
||||||
|
return p, true
|
||||||
|
}
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
name := "opencode"
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
name = "opencode.exe"
|
||||||
|
}
|
||||||
|
fallback := filepath.Join(home, ".opencode", "bin", name)
|
||||||
|
if _, err := os.Stat(fallback); err == nil {
|
||||||
|
return fallback, true
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OpenCode) Run(model string, args []string) error {
|
||||||
|
opencodePath, ok := findOpenCode()
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command(opencodePath, args...)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
cmd.Env = os.Environ()
|
||||||
|
if content := o.resolveContent(model); content != "" {
|
||||||
|
cmd.Env = append(cmd.Env, "OPENCODE_CONFIG_CONTENT="+content)
|
||||||
|
}
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveContent returns the inline config to send via OPENCODE_CONFIG_CONTENT.
|
||||||
|
// Returns content built by Edit if available, otherwise builds from model.json
|
||||||
|
// with the requested model as primary (e.g. re-launch with saved config).
|
||||||
|
func (o *OpenCode) resolveContent(model string) string {
|
||||||
|
if o.configContent != "" {
|
||||||
|
return o.configContent
|
||||||
|
}
|
||||||
|
models := readModelJSONModels()
|
||||||
|
if !slices.Contains(models, model) {
|
||||||
|
models = append([]string{model}, models...)
|
||||||
|
}
|
||||||
|
content, err := buildInlineConfig(model, models)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OpenCode) Paths() []string {
|
||||||
|
sp, err := openCodeStatePath()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(sp); err == nil {
|
||||||
|
return []string{sp}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// openCodeStatePath returns the path to opencode's model state file.
|
||||||
|
// TODO: this hardcodes the Linux/macOS XDG path. On Windows, opencode stores
|
||||||
|
// state under %LOCALAPPDATA% (or similar) — verify and branch on runtime.GOOS.
|
||||||
|
func openCodeStatePath() (string, error) {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return filepath.Join(home, ".local", "state", "opencode", "model.json"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OpenCode) Edit(modelList []string) error {
|
||||||
|
if len(modelList) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := buildInlineConfig(modelList[0], modelList)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
o.configContent = content
|
||||||
|
|
||||||
|
// Write model state file so models appear in OpenCode's model picker
|
||||||
|
statePath, err := openCodeStatePath()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(filepath.Dir(statePath), 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
state := map[string]any{
|
||||||
|
"recent": []any{},
|
||||||
|
"favorite": []any{},
|
||||||
|
"variant": map[string]any{},
|
||||||
|
}
|
||||||
|
if data, err := os.ReadFile(statePath); err == nil {
|
||||||
|
_ = json.Unmarshal(data, &state) // Ignore parse errors; use defaults
|
||||||
|
}
|
||||||
|
|
||||||
|
recent, _ := state["recent"].([]any)
|
||||||
|
|
||||||
|
modelSet := make(map[string]bool)
|
||||||
|
for _, m := range modelList {
|
||||||
|
modelSet[m] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter out existing Ollama models we're about to re-add
|
||||||
|
newRecent := slices.DeleteFunc(slices.Clone(recent), func(entry any) bool {
|
||||||
|
e, ok := entry.(map[string]any)
|
||||||
|
if !ok || e["providerID"] != "ollama" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
modelID, _ := e["modelID"].(string)
|
||||||
|
return modelSet[modelID]
|
||||||
|
})
|
||||||
|
|
||||||
|
// Prepend models in reverse order so first model ends up first
|
||||||
|
for _, model := range slices.Backward(modelList) {
|
||||||
|
newRecent = slices.Insert(newRecent, 0, any(map[string]any{
|
||||||
|
"providerID": "ollama",
|
||||||
|
"modelID": model,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxRecentModels = 10
|
||||||
|
newRecent = newRecent[:min(len(newRecent), maxRecentModels)]
|
||||||
|
|
||||||
|
state["recent"] = newRecent
|
||||||
|
|
||||||
|
stateData, err := json.MarshalIndent(state, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return fileutil.WriteWithBackup(statePath, stateData)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OpenCode) Models() []string {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildInlineConfig produces the JSON string for OPENCODE_CONFIG_CONTENT.
|
||||||
|
// primary is the model to launch with, models is the full list of available models.
|
||||||
|
func buildInlineConfig(primary string, models []string) (string, error) {
|
||||||
|
if primary == "" || len(models) == 0 {
|
||||||
|
return "", fmt.Errorf("buildInlineConfig: primary and models are required")
|
||||||
|
}
|
||||||
|
config := map[string]any{
|
||||||
|
"$schema": "https://opencode.ai/config.json",
|
||||||
|
"provider": map[string]any{
|
||||||
|
"ollama": map[string]any{
|
||||||
|
"npm": "@ai-sdk/openai-compatible",
|
||||||
|
"name": "Ollama",
|
||||||
|
"options": map[string]any{
|
||||||
|
"baseURL": envconfig.Host().String() + "/v1",
|
||||||
|
},
|
||||||
|
"models": buildModelEntries(models),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"model": "ollama/" + primary,
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(config)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// readModelJSONModels reads ollama model IDs from the opencode model.json state file
|
||||||
|
func readModelJSONModels() []string {
|
||||||
|
statePath, err := openCodeStatePath()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
data, err := os.ReadFile(statePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var state map[string]any
|
||||||
|
if err := json.Unmarshal(data, &state); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
recent, _ := state["recent"].([]any)
|
||||||
|
var models []string
|
||||||
|
for _, entry := range recent {
|
||||||
|
e, ok := entry.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if e["providerID"] != "ollama" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if id, ok := e["modelID"].(string); ok && id != "" {
|
||||||
|
models = append(models, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildModelEntries(modelList []string) map[string]any {
|
||||||
|
client := api.NewClient(envconfig.Host(), http.DefaultClient)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
models := make(map[string]any)
|
||||||
|
for _, model := range modelList {
|
||||||
|
entry := map[string]any{
|
||||||
|
"name": model,
|
||||||
|
}
|
||||||
|
if isCloudModelName(model) {
|
||||||
|
if l, ok := lookupCloudModelLimit(model); ok {
|
||||||
|
entry["limit"] = map[string]any{
|
||||||
|
"context": l.Context,
|
||||||
|
"output": l.Output,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
applyOpenCodeReasoning(ctx, client, model, entry)
|
||||||
|
models[model] = entry
|
||||||
|
}
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyOpenCodeReasoning detects thinking capability and sets reasoning config
|
||||||
|
// on the model entry. When the model supports thinking, it sets "reasoning": true
|
||||||
|
// and configures variants for the OpenCode TUI:
|
||||||
|
// - GPT-OSS: supports variable effort levels (low/medium/high) and defaults to
|
||||||
|
// medium via options. Thinking cannot be turned off.
|
||||||
|
// - Other models: only support on/off. Disables built-in low/medium/high variants
|
||||||
|
// and adds a "none" variant so users can toggle thinking off via Ctrl+T.
|
||||||
|
//
|
||||||
|
// When the model does not support thinking, no reasoning config is set.
|
||||||
|
func applyOpenCodeReasoning(ctx context.Context, client *api.Client, modelName string, entry map[string]any) {
|
||||||
|
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelName})
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if slices.Contains(resp.Capabilities, modeltype.CapabilityThinking) {
|
||||||
|
entry["reasoning"] = true
|
||||||
|
|
||||||
|
if strings.Contains(modelName, "gpt-oss") {
|
||||||
|
// GPT-OSS models support variable thinking effort levels
|
||||||
|
// and cannot turn thinking off. Keep the built-in
|
||||||
|
// low/medium/high variants as-is and default to medium.
|
||||||
|
options, ok := entry["options"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
options = make(map[string]any)
|
||||||
|
}
|
||||||
|
options["reasoningEffort"] = "medium"
|
||||||
|
entry["options"] = options
|
||||||
|
} else {
|
||||||
|
// Most models only support thinking on or off.
|
||||||
|
// Disable the built-in low/medium/high variants and add none.
|
||||||
|
entry["variants"] = map[string]any{
|
||||||
|
"none": map[string]any{"reasoningEffort": "none"},
|
||||||
|
"low": map[string]any{"disabled": true},
|
||||||
|
"medium": map[string]any{"disabled": true},
|
||||||
|
"high": map[string]any{"disabled": true},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
907
cmd/launch/opencode_test.go
Normal file
@@ -0,0 +1,907 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOpenCodeIntegration(t *testing.T) {
|
||||||
|
o := &OpenCode{}
|
||||||
|
|
||||||
|
t.Run("String", func(t *testing.T) {
|
||||||
|
if got := o.String(); got != "OpenCode" {
|
||||||
|
t.Errorf("String() = %q, want %q", got, "OpenCode")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("implements Runner", func(t *testing.T) {
|
||||||
|
var _ Runner = o
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("implements Editor", func(t *testing.T) {
|
||||||
|
var _ Editor = o
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenCodeEdit(t *testing.T) {
|
||||||
|
t.Run("builds config content with provider", func(t *testing.T) {
|
||||||
|
setTestHome(t, t.TempDir())
|
||||||
|
o := &OpenCode{}
|
||||||
|
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(o.configContent), &cfg); err != nil {
|
||||||
|
t.Fatalf("configContent is not valid JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify provider structure
|
||||||
|
provider, _ := cfg["provider"].(map[string]any)
|
||||||
|
ollama, _ := provider["ollama"].(map[string]any)
|
||||||
|
if ollama["name"] != "Ollama" {
|
||||||
|
t.Errorf("provider name = %v, want Ollama", ollama["name"])
|
||||||
|
}
|
||||||
|
if ollama["npm"] != "@ai-sdk/openai-compatible" {
|
||||||
|
t.Errorf("npm = %v, want @ai-sdk/openai-compatible", ollama["npm"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify model exists
|
||||||
|
models, _ := ollama["models"].(map[string]any)
|
||||||
|
if models["llama3.2"] == nil {
|
||||||
|
t.Error("model llama3.2 not found in config content")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify default model
|
||||||
|
if cfg["model"] != "ollama/llama3.2" {
|
||||||
|
t.Errorf("model = %v, want ollama/llama3.2", cfg["model"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple models", func(t *testing.T) {
|
||||||
|
setTestHome(t, t.TempDir())
|
||||||
|
o := &OpenCode{}
|
||||||
|
if err := o.Edit([]string{"llama3.2", "qwen3:32b"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg map[string]any
|
||||||
|
json.Unmarshal([]byte(o.configContent), &cfg)
|
||||||
|
provider, _ := cfg["provider"].(map[string]any)
|
||||||
|
ollama, _ := provider["ollama"].(map[string]any)
|
||||||
|
models, _ := ollama["models"].(map[string]any)
|
||||||
|
|
||||||
|
if models["llama3.2"] == nil {
|
||||||
|
t.Error("model llama3.2 not found")
|
||||||
|
}
|
||||||
|
if models["qwen3:32b"] == nil {
|
||||||
|
t.Error("model qwen3:32b not found")
|
||||||
|
}
|
||||||
|
// First model should be the default
|
||||||
|
if cfg["model"] != "ollama/llama3.2" {
|
||||||
|
t.Errorf("default model = %v, want ollama/llama3.2", cfg["model"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty models is no-op", func(t *testing.T) {
|
||||||
|
setTestHome(t, t.TempDir())
|
||||||
|
o := &OpenCode{}
|
||||||
|
if err := o.Edit([]string{}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if o.configContent != "" {
|
||||||
|
t.Errorf("expected empty configContent for no models, got %s", o.configContent)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("does not write config files", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
o := &OpenCode{}
|
||||||
|
o.Edit([]string{"llama3.2"})
|
||||||
|
|
||||||
|
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||||
|
|
||||||
|
if _, err := os.Stat(filepath.Join(configDir, "opencode.json")); !os.IsNotExist(err) {
|
||||||
|
t.Error("opencode.json should not be created")
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(filepath.Join(configDir, "opencode.jsonc")); !os.IsNotExist(err) {
|
||||||
|
t.Error("opencode.jsonc should not be created")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cloud model has limits", func(t *testing.T) {
|
||||||
|
setTestHome(t, t.TempDir())
|
||||||
|
o := &OpenCode{}
|
||||||
|
if err := o.Edit([]string{"glm-4.7:cloud"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg map[string]any
|
||||||
|
json.Unmarshal([]byte(o.configContent), &cfg)
|
||||||
|
provider, _ := cfg["provider"].(map[string]any)
|
||||||
|
ollama, _ := provider["ollama"].(map[string]any)
|
||||||
|
models, _ := ollama["models"].(map[string]any)
|
||||||
|
entry, _ := models["glm-4.7:cloud"].(map[string]any)
|
||||||
|
|
||||||
|
limit, ok := entry["limit"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("cloud model should have limit set")
|
||||||
|
}
|
||||||
|
expected := cloudModelLimits["glm-4.7"]
|
||||||
|
if limit["context"] != float64(expected.Context) {
|
||||||
|
t.Errorf("context = %v, want %d", limit["context"], expected.Context)
|
||||||
|
}
|
||||||
|
if limit["output"] != float64(expected.Output) {
|
||||||
|
t.Errorf("output = %v, want %d", limit["output"], expected.Output)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("local model has no limits", func(t *testing.T) {
|
||||||
|
setTestHome(t, t.TempDir())
|
||||||
|
o := &OpenCode{}
|
||||||
|
o.Edit([]string{"llama3.2"})
|
||||||
|
|
||||||
|
var cfg map[string]any
|
||||||
|
json.Unmarshal([]byte(o.configContent), &cfg)
|
||||||
|
provider, _ := cfg["provider"].(map[string]any)
|
||||||
|
ollama, _ := provider["ollama"].(map[string]any)
|
||||||
|
models, _ := ollama["models"].(map[string]any)
|
||||||
|
entry, _ := models["llama3.2"].(map[string]any)
|
||||||
|
|
||||||
|
if entry["limit"] != nil {
|
||||||
|
t.Errorf("local model should not have limit, got %v", entry["limit"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenCodeModels_ReturnsNil(t *testing.T) {
|
||||||
|
o := &OpenCode{}
|
||||||
|
if models := o.Models(); models != nil {
|
||||||
|
t.Errorf("Models() = %v, want nil", models)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenCodePaths(t *testing.T) {
|
||||||
|
t.Run("returns nil when model.json does not exist", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
o := &OpenCode{}
|
||||||
|
if paths := o.Paths(); paths != nil {
|
||||||
|
t.Errorf("Paths() = %v, want nil", paths)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns model.json path when it exists", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||||
|
os.MkdirAll(stateDir, 0o755)
|
||||||
|
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644)
|
||||||
|
|
||||||
|
o := &OpenCode{}
|
||||||
|
paths := o.Paths()
|
||||||
|
if len(paths) != 1 {
|
||||||
|
t.Fatalf("Paths() returned %d paths, want 1", len(paths))
|
||||||
|
}
|
||||||
|
if paths[0] != filepath.Join(stateDir, "model.json") {
|
||||||
|
t.Errorf("Paths() = %v, want %v", paths[0], filepath.Join(stateDir, "model.json"))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLookupCloudModelLimit(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
wantOK bool
|
||||||
|
wantContext int
|
||||||
|
wantOutput int
|
||||||
|
}{
|
||||||
|
{"glm-4.7", false, 0, 0},
|
||||||
|
{"glm-4.7:cloud", true, 202_752, 131_072},
|
||||||
|
{"glm-5:cloud", true, 202_752, 131_072},
|
||||||
|
{"glm-5.1:cloud", true, 202_752, 131_072},
|
||||||
|
{"gemma4:31b-cloud", true, 262_144, 131_072},
|
||||||
|
{"gpt-oss:120b-cloud", true, 131_072, 131_072},
|
||||||
|
{"gpt-oss:20b-cloud", true, 131_072, 131_072},
|
||||||
|
{"kimi-k2.5", false, 0, 0},
|
||||||
|
{"kimi-k2.5:cloud", true, 262_144, 262_144},
|
||||||
|
{"deepseek-v3.2", false, 0, 0},
|
||||||
|
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
|
||||||
|
{"qwen3.5", false, 0, 0},
|
||||||
|
{"qwen3.5:cloud", true, 262_144, 32_768},
|
||||||
|
{"qwen3-coder:480b", false, 0, 0},
|
||||||
|
{"qwen3-coder:480b:cloud", true, 262_144, 65_536},
|
||||||
|
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
|
||||||
|
{"llama3.2", false, 0, 0},
|
||||||
|
{"unknown-model:cloud", false, 0, 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
l, ok := lookupCloudModelLimit(tt.name)
|
||||||
|
if ok != tt.wantOK {
|
||||||
|
t.Errorf("lookupCloudModelLimit(%q) ok = %v, want %v", tt.name, ok, tt.wantOK)
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
if l.Context != tt.wantContext {
|
||||||
|
t.Errorf("context = %d, want %d", l.Context, tt.wantContext)
|
||||||
|
}
|
||||||
|
if l.Output != tt.wantOutput {
|
||||||
|
t.Errorf("output = %d, want %d", l.Output, tt.wantOutput)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// inlineConfigModel extracts a model entry from the inline config content.
|
||||||
|
func inlineConfigModel(t *testing.T, content, model string) map[string]any {
|
||||||
|
t.Helper()
|
||||||
|
var cfg map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(content), &cfg); err != nil {
|
||||||
|
t.Fatalf("configContent is not valid JSON: %v", err)
|
||||||
|
}
|
||||||
|
provider, _ := cfg["provider"].(map[string]any)
|
||||||
|
ollama, _ := provider["ollama"].(map[string]any)
|
||||||
|
models, _ := ollama["models"].(map[string]any)
|
||||||
|
entry, ok := models[model].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("model %s not found in inline config", model)
|
||||||
|
}
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenCodeEdit_ReasoningOnThinkingModel(t *testing.T) {
|
||||||
|
setTestHome(t, t.TempDir())
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/api/show" {
|
||||||
|
fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{}}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
o := &OpenCode{}
|
||||||
|
if err := o.Edit([]string{"qwq"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := inlineConfigModel(t, o.configContent, "qwq")
|
||||||
|
if entry["reasoning"] != true {
|
||||||
|
t.Error("expected reasoning = true for thinking model")
|
||||||
|
}
|
||||||
|
variants, ok := entry["variants"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected variants to be set")
|
||||||
|
}
|
||||||
|
none, ok := variants["none"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected none variant to be set")
|
||||||
|
}
|
||||||
|
if none["reasoningEffort"] != "none" {
|
||||||
|
t.Errorf("none variant reasoningEffort = %v, want none", none["reasoningEffort"])
|
||||||
|
}
|
||||||
|
// Built-in low/medium/high should be disabled
|
||||||
|
for _, level := range []string{"low", "medium", "high"} {
|
||||||
|
v, ok := variants[level].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("expected %s variant to exist", level)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if v["disabled"] != true {
|
||||||
|
t.Errorf("expected %s variant to be disabled", level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenCodeEdit_ReasoningLevelsOnGptOss(t *testing.T) {
|
||||||
|
setTestHome(t, t.TempDir())
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/api/show" {
|
||||||
|
fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{}}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
o := &OpenCode{}
|
||||||
|
if err := o.Edit([]string{"gpt-oss:120b-cloud"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := inlineConfigModel(t, o.configContent, "gpt-oss:120b-cloud")
|
||||||
|
if entry["reasoning"] != true {
|
||||||
|
t.Error("expected reasoning = true")
|
||||||
|
}
|
||||||
|
// GPT-OSS cannot turn thinking off and supports levels,
|
||||||
|
// so no custom variants should be written.
|
||||||
|
if entry["variants"] != nil {
|
||||||
|
t.Errorf("expected no variants for gpt-oss, got %v", entry["variants"])
|
||||||
|
}
|
||||||
|
// Should default to medium reasoning effort
|
||||||
|
opts, ok := entry["options"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected options to be set for gpt-oss")
|
||||||
|
}
|
||||||
|
if opts["reasoningEffort"] != "medium" {
|
||||||
|
t.Errorf("reasoningEffort = %v, want medium", opts["reasoningEffort"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenCodeEdit_NoReasoningOnNonThinkingModel(t *testing.T) {
|
||||||
|
setTestHome(t, t.TempDir())
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/api/show" {
|
||||||
|
fmt.Fprintf(w, `{"capabilities":[],"model_info":{}}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
o := &OpenCode{}
|
||||||
|
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := inlineConfigModel(t, o.configContent, "llama3.2")
|
||||||
|
if entry["reasoning"] != nil {
|
||||||
|
t.Errorf("expected no reasoning for non-thinking model, got %v", entry["reasoning"])
|
||||||
|
}
|
||||||
|
if entry["variants"] != nil {
|
||||||
|
t.Errorf("expected no variants for non-thinking model, got %v", entry["variants"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindOpenCode(t *testing.T) {
|
||||||
|
t.Run("fallback to ~/.opencode/bin", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
// Ensure opencode is not on PATH
|
||||||
|
t.Setenv("PATH", tmpDir)
|
||||||
|
|
||||||
|
// Without the fallback binary, findOpenCode should fail
|
||||||
|
if _, ok := findOpenCode(); ok {
|
||||||
|
t.Fatal("findOpenCode should fail when binary is not on PATH or in fallback location")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a fake binary at the curl install fallback location
|
||||||
|
binDir := filepath.Join(tmpDir, ".opencode", "bin")
|
||||||
|
os.MkdirAll(binDir, 0o755)
|
||||||
|
name := "opencode"
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
name = "opencode.exe"
|
||||||
|
}
|
||||||
|
fakeBin := filepath.Join(binDir, name)
|
||||||
|
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
|
||||||
|
|
||||||
|
// Now findOpenCode should succeed via fallback
|
||||||
|
path, ok := findOpenCode()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("findOpenCode should succeed with fallback binary")
|
||||||
|
}
|
||||||
|
if path != fakeBin {
|
||||||
|
t.Errorf("findOpenCode = %q, want %q", path, fakeBin)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that the BackfillsCloudModelLimitOnExistingEntry test from the old
|
||||||
|
// file-based approach is covered by the new inline config approach.
|
||||||
|
func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
|
||||||
|
o := &OpenCode{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
expected := cloudModelLimits["glm-4.7"]
|
||||||
|
|
||||||
|
if err := o.Edit([]string{"glm-4.7:cloud"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg map[string]any
|
||||||
|
json.Unmarshal([]byte(o.configContent), &cfg)
|
||||||
|
provider, _ := cfg["provider"].(map[string]any)
|
||||||
|
ollama, _ := provider["ollama"].(map[string]any)
|
||||||
|
models, _ := ollama["models"].(map[string]any)
|
||||||
|
entry, _ := models["glm-4.7:cloud"].(map[string]any)
|
||||||
|
|
||||||
|
limit, ok := entry["limit"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("cloud model limit was not set")
|
||||||
|
}
|
||||||
|
if limit["context"] != float64(expected.Context) {
|
||||||
|
t.Errorf("context = %v, want %d", limit["context"], expected.Context)
|
||||||
|
}
|
||||||
|
if limit["output"] != float64(expected.Output) {
|
||||||
|
t.Errorf("output = %v, want %d", limit["output"], expected.Output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
|
||||||
|
o := &OpenCode{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
specialModel := `model-with-"quotes"`
|
||||||
|
|
||||||
|
err := o.Edit([]string{specialModel})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Edit with special chars failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(o.configContent), &cfg); err != nil {
|
||||||
|
t.Fatalf("resulting config is invalid JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, _ := cfg["provider"].(map[string]any)
|
||||||
|
ollama, _ := provider["ollama"].(map[string]any)
|
||||||
|
models, _ := ollama["models"].(map[string]any)
|
||||||
|
if models[specialModel] == nil {
|
||||||
|
t.Errorf("model with special chars not found in config")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadModelJSONModels(t *testing.T) {
|
||||||
|
t.Run("reads ollama models from model.json", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||||
|
os.MkdirAll(stateDir, 0o755)
|
||||||
|
state := map[string]any{
|
||||||
|
"recent": []any{
|
||||||
|
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
|
||||||
|
map[string]any{"providerID": "ollama", "modelID": "qwen3:32b"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, _ := json.MarshalIndent(state, "", " ")
|
||||||
|
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
|
||||||
|
|
||||||
|
models := readModelJSONModels()
|
||||||
|
if len(models) != 2 {
|
||||||
|
t.Fatalf("got %d models, want 2", len(models))
|
||||||
|
}
|
||||||
|
if models[0] != "llama3.2" || models[1] != "qwen3:32b" {
|
||||||
|
t.Errorf("got %v, want [llama3.2 qwen3:32b]", models)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("skips non-ollama providers", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||||
|
os.MkdirAll(stateDir, 0o755)
|
||||||
|
state := map[string]any{
|
||||||
|
"recent": []any{
|
||||||
|
map[string]any{"providerID": "openai", "modelID": "gpt-4"},
|
||||||
|
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, _ := json.MarshalIndent(state, "", " ")
|
||||||
|
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
|
||||||
|
|
||||||
|
models := readModelJSONModels()
|
||||||
|
if len(models) != 1 || models[0] != "llama3.2" {
|
||||||
|
t.Errorf("got %v, want [llama3.2]", models)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns nil when file does not exist", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
if models := readModelJSONModels(); models != nil {
|
||||||
|
t.Errorf("got %v, want nil", models)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns nil for corrupt JSON", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||||
|
os.MkdirAll(stateDir, 0o755)
|
||||||
|
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{corrupt`), 0o644)
|
||||||
|
|
||||||
|
if models := readModelJSONModels(); models != nil {
|
||||||
|
t.Errorf("got %v, want nil", models)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenCodeResolveContent(t *testing.T) {
|
||||||
|
t.Run("returns Edit's content when set", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
o := &OpenCode{}
|
||||||
|
if err := o.Edit([]string{"gemma4"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
editContent := o.configContent
|
||||||
|
|
||||||
|
// Write a different model.json — should be ignored
|
||||||
|
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||||
|
state := map[string]any{
|
||||||
|
"recent": []any{
|
||||||
|
map[string]any{"providerID": "ollama", "modelID": "different-model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, _ := json.MarshalIndent(state, "", " ")
|
||||||
|
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
|
||||||
|
|
||||||
|
got := o.resolveContent("gemma4")
|
||||||
|
if got != editContent {
|
||||||
|
t.Errorf("resolveContent returned different content than Edit set\ngot: %s\nwant: %s", got, editContent)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("falls back to model.json when Edit was not called", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||||
|
os.MkdirAll(stateDir, 0o755)
|
||||||
|
state := map[string]any{
|
||||||
|
"recent": []any{
|
||||||
|
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
|
||||||
|
map[string]any{"providerID": "ollama", "modelID": "qwen3:32b"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, _ := json.MarshalIndent(state, "", " ")
|
||||||
|
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
|
||||||
|
|
||||||
|
o := &OpenCode{}
|
||||||
|
content := o.resolveContent("llama3.2")
|
||||||
|
if content == "" {
|
||||||
|
t.Fatal("resolveContent returned empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg map[string]any
|
||||||
|
json.Unmarshal([]byte(content), &cfg)
|
||||||
|
if cfg["model"] != "ollama/llama3.2" {
|
||||||
|
t.Errorf("primary = %v, want ollama/llama3.2", cfg["model"])
|
||||||
|
}
|
||||||
|
provider, _ := cfg["provider"].(map[string]any)
|
||||||
|
ollama, _ := provider["ollama"].(map[string]any)
|
||||||
|
cfgModels, _ := ollama["models"].(map[string]any)
|
||||||
|
if cfgModels["llama3.2"] == nil || cfgModels["qwen3:32b"] == nil {
|
||||||
|
t.Errorf("expected both models in config, got %v", cfgModels)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("uses requested model as primary even when not first in model.json", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||||
|
os.MkdirAll(stateDir, 0o755)
|
||||||
|
state := map[string]any{
|
||||||
|
"recent": []any{
|
||||||
|
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
|
||||||
|
map[string]any{"providerID": "ollama", "modelID": "qwen3:32b"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, _ := json.MarshalIndent(state, "", " ")
|
||||||
|
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
|
||||||
|
|
||||||
|
o := &OpenCode{}
|
||||||
|
content := o.resolveContent("qwen3:32b")
|
||||||
|
|
||||||
|
var cfg map[string]any
|
||||||
|
json.Unmarshal([]byte(content), &cfg)
|
||||||
|
if cfg["model"] != "ollama/qwen3:32b" {
|
||||||
|
t.Errorf("primary = %v, want ollama/qwen3:32b", cfg["model"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("injects requested model when missing from model.json", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||||
|
os.MkdirAll(stateDir, 0o755)
|
||||||
|
state := map[string]any{
|
||||||
|
"recent": []any{
|
||||||
|
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, _ := json.MarshalIndent(state, "", " ")
|
||||||
|
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
|
||||||
|
|
||||||
|
o := &OpenCode{}
|
||||||
|
content := o.resolveContent("gemma4")
|
||||||
|
|
||||||
|
var cfg map[string]any
|
||||||
|
json.Unmarshal([]byte(content), &cfg)
|
||||||
|
provider, _ := cfg["provider"].(map[string]any)
|
||||||
|
ollama, _ := provider["ollama"].(map[string]any)
|
||||||
|
cfgModels, _ := ollama["models"].(map[string]any)
|
||||||
|
if cfgModels["gemma4"] == nil {
|
||||||
|
t.Error("requested model gemma4 not injected into config")
|
||||||
|
}
|
||||||
|
if cfg["model"] != "ollama/gemma4" {
|
||||||
|
t.Errorf("primary = %v, want ollama/gemma4", cfg["model"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns empty when no model.json and no model param", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
o := &OpenCode{}
|
||||||
|
if got := o.resolveContent(""); got != "" {
|
||||||
|
t.Errorf("resolveContent(\"\") = %q, want empty", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("does not mutate configContent on fallback", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||||
|
os.MkdirAll(stateDir, 0o755)
|
||||||
|
state := map[string]any{
|
||||||
|
"recent": []any{
|
||||||
|
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, _ := json.MarshalIndent(state, "", " ")
|
||||||
|
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
|
||||||
|
|
||||||
|
o := &OpenCode{}
|
||||||
|
_ = o.resolveContent("llama3.2")
|
||||||
|
if o.configContent != "" {
|
||||||
|
t.Errorf("resolveContent should not mutate configContent, got %q", o.configContent)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildInlineConfig(t *testing.T) {
|
||||||
|
t.Run("returns error for empty primary", func(t *testing.T) {
|
||||||
|
if _, err := buildInlineConfig("", []string{"llama3.2"}); err == nil {
|
||||||
|
t.Error("expected error for empty primary")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns error for empty models", func(t *testing.T) {
|
||||||
|
if _, err := buildInlineConfig("llama3.2", nil); err == nil {
|
||||||
|
t.Error("expected error for empty models")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("primary differs from first model in list", func(t *testing.T) {
|
||||||
|
content, err := buildInlineConfig("qwen3:32b", []string{"llama3.2", "qwen3:32b"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var cfg map[string]any
|
||||||
|
json.Unmarshal([]byte(content), &cfg)
|
||||||
|
if cfg["model"] != "ollama/qwen3:32b" {
|
||||||
|
t.Errorf("primary = %v, want ollama/qwen3:32b", cfg["model"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenCodeEdit_PreservesRecentEntries(t *testing.T) {
|
||||||
|
t.Run("prepends new models to existing recent", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||||
|
os.MkdirAll(stateDir, 0o755)
|
||||||
|
initial := map[string]any{
|
||||||
|
"recent": []any{
|
||||||
|
map[string]any{"providerID": "ollama", "modelID": "old-A"},
|
||||||
|
map[string]any{"providerID": "ollama", "modelID": "old-B"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, _ := json.MarshalIndent(initial, "", " ")
|
||||||
|
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
|
||||||
|
|
||||||
|
o := &OpenCode{}
|
||||||
|
if err := o.Edit([]string{"new-X"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stored, _ := os.ReadFile(filepath.Join(stateDir, "model.json"))
|
||||||
|
var state map[string]any
|
||||||
|
json.Unmarshal(stored, &state)
|
||||||
|
recent, _ := state["recent"].([]any)
|
||||||
|
|
||||||
|
if len(recent) != 3 {
|
||||||
|
t.Fatalf("expected 3 entries, got %d", len(recent))
|
||||||
|
}
|
||||||
|
first, _ := recent[0].(map[string]any)
|
||||||
|
if first["modelID"] != "new-X" {
|
||||||
|
t.Errorf("first entry = %v, want new-X", first["modelID"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("prepends multiple new models in order", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||||
|
os.MkdirAll(stateDir, 0o755)
|
||||||
|
initial := map[string]any{
|
||||||
|
"recent": []any{
|
||||||
|
map[string]any{"providerID": "ollama", "modelID": "old-A"},
|
||||||
|
map[string]any{"providerID": "ollama", "modelID": "old-B"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, _ := json.MarshalIndent(initial, "", " ")
|
||||||
|
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
|
||||||
|
|
||||||
|
o := &OpenCode{}
|
||||||
|
if err := o.Edit([]string{"X", "Y", "Z"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stored, _ := os.ReadFile(filepath.Join(stateDir, "model.json"))
|
||||||
|
var state map[string]any
|
||||||
|
json.Unmarshal(stored, &state)
|
||||||
|
recent, _ := state["recent"].([]any)
|
||||||
|
|
||||||
|
want := []string{"X", "Y", "Z", "old-A", "old-B"}
|
||||||
|
if len(recent) != len(want) {
|
||||||
|
t.Fatalf("expected %d entries, got %d", len(want), len(recent))
|
||||||
|
}
|
||||||
|
for i, w := range want {
|
||||||
|
e, _ := recent[i].(map[string]any)
|
||||||
|
if e["modelID"] != w {
|
||||||
|
t.Errorf("recent[%d] = %v, want %v", i, e["modelID"], w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserves non-ollama entries", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||||
|
os.MkdirAll(stateDir, 0o755)
|
||||||
|
initial := map[string]any{
|
||||||
|
"recent": []any{
|
||||||
|
map[string]any{"providerID": "openai", "modelID": "gpt-4"},
|
||||||
|
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, _ := json.MarshalIndent(initial, "", " ")
|
||||||
|
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
|
||||||
|
|
||||||
|
o := &OpenCode{}
|
||||||
|
if err := o.Edit([]string{"qwen3:32b"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stored, _ := os.ReadFile(filepath.Join(stateDir, "model.json"))
|
||||||
|
var state map[string]any
|
||||||
|
json.Unmarshal(stored, &state)
|
||||||
|
recent, _ := state["recent"].([]any)
|
||||||
|
|
||||||
|
// Should have: qwen3:32b (new), gpt-4 (preserved openai), llama3.2 (preserved ollama)
|
||||||
|
var foundOpenAI bool
|
||||||
|
for _, entry := range recent {
|
||||||
|
e, _ := entry.(map[string]any)
|
||||||
|
if e["providerID"] == "openai" && e["modelID"] == "gpt-4" {
|
||||||
|
foundOpenAI = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundOpenAI {
|
||||||
|
t.Errorf("non-ollama gpt-4 entry was not preserved, got %v", recent)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("deduplicates ollama models being re-added", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||||
|
os.MkdirAll(stateDir, 0o755)
|
||||||
|
initial := map[string]any{
|
||||||
|
"recent": []any{
|
||||||
|
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, _ := json.MarshalIndent(initial, "", " ")
|
||||||
|
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
|
||||||
|
|
||||||
|
o := &OpenCode{}
|
||||||
|
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stored, _ := os.ReadFile(filepath.Join(stateDir, "model.json"))
|
||||||
|
var state map[string]any
|
||||||
|
json.Unmarshal(stored, &state)
|
||||||
|
recent, _ := state["recent"].([]any)
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
for _, entry := range recent {
|
||||||
|
e, _ := entry.(map[string]any)
|
||||||
|
if e["modelID"] == "llama3.2" {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if count != 1 {
|
||||||
|
t.Errorf("expected 1 llama3.2 entry, got %d", count)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("caps recent list at 10", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||||
|
os.MkdirAll(stateDir, 0o755)
|
||||||
|
|
||||||
|
// Pre-populate with 9 distinct ollama models
|
||||||
|
recentEntries := make([]any, 0, 9)
|
||||||
|
for i := range 9 {
|
||||||
|
recentEntries = append(recentEntries, map[string]any{
|
||||||
|
"providerID": "ollama",
|
||||||
|
"modelID": fmt.Sprintf("old-%d", i),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
initial := map[string]any{"recent": recentEntries}
|
||||||
|
data, _ := json.MarshalIndent(initial, "", " ")
|
||||||
|
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
|
||||||
|
|
||||||
|
// Add 5 new models — should cap at 10 total
|
||||||
|
o := &OpenCode{}
|
||||||
|
if err := o.Edit([]string{"new-0", "new-1", "new-2", "new-3", "new-4"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stored, _ := os.ReadFile(filepath.Join(stateDir, "model.json"))
|
||||||
|
var state map[string]any
|
||||||
|
json.Unmarshal(stored, &state)
|
||||||
|
recent, _ := state["recent"].([]any)
|
||||||
|
|
||||||
|
if len(recent) != 10 {
|
||||||
|
t.Errorf("expected 10 entries (capped), got %d", len(recent))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenCodeEdit_BaseURL(t *testing.T) {
|
||||||
|
o := &OpenCode{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
// Default OLLAMA_HOST
|
||||||
|
o.Edit([]string{"llama3.2"})
|
||||||
|
|
||||||
|
var cfg map[string]any
|
||||||
|
json.Unmarshal([]byte(o.configContent), &cfg)
|
||||||
|
provider, _ := cfg["provider"].(map[string]any)
|
||||||
|
ollama, _ := provider["ollama"].(map[string]any)
|
||||||
|
options, _ := ollama["options"].(map[string]any)
|
||||||
|
|
||||||
|
baseURL, _ := options["baseURL"].(string)
|
||||||
|
if baseURL == "" {
|
||||||
|
t.Error("baseURL should be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
396
cmd/launch/pi.go
Normal file
@@ -0,0 +1,396 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Pi implements Runner and Editor for Pi (Pi Coding Agent) integration
|
||||||
|
type Pi struct{}
|
||||||
|
|
||||||
|
const (
|
||||||
|
piNpmPackage = "@mariozechner/pi-coding-agent"
|
||||||
|
piWebSearchSource = "npm:@ollama/pi-web-search"
|
||||||
|
piWebSearchPkg = "@ollama/pi-web-search"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *Pi) String() string { return "Pi" }
|
||||||
|
|
||||||
|
func (p *Pi) Run(model string, args []string) error {
|
||||||
|
fmt.Fprintf(os.Stderr, "\n%sPreparing Pi...%s\n", ansiGray, ansiReset)
|
||||||
|
if err := ensureNpmInstalled(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "%sChecking Pi installation...%s\n", ansiGray, ansiReset)
|
||||||
|
bin, err := ensurePiInstalled()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ensurePiWebSearchPackage(bin)
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "\n%sLaunching Pi...%s\n\n", ansiGray, ansiReset)
|
||||||
|
|
||||||
|
cmd := exec.Command(bin, args...)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureNpmInstalled() error {
|
||||||
|
if _, err := exec.LookPath("npm"); err != nil {
|
||||||
|
return fmt.Errorf("npm (Node.js) is required to launch pi\n\nInstall it first:\n https://nodejs.org/\n\nThen re-run:\n ollama launch pi")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensurePiInstalled() (string, error) {
|
||||||
|
if _, err := exec.LookPath("pi"); err == nil {
|
||||||
|
return "pi", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := exec.LookPath("npm"); err != nil {
|
||||||
|
return "", fmt.Errorf("pi is not installed and required dependencies are missing\n\nInstall the following first:\n npm (Node.js): https://nodejs.org/\n\nThen re-run:\n ollama launch pi")
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := ConfirmPrompt("Pi is not installed. Install with npm?")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("pi installation cancelled")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "\nInstalling Pi...\n")
|
||||||
|
cmd := exec.Command("npm", "install", "-g", piNpmPackage+"@latest")
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to install pi: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := exec.LookPath("pi"); err != nil {
|
||||||
|
return "", fmt.Errorf("pi was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "%sPi installed successfully%s\n\n", ansiGreen, ansiReset)
|
||||||
|
return "pi", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensurePiWebSearchPackage(bin string) {
|
||||||
|
if !shouldManagePiWebSearch() {
|
||||||
|
fmt.Fprintf(os.Stderr, "%sCloud is disabled; skipping %s setup.%s\n", ansiGray, piWebSearchPkg, ansiReset)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "%sChecking Pi web search package...%s\n", ansiGray, ansiReset)
|
||||||
|
|
||||||
|
installed, err := piPackageInstalled(bin, piWebSearchSource)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s Warning: could not check %s installation: %v%s\n", ansiYellow, piWebSearchPkg, err, ansiReset)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !installed {
|
||||||
|
fmt.Fprintf(os.Stderr, "%sInstalling %s...%s\n", ansiGray, piWebSearchPkg, ansiReset)
|
||||||
|
cmd := exec.Command(bin, "install", piWebSearchSource)
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s Warning: could not install %s: %v%s\n", ansiYellow, piWebSearchPkg, err, ansiReset)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "%s ✓ Installed %s%s\n", ansiGreen, piWebSearchPkg, ansiReset)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "%sUpdating %s...%s\n", ansiGray, piWebSearchPkg, ansiReset)
|
||||||
|
cmd := exec.Command(bin, "update", piWebSearchSource)
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s Warning: could not update %s: %v%s\n", ansiYellow, piWebSearchPkg, err, ansiReset)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "%s ✓ Updated %s%s\n", ansiGreen, piWebSearchPkg, ansiReset)
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldManagePiWebSearch() bool {
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
disabled, known := cloudStatusDisabled(context.Background(), client)
|
||||||
|
if known && disabled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func piPackageInstalled(bin, source string) (bool, error) {
|
||||||
|
cmd := exec.Command(bin, "list")
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
msg := strings.TrimSpace(string(out))
|
||||||
|
if msg == "" {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return false, fmt.Errorf("%w: %s", err, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, line := range strings.Split(string(out), "\n") {
|
||||||
|
trimmed := strings.TrimSpace(line)
|
||||||
|
if strings.HasPrefix(trimmed, source) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pi) Paths() []string {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var paths []string
|
||||||
|
modelsPath := filepath.Join(home, ".pi", "agent", "models.json")
|
||||||
|
if _, err := os.Stat(modelsPath); err == nil {
|
||||||
|
paths = append(paths, modelsPath)
|
||||||
|
}
|
||||||
|
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
|
||||||
|
if _, err := os.Stat(settingsPath); err == nil {
|
||||||
|
paths = append(paths, settingsPath)
|
||||||
|
}
|
||||||
|
return paths
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pi) Edit(models []string) error {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
configPath := filepath.Join(home, ".pi", "agent", "models.json")
|
||||||
|
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
config := make(map[string]any)
|
||||||
|
if data, err := os.ReadFile(configPath); err == nil {
|
||||||
|
_ = json.Unmarshal(data, &config)
|
||||||
|
}
|
||||||
|
|
||||||
|
providers, ok := config["providers"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
providers = make(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
|
ollama, ok := providers["ollama"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
ollama = map[string]any{
|
||||||
|
"baseUrl": envconfig.Host().String() + "/v1",
|
||||||
|
"api": "openai-completions",
|
||||||
|
"apiKey": "ollama",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
existingModels, ok := ollama["models"].([]any)
|
||||||
|
if !ok {
|
||||||
|
existingModels = make([]any, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build set of selected models to track which need to be added
|
||||||
|
selectedSet := make(map[string]bool, len(models))
|
||||||
|
for _, m := range models {
|
||||||
|
selectedSet[m] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build new models list:
|
||||||
|
// 1. Keep user-managed models (no _launch marker) - untouched
|
||||||
|
// 2. Keep ollama-managed models (_launch marker) that are still selected,
|
||||||
|
// except stale cloud entries that should be rebuilt below
|
||||||
|
// 3. Add new ollama-managed models
|
||||||
|
var newModels []any
|
||||||
|
for _, m := range existingModels {
|
||||||
|
if modelObj, ok := m.(map[string]any); ok {
|
||||||
|
if id, ok := modelObj["id"].(string); ok {
|
||||||
|
// User-managed model (no _launch marker) - always preserve
|
||||||
|
if !isPiOllamaModel(modelObj) {
|
||||||
|
newModels = append(newModels, m)
|
||||||
|
} else if selectedSet[id] {
|
||||||
|
// Rebuild stale managed cloud entries so createConfig refreshes
|
||||||
|
// the whole entry instead of patching it in place.
|
||||||
|
if !hasContextWindow(modelObj) {
|
||||||
|
if _, ok := lookupCloudModelLimit(id); ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
newModels = append(newModels, m)
|
||||||
|
selectedSet[id] = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add newly selected models that weren't already in the list
|
||||||
|
client := api.NewClient(envconfig.Host(), http.DefaultClient)
|
||||||
|
ctx := context.Background()
|
||||||
|
for _, model := range models {
|
||||||
|
if selectedSet[model] {
|
||||||
|
newModels = append(newModels, createConfig(ctx, client, model))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ollama["models"] = newModels
|
||||||
|
providers["ollama"] = ollama
|
||||||
|
config["providers"] = providers
|
||||||
|
|
||||||
|
configData, err := json.MarshalIndent(config, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := fileutil.WriteWithBackup(configPath, configData); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update settings.json with default provider and model
|
||||||
|
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
|
||||||
|
settings := make(map[string]any)
|
||||||
|
if data, err := os.ReadFile(settingsPath); err == nil {
|
||||||
|
_ = json.Unmarshal(data, &settings)
|
||||||
|
}
|
||||||
|
|
||||||
|
settings["defaultProvider"] = "ollama"
|
||||||
|
settings["defaultModel"] = models[0]
|
||||||
|
|
||||||
|
settingsData, err := json.MarshalIndent(settings, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return fileutil.WriteWithBackup(settingsPath, settingsData)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pi) Models() []string {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
configPath := filepath.Join(home, ".pi", "agent", "models.json")
|
||||||
|
config, err := fileutil.ReadJSON(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
providers, _ := config["providers"].(map[string]any)
|
||||||
|
ollama, _ := providers["ollama"].(map[string]any)
|
||||||
|
models, _ := ollama["models"].([]any)
|
||||||
|
|
||||||
|
var result []string
|
||||||
|
for _, m := range models {
|
||||||
|
if modelObj, ok := m.(map[string]any); ok {
|
||||||
|
if id, ok := modelObj["id"].(string); ok {
|
||||||
|
result = append(result, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
slices.Sort(result)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPiOllamaModel reports whether a model config entry is managed by ollama launch
|
||||||
|
func isPiOllamaModel(cfg map[string]any) bool {
|
||||||
|
if v, ok := cfg["_launch"].(bool); ok && v {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasContextWindow(cfg map[string]any) bool {
|
||||||
|
switch v := cfg["contextWindow"].(type) {
|
||||||
|
case float64:
|
||||||
|
return v > 0
|
||||||
|
case int:
|
||||||
|
return v > 0
|
||||||
|
case int64:
|
||||||
|
return v > 0
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// createConfig builds Pi model config with capability detection
|
||||||
|
func createConfig(ctx context.Context, client *api.Client, modelID string) map[string]any {
|
||||||
|
cfg := map[string]any{
|
||||||
|
"id": modelID,
|
||||||
|
"_launch": true,
|
||||||
|
}
|
||||||
|
if l, ok := lookupCloudModelLimit(modelID); ok {
|
||||||
|
cfg["contextWindow"] = l.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
applyCloudContextFallback := func() {
|
||||||
|
if l, ok := lookupCloudModelLimit(modelID); ok {
|
||||||
|
cfg["contextWindow"] = l.Context
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID})
|
||||||
|
if err != nil {
|
||||||
|
applyCloudContextFallback()
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set input types based on vision capability
|
||||||
|
if slices.Contains(resp.Capabilities, model.CapabilityVision) {
|
||||||
|
cfg["input"] = []string{"text", "image"}
|
||||||
|
} else {
|
||||||
|
cfg["input"] = []string{"text"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set reasoning based on thinking capability
|
||||||
|
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||||
|
cfg["reasoning"] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract context window from ModelInfo. For known cloud models, the
|
||||||
|
// pre-filled shared limit remains unless the server provides a positive value.
|
||||||
|
hasContextWindow := false
|
||||||
|
for key, val := range resp.ModelInfo {
|
||||||
|
if strings.HasSuffix(key, ".context_length") {
|
||||||
|
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
|
||||||
|
cfg["contextWindow"] = int(ctxLen)
|
||||||
|
hasContextWindow = true
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasContextWindow {
|
||||||
|
applyCloudContextFallback()
|
||||||
|
}
|
||||||
|
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -9,6 +9,8 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
@@ -33,6 +35,341 @@ func TestPiIntegration(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPiRun_InstallAndWebSearchLifecycle(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("uses POSIX shell test binaries")
|
||||||
|
}
|
||||||
|
|
||||||
|
writeScript := func(t *testing.T, path, content string) {
|
||||||
|
t.Helper()
|
||||||
|
if err := os.WriteFile(path, []byte(content), 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
seedPiScript := func(t *testing.T, dir string) {
|
||||||
|
t.Helper()
|
||||||
|
piPath := filepath.Join(dir, "pi")
|
||||||
|
listPath := filepath.Join(dir, "pi-list.txt")
|
||||||
|
piScript := fmt.Sprintf(`#!/bin/sh
|
||||||
|
echo "$@" >> %q
|
||||||
|
if [ "$1" = "list" ]; then
|
||||||
|
if [ -f %q ]; then
|
||||||
|
/bin/cat %q
|
||||||
|
fi
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
if [ "$1" = "update" ] && [ "$PI_FAIL_UPDATE" = "1" ]; then
|
||||||
|
echo "update failed" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [ "$1" = "install" ] && [ "$PI_FAIL_INSTALL" = "1" ]; then
|
||||||
|
echo "install failed" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
exit 0
|
||||||
|
`, filepath.Join(dir, "pi.log"), listPath, listPath)
|
||||||
|
writeScript(t, piPath, piScript)
|
||||||
|
}
|
||||||
|
|
||||||
|
seedNpmNoop := func(t *testing.T, dir string) {
|
||||||
|
t.Helper()
|
||||||
|
writeScript(t, filepath.Join(dir, "npm"), "#!/bin/sh\nexit 0\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
withConfirm := func(t *testing.T, fn func(prompt string) (bool, error)) {
|
||||||
|
t.Helper()
|
||||||
|
oldConfirm := DefaultConfirmPrompt
|
||||||
|
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||||
|
return fn(prompt)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { DefaultConfirmPrompt = oldConfirm })
|
||||||
|
}
|
||||||
|
|
||||||
|
setCloudStatus := func(t *testing.T, disabled bool) {
|
||||||
|
t.Helper()
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/api/status" {
|
||||||
|
fmt.Fprintf(w, `{"cloud":{"disabled":%t,"source":"config"}}`, disabled)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}))
|
||||||
|
t.Cleanup(srv.Close)
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("pi missing + user accepts install", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("PATH", tmpDir)
|
||||||
|
setCloudStatus(t, false)
|
||||||
|
|
||||||
|
if err := os.WriteFile(filepath.Join(tmpDir, "pi-list.txt"), []byte("User packages:\n npm:@ollama/pi-web-search\n"), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
npmScript := fmt.Sprintf(`#!/bin/sh
|
||||||
|
echo "$@" >> %q
|
||||||
|
if [ "$1" = "install" ] && [ "$2" = "-g" ] && [ "$3" = %q ]; then
|
||||||
|
/bin/cat > %q <<'EOS'
|
||||||
|
#!/bin/sh
|
||||||
|
echo "$@" >> %q
|
||||||
|
if [ "$1" = "list" ]; then
|
||||||
|
if [ -f %q ]; then
|
||||||
|
/bin/cat %q
|
||||||
|
fi
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
exit 0
|
||||||
|
EOS
|
||||||
|
/bin/chmod +x %q
|
||||||
|
fi
|
||||||
|
exit 0
|
||||||
|
`, filepath.Join(tmpDir, "npm.log"), piNpmPackage+"@latest", filepath.Join(tmpDir, "pi"), filepath.Join(tmpDir, "pi.log"), filepath.Join(tmpDir, "pi-list.txt"), filepath.Join(tmpDir, "pi-list.txt"), filepath.Join(tmpDir, "pi"))
|
||||||
|
writeScript(t, filepath.Join(tmpDir, "npm"), npmScript)
|
||||||
|
|
||||||
|
withConfirm(t, func(prompt string) (bool, error) {
|
||||||
|
if strings.Contains(prompt, "Pi is not installed.") {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
p := &Pi{}
|
||||||
|
if err := p.Run("ignored", []string{"--version"}); err != nil {
|
||||||
|
t.Fatalf("Run() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
npmCalls, err := os.ReadFile(filepath.Join(tmpDir, "npm.log"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(npmCalls), "install -g "+piNpmPackage+"@latest") {
|
||||||
|
t.Fatalf("expected npm install call, got:\n%s", npmCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
piCalls, err := os.ReadFile(filepath.Join(tmpDir, "pi.log"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
got := string(piCalls)
|
||||||
|
if !strings.Contains(got, "list\n") {
|
||||||
|
t.Fatalf("expected pi list call, got:\n%s", got)
|
||||||
|
}
|
||||||
|
if !strings.Contains(got, "update "+piWebSearchSource+"\n") {
|
||||||
|
t.Fatalf("expected pi update call, got:\n%s", got)
|
||||||
|
}
|
||||||
|
if !strings.Contains(got, "--version\n") {
|
||||||
|
t.Fatalf("expected final pi launch call, got:\n%s", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("pi missing + user declines install", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("PATH", tmpDir)
|
||||||
|
setCloudStatus(t, false)
|
||||||
|
writeScript(t, filepath.Join(tmpDir, "npm"), "#!/bin/sh\nexit 0\n")
|
||||||
|
|
||||||
|
withConfirm(t, func(prompt string) (bool, error) {
|
||||||
|
if strings.Contains(prompt, "Pi is not installed.") {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
p := &Pi{}
|
||||||
|
err := p.Run("ignored", nil)
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "pi installation cancelled") {
|
||||||
|
t.Fatalf("expected install cancellation error, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("pi installed + web search missing auto-installs", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("PATH", tmpDir)
|
||||||
|
setCloudStatus(t, false)
|
||||||
|
if err := os.WriteFile(filepath.Join(tmpDir, "pi-list.txt"), []byte("User packages:\n"), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
seedPiScript(t, tmpDir)
|
||||||
|
seedNpmNoop(t, tmpDir)
|
||||||
|
withConfirm(t, func(prompt string) (bool, error) {
|
||||||
|
t.Fatalf("did not expect confirmation prompt, got %q", prompt)
|
||||||
|
return false, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
p := &Pi{}
|
||||||
|
if err := p.Run("ignored", []string{"session"}); err != nil {
|
||||||
|
t.Fatalf("Run() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
piCalls, err := os.ReadFile(filepath.Join(tmpDir, "pi.log"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
got := string(piCalls)
|
||||||
|
if !strings.Contains(got, "list\n") {
|
||||||
|
t.Fatalf("expected pi list call, got:\n%s", got)
|
||||||
|
}
|
||||||
|
if !strings.Contains(got, "install "+piWebSearchSource+"\n") {
|
||||||
|
t.Fatalf("expected pi install call, got:\n%s", got)
|
||||||
|
}
|
||||||
|
if strings.Contains(got, "update "+piWebSearchSource+"\n") {
|
||||||
|
t.Fatalf("did not expect pi update call when package missing, got:\n%s", got)
|
||||||
|
}
|
||||||
|
if !strings.Contains(got, "session\n") {
|
||||||
|
t.Fatalf("expected final pi launch call, got:\n%s", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("pi installed + web search present updates every launch", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("PATH", tmpDir)
|
||||||
|
setCloudStatus(t, false)
|
||||||
|
if err := os.WriteFile(filepath.Join(tmpDir, "pi-list.txt"), []byte("User packages:\n "+piWebSearchSource+"\n"), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
seedPiScript(t, tmpDir)
|
||||||
|
seedNpmNoop(t, tmpDir)
|
||||||
|
|
||||||
|
p := &Pi{}
|
||||||
|
if err := p.Run("ignored", []string{"doctor"}); err != nil {
|
||||||
|
t.Fatalf("Run() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
piCalls, err := os.ReadFile(filepath.Join(tmpDir, "pi.log"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
got := string(piCalls)
|
||||||
|
if !strings.Contains(got, "update "+piWebSearchSource+"\n") {
|
||||||
|
t.Fatalf("expected pi update call, got:\n%s", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("web search update failure warns and continues", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("PATH", tmpDir)
|
||||||
|
setCloudStatus(t, false)
|
||||||
|
t.Setenv("PI_FAIL_UPDATE", "1")
|
||||||
|
if err := os.WriteFile(filepath.Join(tmpDir, "pi-list.txt"), []byte("User packages:\n "+piWebSearchSource+"\n"), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
seedPiScript(t, tmpDir)
|
||||||
|
seedNpmNoop(t, tmpDir)
|
||||||
|
|
||||||
|
p := &Pi{}
|
||||||
|
stderr := captureStderr(t, func() {
|
||||||
|
if err := p.Run("ignored", []string{"session"}); err != nil {
|
||||||
|
t.Fatalf("Run() should continue after web search update failure, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if !strings.Contains(stderr, "Warning: could not update "+piWebSearchPkg) {
|
||||||
|
t.Fatalf("expected update warning, got:\n%s", stderr)
|
||||||
|
}
|
||||||
|
|
||||||
|
piCalls, err := os.ReadFile(filepath.Join(tmpDir, "pi.log"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(piCalls), "session\n") {
|
||||||
|
t.Fatalf("expected final pi launch call, got:\n%s", piCalls)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("web search install failure warns and continues", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("PATH", tmpDir)
|
||||||
|
setCloudStatus(t, false)
|
||||||
|
t.Setenv("PI_FAIL_INSTALL", "1")
|
||||||
|
if err := os.WriteFile(filepath.Join(tmpDir, "pi-list.txt"), []byte("User packages:\n"), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
seedPiScript(t, tmpDir)
|
||||||
|
seedNpmNoop(t, tmpDir)
|
||||||
|
withConfirm(t, func(prompt string) (bool, error) {
|
||||||
|
t.Fatalf("did not expect confirmation prompt, got %q", prompt)
|
||||||
|
return false, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
p := &Pi{}
|
||||||
|
stderr := captureStderr(t, func() {
|
||||||
|
if err := p.Run("ignored", []string{"session"}); err != nil {
|
||||||
|
t.Fatalf("Run() should continue after web search install failure, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if !strings.Contains(stderr, "Warning: could not install "+piWebSearchPkg) {
|
||||||
|
t.Fatalf("expected install warning, got:\n%s", stderr)
|
||||||
|
}
|
||||||
|
|
||||||
|
piCalls, err := os.ReadFile(filepath.Join(tmpDir, "pi.log"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(piCalls), "session\n") {
|
||||||
|
t.Fatalf("expected final pi launch call, got:\n%s", piCalls)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cloud disabled skips web search package management", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("PATH", tmpDir)
|
||||||
|
setCloudStatus(t, true)
|
||||||
|
if err := os.WriteFile(filepath.Join(tmpDir, "pi-list.txt"), []byte("User packages:\n"), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
seedPiScript(t, tmpDir)
|
||||||
|
seedNpmNoop(t, tmpDir)
|
||||||
|
|
||||||
|
p := &Pi{}
|
||||||
|
stderr := captureStderr(t, func() {
|
||||||
|
if err := p.Run("ignored", []string{"session"}); err != nil {
|
||||||
|
t.Fatalf("Run() error = %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if !strings.Contains(stderr, "Cloud is disabled; skipping "+piWebSearchPkg+" setup.") {
|
||||||
|
t.Fatalf("expected cloud-disabled skip message, got:\n%s", stderr)
|
||||||
|
}
|
||||||
|
|
||||||
|
piCalls, err := os.ReadFile(filepath.Join(tmpDir, "pi.log"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
got := string(piCalls)
|
||||||
|
if strings.Contains(got, "list\n") || strings.Contains(got, "install "+piWebSearchSource+"\n") || strings.Contains(got, "update "+piWebSearchSource+"\n") {
|
||||||
|
t.Fatalf("did not expect web search package management calls, got:\n%s", got)
|
||||||
|
}
|
||||||
|
if !strings.Contains(got, "session\n") {
|
||||||
|
t.Fatalf("expected final pi launch call, got:\n%s", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing npm returns error before pi flow", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("PATH", tmpDir)
|
||||||
|
setCloudStatus(t, false)
|
||||||
|
seedPiScript(t, tmpDir)
|
||||||
|
|
||||||
|
p := &Pi{}
|
||||||
|
err := p.Run("ignored", []string{"session"})
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "npm (Node.js) is required to launch pi") {
|
||||||
|
t.Fatalf("expected missing npm error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, statErr := os.Stat(filepath.Join(tmpDir, "pi.log")); !os.IsNotExist(statErr) {
|
||||||
|
t.Fatalf("expected pi not to run when npm is missing, stat err = %v", statErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestPiPaths(t *testing.T) {
|
func TestPiPaths(t *testing.T) {
|
||||||
pi := &Pi{}
|
pi := &Pi{}
|
||||||
|
|
||||||
@@ -192,6 +529,48 @@ func TestPiEdit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("rebuilds stale existing managed cloud model", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
|
||||||
|
existingConfig := `{
|
||||||
|
"providers": {
|
||||||
|
"ollama": {
|
||||||
|
"baseUrl": "http://localhost:11434/v1",
|
||||||
|
"api": "openai-completions",
|
||||||
|
"apiKey": "ollama",
|
||||||
|
"models": [
|
||||||
|
{"id": "glm-5:cloud", "_launch": true, "legacyField": "stale"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := pi.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||||
|
t.Fatalf("Edit() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := readConfig()
|
||||||
|
providers := cfg["providers"].(map[string]any)
|
||||||
|
ollama := providers["ollama"].(map[string]any)
|
||||||
|
modelsArray := ollama["models"].([]any)
|
||||||
|
modelEntry := modelsArray[0].(map[string]any)
|
||||||
|
|
||||||
|
if modelEntry["contextWindow"] != float64(202_752) {
|
||||||
|
t.Errorf("contextWindow = %v, want 202752", modelEntry["contextWindow"])
|
||||||
|
}
|
||||||
|
input, ok := modelEntry["input"].([]any)
|
||||||
|
if !ok || len(input) != 1 || input[0] != "text" {
|
||||||
|
t.Errorf("input = %v, want [text]", modelEntry["input"])
|
||||||
|
}
|
||||||
|
if _, ok := modelEntry["legacyField"]; ok {
|
||||||
|
t.Error("legacyField should be removed when stale managed cloud entry is rebuilt")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("replaces old models with new ones", func(t *testing.T) {
|
t.Run("replaces old models with new ones", func(t *testing.T) {
|
||||||
cleanup()
|
cleanup()
|
||||||
os.MkdirAll(configDir, 0o755)
|
os.MkdirAll(configDir, 0o755)
|
||||||
@@ -798,6 +1177,59 @@ func TestCreateConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("cloud model falls back to hardcoded context when show fails", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, _ := url.Parse(srv.URL)
|
||||||
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
|
cfg := createConfig(context.Background(), client, "kimi-k2.5:cloud")
|
||||||
|
|
||||||
|
if cfg["contextWindow"] != 262_144 {
|
||||||
|
t.Errorf("contextWindow = %v, want 262144", cfg["contextWindow"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cloud model falls back to hardcoded context when show omits model info", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/api/show" {
|
||||||
|
fmt.Fprintf(w, `{"capabilities":[],"model_info":{}}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, _ := url.Parse(srv.URL)
|
||||||
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
|
cfg := createConfig(context.Background(), client, "glm-5:cloud")
|
||||||
|
|
||||||
|
if cfg["contextWindow"] != 202_752 {
|
||||||
|
t.Errorf("contextWindow = %v, want 202752", cfg["contextWindow"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cloud model with dash suffix falls back to hardcoded context", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, _ := url.Parse(srv.URL)
|
||||||
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
|
cfg := createConfig(context.Background(), client, "gpt-oss:120b-cloud")
|
||||||
|
|
||||||
|
if cfg["contextWindow"] != 131_072 {
|
||||||
|
t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
|
||||||
|
}
|
||||||
|
})
|
||||||
t.Run("skips zero context length", func(t *testing.T) {
|
t.Run("skips zero context length", func(t *testing.T) {
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.URL.Path == "/api/show" {
|
if r.URL.Path == "/api/show" {
|
||||||
373
cmd/launch/registry.go
Normal file
@@ -0,0 +1,373 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IntegrationInstallSpec describes how launcher should detect and guide installation.
|
||||||
|
type IntegrationInstallSpec struct {
|
||||||
|
CheckInstalled func() bool
|
||||||
|
EnsureInstalled func() error
|
||||||
|
URL string
|
||||||
|
Command []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntegrationSpec is the canonical registry entry for one integration.
|
||||||
|
type IntegrationSpec struct {
|
||||||
|
Name string
|
||||||
|
Runner Runner
|
||||||
|
Aliases []string
|
||||||
|
Hidden bool
|
||||||
|
Description string
|
||||||
|
Install IntegrationInstallSpec
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntegrationInfo contains display information about a registered integration.
|
||||||
|
type IntegrationInfo struct {
|
||||||
|
Name string
|
||||||
|
DisplayName string
|
||||||
|
Description string
|
||||||
|
}
|
||||||
|
|
||||||
|
var launcherIntegrationOrder = []string{"opencode", "droid", "pi"}
|
||||||
|
|
||||||
|
var integrationSpecs = []*IntegrationSpec{
|
||||||
|
{
|
||||||
|
Name: "claude",
|
||||||
|
Runner: &Claude{},
|
||||||
|
Description: "Anthropic's coding tool with subagents",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := (&Claude{}).findPath()
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
URL: "https://code.claude.com/docs/en/quickstart",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "cline",
|
||||||
|
Runner: &Cline{},
|
||||||
|
Description: "Autonomous coding agent with parallel execution",
|
||||||
|
Hidden: true,
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := exec.LookPath("cline")
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
Command: []string{"npm", "install", "-g", "cline"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "codex",
|
||||||
|
Runner: &Codex{},
|
||||||
|
Description: "OpenAI's open-source coding agent",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := exec.LookPath("codex")
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
URL: "https://developers.openai.com/codex/cli/",
|
||||||
|
Command: []string{"npm", "install", "-g", "@openai/codex"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "droid",
|
||||||
|
Runner: &Droid{},
|
||||||
|
Description: "Factory's coding agent across terminal and IDEs",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := exec.LookPath("droid")
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
URL: "https://docs.factory.ai/cli/getting-started/quickstart",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "opencode",
|
||||||
|
Runner: &OpenCode{},
|
||||||
|
Description: "Anomaly's open-source coding agent",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, ok := findOpenCode()
|
||||||
|
return ok
|
||||||
|
},
|
||||||
|
URL: "https://opencode.ai",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "openclaw",
|
||||||
|
Runner: &Openclaw{},
|
||||||
|
Aliases: []string{"clawdbot", "moltbot"},
|
||||||
|
Description: "Personal AI with 100+ skills",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
if _, err := exec.LookPath("openclaw"); err == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if _, err := exec.LookPath("clawdbot"); err == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
},
|
||||||
|
EnsureInstalled: func() error {
|
||||||
|
_, err := ensureOpenclawInstalled()
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
URL: "https://docs.openclaw.ai",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "pi",
|
||||||
|
Runner: &Pi{},
|
||||||
|
Description: "Minimal AI agent toolkit with plugin support",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := exec.LookPath("pi")
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
EnsureInstalled: func() error {
|
||||||
|
_, err := ensurePiInstalled()
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
Command: []string{"npm", "install", "-g", "@mariozechner/pi-coding-agent@latest"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "vscode",
|
||||||
|
Runner: &VSCode{},
|
||||||
|
Aliases: []string{"code"},
|
||||||
|
Description: "Microsoft's open-source AI code editor",
|
||||||
|
Hidden: true,
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
return (&VSCode{}).findBinary() != ""
|
||||||
|
},
|
||||||
|
URL: "https://code.visualstudio.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var integrationSpecsByName map[string]*IntegrationSpec
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rebuildIntegrationSpecIndexes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func hyperlink(url, text string) string {
|
||||||
|
return fmt.Sprintf("\033]8;;%s\033\\%s\033]8;;\033\\", url, text)
|
||||||
|
}
|
||||||
|
|
||||||
|
func rebuildIntegrationSpecIndexes() {
|
||||||
|
integrationSpecsByName = make(map[string]*IntegrationSpec, len(integrationSpecs))
|
||||||
|
|
||||||
|
canonical := make(map[string]bool, len(integrationSpecs))
|
||||||
|
for _, spec := range integrationSpecs {
|
||||||
|
key := strings.ToLower(spec.Name)
|
||||||
|
if key == "" {
|
||||||
|
panic("launch: integration spec missing name")
|
||||||
|
}
|
||||||
|
if canonical[key] {
|
||||||
|
panic(fmt.Sprintf("launch: duplicate integration name %q", key))
|
||||||
|
}
|
||||||
|
canonical[key] = true
|
||||||
|
integrationSpecsByName[key] = spec
|
||||||
|
}
|
||||||
|
|
||||||
|
seenAliases := make(map[string]string)
|
||||||
|
for _, spec := range integrationSpecs {
|
||||||
|
for _, alias := range spec.Aliases {
|
||||||
|
key := strings.ToLower(alias)
|
||||||
|
if key == "" {
|
||||||
|
panic(fmt.Sprintf("launch: integration %q has empty alias", spec.Name))
|
||||||
|
}
|
||||||
|
if canonical[key] {
|
||||||
|
panic(fmt.Sprintf("launch: alias %q collides with canonical integration name", key))
|
||||||
|
}
|
||||||
|
if owner, exists := seenAliases[key]; exists {
|
||||||
|
panic(fmt.Sprintf("launch: alias %q collides between %q and %q", key, owner, spec.Name))
|
||||||
|
}
|
||||||
|
seenAliases[key] = spec.Name
|
||||||
|
integrationSpecsByName[key] = spec
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
orderSeen := make(map[string]bool, len(launcherIntegrationOrder))
|
||||||
|
for _, name := range launcherIntegrationOrder {
|
||||||
|
key := strings.ToLower(name)
|
||||||
|
if orderSeen[key] {
|
||||||
|
panic(fmt.Sprintf("launch: duplicate launcher order entry %q", key))
|
||||||
|
}
|
||||||
|
orderSeen[key] = true
|
||||||
|
|
||||||
|
spec, ok := integrationSpecsByName[key]
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Sprintf("launch: unknown launcher order entry %q", key))
|
||||||
|
}
|
||||||
|
if spec.Name != key {
|
||||||
|
panic(fmt.Sprintf("launch: launcher order entry %q must use canonical name, not alias", key))
|
||||||
|
}
|
||||||
|
if spec.Hidden {
|
||||||
|
panic(fmt.Sprintf("launch: hidden integration %q cannot appear in launcher order", key))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LookupIntegrationSpec resolves either a canonical integration name or alias to its spec.
|
||||||
|
func LookupIntegrationSpec(name string) (*IntegrationSpec, error) {
|
||||||
|
spec, ok := integrationSpecsByName[strings.ToLower(name)]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unknown integration: %s", name)
|
||||||
|
}
|
||||||
|
return spec, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LookupIntegration resolves a registry name to the canonical key and runner.
|
||||||
|
func LookupIntegration(name string) (string, Runner, error) {
|
||||||
|
spec, err := LookupIntegrationSpec(name)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
return spec.Name, spec.Runner, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListVisibleIntegrationSpecs returns the canonical integrations that should appear in interactive UIs.
|
||||||
|
func ListVisibleIntegrationSpecs() []IntegrationSpec {
|
||||||
|
visible := make([]IntegrationSpec, 0, len(integrationSpecs))
|
||||||
|
for _, spec := range integrationSpecs {
|
||||||
|
if spec.Hidden {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
visible = append(visible, *spec)
|
||||||
|
}
|
||||||
|
|
||||||
|
orderRank := make(map[string]int, len(launcherIntegrationOrder))
|
||||||
|
for i, name := range launcherIntegrationOrder {
|
||||||
|
orderRank[name] = i + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
slices.SortFunc(visible, func(a, b IntegrationSpec) int {
|
||||||
|
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
|
||||||
|
if aRank > 0 && bRank > 0 {
|
||||||
|
return aRank - bRank
|
||||||
|
}
|
||||||
|
if aRank > 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if bRank > 0 {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return strings.Compare(a.Name, b.Name)
|
||||||
|
})
|
||||||
|
|
||||||
|
return visible
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListIntegrationInfos returns the registered integrations in launcher display order.
|
||||||
|
func ListIntegrationInfos() []IntegrationInfo {
|
||||||
|
visible := ListVisibleIntegrationSpecs()
|
||||||
|
infos := make([]IntegrationInfo, 0, len(visible))
|
||||||
|
for _, spec := range visible {
|
||||||
|
infos = append(infos, IntegrationInfo{
|
||||||
|
Name: spec.Name,
|
||||||
|
DisplayName: spec.Runner.String(),
|
||||||
|
Description: spec.Description,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return infos
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntegrationSelectionItems returns the sorted integration items shown by launcher selection UIs.
|
||||||
|
func IntegrationSelectionItems() ([]ModelItem, error) {
|
||||||
|
visible := ListVisibleIntegrationSpecs()
|
||||||
|
if len(visible) == 0 {
|
||||||
|
return nil, fmt.Errorf("no integrations available")
|
||||||
|
}
|
||||||
|
|
||||||
|
items := make([]ModelItem, 0, len(visible))
|
||||||
|
for _, spec := range visible {
|
||||||
|
description := spec.Runner.String()
|
||||||
|
if conn, err := loadStoredIntegrationConfig(spec.Name); err == nil && len(conn.Models) > 0 {
|
||||||
|
description = fmt.Sprintf("%s (%s)", spec.Runner.String(), conn.Models[0])
|
||||||
|
}
|
||||||
|
items = append(items, ModelItem{Name: spec.Name, Description: description})
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsIntegrationInstalled checks if an integration binary is installed.
|
||||||
|
func IsIntegrationInstalled(name string) bool {
|
||||||
|
integration, err := integrationFor(name)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Ollama couldn't find integration %q, so it'll show up as not installed.\n", name)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return integration.installed
|
||||||
|
}
|
||||||
|
|
||||||
|
// integration is resolved registry metadata used by launcher state and install checks.
|
||||||
|
// It combines immutable registry spec data with computed runtime traits.
|
||||||
|
type integration struct {
|
||||||
|
spec *IntegrationSpec
|
||||||
|
installed bool
|
||||||
|
autoInstallable bool
|
||||||
|
editor bool
|
||||||
|
installHint string
|
||||||
|
}
|
||||||
|
|
||||||
|
// integrationFor resolves an integration name into the canonical spec plus
|
||||||
|
// derived launcher/install traits used across registry and launch flows.
|
||||||
|
func integrationFor(name string) (integration, error) {
|
||||||
|
spec, err := LookupIntegrationSpec(name)
|
||||||
|
if err != nil {
|
||||||
|
return integration{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
installed := true
|
||||||
|
if spec.Install.CheckInstalled != nil {
|
||||||
|
installed = spec.Install.CheckInstalled()
|
||||||
|
}
|
||||||
|
|
||||||
|
_, editor := spec.Runner.(Editor)
|
||||||
|
hint := ""
|
||||||
|
if spec.Install.URL != "" {
|
||||||
|
hint = "Install from " + hyperlink(spec.Install.URL, spec.Install.URL)
|
||||||
|
} else if len(spec.Install.Command) > 0 {
|
||||||
|
hint = "Install with: " + strings.Join(spec.Install.Command, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
return integration{
|
||||||
|
spec: spec,
|
||||||
|
installed: installed,
|
||||||
|
autoInstallable: spec.Install.EnsureInstalled != nil,
|
||||||
|
editor: editor,
|
||||||
|
installHint: hint,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnsureIntegrationInstalled installs auto-installable integrations when missing.
|
||||||
|
func EnsureIntegrationInstalled(name string, runner Runner) error {
|
||||||
|
integration, err := integrationFor(name)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%s is not installed", runner)
|
||||||
|
}
|
||||||
|
|
||||||
|
if integration.installed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if integration.autoInstallable {
|
||||||
|
return integration.spec.Install.EnsureInstalled()
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case integration.spec.Install.URL != "":
|
||||||
|
return fmt.Errorf("%s is not installed, install from %s", integration.spec.Name, integration.spec.Install.URL)
|
||||||
|
case len(integration.spec.Install.Command) > 0:
|
||||||
|
return fmt.Errorf("%s is not installed, install with: %s", integration.spec.Name, strings.Join(integration.spec.Install.Command, " "))
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%s is not installed", runner)
|
||||||
|
}
|
||||||
|
}
|
||||||
21
cmd/launch/registry_test_helpers_test.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// OverrideIntegration replaces one registry entry's runner for tests and returns a restore function.
|
||||||
|
func OverrideIntegration(name string, runner Runner) func() {
|
||||||
|
spec, err := LookupIntegrationSpec(name)
|
||||||
|
if err != nil {
|
||||||
|
key := strings.ToLower(name)
|
||||||
|
integrationSpecsByName[key] = &IntegrationSpec{Name: key, Runner: runner}
|
||||||
|
return func() {
|
||||||
|
delete(integrationSpecsByName, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
original := spec.Runner
|
||||||
|
spec.Runner = runner
|
||||||
|
return func() {
|
||||||
|
spec.Runner = original
|
||||||
|
}
|
||||||
|
}
|
||||||
71
cmd/launch/runner_exec_only_test.go
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
binary string
|
||||||
|
runner Runner
|
||||||
|
checkPath func(home string) string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "droid",
|
||||||
|
binary: "droid",
|
||||||
|
runner: &Droid{},
|
||||||
|
checkPath: func(home string) string {
|
||||||
|
return filepath.Join(home, ".factory", "settings.json")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "opencode",
|
||||||
|
binary: "opencode",
|
||||||
|
runner: &OpenCode{},
|
||||||
|
checkPath: func(home string) string {
|
||||||
|
return filepath.Join(home, ".local", "state", "opencode", "model.json")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cline",
|
||||||
|
binary: "cline",
|
||||||
|
runner: &Cline{},
|
||||||
|
checkPath: func(home string) string {
|
||||||
|
return filepath.Join(home, ".cline", "data", "globalState.json")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "pi",
|
||||||
|
binary: "pi",
|
||||||
|
runner: &Pi{},
|
||||||
|
checkPath: func(home string) string {
|
||||||
|
return filepath.Join(home, ".pi", "agent", "models.json")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
home := t.TempDir()
|
||||||
|
setTestHome(t, home)
|
||||||
|
|
||||||
|
binDir := t.TempDir()
|
||||||
|
writeFakeBinary(t, binDir, tt.binary)
|
||||||
|
if tt.name == "pi" {
|
||||||
|
writeFakeBinary(t, binDir, "npm")
|
||||||
|
}
|
||||||
|
t.Setenv("PATH", binDir)
|
||||||
|
|
||||||
|
configPath := tt.checkPath(home)
|
||||||
|
if err := tt.runner.Run("llama3.2", nil); err != nil {
|
||||||
|
t.Fatalf("Run returned error: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(configPath); !os.IsNotExist(err) {
|
||||||
|
t.Fatalf("expected Run to leave %s untouched, got err=%v", configPath, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
115
cmd/launch/selector_hooks.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"golang.org/x/term"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ANSI escape sequences for terminal formatting.
|
||||||
|
const (
|
||||||
|
ansiBold = "\033[1m"
|
||||||
|
ansiReset = "\033[0m"
|
||||||
|
ansiGray = "\033[37m"
|
||||||
|
ansiGreen = "\033[32m"
|
||||||
|
ansiYellow = "\033[33m"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrCancelled is returned when the user cancels a selection.
|
||||||
|
var ErrCancelled = errors.New("cancelled")
|
||||||
|
|
||||||
|
// errCancelled is kept as an internal alias for existing call sites.
|
||||||
|
var errCancelled = ErrCancelled
|
||||||
|
|
||||||
|
// DefaultConfirmPrompt provides a TUI-based confirmation prompt.
|
||||||
|
// When set, ConfirmPrompt delegates to it instead of using raw terminal I/O.
|
||||||
|
var DefaultConfirmPrompt func(prompt string, options ConfirmOptions) (bool, error)
|
||||||
|
|
||||||
|
// ConfirmOptions customizes labels for confirmation prompts.
|
||||||
|
type ConfirmOptions struct {
|
||||||
|
YesLabel string
|
||||||
|
NoLabel string
|
||||||
|
}
|
||||||
|
|
||||||
|
// SingleSelector is a function type for single item selection.
|
||||||
|
// current is the name of the previously selected item to highlight; empty means no pre-selection.
|
||||||
|
type SingleSelector func(title string, items []ModelItem, current string) (string, error)
|
||||||
|
|
||||||
|
// MultiSelector is a function type for multi item selection.
|
||||||
|
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
|
||||||
|
|
||||||
|
// DefaultSingleSelector is the default single-select implementation.
|
||||||
|
var DefaultSingleSelector SingleSelector
|
||||||
|
|
||||||
|
// DefaultMultiSelector is the default multi-select implementation.
|
||||||
|
var DefaultMultiSelector MultiSelector
|
||||||
|
|
||||||
|
// DefaultSignIn provides a TUI-based sign-in flow.
|
||||||
|
// When set, ensureAuth uses it instead of plain text prompts.
|
||||||
|
// Returns the signed-in username or an error.
|
||||||
|
var DefaultSignIn func(modelName, signInURL string) (string, error)
|
||||||
|
|
||||||
|
type launchConfirmPolicy struct {
|
||||||
|
yes bool
|
||||||
|
requireYesMessage bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var currentLaunchConfirmPolicy launchConfirmPolicy
|
||||||
|
|
||||||
|
func withLaunchConfirmPolicy(policy launchConfirmPolicy) func() {
|
||||||
|
old := currentLaunchConfirmPolicy
|
||||||
|
currentLaunchConfirmPolicy = policy
|
||||||
|
return func() {
|
||||||
|
currentLaunchConfirmPolicy = old
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfirmPrompt is the shared confirmation gate for launch flows (integration
|
||||||
|
// edits, missing-model pulls, sign-in prompts, OpenClaw install/security, etc).
|
||||||
|
// Behavior is controlled by currentLaunchConfirmPolicy, typically scoped by
|
||||||
|
// withLaunchConfirmPolicy in LaunchCmd (e.g. auto-approve with --yes).
|
||||||
|
func ConfirmPrompt(prompt string) (bool, error) {
|
||||||
|
return ConfirmPromptWithOptions(prompt, ConfirmOptions{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfirmPromptWithOptions is the shared confirmation gate for launch flows
|
||||||
|
// that need custom yes/no labels in interactive UIs.
|
||||||
|
func ConfirmPromptWithOptions(prompt string, options ConfirmOptions) (bool, error) {
|
||||||
|
if currentLaunchConfirmPolicy.yes {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
if currentLaunchConfirmPolicy.requireYesMessage {
|
||||||
|
return false, fmt.Errorf("%s requires confirmation; re-run with --yes to continue", prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
if DefaultConfirmPrompt != nil {
|
||||||
|
return DefaultConfirmPrompt(prompt, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
fd := int(os.Stdin.Fd())
|
||||||
|
oldState, err := term.MakeRaw(fd)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
defer term.Restore(fd, oldState)
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
|
||||||
|
|
||||||
|
buf := make([]byte, 1)
|
||||||
|
for {
|
||||||
|
if _, err := os.Stdin.Read(buf); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch buf[0] {
|
||||||
|
case 'Y', 'y', 13:
|
||||||
|
fmt.Fprintf(os.Stderr, "yes\r\n")
|
||||||
|
return true, nil
|
||||||
|
case 'N', 'n', 27, 3:
|
||||||
|
fmt.Fprintf(os.Stderr, "no\r\n")
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
112
cmd/launch/selector_test.go
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestErrCancelled(t *testing.T) {
|
||||||
|
t.Run("NotNil", func(t *testing.T) {
|
||||||
|
if errCancelled == nil {
|
||||||
|
t.Error("errCancelled should not be nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Message", func(t *testing.T) {
|
||||||
|
if errCancelled.Error() != "cancelled" {
|
||||||
|
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithLaunchConfirmPolicy_ScopesAndRestores(t *testing.T) {
|
||||||
|
oldPolicy := currentLaunchConfirmPolicy
|
||||||
|
oldHook := DefaultConfirmPrompt
|
||||||
|
t.Cleanup(func() {
|
||||||
|
currentLaunchConfirmPolicy = oldPolicy
|
||||||
|
DefaultConfirmPrompt = oldHook
|
||||||
|
})
|
||||||
|
|
||||||
|
currentLaunchConfirmPolicy = launchConfirmPolicy{}
|
||||||
|
var hookCalls int
|
||||||
|
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||||
|
hookCalls++
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
restoreOuter := withLaunchConfirmPolicy(launchConfirmPolicy{requireYesMessage: true})
|
||||||
|
restoreInner := withLaunchConfirmPolicy(launchConfirmPolicy{yes: true})
|
||||||
|
|
||||||
|
ok, err := ConfirmPrompt("test prompt")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected --yes policy to allow prompt, got error: %v", err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected --yes policy to auto-accept prompt")
|
||||||
|
}
|
||||||
|
if hookCalls != 0 {
|
||||||
|
t.Fatalf("expected --yes to skip hook, got %d hook calls", hookCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
restoreInner()
|
||||||
|
|
||||||
|
_, err = ConfirmPrompt("test prompt")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected requireYesMessage policy to block prompt")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "re-run with --yes") {
|
||||||
|
t.Fatalf("expected actionable --yes error, got: %v", err)
|
||||||
|
}
|
||||||
|
if hookCalls != 0 {
|
||||||
|
t.Fatalf("expected blocking policy to skip hook, got %d hook calls", hookCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
restoreOuter()
|
||||||
|
|
||||||
|
ok, err = ConfirmPrompt("test prompt")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected restored default behavior to use hook, got error: %v", err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected hook to return true")
|
||||||
|
}
|
||||||
|
if hookCalls != 1 {
|
||||||
|
t.Fatalf("expected one hook call after restore, got %d", hookCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfirmPromptWithOptions_DelegatesToOptionsHook(t *testing.T) {
|
||||||
|
oldPolicy := currentLaunchConfirmPolicy
|
||||||
|
oldHook := DefaultConfirmPrompt
|
||||||
|
t.Cleanup(func() {
|
||||||
|
currentLaunchConfirmPolicy = oldPolicy
|
||||||
|
DefaultConfirmPrompt = oldHook
|
||||||
|
})
|
||||||
|
|
||||||
|
currentLaunchConfirmPolicy = launchConfirmPolicy{}
|
||||||
|
called := false
|
||||||
|
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||||
|
called = true
|
||||||
|
if prompt != "Connect now?" {
|
||||||
|
t.Fatalf("unexpected prompt: %q", prompt)
|
||||||
|
}
|
||||||
|
if options.YesLabel != "Yes" || options.NoLabel != "Set up later" {
|
||||||
|
t.Fatalf("unexpected options: %+v", options)
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := ConfirmPromptWithOptions("Connect now?", ConfirmOptions{
|
||||||
|
YesLabel: "Yes",
|
||||||
|
NoLabel: "Set up later",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ConfirmPromptWithOptions() error = %v", err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected confirm to return true")
|
||||||
|
}
|
||||||
|
if !called {
|
||||||
|
t.Fatal("expected options hook to be called")
|
||||||
|
}
|
||||||
|
}
|
||||||
82
cmd/launch/test_config_helpers_test.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/cmd/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
integrations map[string]Runner
|
||||||
|
integrationAliases map[string]bool
|
||||||
|
integrationOrder = launcherIntegrationOrder
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
integrations = buildTestIntegrations()
|
||||||
|
integrationAliases = buildTestIntegrationAliases()
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildTestIntegrations() map[string]Runner {
|
||||||
|
result := make(map[string]Runner, len(integrationSpecsByName))
|
||||||
|
for name, spec := range integrationSpecsByName {
|
||||||
|
result[strings.ToLower(name)] = spec.Runner
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildTestIntegrationAliases() map[string]bool {
|
||||||
|
result := make(map[string]bool)
|
||||||
|
for _, spec := range integrationSpecs {
|
||||||
|
for _, alias := range spec.Aliases {
|
||||||
|
result[strings.ToLower(alias)] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func setTestHome(t *testing.T, dir string) {
|
||||||
|
t.Helper()
|
||||||
|
setLaunchTestHome(t, dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SaveIntegration(appName string, models []string) error {
|
||||||
|
return config.SaveIntegration(appName, models)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadIntegration(appName string) (*config.IntegrationConfig, error) {
|
||||||
|
return config.LoadIntegration(appName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SaveAliases(appName string, aliases map[string]string) error {
|
||||||
|
return config.SaveAliases(appName, aliases)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LastModel() string {
|
||||||
|
return config.LastModel()
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetLastModel(model string) error {
|
||||||
|
return config.SetLastModel(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LastSelection() string {
|
||||||
|
return config.LastSelection()
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetLastSelection(selection string) error {
|
||||||
|
return config.SetLastSelection(selection)
|
||||||
|
}
|
||||||
|
|
||||||
|
func IntegrationModel(appName string) string {
|
||||||
|
return config.IntegrationModel(appName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func IntegrationModels(appName string) []string {
|
||||||
|
return config.IntegrationModels(appName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func integrationOnboarded(appName string) error {
|
||||||
|
return config.MarkIntegrationOnboarded(appName)
|
||||||
|
}
|
||||||
591
cmd/launch/vscode.go
Normal file
@@ -0,0 +1,591 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
)
|
||||||
|
|
||||||
|
// VSCode implements Runner and Editor for Visual Studio Code integration.
|
||||||
|
type VSCode struct{}
|
||||||
|
|
||||||
|
func (v *VSCode) String() string { return "Visual Studio Code" }
|
||||||
|
|
||||||
|
// findBinary returns the path/command to launch VS Code, or "" if not found.
|
||||||
|
// It checks platform-specific locations only.
|
||||||
|
func (v *VSCode) findBinary() string {
|
||||||
|
var candidates []string
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
candidates = []string{
|
||||||
|
"/Applications/Visual Studio Code.app",
|
||||||
|
}
|
||||||
|
case "windows":
|
||||||
|
if localAppData := os.Getenv("LOCALAPPDATA"); localAppData != "" {
|
||||||
|
candidates = append(candidates, filepath.Join(localAppData, "Programs", "Microsoft VS Code", "bin", "code.cmd"))
|
||||||
|
}
|
||||||
|
default: // linux
|
||||||
|
candidates = []string{
|
||||||
|
"/usr/bin/code",
|
||||||
|
"/snap/bin/code",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, c := range candidates {
|
||||||
|
if _, err := os.Stat(c); err == nil {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsRunning reports whether VS Code is currently running.
|
||||||
|
// Each platform uses a pattern specific enough to avoid matching Cursor or
|
||||||
|
// other VS Code forks.
|
||||||
|
func (v *VSCode) IsRunning() bool {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
out, err := exec.Command("pgrep", "-f", "Visual Studio Code.app/Contents/MacOS/Code").Output()
|
||||||
|
return err == nil && len(out) > 0
|
||||||
|
case "windows":
|
||||||
|
// Match VS Code by executable path to avoid matching Cursor or other forks.
|
||||||
|
out, err := exec.Command("powershell", "-NoProfile", "-Command",
|
||||||
|
`Get-Process Code -ErrorAction SilentlyContinue | Where-Object { $_.Path -like '*Microsoft VS Code*' } | Select-Object -First 1`).Output()
|
||||||
|
return err == nil && len(strings.TrimSpace(string(out))) > 0
|
||||||
|
default:
|
||||||
|
// Match VS Code specifically by its install path to avoid matching
|
||||||
|
// Cursor (/cursor/) or other forks.
|
||||||
|
for _, pattern := range []string{"/usr/share/code/", "/snap/code/"} {
|
||||||
|
out, err := exec.Command("pgrep", "-f", pattern).Output()
|
||||||
|
if err == nil && len(out) > 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Quit gracefully quits VS Code and waits for it to exit so that it flushes
|
||||||
|
// its in-memory state back to the database.
|
||||||
|
func (v *VSCode) Quit() {
|
||||||
|
if !v.IsRunning() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
_ = exec.Command("osascript", "-e", `quit app "Visual Studio Code"`).Run()
|
||||||
|
case "windows":
|
||||||
|
// Kill VS Code by executable path to avoid killing Cursor or other forks.
|
||||||
|
_ = exec.Command("powershell", "-NoProfile", "-Command",
|
||||||
|
`Get-Process Code -ErrorAction SilentlyContinue | Where-Object { $_.Path -like '*Microsoft VS Code*' } | Stop-Process -Force`).Run()
|
||||||
|
default:
|
||||||
|
for _, pattern := range []string{"/usr/share/code/", "/snap/code/"} {
|
||||||
|
_ = exec.Command("pkill", "-f", pattern).Run()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Wait for the process to fully exit and flush its state to disk
|
||||||
|
// TODO(hoyyeva): update spinner to use bubble tea
|
||||||
|
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||||
|
frame := 0
|
||||||
|
fmt.Fprintf(os.Stderr, "\033[90mRestarting VS Code... %s\033[0m", spinnerFrames[0])
|
||||||
|
|
||||||
|
ticker := time.NewTicker(200 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for range 150 { // 150 ticks × 200ms = 30s timeout
|
||||||
|
<-ticker.C
|
||||||
|
frame++
|
||||||
|
fmt.Fprintf(os.Stderr, "\r\033[90mRestarting VS Code... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||||
|
|
||||||
|
if frame%5 == 0 { // check every ~1s
|
||||||
|
if !v.IsRunning() {
|
||||||
|
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||||
|
// Give VS Code a moment to finish writing its state DB
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
minCopilotChatVersion = "0.41.0"
|
||||||
|
minVSCodeVersion = "1.113"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (v *VSCode) Run(model string, args []string) error {
|
||||||
|
v.checkVSCodeVersion()
|
||||||
|
v.checkCopilotChatVersion()
|
||||||
|
|
||||||
|
// Get all configured models (saved by the launcher framework before Run is called)
|
||||||
|
models := []string{model}
|
||||||
|
if cfg, err := loadStoredIntegrationConfig("vscode"); err == nil && len(cfg.Models) > 0 {
|
||||||
|
models = cfg.Models
|
||||||
|
}
|
||||||
|
|
||||||
|
// VS Code discovers models from ollama ls. Cloud models that pass Show
|
||||||
|
// (the server knows about them) but aren't in ls need to be pulled to
|
||||||
|
// register them so VS Code can find them.
|
||||||
|
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||||
|
v.ensureModelsRegistered(context.Background(), client, models)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warn if the default model doesn't support tool calling
|
||||||
|
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||||
|
if resp, err := client.Show(context.Background(), &api.ShowRequest{Model: models[0]}); err == nil {
|
||||||
|
hasTools := false
|
||||||
|
for _, c := range resp.Capabilities {
|
||||||
|
if c == "tools" {
|
||||||
|
hasTools = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasTools {
|
||||||
|
fmt.Fprintf(os.Stderr, "Note: %s does not support tool calling and may not appear in the Copilot Chat model picker.\n", models[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
v.printModelAccessTip()
|
||||||
|
|
||||||
|
if v.IsRunning() {
|
||||||
|
restart, err := ConfirmPrompt("Restart VS Code?")
|
||||||
|
if err != nil {
|
||||||
|
restart = false
|
||||||
|
}
|
||||||
|
if restart {
|
||||||
|
v.Quit()
|
||||||
|
if err := v.ShowInModelPicker(models); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s Warning: could not update VS Code model picker: %v%s\n", ansiYellow, err, ansiReset)
|
||||||
|
}
|
||||||
|
v.FocusVSCode()
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "\nTo get the latest model configuration, restart VS Code when you're ready.\n")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := v.ShowInModelPicker(models); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s Warning: could not update VS Code model picker: %v%s\n", ansiYellow, err, ansiReset)
|
||||||
|
}
|
||||||
|
v.FocusVSCode()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureModelsRegistered pulls models that the server knows about (Show succeeds)
|
||||||
|
// but aren't in ollama ls yet. This is needed for cloud models so that VS Code
|
||||||
|
// can discover them from the Ollama API.
|
||||||
|
func (v *VSCode) ensureModelsRegistered(ctx context.Context, client *api.Client, models []string) {
|
||||||
|
listed, err := client.List(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
registered := make(map[string]bool, len(listed.Models))
|
||||||
|
for _, m := range listed.Models {
|
||||||
|
registered[m.Name] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, model := range models {
|
||||||
|
if registered[model] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Also check without :latest suffix
|
||||||
|
if !strings.Contains(model, ":") && registered[model+":latest"] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := pullModel(ctx, client, model, false); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s Warning: could not register model %s: %v%s\n", ansiYellow, model, err, ansiReset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FocusVSCode brings VS Code to the foreground.
|
||||||
|
func (v *VSCode) FocusVSCode() {
|
||||||
|
binary := v.findBinary()
|
||||||
|
if binary == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if runtime.GOOS == "darwin" && strings.HasSuffix(binary, ".app") {
|
||||||
|
_ = exec.Command("open", "-a", binary).Run()
|
||||||
|
} else {
|
||||||
|
_ = exec.Command(binary).Start()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// printModelAccessTip shows instructions for finding Ollama models in VS Code.
|
||||||
|
func (v *VSCode) printModelAccessTip() {
|
||||||
|
fmt.Fprintf(os.Stderr, "\nTip: To use Ollama models, open Copilot Chat and click the model picker.\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " If you don't see your models, click \"Other models\" to find them.\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *VSCode) Paths() []string {
|
||||||
|
if p := v.chatLanguageModelsPath(); fileExists(p) {
|
||||||
|
return []string{p}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *VSCode) Edit(models []string) error {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write chatLanguageModels.json with Ollama vendor entry
|
||||||
|
clmPath := v.chatLanguageModelsPath()
|
||||||
|
if err := os.MkdirAll(filepath.Dir(clmPath), 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var entries []map[string]any
|
||||||
|
if data, err := os.ReadFile(clmPath); err == nil {
|
||||||
|
_ = json.Unmarshal(data, &entries)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove any existing Ollama entries, preserve others
|
||||||
|
filtered := make([]map[string]any, 0, len(entries))
|
||||||
|
for _, entry := range entries {
|
||||||
|
if vendor, _ := entry["vendor"].(string); vendor != "ollama" {
|
||||||
|
filtered = append(filtered, entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new Ollama entry
|
||||||
|
filtered = append(filtered, map[string]any{
|
||||||
|
"vendor": "ollama",
|
||||||
|
"name": "Ollama",
|
||||||
|
"url": envconfig.Host().String(),
|
||||||
|
})
|
||||||
|
|
||||||
|
data, err := json.MarshalIndent(filtered, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := fileutil.WriteWithBackup(clmPath, data); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up legacy settings from older Ollama integrations
|
||||||
|
v.updateSettings()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *VSCode) Models() []string {
|
||||||
|
if !v.hasOllamaVendor() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if cfg, err := loadStoredIntegrationConfig("vscode"); err == nil {
|
||||||
|
return cfg.Models
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasOllamaVendor checks if chatLanguageModels.json contains an Ollama vendor entry.
|
||||||
|
func (v *VSCode) hasOllamaVendor() bool {
|
||||||
|
data, err := os.ReadFile(v.chatLanguageModelsPath())
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var entries []map[string]any
|
||||||
|
if err := json.Unmarshal(data, &entries); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, entry := range entries {
|
||||||
|
if vendor, _ := entry["vendor"].(string); vendor == "ollama" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *VSCode) chatLanguageModelsPath() string {
|
||||||
|
return v.vscodePath("chatLanguageModels.json")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *VSCode) settingsPath() string {
|
||||||
|
return v.vscodePath("settings.json")
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateSettings cleans up legacy settings from older Ollama integrations.
|
||||||
|
func (v *VSCode) updateSettings() {
|
||||||
|
settingsPath := v.settingsPath()
|
||||||
|
data, err := os.ReadFile(settingsPath)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var settings map[string]any
|
||||||
|
if err := json.Unmarshal(data, &settings); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
changed := false
|
||||||
|
for _, key := range []string{"github.copilot.chat.byok.ollamaEndpoint", "ollama.launch.configured"} {
|
||||||
|
if _, ok := settings[key]; ok {
|
||||||
|
delete(settings, key)
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !changed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, err := json.MarshalIndent(settings, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = fileutil.WriteWithBackup(settingsPath, updated)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *VSCode) statePath() string {
|
||||||
|
return v.vscodePath("globalStorage", "state.vscdb")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShowInModelPicker ensures the given models are visible in VS Code's Copilot
|
||||||
|
// Chat model picker. It sets the configured models to true in the picker
|
||||||
|
// preferences so they appear in the dropdown. Models use the VS Code identifier
|
||||||
|
// format "ollama/Ollama/<name>".
|
||||||
|
func (v *VSCode) ShowInModelPicker(models []string) error {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dbPath := v.statePath()
|
||||||
|
needsCreate := !fileExists(dbPath)
|
||||||
|
if needsCreate {
|
||||||
|
if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil {
|
||||||
|
return fmt.Errorf("creating state directory: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := sql.Open("sqlite3", dbPath+"?_busy_timeout=5000")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("opening state database: %w", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
// Create the table if this is a fresh DB. Schema must match what VS Code creates.
|
||||||
|
if needsCreate {
|
||||||
|
if _, err := db.Exec("CREATE TABLE ItemTable (key TEXT UNIQUE ON CONFLICT REPLACE, value BLOB)"); err != nil {
|
||||||
|
return fmt.Errorf("initializing state database: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read existing preferences
|
||||||
|
prefs := make(map[string]bool)
|
||||||
|
var prefsJSON string
|
||||||
|
if err := db.QueryRow("SELECT value FROM ItemTable WHERE key = 'chatModelPickerPreferences'").Scan(&prefsJSON); err == nil {
|
||||||
|
_ = json.Unmarshal([]byte(prefsJSON), &prefs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build name→ID map from VS Code's cached model list.
|
||||||
|
// VS Code uses numeric IDs like "ollama/Ollama/4", not "ollama/Ollama/kimi-k2.5:cloud".
|
||||||
|
nameToID := make(map[string]string)
|
||||||
|
var cacheJSON string
|
||||||
|
if err := db.QueryRow("SELECT value FROM ItemTable WHERE key = 'chat.cachedLanguageModels.v2'").Scan(&cacheJSON); err == nil {
|
||||||
|
var cached []map[string]any
|
||||||
|
if json.Unmarshal([]byte(cacheJSON), &cached) == nil {
|
||||||
|
for _, entry := range cached {
|
||||||
|
meta, _ := entry["metadata"].(map[string]any)
|
||||||
|
if meta == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if vendor, _ := meta["vendor"].(string); vendor == "ollama" {
|
||||||
|
name, _ := meta["name"].(string)
|
||||||
|
id, _ := entry["identifier"].(string)
|
||||||
|
if name != "" && id != "" {
|
||||||
|
nameToID[name] = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ollama config is authoritative: always show configured models,
|
||||||
|
// hide Ollama models that are no longer in the config.
|
||||||
|
configuredIDs := make(map[string]bool)
|
||||||
|
for _, m := range models {
|
||||||
|
for _, id := range v.modelVSCodeIDs(m, nameToID) {
|
||||||
|
prefs[id] = true
|
||||||
|
configuredIDs[id] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for id := range prefs {
|
||||||
|
if strings.HasPrefix(id, "ollama/") && !configuredIDs[id] {
|
||||||
|
prefs[id] = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := json.Marshal(prefs)
|
||||||
|
if _, err = db.Exec("INSERT OR REPLACE INTO ItemTable (key, value) VALUES ('chatModelPickerPreferences', ?)", string(data)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelVSCodeIDs returns all possible VS Code picker IDs for a model name.
|
||||||
|
func (v *VSCode) modelVSCodeIDs(model string, nameToID map[string]string) []string {
|
||||||
|
var ids []string
|
||||||
|
if id, ok := nameToID[model]; ok {
|
||||||
|
ids = append(ids, id)
|
||||||
|
} else if !strings.Contains(model, ":") {
|
||||||
|
if id, ok := nameToID[model+":latest"]; ok {
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ids = append(ids, "ollama/Ollama/"+model)
|
||||||
|
if !strings.Contains(model, ":") {
|
||||||
|
ids = append(ids, "ollama/Ollama/"+model+":latest")
|
||||||
|
}
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *VSCode) vscodePath(parts ...string) string {
|
||||||
|
home, _ := os.UserHomeDir()
|
||||||
|
var base string
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
base = filepath.Join(home, "Library", "Application Support", "Code", "User")
|
||||||
|
case "windows":
|
||||||
|
base = filepath.Join(os.Getenv("APPDATA"), "Code", "User")
|
||||||
|
default:
|
||||||
|
base = filepath.Join(home, ".config", "Code", "User")
|
||||||
|
}
|
||||||
|
return filepath.Join(append([]string{base}, parts...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkVSCodeVersion warns if VS Code is older than minVSCodeVersion.
|
||||||
|
func (v *VSCode) checkVSCodeVersion() {
|
||||||
|
codeCLI := v.findCodeCLI()
|
||||||
|
if codeCLI == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := exec.Command(codeCLI, "--version").Output()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// "code --version" outputs: version\ncommit\narch
|
||||||
|
lines := strings.Split(strings.TrimSpace(string(out)), "\n")
|
||||||
|
if len(lines) == 0 || lines[0] == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
version := strings.TrimSpace(lines[0])
|
||||||
|
|
||||||
|
if compareVersions(version, minVSCodeVersion) < 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "\n%sWarning: VS Code version (%s) is older than the recommended version (%s)%s\n", ansiYellow, version, minVSCodeVersion, ansiReset)
|
||||||
|
fmt.Fprintf(os.Stderr, "Please update VS Code to the latest version.\n\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkCopilotChatVersion warns if the GitHub Copilot Chat extension is
|
||||||
|
// missing or older than minCopilotChatVersion.
|
||||||
|
func (v *VSCode) checkCopilotChatVersion() {
|
||||||
|
codeCLI := v.findCodeCLI()
|
||||||
|
if codeCLI == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := exec.Command(codeCLI, "--list-extensions", "--show-versions").Output()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
installed, version := parseCopilotChatVersion(string(out))
|
||||||
|
if !installed {
|
||||||
|
fmt.Fprintf(os.Stderr, "\n%sWarning: GitHub Copilot Chat extension is not installed%s\n", ansiYellow, ansiReset)
|
||||||
|
fmt.Fprintf(os.Stderr, "Install it in VS Code: Extensions → search \"GitHub Copilot Chat\" → Install\n\n")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if compareVersions(version, minCopilotChatVersion) < 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "\n%sWarning: GitHub Copilot Chat extension version (%s) is older than the recommended version (%s)%s\n", ansiYellow, version, minCopilotChatVersion, ansiReset)
|
||||||
|
fmt.Fprintf(os.Stderr, "Please update it in VS Code: Extensions → search \"GitHub Copilot Chat\" → Update\n\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// findCodeCLI returns the path to the VS Code CLI for querying extensions.
|
||||||
|
// On macOS, findBinary may return an .app bundle which can't run --list-extensions,
|
||||||
|
// so this resolves to the actual CLI binary inside the bundle.
|
||||||
|
func (v *VSCode) findCodeCLI() string {
|
||||||
|
binary := v.findBinary()
|
||||||
|
if binary == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if runtime.GOOS == "darwin" && strings.HasSuffix(binary, ".app") {
|
||||||
|
bundleCLI := binary + "/Contents/Resources/app/bin/code"
|
||||||
|
if _, err := os.Stat(bundleCLI); err == nil {
|
||||||
|
return bundleCLI
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return binary
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseCopilotChatVersion extracts the version of the GitHub Copilot Chat
|
||||||
|
// extension from "code --list-extensions --show-versions" output.
|
||||||
|
func parseCopilotChatVersion(output string) (installed bool, version string) {
|
||||||
|
for _, line := range strings.Split(output, "\n") {
|
||||||
|
// Format: github.copilot-chat@0.40.1
|
||||||
|
if !strings.HasPrefix(strings.ToLower(line), "github.copilot-chat@") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parts := strings.SplitN(line, "@", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return true, strings.TrimSpace(parts[1])
|
||||||
|
}
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// compareVersions compares two dot-separated version strings.
|
||||||
|
// Returns -1 if a < b, 0 if a == b, 1 if a > b.
|
||||||
|
func compareVersions(a, b string) int {
|
||||||
|
aParts := strings.Split(a, ".")
|
||||||
|
bParts := strings.Split(b, ".")
|
||||||
|
|
||||||
|
maxLen := len(aParts)
|
||||||
|
if len(bParts) > maxLen {
|
||||||
|
maxLen = len(bParts)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range maxLen {
|
||||||
|
var aNum, bNum int
|
||||||
|
if i < len(aParts) {
|
||||||
|
aNum, _ = strconv.Atoi(aParts[i])
|
||||||
|
}
|
||||||
|
if i < len(bParts) {
|
||||||
|
bNum, _ = strconv.Atoi(bParts[i])
|
||||||
|
}
|
||||||
|
if aNum < bNum {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if aNum > bNum {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileExists(path string) bool {
|
||||||
|
_, err := os.Stat(path)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
486
cmd/launch/vscode_test.go
Normal file
@@ -0,0 +1,486 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestVSCodeIntegration(t *testing.T) {
|
||||||
|
v := &VSCode{}
|
||||||
|
|
||||||
|
t.Run("String", func(t *testing.T) {
|
||||||
|
if got := v.String(); got != "Visual Studio Code" {
|
||||||
|
t.Errorf("String() = %q, want %q", got, "Visual Studio Code")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("implements Runner", func(t *testing.T) {
|
||||||
|
var _ Runner = v
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("implements Editor", func(t *testing.T) {
|
||||||
|
var _ Editor = v
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVSCodeEdit(t *testing.T) {
|
||||||
|
v := &VSCode{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("XDG_CONFIG_HOME", "")
|
||||||
|
clmPath := testVSCodePath(t, tmpDir, "chatLanguageModels.json")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setup string // initial chatLanguageModels.json content, empty means no file
|
||||||
|
models []string
|
||||||
|
validate func(t *testing.T, data []byte)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "fresh install",
|
||||||
|
models: []string{"llama3.2"},
|
||||||
|
validate: func(t *testing.T, data []byte) {
|
||||||
|
assertOllamaVendorConfigured(t, data)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "preserve other vendor entries",
|
||||||
|
setup: `[{"vendor": "azure", "name": "Azure", "url": "https://example.com"}]`,
|
||||||
|
models: []string{"llama3.2"},
|
||||||
|
validate: func(t *testing.T, data []byte) {
|
||||||
|
var entries []map[string]any
|
||||||
|
json.Unmarshal(data, &entries)
|
||||||
|
if len(entries) != 2 {
|
||||||
|
t.Errorf("expected 2 entries, got %d", len(entries))
|
||||||
|
}
|
||||||
|
// Check Azure entry preserved
|
||||||
|
found := false
|
||||||
|
for _, e := range entries {
|
||||||
|
if v, _ := e["vendor"].(string); v == "azure" {
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Error("azure vendor entry was not preserved")
|
||||||
|
}
|
||||||
|
assertOllamaVendorConfigured(t, data)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "update existing ollama entry",
|
||||||
|
setup: `[{"vendor": "ollama", "name": "Ollama", "url": "http://old:11434"}]`,
|
||||||
|
models: []string{"llama3.2"},
|
||||||
|
validate: func(t *testing.T, data []byte) {
|
||||||
|
assertOllamaVendorConfigured(t, data)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty models is no-op",
|
||||||
|
setup: `[{"vendor": "azure", "name": "Azure"}]`,
|
||||||
|
models: []string{},
|
||||||
|
validate: func(t *testing.T, data []byte) {
|
||||||
|
if string(data) != `[{"vendor": "azure", "name": "Azure"}]` {
|
||||||
|
t.Error("empty models should not modify file")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "corrupted JSON treated as empty",
|
||||||
|
setup: `{corrupted json`,
|
||||||
|
models: []string{"llama3.2"},
|
||||||
|
validate: func(t *testing.T, data []byte) {
|
||||||
|
var entries []map[string]any
|
||||||
|
if err := json.Unmarshal(data, &entries); err != nil {
|
||||||
|
t.Errorf("result is not valid JSON: %v", err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
os.RemoveAll(filepath.Dir(clmPath))
|
||||||
|
|
||||||
|
if tt.setup != "" {
|
||||||
|
os.MkdirAll(filepath.Dir(clmPath), 0o755)
|
||||||
|
os.WriteFile(clmPath, []byte(tt.setup), 0o644)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := v.Edit(tt.models); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := os.ReadFile(clmPath)
|
||||||
|
tt.validate(t, data)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVSCodeEditCleansUpOldSettings(t *testing.T) {
|
||||||
|
v := &VSCode{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("XDG_CONFIG_HOME", "")
|
||||||
|
settingsPath := testVSCodePath(t, tmpDir, "settings.json")
|
||||||
|
|
||||||
|
// Create settings.json with old byok setting
|
||||||
|
os.MkdirAll(filepath.Dir(settingsPath), 0o755)
|
||||||
|
os.WriteFile(settingsPath, []byte(`{"github.copilot.chat.byok.ollamaEndpoint": "http://old:11434", "ollama.launch.configured": true, "editor.fontSize": 14}`), 0o644)
|
||||||
|
|
||||||
|
if err := v.Edit([]string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify old settings were removed
|
||||||
|
data, err := os.ReadFile(settingsPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var settings map[string]any
|
||||||
|
json.Unmarshal(data, &settings)
|
||||||
|
if _, ok := settings["github.copilot.chat.byok.ollamaEndpoint"]; ok {
|
||||||
|
t.Error("github.copilot.chat.byok.ollamaEndpoint should have been removed")
|
||||||
|
}
|
||||||
|
if _, ok := settings["ollama.launch.configured"]; ok {
|
||||||
|
t.Error("ollama.launch.configured should have been removed")
|
||||||
|
}
|
||||||
|
if settings["editor.fontSize"] != float64(14) {
|
||||||
|
t.Error("editor.fontSize should have been preserved")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVSCodePaths(t *testing.T) {
|
||||||
|
v := &VSCode{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("XDG_CONFIG_HOME", "")
|
||||||
|
clmPath := testVSCodePath(t, tmpDir, "chatLanguageModels.json")
|
||||||
|
|
||||||
|
t.Run("no file returns nil", func(t *testing.T) {
|
||||||
|
os.Remove(clmPath)
|
||||||
|
if paths := v.Paths(); paths != nil {
|
||||||
|
t.Errorf("expected nil, got %v", paths)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("existing file returns path", func(t *testing.T) {
|
||||||
|
os.MkdirAll(filepath.Dir(clmPath), 0o755)
|
||||||
|
os.WriteFile(clmPath, []byte(`[]`), 0o644)
|
||||||
|
|
||||||
|
if paths := v.Paths(); len(paths) != 1 {
|
||||||
|
t.Errorf("expected 1 path, got %d", len(paths))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// testVSCodePath returns the expected VS Code config path for the given file in tests.
|
||||||
|
func testVSCodePath(t *testing.T, tmpDir, filename string) string {
|
||||||
|
t.Helper()
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
return filepath.Join(tmpDir, "Library", "Application Support", "Code", "User", filename)
|
||||||
|
case "windows":
|
||||||
|
t.Setenv("APPDATA", tmpDir)
|
||||||
|
return filepath.Join(tmpDir, "Code", "User", filename)
|
||||||
|
default:
|
||||||
|
return filepath.Join(tmpDir, ".config", "Code", "User", filename)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertOllamaVendorConfigured(t *testing.T, data []byte) {
|
||||||
|
t.Helper()
|
||||||
|
var entries []map[string]any
|
||||||
|
if err := json.Unmarshal(data, &entries); err != nil {
|
||||||
|
t.Fatalf("invalid JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, entry := range entries {
|
||||||
|
if vendor, _ := entry["vendor"].(string); vendor == "ollama" {
|
||||||
|
if name, _ := entry["name"].(string); name != "Ollama" {
|
||||||
|
t.Errorf("expected name \"Ollama\", got %q", name)
|
||||||
|
}
|
||||||
|
if url, _ := entry["url"].(string); url == "" {
|
||||||
|
t.Error("url not set")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Error("no ollama vendor entry found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShowInModelPicker(t *testing.T) {
|
||||||
|
v := &VSCode{}
|
||||||
|
|
||||||
|
// helper to create a state DB with optional seed data
|
||||||
|
setupDB := func(t *testing.T, tmpDir string, seedPrefs map[string]bool, seedCache []map[string]any) string {
|
||||||
|
t.Helper()
|
||||||
|
dbDir := filepath.Join(tmpDir, "globalStorage")
|
||||||
|
os.MkdirAll(dbDir, 0o755)
|
||||||
|
dbPath := filepath.Join(dbDir, "state.vscdb")
|
||||||
|
|
||||||
|
db, err := sql.Open("sqlite3", dbPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
if _, err := db.Exec("CREATE TABLE ItemTable (key TEXT UNIQUE ON CONFLICT REPLACE, value BLOB)"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if seedPrefs != nil {
|
||||||
|
data, _ := json.Marshal(seedPrefs)
|
||||||
|
db.Exec("INSERT INTO ItemTable (key, value) VALUES ('chatModelPickerPreferences', ?)", string(data))
|
||||||
|
}
|
||||||
|
if seedCache != nil {
|
||||||
|
data, _ := json.Marshal(seedCache)
|
||||||
|
db.Exec("INSERT INTO ItemTable (key, value) VALUES ('chat.cachedLanguageModels.v2', ?)", string(data))
|
||||||
|
}
|
||||||
|
return dbPath
|
||||||
|
}
|
||||||
|
|
||||||
|
// helper to read prefs back from DB
|
||||||
|
readPrefs := func(t *testing.T, dbPath string) map[string]bool {
|
||||||
|
t.Helper()
|
||||||
|
db, err := sql.Open("sqlite3", dbPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
var raw string
|
||||||
|
if err := db.QueryRow("SELECT value FROM ItemTable WHERE key = 'chatModelPickerPreferences'").Scan(&raw); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
prefs := make(map[string]bool)
|
||||||
|
json.Unmarshal([]byte(raw), &prefs)
|
||||||
|
return prefs
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("fresh DB creates table and shows models", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("XDG_CONFIG_HOME", "")
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Setenv("APPDATA", tmpDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := v.ShowInModelPicker([]string{"llama3.2"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dbPath := testVSCodePath(t, tmpDir, filepath.Join("globalStorage", "state.vscdb"))
|
||||||
|
prefs := readPrefs(t, dbPath)
|
||||||
|
if !prefs["ollama/Ollama/llama3.2"] {
|
||||||
|
t.Error("expected llama3.2 to be shown")
|
||||||
|
}
|
||||||
|
if !prefs["ollama/Ollama/llama3.2:latest"] {
|
||||||
|
t.Error("expected llama3.2:latest to be shown")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("configured models are shown", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("XDG_CONFIG_HOME", "")
|
||||||
|
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), nil, nil)
|
||||||
|
|
||||||
|
err := v.ShowInModelPicker([]string{"llama3.2", "qwen3:8b"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prefs := readPrefs(t, dbPath)
|
||||||
|
if !prefs["ollama/Ollama/llama3.2"] {
|
||||||
|
t.Error("expected llama3.2 to be shown")
|
||||||
|
}
|
||||||
|
if !prefs["ollama/Ollama/qwen3:8b"] {
|
||||||
|
t.Error("expected qwen3:8b to be shown")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("removed models are hidden", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("XDG_CONFIG_HOME", "")
|
||||||
|
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), map[string]bool{
|
||||||
|
"ollama/Ollama/llama3.2": true,
|
||||||
|
"ollama/Ollama/llama3.2:latest": true,
|
||||||
|
"ollama/Ollama/mistral": true,
|
||||||
|
"ollama/Ollama/mistral:latest": true,
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
// Only configure llama3.2 — mistral should get hidden
|
||||||
|
err := v.ShowInModelPicker([]string{"llama3.2"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prefs := readPrefs(t, dbPath)
|
||||||
|
if !prefs["ollama/Ollama/llama3.2"] {
|
||||||
|
t.Error("expected llama3.2 to stay shown")
|
||||||
|
}
|
||||||
|
if prefs["ollama/Ollama/mistral"] {
|
||||||
|
t.Error("expected mistral to be hidden")
|
||||||
|
}
|
||||||
|
if prefs["ollama/Ollama/mistral:latest"] {
|
||||||
|
t.Error("expected mistral:latest to be hidden")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-ollama prefs are preserved", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("XDG_CONFIG_HOME", "")
|
||||||
|
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), map[string]bool{
|
||||||
|
"copilot/gpt-4o": true,
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
err := v.ShowInModelPicker([]string{"llama3.2"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prefs := readPrefs(t, dbPath)
|
||||||
|
if !prefs["copilot/gpt-4o"] {
|
||||||
|
t.Error("expected copilot/gpt-4o to stay shown")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("uses cached numeric IDs when available", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("XDG_CONFIG_HOME", "")
|
||||||
|
cache := []map[string]any{
|
||||||
|
{
|
||||||
|
"identifier": "ollama/Ollama/4",
|
||||||
|
"metadata": map[string]any{"vendor": "ollama", "name": "llama3.2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), nil, cache)
|
||||||
|
|
||||||
|
err := v.ShowInModelPicker([]string{"llama3.2"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prefs := readPrefs(t, dbPath)
|
||||||
|
if !prefs["ollama/Ollama/4"] {
|
||||||
|
t.Error("expected numeric ID ollama/Ollama/4 to be shown")
|
||||||
|
}
|
||||||
|
// Name-based fallback should also be set
|
||||||
|
if !prefs["ollama/Ollama/llama3.2"] {
|
||||||
|
t.Error("expected name-based ID to also be shown")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty models is no-op", func(t *testing.T) {
|
||||||
|
err := v.ShowInModelPicker([]string{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("previously hidden model is re-shown when configured", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("XDG_CONFIG_HOME", "")
|
||||||
|
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), map[string]bool{
|
||||||
|
"ollama/Ollama/llama3.2": false,
|
||||||
|
"ollama/Ollama/llama3.2:latest": false,
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
// Ollama config is authoritative — should override the hidden state
|
||||||
|
err := v.ShowInModelPicker([]string{"llama3.2"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prefs := readPrefs(t, dbPath)
|
||||||
|
if !prefs["ollama/Ollama/llama3.2"] {
|
||||||
|
t.Error("expected llama3.2 to be re-shown")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseCopilotChatVersion(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
output string
|
||||||
|
wantInstalled bool
|
||||||
|
wantVersion string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "found among other extensions",
|
||||||
|
output: "ms-python.python@2024.1.1\ngithub.copilot-chat@0.40.1\ngithub.copilot@1.200.0\n",
|
||||||
|
wantInstalled: true,
|
||||||
|
wantVersion: "0.40.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only extension",
|
||||||
|
output: "GitHub.copilot-chat@0.41.0\n",
|
||||||
|
wantInstalled: true,
|
||||||
|
wantVersion: "0.41.0",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not installed",
|
||||||
|
output: "ms-python.python@2024.1.1\ngithub.copilot@1.200.0\n",
|
||||||
|
wantInstalled: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty output",
|
||||||
|
output: "",
|
||||||
|
wantInstalled: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "case insensitive match",
|
||||||
|
output: "GitHub.Copilot-Chat@0.39.0\n",
|
||||||
|
wantInstalled: true,
|
||||||
|
wantVersion: "0.39.0",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
installed, version := parseCopilotChatVersion(tt.output)
|
||||||
|
if installed != tt.wantInstalled {
|
||||||
|
t.Errorf("installed = %v, want %v", installed, tt.wantInstalled)
|
||||||
|
}
|
||||||
|
if installed && version != tt.wantVersion {
|
||||||
|
t.Errorf("version = %q, want %q", version, tt.wantVersion)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompareVersions(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
a, b string
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{"0.40.1", "0.40.1", 0},
|
||||||
|
{"0.40.2", "0.40.1", 1},
|
||||||
|
{"0.40.0", "0.40.1", -1},
|
||||||
|
{"0.41.0", "0.40.1", 1},
|
||||||
|
{"0.39.9", "0.40.1", -1},
|
||||||
|
{"1.0.0", "0.40.1", 1},
|
||||||
|
{"0.40", "0.40.1", -1},
|
||||||
|
{"0.40.1.1", "0.40.1", 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.a+"_vs_"+tt.b, func(t *testing.T) {
|
||||||
|
got := compareVersions(tt.a, tt.b)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("compareVersions(%q, %q) = %d, want %d", tt.a, tt.b, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
|
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
"github.com/charmbracelet/lipgloss"
|
"github.com/charmbracelet/lipgloss"
|
||||||
|
"github.com/ollama/ollama/cmd/launch"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -18,12 +19,16 @@ var (
|
|||||||
|
|
||||||
type confirmModel struct {
|
type confirmModel struct {
|
||||||
prompt string
|
prompt string
|
||||||
|
yesLabel string
|
||||||
|
noLabel string
|
||||||
yes bool
|
yes bool
|
||||||
confirmed bool
|
confirmed bool
|
||||||
cancelled bool
|
cancelled bool
|
||||||
width int
|
width int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ConfirmOptions = launch.ConfirmOptions
|
||||||
|
|
||||||
func (m confirmModel) Init() tea.Cmd {
|
func (m confirmModel) Init() tea.Cmd {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -40,22 +45,16 @@ func (m confirmModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||||||
|
|
||||||
case tea.KeyMsg:
|
case tea.KeyMsg:
|
||||||
switch msg.String() {
|
switch msg.String() {
|
||||||
case "ctrl+c", "esc", "n":
|
case "ctrl+c", "esc":
|
||||||
m.cancelled = true
|
m.cancelled = true
|
||||||
return m, tea.Quit
|
return m, tea.Quit
|
||||||
case "y":
|
|
||||||
m.yes = true
|
|
||||||
m.confirmed = true
|
|
||||||
return m, tea.Quit
|
|
||||||
case "enter":
|
case "enter":
|
||||||
m.confirmed = true
|
m.confirmed = true
|
||||||
return m, tea.Quit
|
return m, tea.Quit
|
||||||
case "left", "h":
|
case "left":
|
||||||
m.yes = true
|
m.yes = true
|
||||||
case "right", "l":
|
case "right":
|
||||||
m.yes = false
|
m.yes = false
|
||||||
case "tab":
|
|
||||||
m.yes = !m.yes
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -68,12 +67,20 @@ func (m confirmModel) View() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var yesBtn, noBtn string
|
var yesBtn, noBtn string
|
||||||
|
yesLabel := m.yesLabel
|
||||||
|
if yesLabel == "" {
|
||||||
|
yesLabel = "Yes"
|
||||||
|
}
|
||||||
|
noLabel := m.noLabel
|
||||||
|
if noLabel == "" {
|
||||||
|
noLabel = "No"
|
||||||
|
}
|
||||||
if m.yes {
|
if m.yes {
|
||||||
yesBtn = confirmActiveStyle.Render(" Yes ")
|
yesBtn = confirmActiveStyle.Render(" " + yesLabel + " ")
|
||||||
noBtn = confirmInactiveStyle.Render(" No ")
|
noBtn = confirmInactiveStyle.Render(" " + noLabel + " ")
|
||||||
} else {
|
} else {
|
||||||
yesBtn = confirmInactiveStyle.Render(" Yes ")
|
yesBtn = confirmInactiveStyle.Render(" " + yesLabel + " ")
|
||||||
noBtn = confirmActiveStyle.Render(" No ")
|
noBtn = confirmActiveStyle.Render(" " + noLabel + " ")
|
||||||
}
|
}
|
||||||
|
|
||||||
s := selectorTitleStyle.Render(m.prompt) + "\n\n"
|
s := selectorTitleStyle.Render(m.prompt) + "\n\n"
|
||||||
@@ -89,9 +96,26 @@ func (m confirmModel) View() string {
|
|||||||
// RunConfirm shows a bubbletea yes/no confirmation prompt.
|
// RunConfirm shows a bubbletea yes/no confirmation prompt.
|
||||||
// Returns true if the user confirmed, false if cancelled.
|
// Returns true if the user confirmed, false if cancelled.
|
||||||
func RunConfirm(prompt string) (bool, error) {
|
func RunConfirm(prompt string) (bool, error) {
|
||||||
|
return RunConfirmWithOptions(prompt, ConfirmOptions{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunConfirmWithOptions shows a bubbletea yes/no confirmation prompt with
|
||||||
|
// optional custom button labels.
|
||||||
|
func RunConfirmWithOptions(prompt string, options ConfirmOptions) (bool, error) {
|
||||||
|
yesLabel := options.YesLabel
|
||||||
|
if yesLabel == "" {
|
||||||
|
yesLabel = "Yes"
|
||||||
|
}
|
||||||
|
noLabel := options.NoLabel
|
||||||
|
if noLabel == "" {
|
||||||
|
noLabel = "No"
|
||||||
|
}
|
||||||
|
|
||||||
m := confirmModel{
|
m := confirmModel{
|
||||||
prompt: prompt,
|
prompt: prompt,
|
||||||
yes: true, // default to yes
|
yesLabel: yesLabel,
|
||||||
|
noLabel: noLabel,
|
||||||
|
yes: true, // default to yes
|
||||||
}
|
}
|
||||||
|
|
||||||
p := tea.NewProgram(m)
|
p := tea.NewProgram(m)
|
||||||
|
|||||||
@@ -33,6 +33,22 @@ func TestConfirmModel_View_ContainsButtons(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConfirmModel_View_ContainsCustomButtons(t *testing.T) {
|
||||||
|
m := confirmModel{
|
||||||
|
prompt: "Connect a messaging app now?",
|
||||||
|
yesLabel: "Yes",
|
||||||
|
noLabel: "Set up later",
|
||||||
|
yes: true,
|
||||||
|
}
|
||||||
|
got := m.View()
|
||||||
|
if !strings.Contains(got, "Yes") {
|
||||||
|
t.Error("should contain custom yes button")
|
||||||
|
}
|
||||||
|
if !strings.Contains(got, "Set up later") {
|
||||||
|
t.Error("should contain custom no button")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestConfirmModel_View_ContainsHelp(t *testing.T) {
|
func TestConfirmModel_View_ContainsHelp(t *testing.T) {
|
||||||
m := confirmModel{prompt: "Download?", yes: true}
|
m := confirmModel{prompt: "Download?", yes: true}
|
||||||
got := m.View()
|
got := m.View()
|
||||||
@@ -109,30 +125,33 @@ func TestConfirmModel_CtrlCCancels(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfirmModel_NCancels(t *testing.T) {
|
func TestConfirmModel_NDoesNothing(t *testing.T) {
|
||||||
m := confirmModel{prompt: "Download?", yes: true}
|
m := confirmModel{prompt: "Download?", yes: true}
|
||||||
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'n'}})
|
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'n'}})
|
||||||
fm := updated.(confirmModel)
|
fm := updated.(confirmModel)
|
||||||
if !fm.cancelled {
|
if fm.cancelled {
|
||||||
t.Error("'n' should set cancelled=true")
|
t.Error("'n' should not cancel")
|
||||||
}
|
}
|
||||||
if cmd == nil {
|
if fm.confirmed {
|
||||||
t.Error("'n' should return tea.Quit")
|
t.Error("'n' should not confirm")
|
||||||
|
}
|
||||||
|
if cmd != nil {
|
||||||
|
t.Error("'n' should not quit")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfirmModel_YConfirmsYes(t *testing.T) {
|
func TestConfirmModel_YDoesNothing(t *testing.T) {
|
||||||
m := confirmModel{prompt: "Download?", yes: false}
|
m := confirmModel{prompt: "Download?", yes: false}
|
||||||
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'y'}})
|
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'y'}})
|
||||||
fm := updated.(confirmModel)
|
fm := updated.(confirmModel)
|
||||||
if !fm.confirmed {
|
if fm.confirmed {
|
||||||
t.Error("'y' should set confirmed=true")
|
t.Error("'y' should not confirm")
|
||||||
}
|
}
|
||||||
if !fm.yes {
|
if fm.yes {
|
||||||
t.Error("'y' should set yes=true")
|
t.Error("'y' should not change selection")
|
||||||
}
|
}
|
||||||
if cmd == nil {
|
if cmd != nil {
|
||||||
t.Error("'y' should return tea.Quit")
|
t.Error("'y' should not quit")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,36 +159,33 @@ func TestConfirmModel_ArrowKeysNavigate(t *testing.T) {
|
|||||||
m := confirmModel{prompt: "Download?", yes: true}
|
m := confirmModel{prompt: "Download?", yes: true}
|
||||||
|
|
||||||
// Right moves to No
|
// Right moves to No
|
||||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'l'}})
|
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRight})
|
||||||
fm := updated.(confirmModel)
|
fm := updated.(confirmModel)
|
||||||
if fm.yes {
|
if fm.yes {
|
||||||
t.Error("right/l should move to No")
|
t.Error("right should move to No")
|
||||||
}
|
}
|
||||||
if fm.confirmed || fm.cancelled {
|
if fm.confirmed || fm.cancelled {
|
||||||
t.Error("navigation should not confirm or cancel")
|
t.Error("navigation should not confirm or cancel")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Left moves back to Yes
|
// Left moves back to Yes
|
||||||
updated, _ = fm.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'h'}})
|
updated, _ = fm.Update(tea.KeyMsg{Type: tea.KeyLeft})
|
||||||
fm = updated.(confirmModel)
|
fm = updated.(confirmModel)
|
||||||
if !fm.yes {
|
if !fm.yes {
|
||||||
t.Error("left/h should move to Yes")
|
t.Error("left should move to Yes")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfirmModel_TabToggles(t *testing.T) {
|
func TestConfirmModel_TabDoesNothing(t *testing.T) {
|
||||||
m := confirmModel{prompt: "Download?", yes: true}
|
m := confirmModel{prompt: "Download?", yes: true}
|
||||||
|
|
||||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyTab})
|
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyTab})
|
||||||
fm := updated.(confirmModel)
|
fm := updated.(confirmModel)
|
||||||
if fm.yes {
|
|
||||||
t.Error("tab should toggle from Yes to No")
|
|
||||||
}
|
|
||||||
|
|
||||||
updated, _ = fm.Update(tea.KeyMsg{Type: tea.KeyTab})
|
|
||||||
fm = updated.(confirmModel)
|
|
||||||
if !fm.yes {
|
if !fm.yes {
|
||||||
t.Error("tab should toggle from No to Yes")
|
t.Error("tab should not change selection")
|
||||||
|
}
|
||||||
|
if fm.confirmed || fm.cancelled {
|
||||||
|
t.Error("tab should not confirm or cancel")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
package tui
|
package tui
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
"github.com/charmbracelet/lipgloss"
|
"github.com/charmbracelet/lipgloss"
|
||||||
"github.com/ollama/ollama/cmd/config"
|
"github.com/ollama/ollama/cmd/launch"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -56,7 +55,7 @@ var (
|
|||||||
const maxSelectorItems = 10
|
const maxSelectorItems = 10
|
||||||
|
|
||||||
// ErrCancelled is returned when the user cancels the selection.
|
// ErrCancelled is returned when the user cancels the selection.
|
||||||
var ErrCancelled = errors.New("cancelled")
|
var ErrCancelled = launch.ErrCancelled
|
||||||
|
|
||||||
type SelectItem struct {
|
type SelectItem struct {
|
||||||
Name string
|
Name string
|
||||||
@@ -64,8 +63,8 @@ type SelectItem struct {
|
|||||||
Recommended bool
|
Recommended bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConvertItems converts config.ModelItem slice to SelectItem slice.
|
// ConvertItems converts launch.ModelItem slice to SelectItem slice.
|
||||||
func ConvertItems(items []config.ModelItem) []SelectItem {
|
func ConvertItems(items []launch.ModelItem) []SelectItem {
|
||||||
out := make([]SelectItem, len(items))
|
out := make([]SelectItem, len(items))
|
||||||
for i, item := range items {
|
for i, item := range items {
|
||||||
out[i] = SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
|
out[i] = SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
|
||||||
@@ -101,6 +100,16 @@ type selectorModel struct {
|
|||||||
width int
|
width int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func selectorModelWithCurrent(title string, items []SelectItem, current string) selectorModel {
|
||||||
|
m := selectorModel{
|
||||||
|
title: title,
|
||||||
|
items: items,
|
||||||
|
cursor: cursorForCurrent(items, current),
|
||||||
|
}
|
||||||
|
m.updateScroll(m.otherStart())
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
func (m selectorModel) filteredItems() []SelectItem {
|
func (m selectorModel) filteredItems() []SelectItem {
|
||||||
if m.filter == "" {
|
if m.filter == "" {
|
||||||
return m.items
|
return m.items
|
||||||
@@ -232,6 +241,10 @@ func (m selectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||||||
m.cancelled = true
|
m.cancelled = true
|
||||||
return m, tea.Quit
|
return m, tea.Quit
|
||||||
|
|
||||||
|
case tea.KeyLeft:
|
||||||
|
m.cancelled = true
|
||||||
|
return m, tea.Quit
|
||||||
|
|
||||||
case tea.KeyEnter:
|
case tea.KeyEnter:
|
||||||
filtered := m.filteredItems()
|
filtered := m.filteredItems()
|
||||||
if len(filtered) > 0 && m.cursor < len(filtered) {
|
if len(filtered) > 0 && m.cursor < len(filtered) {
|
||||||
@@ -344,7 +357,7 @@ func (m selectorModel) renderContent() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.WriteString("\n")
|
s.WriteString("\n")
|
||||||
help := "↑/↓ navigate • enter select • esc cancel"
|
help := "↑/↓ navigate • enter select • ← back"
|
||||||
if m.helpText != "" {
|
if m.helpText != "" {
|
||||||
help = m.helpText
|
help = m.helpText
|
||||||
}
|
}
|
||||||
@@ -367,13 +380,24 @@ func (m selectorModel) View() string {
|
|||||||
|
|
||||||
// cursorForCurrent returns the item index matching current, or 0 if not found.
|
// cursorForCurrent returns the item index matching current, or 0 if not found.
|
||||||
func cursorForCurrent(items []SelectItem, current string) int {
|
func cursorForCurrent(items []SelectItem, current string) int {
|
||||||
if current != "" {
|
if current == "" {
|
||||||
for i, item := range items {
|
return 0
|
||||||
if item.Name == current || strings.HasPrefix(item.Name, current+":") || strings.HasPrefix(current, item.Name+":") {
|
}
|
||||||
return i
|
|
||||||
}
|
// Prefer exact name matches before tag-prefix fallback so "qwen3.5" does not
|
||||||
|
// incorrectly select "qwen3.5:cloud" (and vice versa) based on list order.
|
||||||
|
for i, item := range items {
|
||||||
|
if item.Name == current {
|
||||||
|
return i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for i, item := range items {
|
||||||
|
if strings.HasPrefix(item.Name, current+":") || strings.HasPrefix(current, item.Name+":") {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -382,11 +406,7 @@ func SelectSingle(title string, items []SelectItem, current string) (string, err
|
|||||||
return "", fmt.Errorf("no items to select from")
|
return "", fmt.Errorf("no items to select from")
|
||||||
}
|
}
|
||||||
|
|
||||||
m := selectorModel{
|
m := selectorModelWithCurrent(title, items, current)
|
||||||
title: title,
|
|
||||||
items: items,
|
|
||||||
cursor: cursorForCurrent(items, current),
|
|
||||||
}
|
|
||||||
|
|
||||||
p := tea.NewProgram(m)
|
p := tea.NewProgram(m)
|
||||||
finalModel, err := p.Run()
|
finalModel, err := p.Run()
|
||||||
@@ -523,6 +543,7 @@ func (m *multiSelectorModel) toggleItem() {
|
|||||||
origIdx := m.itemIndex[item.Name]
|
origIdx := m.itemIndex[item.Name]
|
||||||
|
|
||||||
if m.checked[origIdx] {
|
if m.checked[origIdx] {
|
||||||
|
wasDefault := len(m.checkOrder) > 0 && m.checkOrder[len(m.checkOrder)-1] == origIdx
|
||||||
delete(m.checked, origIdx)
|
delete(m.checked, origIdx)
|
||||||
for i, idx := range m.checkOrder {
|
for i, idx := range m.checkOrder {
|
||||||
if idx == origIdx {
|
if idx == origIdx {
|
||||||
@@ -530,6 +551,34 @@ func (m *multiSelectorModel) toggleItem() {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if wasDefault {
|
||||||
|
// When removing the default, pick the nearest checked model above it
|
||||||
|
// (or below if none above) so default fallback follows list order.
|
||||||
|
newDefault := -1
|
||||||
|
for i := origIdx - 1; i >= 0; i-- {
|
||||||
|
if m.checked[i] {
|
||||||
|
newDefault = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if newDefault == -1 {
|
||||||
|
for i := origIdx + 1; i < len(m.items); i++ {
|
||||||
|
if m.checked[i] {
|
||||||
|
newDefault = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if newDefault != -1 {
|
||||||
|
for i, idx := range m.checkOrder {
|
||||||
|
if idx == newDefault {
|
||||||
|
m.checkOrder = append(m.checkOrder[:i], m.checkOrder[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.checkOrder = append(m.checkOrder, newDefault)
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
m.checked[origIdx] = true
|
m.checked[origIdx] = true
|
||||||
m.checkOrder = append(m.checkOrder, origIdx)
|
m.checkOrder = append(m.checkOrder, origIdx)
|
||||||
@@ -562,6 +611,10 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||||||
m.cancelled = true
|
m.cancelled = true
|
||||||
return m, tea.Quit
|
return m, tea.Quit
|
||||||
|
|
||||||
|
case tea.KeyLeft:
|
||||||
|
m.cancelled = true
|
||||||
|
return m, tea.Quit
|
||||||
|
|
||||||
case tea.KeyTab:
|
case tea.KeyTab:
|
||||||
m.multi = !m.multi
|
m.multi = !m.multi
|
||||||
|
|
||||||
@@ -763,17 +816,21 @@ func (m multiSelectorModel) View() string {
|
|||||||
|
|
||||||
s.WriteString("\n")
|
s.WriteString("\n")
|
||||||
|
|
||||||
|
count := m.selectedCount()
|
||||||
if !m.multi {
|
if !m.multi {
|
||||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • tab add multiple • esc cancel"))
|
if count > 0 {
|
||||||
|
s.WriteString(sectionHeaderStyle.Render(fmt.Sprintf("%d models selected - press tab to edit", count)))
|
||||||
|
s.WriteString("\n\n")
|
||||||
|
}
|
||||||
|
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • tab add multiple • ← back"))
|
||||||
} else {
|
} else {
|
||||||
count := m.selectedCount()
|
|
||||||
if count == 0 {
|
if count == 0 {
|
||||||
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
|
s.WriteString(sectionHeaderStyle.Render("Select at least one model."))
|
||||||
} else {
|
} else {
|
||||||
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
|
s.WriteString(sectionHeaderStyle.Render(fmt.Sprintf("%d models selected - press enter to continue", count)))
|
||||||
}
|
}
|
||||||
s.WriteString("\n\n")
|
s.WriteString("\n\n")
|
||||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • tab select single • enter confirm • esc cancel"))
|
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • tab select single • enter confirm • ← back"))
|
||||||
}
|
}
|
||||||
|
|
||||||
result := s.String()
|
result := s.String()
|
||||||
|
|||||||
@@ -216,6 +216,41 @@ func TestUpdateScroll(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSelectorModelWithCurrent_ScrollsToCurrentInMoreSection(t *testing.T) {
|
||||||
|
m := selectorModelWithCurrent("Pick:", mixedItems(), "other-10")
|
||||||
|
|
||||||
|
if m.cursor != 11 {
|
||||||
|
t.Fatalf("cursor = %d, want 11", m.cursor)
|
||||||
|
}
|
||||||
|
if m.scrollOffset == 0 {
|
||||||
|
t.Fatal("scrollOffset should move to reveal current item in More section")
|
||||||
|
}
|
||||||
|
|
||||||
|
content := m.renderContent()
|
||||||
|
if !strings.Contains(content, "▸ other-10") {
|
||||||
|
t.Fatalf("expected current item to be visible and highlighted\n%s", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectorModelWithCurrent_HighlightsExactLocalWhenCloudVariantExists(t *testing.T) {
|
||||||
|
m := selectorModelWithCurrent("Pick:", []SelectItem{
|
||||||
|
{Name: "qwen3.5:cloud", Recommended: true},
|
||||||
|
{Name: "qwen3.5", Recommended: true},
|
||||||
|
}, "qwen3.5")
|
||||||
|
|
||||||
|
if m.cursor != 1 {
|
||||||
|
t.Fatalf("cursor = %d, want 1", m.cursor)
|
||||||
|
}
|
||||||
|
|
||||||
|
content := m.renderContent()
|
||||||
|
if !strings.Contains(content, "▸ qwen3.5") {
|
||||||
|
t.Fatalf("expected local qwen3.5 to be highlighted\n%s", content)
|
||||||
|
}
|
||||||
|
if strings.Contains(content, "▸ qwen3.5:cloud") {
|
||||||
|
t.Fatalf("did not expect cloud qwen3.5:cloud to be highlighted\n%s", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRenderContent_SectionHeaders(t *testing.T) {
|
func TestRenderContent_SectionHeaders(t *testing.T) {
|
||||||
m := selectorModel{
|
m := selectorModel{
|
||||||
title: "Pick:",
|
title: "Pick:",
|
||||||
@@ -418,6 +453,28 @@ func TestCursorForCurrent(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCursorForCurrent_PrefersExactLocalOverCloudPrefix(t *testing.T) {
|
||||||
|
testItems := []SelectItem{
|
||||||
|
{Name: "qwen3.5:cloud", Recommended: true},
|
||||||
|
{Name: "qwen3.5", Recommended: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := cursorForCurrent(testItems, "qwen3.5"); got != 1 {
|
||||||
|
t.Errorf("cursorForCurrent(%q) = %d, want %d", "qwen3.5", got, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCursorForCurrent_PrefersExactCloudOverLocalPrefix(t *testing.T) {
|
||||||
|
testItems := []SelectItem{
|
||||||
|
{Name: "qwen3.5", Recommended: true},
|
||||||
|
{Name: "qwen3.5:cloud", Recommended: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := cursorForCurrent(testItems, "qwen3.5:cloud"); got != 1 {
|
||||||
|
t.Errorf("cursorForCurrent(%q) = %d, want %d", "qwen3.5:cloud", got, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// --- ReorderItems ---
|
// --- ReorderItems ---
|
||||||
|
|
||||||
func TestReorderItems(t *testing.T) {
|
func TestReorderItems(t *testing.T) {
|
||||||
@@ -725,6 +782,9 @@ func TestMulti_MultiModeHelpText(t *testing.T) {
|
|||||||
if !strings.Contains(content, "tab select single") {
|
if !strings.Contains(content, "tab select single") {
|
||||||
t.Error("multi mode should show 'tab select single' in help")
|
t.Error("multi mode should show 'tab select single' in help")
|
||||||
}
|
}
|
||||||
|
if !strings.Contains(content, "← back") {
|
||||||
|
t.Error("multi mode should show '← back' in help")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- preChecked initialization order ---
|
// --- preChecked initialization order ---
|
||||||
@@ -783,6 +843,74 @@ func TestMulti_LastCheckedIsDefault(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMulti_UncheckingDefaultFallsBackToNearestCheckedAbove(t *testing.T) {
|
||||||
|
// Default is "b", and checked models are "a", "b", "c".
|
||||||
|
// Unticking default should make "a" (the nearest checked item above) default.
|
||||||
|
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"b", "c", "a"})
|
||||||
|
m.multi = true
|
||||||
|
m.cursor = 1 // "b"
|
||||||
|
m.toggleItem()
|
||||||
|
|
||||||
|
lastIdx := m.checkOrder[len(m.checkOrder)-1]
|
||||||
|
if m.items[lastIdx].Name != "a" {
|
||||||
|
t.Fatalf("expected default to fall back to 'a', got %q", m.items[lastIdx].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMulti_UncheckingTopDefaultFallsBackToNearestCheckedBelow(t *testing.T) {
|
||||||
|
// Default is top item "a". With no checked item above, fallback should pick
|
||||||
|
// the nearest checked item below ("b").
|
||||||
|
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "c", "b"})
|
||||||
|
m.multi = true
|
||||||
|
m.cursor = 0 // "a"
|
||||||
|
m.toggleItem()
|
||||||
|
|
||||||
|
lastIdx := m.checkOrder[len(m.checkOrder)-1]
|
||||||
|
if m.items[lastIdx].Name != "b" {
|
||||||
|
t.Fatalf("expected default to fall back to 'b', got %q", m.items[lastIdx].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Left arrow back navigation ---
|
||||||
|
|
||||||
|
func TestSelectorLeftArrowCancelsWhenNoFilter(t *testing.T) {
|
||||||
|
m := selectorModelWithCurrent("Pick:", items("a", "b", "c"), "")
|
||||||
|
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyLeft})
|
||||||
|
got := updated.(selectorModel)
|
||||||
|
if !got.cancelled {
|
||||||
|
t.Error("left arrow with empty filter should cancel (go back)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectorLeftArrowCancelsWhenFiltering(t *testing.T) {
|
||||||
|
m := selectorModelWithCurrent("Pick:", items("a", "b", "c"), "")
|
||||||
|
m.filter = "a"
|
||||||
|
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyLeft})
|
||||||
|
got := updated.(selectorModel)
|
||||||
|
if !got.cancelled {
|
||||||
|
t.Error("left arrow with active filter should still cancel (go back)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultiSelectorLeftArrowCancelsWhenNoFilter(t *testing.T) {
|
||||||
|
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||||
|
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyLeft})
|
||||||
|
got := updated.(multiSelectorModel)
|
||||||
|
if !got.cancelled {
|
||||||
|
t.Error("left arrow with empty filter should cancel (go back)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultiSelectorLeftArrowCancelsWhenFiltering(t *testing.T) {
|
||||||
|
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||||
|
m.filter = "a"
|
||||||
|
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyLeft})
|
||||||
|
got := updated.(multiSelectorModel)
|
||||||
|
if !got.cancelled {
|
||||||
|
t.Error("left arrow with active filter should still cancel (go back)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Key message helpers for testing
|
// Key message helpers for testing
|
||||||
|
|
||||||
type keyType = int
|
type keyType = int
|
||||||
|
|||||||