mirror of
https://github.com/ollama/ollama.git
synced 2026-04-21 00:05:40 +02:00
Compare commits
104 Commits
brucemacd/
...
v0.18.2-rc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
676d9845ba | ||
|
|
e37a9b4c01 | ||
|
|
d727aacd04 | ||
|
|
fa69b833cd | ||
|
|
bbbad97686 | ||
|
|
bcf6d55b54 | ||
|
|
810d4f9c22 | ||
|
|
856c047a6c | ||
|
|
79c1e93c00 | ||
|
|
f8b657c967 | ||
|
|
10fefe0d57 | ||
|
|
2f9a68f9e9 | ||
|
|
3980c0217d | ||
|
|
870599f5da | ||
|
|
abf8e8e9c8 | ||
|
|
f3f31a8192 | ||
|
|
9e7ba835da | ||
|
|
347f17b8d1 | ||
|
|
081b9eb423 | ||
|
|
bb867c6fdb | ||
|
|
81f4506a61 | ||
|
|
76925f1284 | ||
|
|
f676231de9 | ||
|
|
af5f7c0a9e | ||
|
|
a6b27d776b | ||
|
|
539741199e | ||
|
|
8f45236d09 | ||
|
|
97013a190c | ||
|
|
c222735c02 | ||
|
|
87d21c7fc0 | ||
|
|
54e05172a0 | ||
|
|
464186e995 | ||
|
|
8c4d5d6c2f | ||
|
|
bc72b14016 | ||
|
|
61086083eb | ||
|
|
62d1f01ab4 | ||
|
|
10e51c5177 | ||
|
|
3e06bde643 | ||
|
|
6be2de8214 | ||
|
|
ebb1b9ec14 | ||
|
|
d126467d5d | ||
|
|
afb4c62fbf | ||
|
|
e790dc435b | ||
|
|
288077c3a3 | ||
|
|
4425c54eda | ||
|
|
778899a5d2 | ||
|
|
4eab60c1e2 | ||
|
|
1af850e6e3 | ||
|
|
9b0c7cc7b9 | ||
|
|
6928630601 | ||
|
|
9896e3627f | ||
|
|
15732f0ea7 | ||
|
|
562c76d7cc | ||
|
|
122c68c151 | ||
|
|
82848a7806 | ||
|
|
39982a954e | ||
|
|
e9f6ea232f | ||
|
|
110eff01a9 | ||
|
|
799e51d419 | ||
|
|
e8fcb29586 | ||
|
|
97d2f05a6d | ||
|
|
8207e55ec7 | ||
|
|
ad16bffc7d | ||
|
|
c1e3ef4bcc | ||
|
|
a3093cd5e5 | ||
|
|
23d4cad1a2 | ||
|
|
86513cb697 | ||
|
|
3490e9590b | ||
|
|
8da09b1e7e | ||
|
|
a60b9adcce | ||
|
|
a16f96658b | ||
|
|
18ab09b431 | ||
|
|
638faeac54 | ||
|
|
dd5eb6337d | ||
|
|
79917cf80b | ||
|
|
cc90a035a0 | ||
|
|
d98dda4676 | ||
|
|
d69ddc1edc | ||
|
|
9bf41969f0 | ||
|
|
0f23b7bff5 | ||
|
|
4e57d2094e | ||
|
|
7f9efd53df | ||
|
|
da70c3222e | ||
|
|
9d902d63ce | ||
|
|
f4f0a4a471 | ||
|
|
3323c1d319 | ||
|
|
f20dc6b698 | ||
|
|
4b2ac1f369 | ||
|
|
8daf47fb3a | ||
|
|
6c980579cd | ||
|
|
5c73c4e2ee | ||
|
|
5daf59cc66 | ||
|
|
0ade9205cc | ||
|
|
06edabdde1 | ||
|
|
8b4e5a82a8 | ||
|
|
3445223311 | ||
|
|
fa6c0127e6 | ||
|
|
97323d1c68 | ||
|
|
458dd1b9d9 | ||
|
|
9d02d1d767 | ||
|
|
1a636fb47a | ||
|
|
0759fface9 | ||
|
|
325b72bc31 | ||
|
|
f01a9a7859 |
64
.github/workflows/release.yaml
vendored
64
.github/workflows/release.yaml
vendored
@@ -117,6 +117,25 @@ jobs:
|
|||||||
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
||||||
flags: ''
|
flags: ''
|
||||||
runner_dir: 'vulkan'
|
runner_dir: 'vulkan'
|
||||||
|
- os: windows
|
||||||
|
arch: amd64
|
||||||
|
preset: 'MLX CUDA 13'
|
||||||
|
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
|
||||||
|
cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip
|
||||||
|
cuda-components:
|
||||||
|
- '"cudart"'
|
||||||
|
- '"nvcc"'
|
||||||
|
- '"cublas"'
|
||||||
|
- '"cublas_dev"'
|
||||||
|
- '"cufft"'
|
||||||
|
- '"cufft_dev"'
|
||||||
|
- '"nvrtc"'
|
||||||
|
- '"nvrtc_dev"'
|
||||||
|
- '"crt"'
|
||||||
|
- '"nvvm"'
|
||||||
|
- '"nvptxcompiler"'
|
||||||
|
cuda-version: '13.0'
|
||||||
|
flags: ''
|
||||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||||
environment: release
|
environment: release
|
||||||
env:
|
env:
|
||||||
@@ -125,8 +144,10 @@ jobs:
|
|||||||
- name: Install system dependencies
|
- name: Install system dependencies
|
||||||
run: |
|
run: |
|
||||||
choco install -y --no-progress ccache ninja
|
choco install -y --no-progress ccache ninja
|
||||||
|
if (Get-Command ccache -ErrorAction SilentlyContinue) {
|
||||||
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
||||||
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan')
|
}
|
||||||
|
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan') || startsWith(matrix.preset, 'MLX ')
|
||||||
id: cache-install
|
id: cache-install
|
||||||
uses: actions/cache/restore@v4
|
uses: actions/cache/restore@v4
|
||||||
with:
|
with:
|
||||||
@@ -134,8 +155,9 @@ jobs:
|
|||||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||||
C:\Program Files\AMD\ROCm
|
C:\Program Files\AMD\ROCm
|
||||||
C:\VulkanSDK
|
C:\VulkanSDK
|
||||||
key: ${{ matrix.install }}
|
C:\Program Files\NVIDIA\CUDNN
|
||||||
- if: startsWith(matrix.preset, 'CUDA ')
|
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||||
|
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'MLX ')
|
||||||
name: Install CUDA ${{ matrix.cuda-version }}
|
name: Install CUDA ${{ matrix.cuda-version }}
|
||||||
run: |
|
run: |
|
||||||
$ErrorActionPreference = "Stop"
|
$ErrorActionPreference = "Stop"
|
||||||
@@ -179,6 +201,23 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
- if: startsWith(matrix.preset, 'MLX ')
|
||||||
|
name: Install cuDNN for MLX
|
||||||
|
run: |
|
||||||
|
$ErrorActionPreference = "Stop"
|
||||||
|
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
|
||||||
|
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||||
|
Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip"
|
||||||
|
Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted
|
||||||
|
$cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName
|
||||||
|
New-Item -ItemType Directory -Force -Path $cudnnRoot
|
||||||
|
Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse
|
||||||
|
}
|
||||||
|
|
||||||
|
echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||||
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
||||||
uses: actions/cache/save@v4
|
uses: actions/cache/save@v4
|
||||||
with:
|
with:
|
||||||
@@ -186,7 +225,8 @@ jobs:
|
|||||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||||
C:\Program Files\AMD\ROCm
|
C:\Program Files\AMD\ROCm
|
||||||
C:\VulkanSDK
|
C:\VulkanSDK
|
||||||
key: ${{ matrix.install }}
|
C:\Program Files\NVIDIA\CUDNN
|
||||||
|
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/cache@v4
|
- uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
@@ -198,7 +238,7 @@ jobs:
|
|||||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}"
|
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}"
|
||||||
cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}"
|
cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}"
|
||||||
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
|
cmake --install build --component "${{ startsWith(matrix.preset, 'MLX ') && 'MLX' || startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
|
||||||
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
|
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
|
||||||
env:
|
env:
|
||||||
CMAKE_GENERATOR: Ninja
|
CMAKE_GENERATOR: Ninja
|
||||||
@@ -543,11 +583,19 @@ jobs:
|
|||||||
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg dist/*.ps1 dist/*.sh ; do
|
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg dist/*.ps1 dist/*.sh ; do
|
||||||
echo "Uploading $payload"
|
echo "Uploading $payload"
|
||||||
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
|
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
|
||||||
pids[$!]=$!
|
pids+=($!)
|
||||||
sleep 1
|
sleep 1
|
||||||
done
|
done
|
||||||
echo "Waiting for uploads to complete"
|
echo "Waiting for uploads to complete"
|
||||||
for pid in "${pids[*]}"; do
|
failed=0
|
||||||
wait $pid
|
for pid in "${pids[@]}"; do
|
||||||
|
if ! wait $pid; then
|
||||||
|
echo "::error::Upload failed (pid $pid)"
|
||||||
|
failed=1
|
||||||
|
fi
|
||||||
done
|
done
|
||||||
|
if [ $failed -ne 0 ]; then
|
||||||
|
echo "One or more uploads failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
echo "done"
|
echo "done"
|
||||||
|
|||||||
62
.github/workflows/test.yaml
vendored
62
.github/workflows/test.yaml
vendored
@@ -37,7 +37,7 @@ jobs:
|
|||||||
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
|
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
|
||||||
}
|
}
|
||||||
|
|
||||||
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
|
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*' '.github/**/*') | tee -a $GITHUB_OUTPUT
|
||||||
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
|
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
|
||||||
|
|
||||||
linux:
|
linux:
|
||||||
@@ -51,7 +51,7 @@ jobs:
|
|||||||
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
|
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
|
||||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
||||||
- preset: ROCm
|
- preset: ROCm
|
||||||
container: rocm/dev-ubuntu-22.04:6.1.2
|
container: rocm/dev-ubuntu-22.04:7.2
|
||||||
extra-packages: rocm-libs
|
extra-packages: rocm-libs
|
||||||
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm'
|
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm'
|
||||||
- preset: Vulkan
|
- preset: Vulkan
|
||||||
@@ -60,6 +60,10 @@ jobs:
|
|||||||
mesa-vulkan-drivers vulkan-tools
|
mesa-vulkan-drivers vulkan-tools
|
||||||
libvulkan1 libvulkan-dev
|
libvulkan1 libvulkan-dev
|
||||||
vulkan-sdk cmake ccache g++ make
|
vulkan-sdk cmake ccache g++ make
|
||||||
|
- preset: 'MLX CUDA 13'
|
||||||
|
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
|
||||||
|
extra-packages: libcudnn9-dev-cuda-13 libopenblas-dev liblapack-dev liblapacke-dev git curl
|
||||||
|
flags: '-DCMAKE_CUDA_ARCHITECTURES=87 -DBLAS_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu -DLAPACK_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu'
|
||||||
runs-on: linux
|
runs-on: linux
|
||||||
container: ${{ matrix.container }}
|
container: ${{ matrix.container }}
|
||||||
steps:
|
steps:
|
||||||
@@ -76,6 +80,10 @@ jobs:
|
|||||||
$sudo apt-get update
|
$sudo apt-get update
|
||||||
fi
|
fi
|
||||||
$sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }}
|
$sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }}
|
||||||
|
# MLX requires CMake 3.25+, install from official releases
|
||||||
|
if [ "${{ matrix.preset }}" = "MLX CUDA 13" ]; then
|
||||||
|
curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.31.2/cmake-3.31.2-linux-$(uname -m).tar.gz | $sudo tar xz -C /usr/local --strip-components 1
|
||||||
|
fi
|
||||||
# Export VULKAN_SDK if provided by LunarG package (defensive)
|
# Export VULKAN_SDK if provided by LunarG package (defensive)
|
||||||
if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then
|
if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then
|
||||||
echo "VULKAN_SDK=/usr" >> $GITHUB_ENV
|
echo "VULKAN_SDK=/usr" >> $GITHUB_ENV
|
||||||
@@ -87,8 +95,8 @@ jobs:
|
|||||||
path: /github/home/.cache/ccache
|
path: /github/home/.cache/ccache
|
||||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
|
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
|
||||||
- run: |
|
- run: |
|
||||||
cmake --preset ${{ matrix.preset }} ${{ matrix.flags }}
|
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
|
||||||
cmake --build --preset ${{ matrix.preset }} --parallel
|
cmake --build --preset "${{ matrix.preset }}" --parallel
|
||||||
|
|
||||||
windows:
|
windows:
|
||||||
needs: [changes]
|
needs: [changes]
|
||||||
@@ -114,12 +122,31 @@ jobs:
|
|||||||
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||||
- preset: Vulkan
|
- preset: Vulkan
|
||||||
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
||||||
|
- preset: 'MLX CUDA 13'
|
||||||
|
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
|
||||||
|
cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip
|
||||||
|
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
|
||||||
|
cuda-components:
|
||||||
|
- '"cudart"'
|
||||||
|
- '"nvcc"'
|
||||||
|
- '"cublas"'
|
||||||
|
- '"cublas_dev"'
|
||||||
|
- '"cufft"'
|
||||||
|
- '"cufft_dev"'
|
||||||
|
- '"nvrtc"'
|
||||||
|
- '"nvrtc_dev"'
|
||||||
|
- '"crt"'
|
||||||
|
- '"nvvm"'
|
||||||
|
- '"nvptxcompiler"'
|
||||||
|
cuda-version: '13.0'
|
||||||
runs-on: windows
|
runs-on: windows
|
||||||
steps:
|
steps:
|
||||||
- run: |
|
- run: |
|
||||||
choco install -y --no-progress ccache ninja
|
choco install -y --no-progress ccache ninja
|
||||||
|
if (Get-Command ccache -ErrorAction SilentlyContinue) {
|
||||||
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
||||||
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan'
|
}
|
||||||
|
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan' || matrix.preset == 'MLX CUDA 13'
|
||||||
id: cache-install
|
id: cache-install
|
||||||
uses: actions/cache/restore@v4
|
uses: actions/cache/restore@v4
|
||||||
with:
|
with:
|
||||||
@@ -127,8 +154,9 @@ jobs:
|
|||||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||||
C:\Program Files\AMD\ROCm
|
C:\Program Files\AMD\ROCm
|
||||||
C:\VulkanSDK
|
C:\VulkanSDK
|
||||||
key: ${{ matrix.install }}
|
C:\Program Files\NVIDIA\CUDNN
|
||||||
- if: matrix.preset == 'CUDA'
|
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||||
|
- if: matrix.preset == 'CUDA' || matrix.preset == 'MLX CUDA 13'
|
||||||
name: Install CUDA ${{ matrix.cuda-version }}
|
name: Install CUDA ${{ matrix.cuda-version }}
|
||||||
run: |
|
run: |
|
||||||
$ErrorActionPreference = "Stop"
|
$ErrorActionPreference = "Stop"
|
||||||
@@ -168,6 +196,23 @@ jobs:
|
|||||||
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
|
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
|
||||||
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||||
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
|
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
|
||||||
|
- if: matrix.preset == 'MLX CUDA 13'
|
||||||
|
name: Install cuDNN for MLX
|
||||||
|
run: |
|
||||||
|
$ErrorActionPreference = "Stop"
|
||||||
|
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
|
||||||
|
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||||
|
Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip"
|
||||||
|
Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted
|
||||||
|
$cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName
|
||||||
|
New-Item -ItemType Directory -Force -Path $cudnnRoot
|
||||||
|
Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse
|
||||||
|
}
|
||||||
|
|
||||||
|
echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||||
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
||||||
uses: actions/cache/save@v4
|
uses: actions/cache/save@v4
|
||||||
with:
|
with:
|
||||||
@@ -175,7 +220,8 @@ jobs:
|
|||||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||||
C:\Program Files\AMD\ROCm
|
C:\Program Files\AMD\ROCm
|
||||||
C:\VulkanSDK
|
C:\VulkanSDK
|
||||||
key: ${{ matrix.install }}
|
C:\Program Files\NVIDIA\CUDNN
|
||||||
|
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/cache@v4
|
- uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
|
|||||||
108
CMakeLists.txt
108
CMakeLists.txt
@@ -64,10 +64,15 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR})
|
|||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
|
||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${OLLAMA_BUILD_DIR})
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${OLLAMA_BUILD_DIR})
|
||||||
|
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
# Store ggml include paths for use with target_include_directories later.
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include)
|
# We avoid global include_directories() to prevent polluting the include path
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
|
# for other projects like MLX (whose openblas dependency has its own common.h).
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
|
set(GGML_INCLUDE_DIRS
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx
|
||||||
|
)
|
||||||
|
|
||||||
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
|
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
|
||||||
|
|
||||||
@@ -87,6 +92,14 @@ if(NOT CPU_VARIANTS)
|
|||||||
set(CPU_VARIANTS "ggml-cpu")
|
set(CPU_VARIANTS "ggml-cpu")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# Apply ggml include directories to ggml targets only (not globally)
|
||||||
|
target_include_directories(ggml-base PRIVATE ${GGML_INCLUDE_DIRS})
|
||||||
|
foreach(variant ${CPU_VARIANTS})
|
||||||
|
if(TARGET ${variant})
|
||||||
|
target_include_directories(${variant} PRIVATE ${GGML_INCLUDE_DIRS})
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
|
||||||
install(TARGETS ggml-base ${CPU_VARIANTS}
|
install(TARGETS ggml-base ${CPU_VARIANTS}
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
PRE_EXCLUDE_REGEXES ".*"
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
@@ -103,6 +116,7 @@ if(CMAKE_CUDA_COMPILER)
|
|||||||
|
|
||||||
find_package(CUDAToolkit)
|
find_package(CUDAToolkit)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
||||||
|
target_include_directories(ggml-cuda PRIVATE ${GGML_INCLUDE_DIRS})
|
||||||
install(TARGETS ggml-cuda
|
install(TARGETS ggml-cuda
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||||
@@ -134,6 +148,7 @@ if(CMAKE_HIP_COMPILER)
|
|||||||
if(AMDGPU_TARGETS)
|
if(AMDGPU_TARGETS)
|
||||||
find_package(hip REQUIRED)
|
find_package(hip REQUIRED)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
||||||
|
target_include_directories(ggml-hip PRIVATE ${GGML_INCLUDE_DIRS})
|
||||||
|
|
||||||
if (WIN32)
|
if (WIN32)
|
||||||
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
|
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
|
||||||
@@ -148,7 +163,7 @@ if(CMAKE_HIP_COMPILER)
|
|||||||
)
|
)
|
||||||
install(RUNTIME_DEPENDENCY_SET rocm
|
install(RUNTIME_DEPENDENCY_SET rocm
|
||||||
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
|
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
|
||||||
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf
|
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register roctx64 rocroller drm drm_amdgpu numa elf
|
||||||
PRE_EXCLUDE_REGEXES ".*"
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
POST_EXCLUDE_REGEXES "system32"
|
POST_EXCLUDE_REGEXES "system32"
|
||||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
|
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
|
||||||
@@ -168,6 +183,7 @@ if(NOT APPLE)
|
|||||||
find_package(Vulkan)
|
find_package(Vulkan)
|
||||||
if(Vulkan_FOUND)
|
if(Vulkan_FOUND)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
||||||
|
target_include_directories(ggml-vulkan PRIVATE ${GGML_INCLUDE_DIRS})
|
||||||
install(TARGETS ggml-vulkan
|
install(TARGETS ggml-vulkan
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
PRE_INCLUDE_REGEXES vulkan
|
PRE_INCLUDE_REGEXES vulkan
|
||||||
@@ -179,7 +195,6 @@ if(NOT APPLE)
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
option(MLX_ENGINE "Enable MLX backend" OFF)
|
option(MLX_ENGINE "Enable MLX backend" OFF)
|
||||||
|
|
||||||
if(MLX_ENGINE)
|
if(MLX_ENGINE)
|
||||||
message(STATUS "Setting up MLX (this takes a while...)")
|
message(STATUS "Setting up MLX (this takes a while...)")
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx)
|
||||||
@@ -187,10 +202,36 @@ if(MLX_ENGINE)
|
|||||||
# Find CUDA toolkit if MLX is built with CUDA support
|
# Find CUDA toolkit if MLX is built with CUDA support
|
||||||
find_package(CUDAToolkit)
|
find_package(CUDAToolkit)
|
||||||
|
|
||||||
|
# Build list of directories for runtime dependency resolution
|
||||||
|
set(MLX_RUNTIME_DIRS ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR})
|
||||||
|
# Add cuDNN bin paths for DLLs (Windows MLX CUDA builds)
|
||||||
|
# CUDNN_ROOT_DIR is the standard CMake variable for cuDNN location
|
||||||
|
if(DEFINED ENV{CUDNN_ROOT_DIR})
|
||||||
|
# cuDNN 9.x has versioned subdirectories under bin/ (e.g., bin/13.0/)
|
||||||
|
file(GLOB CUDNN_BIN_SUBDIRS "$ENV{CUDNN_ROOT_DIR}/bin/*")
|
||||||
|
list(APPEND MLX_RUNTIME_DIRS ${CUDNN_BIN_SUBDIRS})
|
||||||
|
endif()
|
||||||
|
# Add build output directory and MLX dependency build directories
|
||||||
|
list(APPEND MLX_RUNTIME_DIRS ${OLLAMA_BUILD_DIR})
|
||||||
|
# OpenBLAS DLL location (pre-built zip extracts into openblas-src/bin/)
|
||||||
|
list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/openblas-src/bin)
|
||||||
|
# NCCL: on Linux, if real NCCL is found, cmake bundles libnccl.so via the
|
||||||
|
# regex below. If NCCL is not found, MLX links a static stub (OBJECT lib)
|
||||||
|
# so there is no runtime dependency. This path covers the stub build dir
|
||||||
|
# for windows so we include the DLL in our dependencies.
|
||||||
|
list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/distributed/nccl/nccl_stub-prefix/src/nccl_stub-build/Release)
|
||||||
|
|
||||||
|
# Base regexes for runtime dependencies (cross-platform)
|
||||||
|
set(MLX_INCLUDE_REGEXES cublas cublasLt cudart cufft nvrtc nvrtc-builtins cudnn nccl openblas gfortran)
|
||||||
|
# On Windows, also include dl.dll (dlfcn-win32 POSIX emulation layer)
|
||||||
|
if(WIN32)
|
||||||
|
list(APPEND MLX_INCLUDE_REGEXES "^dl\\.dll$")
|
||||||
|
endif()
|
||||||
|
|
||||||
install(TARGETS mlx mlxc
|
install(TARGETS mlx mlxc
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
DIRECTORIES ${MLX_RUNTIME_DIRS}
|
||||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
|
PRE_INCLUDE_REGEXES ${MLX_INCLUDE_REGEXES}
|
||||||
PRE_EXCLUDE_REGEXES ".*"
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||||
@@ -205,13 +246,54 @@ if(MLX_ENGINE)
|
|||||||
COMPONENT MLX)
|
COMPONENT MLX)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Manually install cudart and cublas since they might not be picked up as direct dependencies
|
# Install CCCL headers for NVRTC JIT compilation at runtime.
|
||||||
|
# MLX's own install rules use the default component so they get skipped by
|
||||||
|
# --component MLX. Headers are installed alongside libmlx in OLLAMA_INSTALL_DIR.
|
||||||
|
# On Linux, MLX's jit_module.cpp resolves CCCL via
|
||||||
|
# current_binary_dir().parent_path() / "include" / "cccl", so we create a
|
||||||
|
# symlink from lib/ollama/include -> ${OLLAMA_RUNNER_DIR}/include
|
||||||
|
# This will need refinement if we add multiple CUDA versions for MLX in the future.
|
||||||
|
if(EXISTS ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda)
|
||||||
|
install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda
|
||||||
|
DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl
|
||||||
|
COMPONENT MLX)
|
||||||
|
install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/nv
|
||||||
|
DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl
|
||||||
|
COMPONENT MLX)
|
||||||
|
if(NOT WIN32 AND NOT APPLE)
|
||||||
|
install(CODE "
|
||||||
|
set(_link \"${CMAKE_INSTALL_PREFIX}/lib/ollama/include\")
|
||||||
|
set(_target \"${OLLAMA_RUNNER_DIR}/include\")
|
||||||
|
if(NOT EXISTS \${_link})
|
||||||
|
execute_process(COMMAND \${CMAKE_COMMAND} -E create_symlink \${_target} \${_link})
|
||||||
|
endif()
|
||||||
|
" COMPONENT MLX)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# On Windows, explicitly install dl.dll (dlfcn-win32 POSIX dlopen emulation)
|
||||||
|
# RUNTIME_DEPENDENCIES auto-excludes it via POST_EXCLUDE_FILES_STRICT because
|
||||||
|
# dlfcn-win32 is a known CMake target with its own install rules (which install
|
||||||
|
# to the wrong destination). We must install it explicitly here.
|
||||||
|
if(WIN32)
|
||||||
|
install(FILES ${OLLAMA_BUILD_DIR}/dl.dll
|
||||||
|
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||||
|
COMPONENT MLX)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Manually install CUDA runtime libraries that MLX loads via dlopen
|
||||||
|
# (not detected by RUNTIME_DEPENDENCIES since they aren't link-time deps)
|
||||||
if(CUDAToolkit_FOUND)
|
if(CUDAToolkit_FOUND)
|
||||||
file(GLOB CUDART_LIBS
|
file(GLOB MLX_CUDA_LIBS
|
||||||
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
|
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
|
||||||
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
|
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*"
|
||||||
if(CUDART_LIBS)
|
"${CUDAToolkit_LIBRARY_DIR}/libcublasLt.so*"
|
||||||
install(FILES ${CUDART_LIBS}
|
"${CUDAToolkit_LIBRARY_DIR}/libnvrtc.so*"
|
||||||
|
"${CUDAToolkit_LIBRARY_DIR}/libnvrtc-builtins.so*"
|
||||||
|
"${CUDAToolkit_LIBRARY_DIR}/libcufft.so*"
|
||||||
|
"${CUDAToolkit_LIBRARY_DIR}/libcudnn.so*")
|
||||||
|
if(MLX_CUDA_LIBS)
|
||||||
|
install(FILES ${MLX_CUDA_LIBS}
|
||||||
DESTINATION ${OLLAMA_INSTALL_DIR}
|
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||||
COMPONENT MLX)
|
COMPONENT MLX)
|
||||||
endif()
|
endif()
|
||||||
|
|||||||
@@ -77,6 +77,15 @@
|
|||||||
"OLLAMA_RUNNER_DIR": "rocm"
|
"OLLAMA_RUNNER_DIR": "rocm"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "ROCm 7",
|
||||||
|
"inherits": [ "ROCm" ],
|
||||||
|
"cacheVariables": {
|
||||||
|
"CMAKE_HIP_FLAGS": "-parallel-jobs=4",
|
||||||
|
"AMDGPU_TARGETS": "gfx942;gfx950;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1151;gfx1200;gfx1201;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-",
|
||||||
|
"OLLAMA_RUNNER_DIR": "rocm"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "Vulkan",
|
"name": "Vulkan",
|
||||||
"inherits": [ "Default" ],
|
"inherits": [ "Default" ],
|
||||||
@@ -103,6 +112,7 @@
|
|||||||
"name": "MLX CUDA 13",
|
"name": "MLX CUDA 13",
|
||||||
"inherits": [ "MLX", "CUDA 13" ],
|
"inherits": [ "MLX", "CUDA 13" ],
|
||||||
"cacheVariables": {
|
"cacheVariables": {
|
||||||
|
"MLX_CUDA_ARCHITECTURES": "86;89;90;90a;100;103;75-virtual;80-virtual;110-virtual;120-virtual;121-virtual",
|
||||||
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
|
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -158,6 +168,11 @@
|
|||||||
"inherits": [ "ROCm" ],
|
"inherits": [ "ROCm" ],
|
||||||
"configurePreset": "ROCm 6"
|
"configurePreset": "ROCm 6"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "ROCm 7",
|
||||||
|
"inherits": [ "ROCm" ],
|
||||||
|
"configurePreset": "ROCm 7"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "Vulkan",
|
"name": "Vulkan",
|
||||||
"targets": [ "ggml-vulkan" ],
|
"targets": [ "ggml-vulkan" ],
|
||||||
|
|||||||
131
Dockerfile
131
Dockerfile
@@ -1,33 +1,23 @@
|
|||||||
# vim: filetype=dockerfile
|
# vim: filetype=dockerfile
|
||||||
|
|
||||||
ARG FLAVOR=${TARGETARCH}
|
ARG FLAVOR=${TARGETARCH}
|
||||||
ARG PARALLEL=8
|
|
||||||
|
|
||||||
ARG ROCMVERSION=6.3.3
|
ARG ROCMVERSION=7.2
|
||||||
ARG JETPACK5VERSION=r35.4.1
|
ARG JETPACK5VERSION=r35.4.1
|
||||||
ARG JETPACK6VERSION=r36.4.0
|
ARG JETPACK6VERSION=r36.4.0
|
||||||
ARG CMAKEVERSION=3.31.2
|
ARG CMAKEVERSION=3.31.2
|
||||||
|
ARG NINJAVERSION=1.12.1
|
||||||
ARG VULKANVERSION=1.4.321.1
|
ARG VULKANVERSION=1.4.321.1
|
||||||
|
|
||||||
# We require gcc v10 minimum. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
|
# Default empty stages for local MLX source overrides.
|
||||||
|
# Override with: docker build --build-context local-mlx=../mlx --build-context local-mlx-c=../mlx-c
|
||||||
|
FROM scratch AS local-mlx
|
||||||
|
FROM scratch AS local-mlx-c
|
||||||
|
|
||||||
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
|
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
|
||||||
RUN yum install -y yum-utils \
|
RUN dnf install -y yum-utils ccache gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ gcc-toolset-11-binutils \
|
||||||
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
|
|
||||||
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
|
|
||||||
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \
|
|
||||||
&& dnf install -y ccache \
|
|
||||||
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
||||||
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
|
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
||||||
ARG VULKANVERSION
|
|
||||||
RUN wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
|
|
||||||
&& tar xvf /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
|
|
||||||
&& dnf -y install ninja-build \
|
|
||||||
&& ln -s /usr/bin/python3 /usr/bin/python \
|
|
||||||
&& /${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \
|
|
||||||
&& /${VULKANVERSION}/vulkansdk -j 8 shaderc
|
|
||||||
RUN cp -r /${VULKANVERSION}/x86_64/include/* /usr/local/include/ \
|
|
||||||
&& cp -r /${VULKANVERSION}/x86_64/lib/* /usr/local/lib
|
|
||||||
ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH
|
|
||||||
|
|
||||||
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
|
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
|
||||||
# install epel-release for ccache
|
# install epel-release for ccache
|
||||||
@@ -38,100 +28,119 @@ ENV CC=clang CXX=clang++
|
|||||||
|
|
||||||
FROM base-${TARGETARCH} AS base
|
FROM base-${TARGETARCH} AS base
|
||||||
ARG CMAKEVERSION
|
ARG CMAKEVERSION
|
||||||
|
ARG NINJAVERSION
|
||||||
RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||||
|
RUN dnf install -y unzip \
|
||||||
|
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux$([ "$(uname -m)" = "aarch64" ] && echo "-aarch64").zip \
|
||||||
|
&& unzip /tmp/ninja.zip -d /usr/local/bin \
|
||||||
|
&& rm /tmp/ninja.zip
|
||||||
|
ENV CMAKE_GENERATOR=Ninja
|
||||||
ENV LDFLAGS=-s
|
ENV LDFLAGS=-s
|
||||||
|
|
||||||
FROM base AS cpu
|
FROM base AS cpu
|
||||||
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
|
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
|
||||||
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
||||||
ARG PARALLEL
|
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CPU' \
|
cmake --preset 'CPU' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'CPU' \
|
&& cmake --build --preset 'CPU' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CPU --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CPU --strip
|
||||||
|
|
||||||
FROM base AS cuda-11
|
FROM base AS cuda-11
|
||||||
ARG CUDA11VERSION=11.8
|
ARG CUDA11VERSION=11.8
|
||||||
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
||||||
ARG PARALLEL
|
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 11' \
|
cmake --preset 'CUDA 11' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
|
&& cmake --build --preset 'CUDA 11' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CUDA --strip
|
||||||
|
|
||||||
FROM base AS cuda-12
|
FROM base AS cuda-12
|
||||||
ARG CUDA12VERSION=12.8
|
ARG CUDA12VERSION=12.8
|
||||||
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
||||||
ARG PARALLEL
|
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 12' \
|
cmake --preset 'CUDA 12' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
|
&& cmake --build --preset 'CUDA 12' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CUDA --strip
|
||||||
|
|
||||||
|
|
||||||
FROM base AS cuda-13
|
FROM base AS cuda-13
|
||||||
ARG CUDA13VERSION=13.0
|
ARG CUDA13VERSION=13.0
|
||||||
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
||||||
ARG PARALLEL
|
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 13' \
|
cmake --preset 'CUDA 13' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
|
&& cmake --build --preset 'CUDA 13' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CUDA --strip
|
||||||
|
|
||||||
|
|
||||||
FROM base AS rocm-6
|
FROM base AS rocm-7
|
||||||
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
||||||
ARG PARALLEL
|
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'ROCm 6' \
|
cmake --preset 'ROCm 7' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
|
&& cmake --build --preset 'ROCm 7' -- -l $(nproc) \
|
||||||
&& cmake --install build --component HIP --strip --parallel ${PARALLEL}
|
&& cmake --install build --component HIP --strip
|
||||||
RUN rm -f dist/lib/ollama/rocm/rocblas/library/*gfx90[06]*
|
RUN rm -f dist/lib/ollama/rocm/rocblas/library/*gfx90[06]*
|
||||||
|
|
||||||
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
|
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
|
||||||
ARG CMAKEVERSION
|
ARG CMAKEVERSION
|
||||||
RUN apt-get update && apt-get install -y curl ccache \
|
ARG NINJAVERSION
|
||||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
RUN apt-get update && apt-get install -y curl ccache unzip \
|
||||||
|
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 \
|
||||||
|
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux-aarch64.zip \
|
||||||
|
&& unzip /tmp/ninja.zip -d /usr/local/bin \
|
||||||
|
&& rm /tmp/ninja.zip
|
||||||
|
ENV CMAKE_GENERATOR=Ninja
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
ARG PARALLEL
|
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'JetPack 5' \
|
cmake --preset 'JetPack 5' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 5' \
|
&& cmake --build --preset 'JetPack 5' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CUDA --strip
|
||||||
|
|
||||||
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
|
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
|
||||||
ARG CMAKEVERSION
|
ARG CMAKEVERSION
|
||||||
RUN apt-get update && apt-get install -y curl ccache \
|
ARG NINJAVERSION
|
||||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
RUN apt-get update && apt-get install -y curl ccache unzip \
|
||||||
|
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 \
|
||||||
|
&& curl -fsSL -o /tmp/ninja.zip https://github.com/ninja-build/ninja/releases/download/v${NINJAVERSION}/ninja-linux-aarch64.zip \
|
||||||
|
&& unzip /tmp/ninja.zip -d /usr/local/bin \
|
||||||
|
&& rm /tmp/ninja.zip
|
||||||
|
ENV CMAKE_GENERATOR=Ninja
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
ARG PARALLEL
|
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'JetPack 6' \
|
cmake --preset 'JetPack 6' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \
|
&& cmake --build --preset 'JetPack 6' -- -l $(nproc) \
|
||||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CUDA --strip
|
||||||
|
|
||||||
FROM base AS vulkan
|
FROM base AS vulkan
|
||||||
|
ARG VULKANVERSION
|
||||||
|
RUN ln -s /usr/bin/python3 /usr/bin/python \
|
||||||
|
&& wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk.tar.xz \
|
||||||
|
&& tar xvf /tmp/vulkansdk.tar.xz -C /tmp \
|
||||||
|
&& /tmp/${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \
|
||||||
|
&& /tmp/${VULKANVERSION}/vulkansdk -j 8 shaderc \
|
||||||
|
&& cp -r /tmp/${VULKANVERSION}/x86_64/include/* /usr/local/include/ \
|
||||||
|
&& cp -r /tmp/${VULKANVERSION}/x86_64/lib/* /usr/local/lib \
|
||||||
|
&& cp -r /tmp/${VULKANVERSION}/x86_64/bin/* /usr/local/bin/ \
|
||||||
|
&& rm -rf /tmp/${VULKANVERSION} /tmp/vulkansdk.tar.xz
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'Vulkan' \
|
cmake --preset 'Vulkan' \
|
||||||
&& cmake --build --parallel --preset 'Vulkan' \
|
&& cmake --build --preset 'Vulkan' -- -l $(nproc) \
|
||||||
&& cmake --install build --component Vulkan --strip --parallel 8
|
&& cmake --install build --component Vulkan --strip
|
||||||
|
|
||||||
FROM base AS mlx
|
FROM base AS mlx
|
||||||
ARG CUDA13VERSION=13.0
|
ARG CUDA13VERSION=13.0
|
||||||
@@ -143,20 +152,27 @@ ENV PATH=/usr/local/cuda-13/bin:$PATH
|
|||||||
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
|
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
|
||||||
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
|
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
|
||||||
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
|
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
|
||||||
ARG PARALLEL
|
|
||||||
WORKDIR /go/src/github.com/ollama/ollama
|
WORKDIR /go/src/github.com/ollama/ollama
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
COPY x/imagegen/mlx x/imagegen/mlx
|
COPY x/imagegen/mlx x/imagegen/mlx
|
||||||
COPY go.mod go.sum .
|
COPY go.mod go.sum .
|
||||||
COPY MLX_VERSION .
|
COPY MLX_VERSION MLX_CORE_VERSION .
|
||||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||||
ENV PATH=/usr/local/go/bin:$PATH
|
ENV PATH=/usr/local/go/bin:$PATH
|
||||||
RUN go mod download
|
RUN go mod download
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
--mount=type=bind,from=local-mlx,target=/tmp/local-mlx \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
|
--mount=type=bind,from=local-mlx-c,target=/tmp/local-mlx-c \
|
||||||
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
|
if [ -f /tmp/local-mlx/CMakeLists.txt ]; then \
|
||||||
|
export OLLAMA_MLX_SOURCE=/tmp/local-mlx; \
|
||||||
|
fi \
|
||||||
|
&& if [ -f /tmp/local-mlx-c/CMakeLists.txt ]; then \
|
||||||
|
export OLLAMA_MLX_C_SOURCE=/tmp/local-mlx-c; \
|
||||||
|
fi \
|
||||||
|
&& cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
||||||
|
&& cmake --build --preset 'MLX CUDA 13' -- -l $(nproc) \
|
||||||
|
&& cmake --install build --component MLX --strip
|
||||||
|
|
||||||
FROM base AS build
|
FROM base AS build
|
||||||
WORKDIR /go/src/github.com/ollama/ollama
|
WORKDIR /go/src/github.com/ollama/ollama
|
||||||
@@ -165,16 +181,14 @@ RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-
|
|||||||
ENV PATH=/usr/local/go/bin:$PATH
|
ENV PATH=/usr/local/go/bin:$PATH
|
||||||
RUN go mod download
|
RUN go mod download
|
||||||
COPY . .
|
COPY . .
|
||||||
# Clone mlx-c headers for CGO (version from MLX_VERSION file)
|
|
||||||
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
|
|
||||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||||
ENV CGO_ENABLED=1
|
ENV CGO_ENABLED=1
|
||||||
ARG CGO_CFLAGS
|
ARG CGO_CFLAGS
|
||||||
ARG CGO_CXXFLAGS
|
ARG CGO_CXXFLAGS
|
||||||
ENV CGO_CFLAGS="${CGO_CFLAGS} -I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
|
ENV CGO_CFLAGS="${CGO_CFLAGS}"
|
||||||
ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}"
|
ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}"
|
||||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||||
go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .
|
go build -trimpath -buildmode=pie -o /bin/ollama .
|
||||||
|
|
||||||
FROM --platform=linux/amd64 scratch AS amd64
|
FROM --platform=linux/amd64 scratch AS amd64
|
||||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||||
@@ -191,10 +205,9 @@ COPY --from=jetpack-5 dist/lib/ollama/ /lib/ollama/
|
|||||||
COPY --from=jetpack-6 dist/lib/ollama/ /lib/ollama/
|
COPY --from=jetpack-6 dist/lib/ollama/ /lib/ollama/
|
||||||
|
|
||||||
FROM scratch AS rocm
|
FROM scratch AS rocm
|
||||||
COPY --from=rocm-6 dist/lib/ollama /lib/ollama
|
COPY --from=rocm-7 dist/lib/ollama /lib/ollama
|
||||||
|
|
||||||
FROM ${FLAVOR} AS archive
|
FROM ${FLAVOR} AS archive
|
||||||
ARG VULKANVERSION
|
|
||||||
COPY --from=cpu dist/lib/ollama /lib/ollama
|
COPY --from=cpu dist/lib/ollama /lib/ollama
|
||||||
COPY --from=build /bin/ollama /bin/ollama
|
COPY --from=build /bin/ollama /bin/ollama
|
||||||
|
|
||||||
|
|||||||
1
MLX_CORE_VERSION
Normal file
1
MLX_CORE_VERSION
Normal file
@@ -0,0 +1 @@
|
|||||||
|
v0.30.6
|
||||||
@@ -1 +1 @@
|
|||||||
v0.4.1
|
v0.5.0
|
||||||
|
|||||||
@@ -852,6 +852,19 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close thinking block if still open (thinking → tool_use without text in between)
|
||||||
|
if c.thinkingStarted && !c.thinkingDone {
|
||||||
|
c.thinkingDone = true
|
||||||
|
events = append(events, StreamEvent{
|
||||||
|
Event: "content_block_stop",
|
||||||
|
Data: ContentBlockStopEvent{
|
||||||
|
Type: "content_block_stop",
|
||||||
|
Index: c.contentIndex,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
c.contentIndex++
|
||||||
|
}
|
||||||
|
|
||||||
if c.textStarted {
|
if c.textStarted {
|
||||||
events = append(events, StreamEvent{
|
events = append(events, StreamEvent{
|
||||||
Event: "content_block_stop",
|
Event: "content_block_stop",
|
||||||
|
|||||||
@@ -799,6 +799,107 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestStreamConverter_ThinkingDirectlyFollowedByToolCall verifies that when a
|
||||||
|
// model emits a thinking block followed directly by a tool_use block (with no
|
||||||
|
// text block in between), the streaming converter correctly closes the thinking
|
||||||
|
// block and increments the content index before opening the tool_use block.
|
||||||
|
// Previously, the converter reused contentIndex=0 for the tool_use block,
|
||||||
|
// which caused "Content block not found" errors in clients. See #14816.
|
||||||
|
func TestStreamConverter_ThinkingDirectlyFollowedByToolCall(t *testing.T) {
|
||||||
|
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||||
|
|
||||||
|
// First chunk: thinking content (no text)
|
||||||
|
resp1 := api.ChatResponse{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Thinking: "I should call the tool.",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
events1 := conv.Process(resp1)
|
||||||
|
|
||||||
|
// Should have: message_start, content_block_start(thinking), content_block_delta(thinking)
|
||||||
|
if len(events1) < 3 {
|
||||||
|
t.Fatalf("expected at least 3 events for thinking chunk, got %d", len(events1))
|
||||||
|
}
|
||||||
|
if events1[0].Event != "message_start" {
|
||||||
|
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
|
||||||
|
}
|
||||||
|
thinkingStart, ok := events1[1].Data.(ContentBlockStartEvent)
|
||||||
|
if !ok || thinkingStart.ContentBlock.Type != "thinking" {
|
||||||
|
t.Errorf("expected content_block_start(thinking) as second event, got %+v", events1[1])
|
||||||
|
}
|
||||||
|
if thinkingStart.Index != 0 {
|
||||||
|
t.Errorf("expected thinking block at index 0, got %d", thinkingStart.Index)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second chunk: tool call (no text between thinking and tool)
|
||||||
|
resp2 := api.ChatResponse{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_abc",
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "ask_user",
|
||||||
|
Arguments: testArgs(map[string]any{"question": "cats or dogs?"}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "stop",
|
||||||
|
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
|
||||||
|
}
|
||||||
|
events2 := conv.Process(resp2)
|
||||||
|
|
||||||
|
// Expect: content_block_stop(index=0), content_block_start(tool_use, index=1),
|
||||||
|
// content_block_delta(input_json_delta, index=1), content_block_stop(index=1),
|
||||||
|
// message_delta, message_stop
|
||||||
|
var thinkingStop, toolStart, toolDelta, toolStop *StreamEvent
|
||||||
|
for i := range events2 {
|
||||||
|
e := &events2[i]
|
||||||
|
switch e.Event {
|
||||||
|
case "content_block_stop":
|
||||||
|
if stop, ok := e.Data.(ContentBlockStopEvent); ok {
|
||||||
|
if stop.Index == 0 && thinkingStop == nil {
|
||||||
|
thinkingStop = e
|
||||||
|
} else if stop.Index == 1 {
|
||||||
|
toolStop = e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "content_block_start":
|
||||||
|
if start, ok := e.Data.(ContentBlockStartEvent); ok && start.ContentBlock.Type == "tool_use" {
|
||||||
|
toolStart = e
|
||||||
|
}
|
||||||
|
case "content_block_delta":
|
||||||
|
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok && delta.Delta.Type == "input_json_delta" {
|
||||||
|
toolDelta = e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if thinkingStop == nil {
|
||||||
|
t.Error("expected content_block_stop for thinking block (index 0)")
|
||||||
|
}
|
||||||
|
if toolStart == nil {
|
||||||
|
t.Fatal("expected content_block_start for tool_use block")
|
||||||
|
}
|
||||||
|
if start, ok := toolStart.Data.(ContentBlockStartEvent); !ok || start.Index != 1 {
|
||||||
|
t.Errorf("expected tool_use block at index 1, got %+v", toolStart.Data)
|
||||||
|
}
|
||||||
|
if toolDelta == nil {
|
||||||
|
t.Fatal("expected input_json_delta event for tool call")
|
||||||
|
}
|
||||||
|
if delta, ok := toolDelta.Data.(ContentBlockDeltaEvent); !ok || delta.Index != 1 {
|
||||||
|
t.Errorf("expected tool delta at index 1, got %+v", toolDelta.Data)
|
||||||
|
}
|
||||||
|
if toolStop == nil {
|
||||||
|
t.Error("expected content_block_stop for tool_use block (index 1)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||||
// Test that unmarshalable arguments (like channels) are handled gracefully
|
// Test that unmarshalable arguments (like channels) are handled gracefully
|
||||||
// and don't cause a panic or corrupt stream
|
// and don't cause a panic or corrupt stream
|
||||||
|
|||||||
@@ -476,25 +476,3 @@ func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
|
|||||||
}
|
}
|
||||||
return &resp, nil
|
return &resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AliasRequest is the request body for creating or updating a model alias.
|
|
||||||
type AliasRequest struct {
|
|
||||||
Alias string `json:"alias"`
|
|
||||||
Target string `json:"target"`
|
|
||||||
PrefixMatching bool `json:"prefix_matching,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetAliasExperimental creates or updates a model alias via the experimental aliases API.
|
|
||||||
func (c *Client) SetAliasExperimental(ctx context.Context, req *AliasRequest) error {
|
|
||||||
return c.do(ctx, http.MethodPost, "/api/experimental/aliases", req, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AliasDeleteRequest is the request body for deleting a model alias.
|
|
||||||
type AliasDeleteRequest struct {
|
|
||||||
Alias string `json:"alias"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteAliasExperimental deletes a model alias via the experimental aliases API.
|
|
||||||
func (c *Client) DeleteAliasExperimental(ctx context.Context, req *AliasDeleteRequest) error {
|
|
||||||
return c.do(ctx, http.MethodDelete, "/api/experimental/aliases", req, nil)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ import (
|
|||||||
var (
|
var (
|
||||||
wv = &Webview{}
|
wv = &Webview{}
|
||||||
uiServerPort int
|
uiServerPort int
|
||||||
|
appStore *store.Store
|
||||||
)
|
)
|
||||||
|
|
||||||
var debug = strings.EqualFold(os.Getenv("OLLAMA_DEBUG"), "true") || os.Getenv("OLLAMA_DEBUG") == "1"
|
var debug = strings.EqualFold(os.Getenv("OLLAMA_DEBUG"), "true") || os.Getenv("OLLAMA_DEBUG") == "1"
|
||||||
@@ -208,6 +209,7 @@ func main() {
|
|||||||
uiServerPort = port
|
uiServerPort = port
|
||||||
|
|
||||||
st := &store.Store{}
|
st := &store.Store{}
|
||||||
|
appStore = st
|
||||||
|
|
||||||
// Enable CORS in development mode
|
// Enable CORS in development mode
|
||||||
if devMode {
|
if devMode {
|
||||||
@@ -253,6 +255,8 @@ func main() {
|
|||||||
done <- osrv.Run(octx)
|
done <- osrv.Run(octx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
upd := &updater.Updater{Store: st}
|
||||||
|
|
||||||
uiServer := ui.Server{
|
uiServer := ui.Server{
|
||||||
Token: token,
|
Token: token,
|
||||||
Restart: func() {
|
Restart: func() {
|
||||||
@@ -267,6 +271,10 @@ func main() {
|
|||||||
ToolRegistry: toolRegistry,
|
ToolRegistry: toolRegistry,
|
||||||
Dev: devMode,
|
Dev: devMode,
|
||||||
Logger: slog.Default(),
|
Logger: slog.Default(),
|
||||||
|
Updater: upd,
|
||||||
|
UpdateAvailableFunc: func() {
|
||||||
|
UpdateAvailable("")
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
@@ -284,8 +292,20 @@ func main() {
|
|||||||
slog.Debug("background desktop server done")
|
slog.Debug("background desktop server done")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
updater := &updater.Updater{Store: st}
|
upd.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
|
||||||
updater.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
|
|
||||||
|
// Check for pending updates on startup (show tray notification if update is ready)
|
||||||
|
if updater.IsUpdatePending() {
|
||||||
|
// On Windows, the tray is initialized in osRun(). Calling UpdateAvailable
|
||||||
|
// before that would dereference a nil tray callback.
|
||||||
|
// TODO: refactor so the update check runs after platform init on all platforms.
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
slog.Debug("update pending on startup, deferring tray notification until tray initialization")
|
||||||
|
} else {
|
||||||
|
slog.Debug("update pending on startup, showing tray notification")
|
||||||
|
UpdateAvailable("")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
hasCompletedFirstRun, err := st.HasCompletedFirstRun()
|
hasCompletedFirstRun, err := st.HasCompletedFirstRun()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -348,6 +368,17 @@ func startHiddenTasks() {
|
|||||||
// CLI triggered app startup use-case
|
// CLI triggered app startup use-case
|
||||||
slog.Info("deferring pending update for fast startup")
|
slog.Info("deferring pending update for fast startup")
|
||||||
} else {
|
} else {
|
||||||
|
// Check if auto-update is enabled before automatically upgrading
|
||||||
|
settings, err := appStore.Settings()
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to load settings for upgrade check", "error", err)
|
||||||
|
} else if !settings.AutoUpdateEnabled {
|
||||||
|
slog.Info("auto-update disabled, skipping automatic upgrade at startup")
|
||||||
|
// Still show tray notification so user knows update is ready
|
||||||
|
UpdateAvailable("")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if err := updater.DoUpgradeAtStartup(); err != nil {
|
if err := updater.DoUpgradeAtStartup(); err != nil {
|
||||||
slog.Info("unable to perform upgrade at startup", "error", err)
|
slog.Info("unable to perform upgrade at startup", "error", err)
|
||||||
// Make sure the restart to upgrade menu shows so we can attempt an interactive upgrade to get authorization
|
// Make sure the restart to upgrade menu shows so we can attempt an interactive upgrade to get authorization
|
||||||
|
|||||||
@@ -154,6 +154,10 @@ func handleURLSchemeRequest(urlScheme string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func UpdateAvailable(ver string) error {
|
func UpdateAvailable(ver string) error {
|
||||||
|
if app.t == nil {
|
||||||
|
slog.Debug("tray not yet initialized, skipping update notification")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return app.t.UpdateAvailable(ver)
|
return app.t.UpdateAvailable(ver)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -165,6 +169,14 @@ func osRun(shutdown func(), hasCompletedFirstRun, startHidden bool) {
|
|||||||
log.Fatalf("Failed to start: %s", err)
|
log.Fatalf("Failed to start: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for pending updates now that the tray is initialized.
|
||||||
|
// The platform-independent check in app.go fires before osRun,
|
||||||
|
// when app.t is still nil, so we must re-check here.
|
||||||
|
if updater.IsUpdatePending() {
|
||||||
|
slog.Debug("update pending on startup, showing tray notification")
|
||||||
|
UpdateAvailable("")
|
||||||
|
}
|
||||||
|
|
||||||
signals := make(chan os.Signal, 1)
|
signals := make(chan os.Signal, 1)
|
||||||
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
|
|||||||
@@ -41,6 +41,11 @@ type InferenceCompute struct {
|
|||||||
VRAM string
|
VRAM string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type InferenceInfo struct {
|
||||||
|
Computes []InferenceCompute
|
||||||
|
DefaultContextLength int
|
||||||
|
}
|
||||||
|
|
||||||
func New(s *store.Store, devMode bool) *Server {
|
func New(s *store.Store, devMode bool) *Server {
|
||||||
p := resolvePath("ollama")
|
p := resolvePath("ollama")
|
||||||
return &Server{store: s, bin: p, dev: devMode}
|
return &Server{store: s, bin: p, dev: devMode}
|
||||||
@@ -272,9 +277,12 @@ func openRotatingLog() (io.WriteCloser, error) {
|
|||||||
|
|
||||||
// Attempt to retrieve inference compute information from the server
|
// Attempt to retrieve inference compute information from the server
|
||||||
// log. Set ctx to timeout to control how long to wait for the logs to appear
|
// log. Set ctx to timeout to control how long to wait for the logs to appear
|
||||||
func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
|
func GetInferenceInfo(ctx context.Context) (*InferenceInfo, error) {
|
||||||
inference := []InferenceCompute{}
|
info := &InferenceInfo{}
|
||||||
marker := regexp.MustCompile(`inference compute.*library=`)
|
computeMarker := regexp.MustCompile(`inference compute.*library=`)
|
||||||
|
defaultCtxMarker := regexp.MustCompile(`vram-based default context`)
|
||||||
|
defaultCtxRegex := regexp.MustCompile(`default_num_ctx=(\d+)`)
|
||||||
|
|
||||||
q := `inference compute.*%s=["]([^"]*)["]`
|
q := `inference compute.*%s=["]([^"]*)["]`
|
||||||
nq := `inference compute.*%s=(\S+)\s`
|
nq := `inference compute.*%s=(\S+)\s`
|
||||||
type regex struct {
|
type regex struct {
|
||||||
@@ -340,8 +348,8 @@ func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
|
|||||||
scanner := bufio.NewScanner(file)
|
scanner := bufio.NewScanner(file)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
match := marker.FindStringSubmatch(line)
|
// Check for inference compute lines
|
||||||
if len(match) > 0 {
|
if computeMarker.MatchString(line) {
|
||||||
ic := InferenceCompute{
|
ic := InferenceCompute{
|
||||||
Library: get("library", line),
|
Library: get("library", line),
|
||||||
Variant: get("variant", line),
|
Variant: get("variant", line),
|
||||||
@@ -352,12 +360,25 @@ func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("Matched", "inference compute", ic)
|
slog.Info("Matched", "inference compute", ic)
|
||||||
inference = append(inference, ic)
|
info.Computes = append(info.Computes, ic)
|
||||||
} else {
|
continue
|
||||||
// Break out on first non matching line after we start matching
|
|
||||||
if len(inference) > 0 {
|
|
||||||
return inference, nil
|
|
||||||
}
|
}
|
||||||
|
// Check for default context length line
|
||||||
|
if defaultCtxMarker.MatchString(line) {
|
||||||
|
match := defaultCtxRegex.FindStringSubmatch(line)
|
||||||
|
if len(match) > 1 {
|
||||||
|
numCtx, err := strconv.Atoi(match[1])
|
||||||
|
if err == nil {
|
||||||
|
info.DefaultContextLength = numCtx
|
||||||
|
slog.Info("Matched default context length", "default_num_ctx", numCtx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return info, nil
|
||||||
|
}
|
||||||
|
// If we've found compute info but hit a non-matching line, return what we have
|
||||||
|
// This handles older server versions that don't log the default context line
|
||||||
|
if len(info.Computes) > 0 {
|
||||||
|
return info, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|||||||
@@ -205,44 +205,50 @@ func TestServerCmdCloudSettingEnv(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetInferenceComputer(t *testing.T) {
|
func TestGetInferenceInfo(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
log string
|
log string
|
||||||
exp []InferenceCompute
|
expComputes []InferenceCompute
|
||||||
|
expDefaultCtxLen int
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "metal",
|
name: "metal",
|
||||||
log: `time=2025-06-30T09:23:07.374-07:00 level=DEBUG source=sched.go:108 msg="starting llm scheduler"
|
log: `time=2025-06-30T09:23:07.374-07:00 level=DEBUG source=sched.go:108 msg="starting llm scheduler"
|
||||||
time=2025-06-30T09:23:07.416-07:00 level=INFO source=types.go:130 msg="inference compute" id=0 library=metal variant="" compute="" driver=0.0 name="" total="96.0 GiB" available="96.0 GiB"
|
time=2025-06-30T09:23:07.416-07:00 level=INFO source=types.go:130 msg="inference compute" id=0 library=metal variant="" compute="" driver=0.0 name="" total="96.0 GiB" available="96.0 GiB"
|
||||||
|
time=2025-06-30T09:23:07.417-07:00 level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="96.0 GiB" default_num_ctx=262144
|
||||||
time=2025-06-30T09:25:56.197-07:00 level=DEBUG source=ggml.go:155 msg="key not found" key=general.alignment default=32
|
time=2025-06-30T09:25:56.197-07:00 level=DEBUG source=ggml.go:155 msg="key not found" key=general.alignment default=32
|
||||||
`,
|
`,
|
||||||
exp: []InferenceCompute{{
|
expComputes: []InferenceCompute{{
|
||||||
Library: "metal",
|
Library: "metal",
|
||||||
Driver: "0.0",
|
Driver: "0.0",
|
||||||
VRAM: "96.0 GiB",
|
VRAM: "96.0 GiB",
|
||||||
}},
|
}},
|
||||||
|
expDefaultCtxLen: 262144,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "cpu",
|
name: "cpu",
|
||||||
log: `time=2025-07-01T17:59:51.470Z level=INFO source=gpu.go:377 msg="no compatible GPUs were discovered"
|
log: `time=2025-07-01T17:59:51.470Z level=INFO source=gpu.go:377 msg="no compatible GPUs were discovered"
|
||||||
time=2025-07-01T17:59:51.470Z level=INFO source=types.go:130 msg="inference compute" id=0 library=cpu variant="" compute="" driver=0.0 name="" total="31.3 GiB" available="30.4 GiB"
|
time=2025-07-01T17:59:51.470Z level=INFO source=types.go:130 msg="inference compute" id=0 library=cpu variant="" compute="" driver=0.0 name="" total="31.3 GiB" available="30.4 GiB"
|
||||||
|
time=2025-07-01T17:59:51.471Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="31.3 GiB" default_num_ctx=32768
|
||||||
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
|
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
|
||||||
`,
|
`,
|
||||||
exp: []InferenceCompute{{
|
expComputes: []InferenceCompute{{
|
||||||
Library: "cpu",
|
Library: "cpu",
|
||||||
Driver: "0.0",
|
Driver: "0.0",
|
||||||
VRAM: "31.3 GiB",
|
VRAM: "31.3 GiB",
|
||||||
}},
|
}},
|
||||||
|
expDefaultCtxLen: 32768,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "cuda1",
|
name: "cuda1",
|
||||||
log: `time=2025-07-01T19:33:43.162Z level=DEBUG source=amd_linux.go:419 msg="amdgpu driver not detected /sys/module/amdgpu"
|
log: `time=2025-07-01T19:33:43.162Z level=DEBUG source=amd_linux.go:419 msg="amdgpu driver not detected /sys/module/amdgpu"
|
||||||
releasing cuda driver library
|
releasing cuda driver library
|
||||||
time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference compute" id=GPU-452cac9f-6960-839c-4fb3-0cec83699196 library=cuda variant=v12 compute=6.1 driver=12.7 name="NVIDIA GeForce GT 1030" total="3.9 GiB" available="3.9 GiB"
|
time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference compute" id=GPU-452cac9f-6960-839c-4fb3-0cec83699196 library=cuda variant=v12 compute=6.1 driver=12.7 name="NVIDIA GeForce GT 1030" total="3.9 GiB" available="3.9 GiB"
|
||||||
|
time=2025-07-01T19:33:43.163Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="3.9 GiB" default_num_ctx=4096
|
||||||
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
|
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
|
||||||
`,
|
`,
|
||||||
exp: []InferenceCompute{{
|
expComputes: []InferenceCompute{{
|
||||||
Library: "cuda",
|
Library: "cuda",
|
||||||
Variant: "v12",
|
Variant: "v12",
|
||||||
Compute: "6.1",
|
Compute: "6.1",
|
||||||
@@ -250,6 +256,7 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
|
|||||||
Name: "NVIDIA GeForce GT 1030",
|
Name: "NVIDIA GeForce GT 1030",
|
||||||
VRAM: "3.9 GiB",
|
VRAM: "3.9 GiB",
|
||||||
}},
|
}},
|
||||||
|
expDefaultCtxLen: 4096,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "frank",
|
name: "frank",
|
||||||
@@ -257,9 +264,10 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
|
|||||||
releasing cuda driver library
|
releasing cuda driver library
|
||||||
time=2025-07-01T19:36:13.315Z level=INFO source=types.go:130 msg="inference compute" id=GPU-d6de3398-9932-6902-11ec-fee8e424c8a2 library=cuda variant=v12 compute=7.5 driver=12.8 name="NVIDIA GeForce RTX 2080 Ti" total="10.6 GiB" available="10.4 GiB"
|
time=2025-07-01T19:36:13.315Z level=INFO source=types.go:130 msg="inference compute" id=GPU-d6de3398-9932-6902-11ec-fee8e424c8a2 library=cuda variant=v12 compute=7.5 driver=12.8 name="NVIDIA GeForce RTX 2080 Ti" total="10.6 GiB" available="10.4 GiB"
|
||||||
time=2025-07-01T19:36:13.315Z level=INFO source=types.go:130 msg="inference compute" id=GPU-9abb57639fa80c50 library=rocm variant="" compute=gfx1030 driver=6.3 name=1002:73bf total="16.0 GiB" available="1.3 GiB"
|
time=2025-07-01T19:36:13.315Z level=INFO source=types.go:130 msg="inference compute" id=GPU-9abb57639fa80c50 library=rocm variant="" compute=gfx1030 driver=6.3 name=1002:73bf total="16.0 GiB" available="1.3 GiB"
|
||||||
|
time=2025-07-01T19:36:13.316Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="26.6 GiB" default_num_ctx=32768
|
||||||
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
|
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
|
||||||
`,
|
`,
|
||||||
exp: []InferenceCompute{
|
expComputes: []InferenceCompute{
|
||||||
{
|
{
|
||||||
Library: "cuda",
|
Library: "cuda",
|
||||||
Variant: "v12",
|
Variant: "v12",
|
||||||
@@ -276,6 +284,20 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
|
|||||||
VRAM: "16.0 GiB",
|
VRAM: "16.0 GiB",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
expDefaultCtxLen: 32768,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing_default_context",
|
||||||
|
log: `time=2025-06-30T09:23:07.374-07:00 level=DEBUG source=sched.go:108 msg="starting llm scheduler"
|
||||||
|
time=2025-06-30T09:23:07.416-07:00 level=INFO source=types.go:130 msg="inference compute" id=0 library=metal variant="" compute="" driver=0.0 name="" total="96.0 GiB" available="96.0 GiB"
|
||||||
|
time=2025-06-30T09:25:56.197-07:00 level=DEBUG source=ggml.go:155 msg="key not found" key=general.alignment default=32
|
||||||
|
`,
|
||||||
|
expComputes: []InferenceCompute{{
|
||||||
|
Library: "metal",
|
||||||
|
Driver: "0.0",
|
||||||
|
VRAM: "96.0 GiB",
|
||||||
|
}},
|
||||||
|
expDefaultCtxLen: 0, // No default context line, should return 0
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -288,18 +310,21 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
|
|||||||
}
|
}
|
||||||
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond)
|
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
ics, err := GetInferenceComputer(ctx)
|
info, err := GetInferenceInfo(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf(" failed to get inference compute: %v", err)
|
t.Fatalf("failed to get inference info: %v", err)
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(ics, tt.exp) {
|
if !reflect.DeepEqual(info.Computes, tt.expComputes) {
|
||||||
t.Fatalf("got:\n%#v\nwant:\n%#v", ics, tt.exp)
|
t.Fatalf("computes mismatch\ngot:\n%#v\nwant:\n%#v", info.Computes, tt.expComputes)
|
||||||
|
}
|
||||||
|
if info.DefaultContextLength != tt.expDefaultCtxLen {
|
||||||
|
t.Fatalf("default context length mismatch: got %d, want %d", info.DefaultContextLength, tt.expDefaultCtxLen)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetInferenceComputerTimeout(t *testing.T) {
|
func TestGetInferenceInfoTimeout(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond)
|
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
@@ -308,7 +333,7 @@ func TestGetInferenceComputerTimeout(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to write log file %s: %s", serverLogPath, err)
|
t.Fatalf("failed to write log file %s: %s", serverLogPath, err)
|
||||||
}
|
}
|
||||||
_, err = GetInferenceComputer(ctx)
|
_, err = GetInferenceInfo(ctx)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected timeout")
|
t.Fatal("expected timeout")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,12 +9,12 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
sqlite3 "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
// currentSchemaVersion defines the current database schema version.
|
// currentSchemaVersion defines the current database schema version.
|
||||||
// Increment this when making schema changes that require migrations.
|
// Increment this when making schema changes that require migrations.
|
||||||
const currentSchemaVersion = 13
|
const currentSchemaVersion = 15
|
||||||
|
|
||||||
// database wraps the SQLite connection.
|
// database wraps the SQLite connection.
|
||||||
// SQLite handles its own locking for concurrent access:
|
// SQLite handles its own locking for concurrent access:
|
||||||
@@ -73,7 +73,7 @@ func (db *database) init() error {
|
|||||||
agent BOOLEAN NOT NULL DEFAULT 0,
|
agent BOOLEAN NOT NULL DEFAULT 0,
|
||||||
tools BOOLEAN NOT NULL DEFAULT 0,
|
tools BOOLEAN NOT NULL DEFAULT 0,
|
||||||
working_dir TEXT NOT NULL DEFAULT '',
|
working_dir TEXT NOT NULL DEFAULT '',
|
||||||
context_length INTEGER NOT NULL DEFAULT 4096,
|
context_length INTEGER NOT NULL DEFAULT 0,
|
||||||
window_width INTEGER NOT NULL DEFAULT 0,
|
window_width INTEGER NOT NULL DEFAULT 0,
|
||||||
window_height INTEGER NOT NULL DEFAULT 0,
|
window_height INTEGER NOT NULL DEFAULT 0,
|
||||||
config_migrated BOOLEAN NOT NULL DEFAULT 0,
|
config_migrated BOOLEAN NOT NULL DEFAULT 0,
|
||||||
@@ -86,6 +86,7 @@ func (db *database) init() error {
|
|||||||
think_level TEXT NOT NULL DEFAULT '',
|
think_level TEXT NOT NULL DEFAULT '',
|
||||||
cloud_setting_migrated BOOLEAN NOT NULL DEFAULT 0,
|
cloud_setting_migrated BOOLEAN NOT NULL DEFAULT 0,
|
||||||
remote TEXT NOT NULL DEFAULT '', -- deprecated
|
remote TEXT NOT NULL DEFAULT '', -- deprecated
|
||||||
|
auto_update_enabled BOOLEAN NOT NULL DEFAULT 1,
|
||||||
schema_version INTEGER NOT NULL DEFAULT %d
|
schema_version INTEGER NOT NULL DEFAULT %d
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -251,6 +252,18 @@ func (db *database) migrate() error {
|
|||||||
return fmt.Errorf("migrate v12 to v13: %w", err)
|
return fmt.Errorf("migrate v12 to v13: %w", err)
|
||||||
}
|
}
|
||||||
version = 13
|
version = 13
|
||||||
|
case 13:
|
||||||
|
// change default context_length from 4096 to 0 (VRAM-based tiered defaults)
|
||||||
|
if err := db.migrateV13ToV14(); err != nil {
|
||||||
|
return fmt.Errorf("migrate v13 to v14: %w", err)
|
||||||
|
}
|
||||||
|
version = 14
|
||||||
|
case 14:
|
||||||
|
// add auto_update_enabled column to settings table
|
||||||
|
if err := db.migrateV14ToV15(); err != nil {
|
||||||
|
return fmt.Errorf("migrate v14 to v15: %w", err)
|
||||||
|
}
|
||||||
|
version = 15
|
||||||
default:
|
default:
|
||||||
// If we have a version we don't recognize, just set it to current
|
// If we have a version we don't recognize, just set it to current
|
||||||
// This might happen during development
|
// This might happen during development
|
||||||
@@ -474,6 +487,37 @@ func (db *database) migrateV12ToV13() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// migrateV13ToV14 changes the default context_length from 4096 to 0.
|
||||||
|
// When context_length is 0, the ollama server uses VRAM-based tiered defaults.
|
||||||
|
func (db *database) migrateV13ToV14() error {
|
||||||
|
_, err := db.conn.Exec(`UPDATE settings SET context_length = 0 WHERE context_length = 4096`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update context_length default: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 14`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update schema version: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// migrateV14ToV15 adds the auto_update_enabled column to the settings table
|
||||||
|
func (db *database) migrateV14ToV15() error {
|
||||||
|
_, err := db.conn.Exec(`ALTER TABLE settings ADD COLUMN auto_update_enabled BOOLEAN NOT NULL DEFAULT 1`)
|
||||||
|
if err != nil && !duplicateColumnError(err) {
|
||||||
|
return fmt.Errorf("add auto_update_enabled column: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 15`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update schema version: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
|
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
|
||||||
func (db *database) cleanupOrphanedData() error {
|
func (db *database) cleanupOrphanedData() error {
|
||||||
_, err := db.conn.Exec(`
|
_, err := db.conn.Exec(`
|
||||||
@@ -504,19 +548,11 @@ func (db *database) cleanupOrphanedData() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func duplicateColumnError(err error) bool {
|
func duplicateColumnError(err error) bool {
|
||||||
if sqlite3Err, ok := err.(sqlite3.Error); ok {
|
return err != nil && strings.Contains(err.Error(), "duplicate column name")
|
||||||
return sqlite3Err.Code == sqlite3.ErrError &&
|
|
||||||
strings.Contains(sqlite3Err.Error(), "duplicate column name")
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func columnNotExists(err error) bool {
|
func columnNotExists(err error) bool {
|
||||||
if sqlite3Err, ok := err.(sqlite3.Error); ok {
|
return err != nil && strings.Contains(err.Error(), "no such column")
|
||||||
return sqlite3Err.Code == sqlite3.ErrError &&
|
|
||||||
strings.Contains(sqlite3Err.Error(), "no such column")
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *database) getAllChats() ([]Chat, error) {
|
func (db *database) getAllChats() ([]Chat, error) {
|
||||||
@@ -1130,9 +1166,9 @@ func (db *database) getSettings() (Settings, error) {
|
|||||||
var s Settings
|
var s Settings
|
||||||
|
|
||||||
err := db.conn.QueryRow(`
|
err := db.conn.QueryRow(`
|
||||||
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level
|
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level, auto_update_enabled
|
||||||
FROM settings
|
FROM settings
|
||||||
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel)
|
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel, &s.AutoUpdateEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Settings{}, fmt.Errorf("get settings: %w", err)
|
return Settings{}, fmt.Errorf("get settings: %w", err)
|
||||||
}
|
}
|
||||||
@@ -1143,8 +1179,8 @@ func (db *database) getSettings() (Settings, error) {
|
|||||||
func (db *database) setSettings(s Settings) error {
|
func (db *database) setSettings(s Settings) error {
|
||||||
_, err := db.conn.Exec(`
|
_, err := db.conn.Exec(`
|
||||||
UPDATE settings
|
UPDATE settings
|
||||||
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?
|
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?, auto_update_enabled = ?
|
||||||
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel)
|
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel, s.AutoUpdateEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("set settings: %w", err)
|
return fmt.Errorf("set settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -98,6 +98,43 @@ func TestSchemaMigrations(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMigrationV13ToV14ContextLength(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
dbPath := filepath.Join(tmpDir, "test.db")
|
||||||
|
|
||||||
|
db, err := newDatabase(dbPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create database: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
_, err = db.conn.Exec("UPDATE settings SET context_length = 4096, schema_version = 13")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to seed v13 settings row: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.migrate(); err != nil {
|
||||||
|
t.Fatalf("migration from v13 to v14 failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var contextLength int
|
||||||
|
if err := db.conn.QueryRow("SELECT context_length FROM settings").Scan(&contextLength); err != nil {
|
||||||
|
t.Fatalf("failed to read context_length: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if contextLength != 0 {
|
||||||
|
t.Fatalf("expected context_length to migrate to 0, got %d", contextLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
version, err := db.getSchemaVersion()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get schema version: %v", err)
|
||||||
|
}
|
||||||
|
if version != currentSchemaVersion {
|
||||||
|
t.Fatalf("expected schema version %d, got %d", currentSchemaVersion, version)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestChatDeletionWithCascade(t *testing.T) {
|
func TestChatDeletionWithCascade(t *testing.T) {
|
||||||
t.Run("chat deletion cascades to related messages", func(t *testing.T) {
|
t.Run("chat deletion cascades to related messages", func(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
|
|||||||
@@ -166,6 +166,9 @@ type Settings struct {
|
|||||||
|
|
||||||
// SidebarOpen indicates if the chat sidebar is open
|
// SidebarOpen indicates if the chat sidebar is open
|
||||||
SidebarOpen bool
|
SidebarOpen bool
|
||||||
|
|
||||||
|
// AutoUpdateEnabled indicates if automatic updates should be downloaded
|
||||||
|
AutoUpdateEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type Store struct {
|
type Store struct {
|
||||||
|
|||||||
2
app/store/testdata/schema.sql
vendored
2
app/store/testdata/schema.sql
vendored
@@ -13,7 +13,7 @@ CREATE TABLE IF NOT EXISTS settings (
|
|||||||
agent BOOLEAN NOT NULL DEFAULT 0,
|
agent BOOLEAN NOT NULL DEFAULT 0,
|
||||||
tools BOOLEAN NOT NULL DEFAULT 0,
|
tools BOOLEAN NOT NULL DEFAULT 0,
|
||||||
working_dir TEXT NOT NULL DEFAULT '',
|
working_dir TEXT NOT NULL DEFAULT '',
|
||||||
context_length INTEGER NOT NULL DEFAULT 4096,
|
context_length INTEGER NOT NULL DEFAULT 0,
|
||||||
window_width INTEGER NOT NULL DEFAULT 0,
|
window_width INTEGER NOT NULL DEFAULT 0,
|
||||||
window_height INTEGER NOT NULL DEFAULT 0,
|
window_height INTEGER NOT NULL DEFAULT 0,
|
||||||
config_migrated BOOLEAN NOT NULL DEFAULT 0,
|
config_migrated BOOLEAN NOT NULL DEFAULT 0,
|
||||||
|
|||||||
@@ -289,10 +289,12 @@ export class InferenceCompute {
|
|||||||
}
|
}
|
||||||
export class InferenceComputeResponse {
|
export class InferenceComputeResponse {
|
||||||
inferenceComputes: InferenceCompute[];
|
inferenceComputes: InferenceCompute[];
|
||||||
|
defaultContextLength: number;
|
||||||
|
|
||||||
constructor(source: any = {}) {
|
constructor(source: any = {}) {
|
||||||
if ('string' === typeof source) source = JSON.parse(source);
|
if ('string' === typeof source) source = JSON.parse(source);
|
||||||
this.inferenceComputes = this.convertValues(source["inferenceComputes"], InferenceCompute);
|
this.inferenceComputes = this.convertValues(source["inferenceComputes"], InferenceCompute);
|
||||||
|
this.defaultContextLength = source["defaultContextLength"];
|
||||||
}
|
}
|
||||||
|
|
||||||
convertValues(a: any, classs: any, asMap: boolean = false): any {
|
convertValues(a: any, classs: any, asMap: boolean = false): any {
|
||||||
@@ -412,6 +414,7 @@ export class Settings {
|
|||||||
ThinkLevel: string;
|
ThinkLevel: string;
|
||||||
SelectedModel: string;
|
SelectedModel: string;
|
||||||
SidebarOpen: boolean;
|
SidebarOpen: boolean;
|
||||||
|
AutoUpdateEnabled: boolean;
|
||||||
|
|
||||||
constructor(source: any = {}) {
|
constructor(source: any = {}) {
|
||||||
if ('string' === typeof source) source = JSON.parse(source);
|
if ('string' === typeof source) source = JSON.parse(source);
|
||||||
@@ -429,6 +432,7 @@ export class Settings {
|
|||||||
this.ThinkLevel = source["ThinkLevel"];
|
this.ThinkLevel = source["ThinkLevel"];
|
||||||
this.SelectedModel = source["SelectedModel"];
|
this.SelectedModel = source["SelectedModel"];
|
||||||
this.SidebarOpen = source["SidebarOpen"];
|
this.SidebarOpen = source["SidebarOpen"];
|
||||||
|
this.AutoUpdateEnabled = source["AutoUpdateEnabled"];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
export class SettingsResponse {
|
export class SettingsResponse {
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import {
|
|||||||
ChatEvent,
|
ChatEvent,
|
||||||
DownloadEvent,
|
DownloadEvent,
|
||||||
ErrorEvent,
|
ErrorEvent,
|
||||||
InferenceCompute,
|
|
||||||
InferenceComputeResponse,
|
InferenceComputeResponse,
|
||||||
ModelCapabilitiesResponse,
|
ModelCapabilitiesResponse,
|
||||||
Model,
|
Model,
|
||||||
@@ -407,7 +406,7 @@ export async function* pullModel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function getInferenceCompute(): Promise<InferenceCompute[]> {
|
export async function getInferenceCompute(): Promise<InferenceComputeResponse> {
|
||||||
const response = await fetch(`${API_BASE}/api/v1/inference-compute`);
|
const response = await fetch(`${API_BASE}/api/v1/inference-compute`);
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
@@ -416,8 +415,7 @@ export async function getInferenceCompute(): Promise<InferenceCompute[]> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
const inferenceComputeResponse = new InferenceComputeResponse(data);
|
return new InferenceComputeResponse(data);
|
||||||
return inferenceComputeResponse.inferenceComputes || [];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function fetchHealth(): Promise<boolean> {
|
export async function fetchHealth(): Promise<boolean> {
|
||||||
|
|||||||
@@ -17,7 +17,10 @@ import {
|
|||||||
} from "@/hooks/useChats";
|
} from "@/hooks/useChats";
|
||||||
import { useNavigate } from "@tanstack/react-router";
|
import { useNavigate } from "@tanstack/react-router";
|
||||||
import { useSelectedModel } from "@/hooks/useSelectedModel";
|
import { useSelectedModel } from "@/hooks/useSelectedModel";
|
||||||
import { useHasVisionCapability } from "@/hooks/useModelCapabilities";
|
import {
|
||||||
|
useHasVisionCapability,
|
||||||
|
useHasToolsCapability,
|
||||||
|
} from "@/hooks/useModelCapabilities";
|
||||||
import { useUser } from "@/hooks/useUser";
|
import { useUser } from "@/hooks/useUser";
|
||||||
import { DisplayLogin } from "@/components/DisplayLogin";
|
import { DisplayLogin } from "@/components/DisplayLogin";
|
||||||
import { ErrorEvent, Message } from "@/gotypes";
|
import { ErrorEvent, Message } from "@/gotypes";
|
||||||
@@ -149,12 +152,7 @@ function ChatForm({
|
|||||||
} = useSettings();
|
} = useSettings();
|
||||||
const { cloudDisabled } = useCloudStatus();
|
const { cloudDisabled } = useCloudStatus();
|
||||||
|
|
||||||
// current supported models for web search
|
const supportsWebSearch = useHasToolsCapability(selectedModel?.model);
|
||||||
const modelLower = selectedModel?.model.toLowerCase() || "";
|
|
||||||
const supportsWebSearch =
|
|
||||||
modelLower.startsWith("gpt-oss") ||
|
|
||||||
modelLower.startsWith("qwen3") ||
|
|
||||||
modelLower.startsWith("deepseek-v3");
|
|
||||||
// Use per-chat thinking level instead of global
|
// Use per-chat thinking level instead of global
|
||||||
const thinkLevel: ThinkingLevel =
|
const thinkLevel: ThinkingLevel =
|
||||||
settingsThinkLevel === "none" || !settingsThinkLevel
|
settingsThinkLevel === "none" || !settingsThinkLevel
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import {
|
|||||||
XMarkIcon,
|
XMarkIcon,
|
||||||
CogIcon,
|
CogIcon,
|
||||||
ArrowLeftIcon,
|
ArrowLeftIcon,
|
||||||
|
ArrowDownTrayIcon,
|
||||||
} from "@heroicons/react/20/solid";
|
} from "@heroicons/react/20/solid";
|
||||||
import { Settings as SettingsType } from "@/gotypes";
|
import { Settings as SettingsType } from "@/gotypes";
|
||||||
import { useNavigate } from "@tanstack/react-router";
|
import { useNavigate } from "@tanstack/react-router";
|
||||||
@@ -26,6 +27,7 @@ import {
|
|||||||
type CloudStatusResponse,
|
type CloudStatusResponse,
|
||||||
updateCloudSetting,
|
updateCloudSetting,
|
||||||
updateSettings,
|
updateSettings,
|
||||||
|
getInferenceCompute,
|
||||||
} from "@/api";
|
} from "@/api";
|
||||||
|
|
||||||
function AnimatedDots() {
|
function AnimatedDots() {
|
||||||
@@ -77,6 +79,13 @@ export default function Settings() {
|
|||||||
|
|
||||||
const settings = settingsData?.settings || null;
|
const settings = settingsData?.settings || null;
|
||||||
|
|
||||||
|
const { data: inferenceComputeResponse } = useQuery({
|
||||||
|
queryKey: ["inferenceCompute"],
|
||||||
|
queryFn: getInferenceCompute,
|
||||||
|
});
|
||||||
|
|
||||||
|
const defaultContextLength = inferenceComputeResponse?.defaultContextLength;
|
||||||
|
|
||||||
const updateSettingsMutation = useMutation({
|
const updateSettingsMutation = useMutation({
|
||||||
mutationFn: updateSettings,
|
mutationFn: updateSettings,
|
||||||
onSuccess: () => {
|
onSuccess: () => {
|
||||||
@@ -204,7 +213,8 @@ export default function Settings() {
|
|||||||
Models: "",
|
Models: "",
|
||||||
Agent: false,
|
Agent: false,
|
||||||
Tools: false,
|
Tools: false,
|
||||||
ContextLength: 4096,
|
ContextLength: 0,
|
||||||
|
AutoUpdateEnabled: true,
|
||||||
});
|
});
|
||||||
updateSettingsMutation.mutate(defaultSettings);
|
updateSettingsMutation.mutate(defaultSettings);
|
||||||
}
|
}
|
||||||
@@ -432,6 +442,29 @@ export default function Settings() {
|
|||||||
</div>
|
</div>
|
||||||
</Field>
|
</Field>
|
||||||
|
|
||||||
|
{/* Auto Update */}
|
||||||
|
<Field>
|
||||||
|
<div className="flex items-start justify-between gap-4">
|
||||||
|
<div className="flex items-start space-x-3 flex-1">
|
||||||
|
<ArrowDownTrayIcon className="mt-1 h-5 w-5 flex-shrink-0 text-black dark:text-neutral-100" />
|
||||||
|
<div>
|
||||||
|
<Label>Auto-download updates</Label>
|
||||||
|
<Description>
|
||||||
|
{settings.AutoUpdateEnabled
|
||||||
|
? "Automatically download updates when available."
|
||||||
|
: "Updates will not be downloaded automatically."}
|
||||||
|
</Description>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="flex-shrink-0">
|
||||||
|
<Switch
|
||||||
|
checked={settings.AutoUpdateEnabled}
|
||||||
|
onChange={(checked) => handleChange("AutoUpdateEnabled", checked)}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</Field>
|
||||||
|
|
||||||
{/* Expose Ollama */}
|
{/* Expose Ollama */}
|
||||||
<Field>
|
<Field>
|
||||||
<div className="flex items-start justify-between gap-4">
|
<div className="flex items-start justify-between gap-4">
|
||||||
@@ -507,13 +540,11 @@ export default function Settings() {
|
|||||||
</Description>
|
</Description>
|
||||||
<div className="mt-3">
|
<div className="mt-3">
|
||||||
<Slider
|
<Slider
|
||||||
value={(() => {
|
value={settings.ContextLength || defaultContextLength || 0}
|
||||||
// Otherwise use the settings value
|
|
||||||
return settings.ContextLength || 4096;
|
|
||||||
})()}
|
|
||||||
onChange={(value) => {
|
onChange={(value) => {
|
||||||
handleChange("ContextLength", value);
|
handleChange("ContextLength", value);
|
||||||
}}
|
}}
|
||||||
|
disabled={!defaultContextLength}
|
||||||
options={[
|
options={[
|
||||||
{ value: 4096, label: "4k" },
|
{ value: 4096, label: "4k" },
|
||||||
{ value: 8192, label: "8k" },
|
{ value: 8192, label: "8k" },
|
||||||
|
|||||||
@@ -6,10 +6,11 @@ export interface SliderProps {
|
|||||||
value?: number;
|
value?: number;
|
||||||
onChange?: (value: number) => void;
|
onChange?: (value: number) => void;
|
||||||
className?: string;
|
className?: string;
|
||||||
|
disabled?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
||||||
({ label, options, value = 0, onChange }, ref) => {
|
({ label, options, value = 0, onChange, disabled = false }, ref) => {
|
||||||
const [selectedValue, setSelectedValue] = React.useState(value);
|
const [selectedValue, setSelectedValue] = React.useState(value);
|
||||||
const [isDragging, setIsDragging] = React.useState(false);
|
const [isDragging, setIsDragging] = React.useState(false);
|
||||||
const containerRef = React.useRef<HTMLDivElement>(null);
|
const containerRef = React.useRef<HTMLDivElement>(null);
|
||||||
@@ -20,6 +21,7 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
|||||||
}, [value]);
|
}, [value]);
|
||||||
|
|
||||||
const handleClick = (optionValue: number) => {
|
const handleClick = (optionValue: number) => {
|
||||||
|
if (disabled) return;
|
||||||
setSelectedValue(optionValue);
|
setSelectedValue(optionValue);
|
||||||
onChange?.(optionValue);
|
onChange?.(optionValue);
|
||||||
};
|
};
|
||||||
@@ -39,6 +41,7 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
|||||||
};
|
};
|
||||||
|
|
||||||
const handleMouseDown = (e: React.MouseEvent) => {
|
const handleMouseDown = (e: React.MouseEvent) => {
|
||||||
|
if (disabled) return;
|
||||||
setIsDragging(true);
|
setIsDragging(true);
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
};
|
};
|
||||||
@@ -77,7 +80,7 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="space-y-2" ref={ref}>
|
<div className={`space-y-2 ${disabled ? "opacity-50" : ""}`} ref={ref}>
|
||||||
{label && <label className="text-sm font-medium">{label}</label>}
|
{label && <label className="text-sm font-medium">{label}</label>}
|
||||||
<div className="relative">
|
<div className="relative">
|
||||||
<div className="absolute top-[9px] left-2 right-2 h-1 bg-neutral-200 dark:bg-neutral-700 pointer-events-none rounded-full" />
|
<div className="absolute top-[9px] left-2 right-2 h-1 bg-neutral-200 dark:bg-neutral-700 pointer-events-none rounded-full" />
|
||||||
@@ -88,10 +91,11 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
|||||||
<button
|
<button
|
||||||
onClick={() => handleClick(option.value)}
|
onClick={() => handleClick(option.value)}
|
||||||
onMouseDown={handleMouseDown}
|
onMouseDown={handleMouseDown}
|
||||||
className="relative px-3 py-6 -mx-3 -my-6 z-10 cursor-pointer"
|
disabled={disabled}
|
||||||
|
className={`relative px-3 py-6 -mx-3 -my-6 z-10 ${disabled ? "cursor-not-allowed" : "cursor-pointer"}`}
|
||||||
>
|
>
|
||||||
<div className="relative w-5 h-5 flex items-center justify-center">
|
<div className="relative w-5 h-5 flex items-center justify-center">
|
||||||
{selectedValue === option.value && (
|
{selectedValue === option.value && !disabled && (
|
||||||
<div className="w-4 h-4 bg-white dark:bg-white border border-neutral-400 dark:border-neutral-500 rounded-full cursor-grab active:cursor-grabbing" />
|
<div className="w-4 h-4 bg-white dark:bg-white border border-neutral-400 dark:border-neutral-500 rounded-full cursor-grab active:cursor-grabbing" />
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -20,3 +20,8 @@ export function useHasVisionCapability(modelName: string | undefined) {
|
|||||||
const { data: capabilitiesResponse } = useModelCapabilities(modelName);
|
const { data: capabilitiesResponse } = useModelCapabilities(modelName);
|
||||||
return capabilitiesResponse?.capabilities?.includes("vision") ?? false;
|
return capabilitiesResponse?.capabilities?.includes("vision") ?? false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function useHasToolsCapability(modelName: string | undefined) {
|
||||||
|
const { data: capabilitiesResponse } = useModelCapabilities(modelName);
|
||||||
|
return capabilitiesResponse?.capabilities?.includes("tools") ?? false;
|
||||||
|
}
|
||||||
|
|||||||
@@ -28,12 +28,14 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
|||||||
currentChatId && currentChatId !== "new" ? currentChatId : "",
|
currentChatId && currentChatId !== "new" ? currentChatId : "",
|
||||||
);
|
);
|
||||||
|
|
||||||
const { data: inferenceComputes = [] } = useQuery({
|
const { data: inferenceComputeResponse } = useQuery({
|
||||||
queryKey: ["inference-compute"],
|
queryKey: ["inferenceCompute"],
|
||||||
queryFn: getInferenceCompute,
|
queryFn: getInferenceCompute,
|
||||||
enabled: !settings.selectedModel, // Only fetch if no model is selected
|
enabled: !settings.selectedModel, // Only fetch if no model is selected
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const inferenceComputes = inferenceComputeResponse?.inferenceComputes || [];
|
||||||
|
|
||||||
const totalVRAM = useMemo(
|
const totalVRAM = useMemo(
|
||||||
() => getTotalVRAM(inferenceComputes),
|
() => getTotalVRAM(inferenceComputes),
|
||||||
[inferenceComputes],
|
[inferenceComputes],
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ type InferenceCompute struct {
|
|||||||
|
|
||||||
type InferenceComputeResponse struct {
|
type InferenceComputeResponse struct {
|
||||||
InferenceComputes []InferenceCompute `json:"inferenceComputes"`
|
InferenceComputes []InferenceCompute `json:"inferenceComputes"`
|
||||||
|
DefaultContextLength int `json:"defaultContextLength"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ModelCapabilitiesResponse struct {
|
type ModelCapabilitiesResponse struct {
|
||||||
|
|||||||
55
app/ui/ui.go
55
app/ui/ui.go
@@ -28,6 +28,7 @@ import (
|
|||||||
"github.com/ollama/ollama/app/tools"
|
"github.com/ollama/ollama/app/tools"
|
||||||
"github.com/ollama/ollama/app/types/not"
|
"github.com/ollama/ollama/app/types/not"
|
||||||
"github.com/ollama/ollama/app/ui/responses"
|
"github.com/ollama/ollama/app/ui/responses"
|
||||||
|
"github.com/ollama/ollama/app/updater"
|
||||||
"github.com/ollama/ollama/app/version"
|
"github.com/ollama/ollama/app/version"
|
||||||
ollamaAuth "github.com/ollama/ollama/auth"
|
ollamaAuth "github.com/ollama/ollama/auth"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
@@ -106,6 +107,10 @@ type Server struct {
|
|||||||
|
|
||||||
// Dev is true if the server is running in development mode
|
// Dev is true if the server is running in development mode
|
||||||
Dev bool
|
Dev bool
|
||||||
|
|
||||||
|
// Updater for checking and downloading updates
|
||||||
|
Updater *updater.Updater
|
||||||
|
UpdateAvailableFunc func()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) log() *slog.Logger {
|
func (s *Server) log() *slog.Logger {
|
||||||
@@ -829,8 +834,9 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
|
|||||||
|
|
||||||
if !hasAttachments {
|
if !hasAttachments {
|
||||||
WebSearchEnabled := req.WebSearch != nil && *req.WebSearch
|
WebSearchEnabled := req.WebSearch != nil && *req.WebSearch
|
||||||
|
hasToolsCapability := slices.Contains(details.Capabilities, model.CapabilityTools)
|
||||||
|
|
||||||
if WebSearchEnabled {
|
if WebSearchEnabled && hasToolsCapability {
|
||||||
if supportsBrowserTools(req.Model) {
|
if supportsBrowserTools(req.Model) {
|
||||||
browserState, ok := s.browserState(chat)
|
browserState, ok := s.browserState(chat)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -840,7 +846,7 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
|
|||||||
registry.Register(tools.NewBrowserSearch(browser))
|
registry.Register(tools.NewBrowserSearch(browser))
|
||||||
registry.Register(tools.NewBrowserOpen(browser))
|
registry.Register(tools.NewBrowserOpen(browser))
|
||||||
registry.Register(tools.NewBrowserFind(browser))
|
registry.Register(tools.NewBrowserFind(browser))
|
||||||
} else if supportsWebSearchTools(req.Model) {
|
} else {
|
||||||
registry.Register(&tools.WebSearch{})
|
registry.Register(&tools.WebSearch{})
|
||||||
registry.Register(&tools.WebFetch{})
|
registry.Register(&tools.WebFetch{})
|
||||||
}
|
}
|
||||||
@@ -1420,11 +1426,6 @@ func (s *Server) getSettings(w http.ResponseWriter, r *http.Request) error {
|
|||||||
settings.Models = envconfig.Models()
|
settings.Models = envconfig.Models()
|
||||||
}
|
}
|
||||||
|
|
||||||
// set default context length if not set
|
|
||||||
if settings.ContextLength == 0 {
|
|
||||||
settings.ContextLength = 4096
|
|
||||||
}
|
|
||||||
|
|
||||||
// Include current runtime settings
|
// Include current runtime settings
|
||||||
settings.Agent = s.Agent
|
settings.Agent = s.Agent
|
||||||
settings.Tools = s.Tools
|
settings.Tools = s.Tools
|
||||||
@@ -1451,6 +1452,24 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
|
|||||||
return fmt.Errorf("failed to save settings: %w", err)
|
return fmt.Errorf("failed to save settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle auto-update toggle changes
|
||||||
|
if old.AutoUpdateEnabled != settings.AutoUpdateEnabled {
|
||||||
|
if !settings.AutoUpdateEnabled {
|
||||||
|
// Auto-update disabled: cancel any ongoing download
|
||||||
|
if s.Updater != nil {
|
||||||
|
s.Updater.CancelOngoingDownload()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Auto-update re-enabled: show notification if update is already staged, or trigger immediate check
|
||||||
|
if (updater.IsUpdatePending() || updater.UpdateDownloaded) && s.UpdateAvailableFunc != nil {
|
||||||
|
s.UpdateAvailableFunc()
|
||||||
|
} else if s.Updater != nil {
|
||||||
|
// Trigger the background checker to run immediately
|
||||||
|
s.Updater.TriggerImmediateCheck()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if old.ContextLength != settings.ContextLength ||
|
if old.ContextLength != settings.ContextLength ||
|
||||||
old.Models != settings.Models ||
|
old.Models != settings.Models ||
|
||||||
old.Expose != settings.Expose {
|
old.Expose != settings.Expose {
|
||||||
@@ -1500,14 +1519,14 @@ func (s *Server) writeCloudStatus(w http.ResponseWriter) error {
|
|||||||
func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error {
|
func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error {
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond)
|
ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
serverInferenceComputes, err := server.GetInferenceComputer(ctx)
|
info, err := server.GetInferenceInfo(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.log().Error("failed to get inference compute", "error", err)
|
s.log().Error("failed to get inference info", "error", err)
|
||||||
return fmt.Errorf("failed to get inference compute: %w", err)
|
return fmt.Errorf("failed to get inference info: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
inferenceComputes := make([]responses.InferenceCompute, len(serverInferenceComputes))
|
inferenceComputes := make([]responses.InferenceCompute, len(info.Computes))
|
||||||
for i, ic := range serverInferenceComputes {
|
for i, ic := range info.Computes {
|
||||||
inferenceComputes[i] = responses.InferenceCompute{
|
inferenceComputes[i] = responses.InferenceCompute{
|
||||||
Library: ic.Library,
|
Library: ic.Library,
|
||||||
Variant: ic.Variant,
|
Variant: ic.Variant,
|
||||||
@@ -1520,6 +1539,7 @@ func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) err
|
|||||||
|
|
||||||
response := responses.InferenceComputeResponse{
|
response := responses.InferenceComputeResponse{
|
||||||
InferenceComputes: inferenceComputes,
|
InferenceComputes: inferenceComputes,
|
||||||
|
DefaultContextLength: info.DefaultContextLength,
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
@@ -1652,17 +1672,6 @@ func supportsBrowserTools(model string) bool {
|
|||||||
return strings.HasPrefix(strings.ToLower(model), "gpt-oss")
|
return strings.HasPrefix(strings.ToLower(model), "gpt-oss")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Web search tools are simpler, providing only basic web search and fetch capabilities (e.g., "web_search", "web_fetch") without simulating a browser. Currently only qwen3 and deepseek-v3 support web search tools.
|
|
||||||
func supportsWebSearchTools(model string) bool {
|
|
||||||
model = strings.ToLower(model)
|
|
||||||
prefixes := []string{"qwen3", "deepseek-v3"}
|
|
||||||
for _, p := range prefixes {
|
|
||||||
if strings.HasPrefix(model, p) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildChatRequest converts store.Chat to api.ChatRequest
|
// buildChatRequest converts store.Chat to api.ChatRequest
|
||||||
func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, availableTools []map[string]any) (*api.ChatRequest, error) {
|
func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, availableTools []map[string]any) (*api.ChatRequest, error) {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ package ui
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -11,9 +12,11 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/ollama/ollama/app/store"
|
"github.com/ollama/ollama/app/store"
|
||||||
|
"github.com/ollama/ollama/app/updater"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHandlePostApiSettings(t *testing.T) {
|
func TestHandlePostApiSettings(t *testing.T) {
|
||||||
@@ -522,3 +525,290 @@ func TestUserAgentTransport(t *testing.T) {
|
|||||||
|
|
||||||
t.Logf("User-Agent transport successfully set: %s", receivedUA)
|
t.Logf("User-Agent transport successfully set: %s", receivedUA)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSupportsBrowserTools(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
model string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"gpt-oss", true},
|
||||||
|
{"gpt-oss-latest", true},
|
||||||
|
{"GPT-OSS", true},
|
||||||
|
{"Gpt-Oss-v2", true},
|
||||||
|
{"qwen3", false},
|
||||||
|
{"deepseek-v3", false},
|
||||||
|
{"llama3.3", false},
|
||||||
|
{"", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.model, func(t *testing.T) {
|
||||||
|
if got := supportsBrowserTools(tt.model); got != tt.want {
|
||||||
|
t.Errorf("supportsBrowserTools(%q) = %v, want %v", tt.model, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebSearchToolRegistration(t *testing.T) {
|
||||||
|
// Validates that the capability-gating logic in chat() correctly
|
||||||
|
// decides which tools to register based on model capabilities and
|
||||||
|
// the web search flag.
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
webSearchEnabled bool
|
||||||
|
hasToolsCap bool
|
||||||
|
model string
|
||||||
|
wantBrowser bool // expects browser tools (gpt-oss)
|
||||||
|
wantWebSearch bool // expects basic web search/fetch tools
|
||||||
|
wantNone bool // expects no tools registered
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "web search enabled with tools capability - browser model",
|
||||||
|
webSearchEnabled: true,
|
||||||
|
hasToolsCap: true,
|
||||||
|
model: "gpt-oss-latest",
|
||||||
|
wantBrowser: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "web search enabled with tools capability - non-browser model",
|
||||||
|
webSearchEnabled: true,
|
||||||
|
hasToolsCap: true,
|
||||||
|
model: "qwen3",
|
||||||
|
wantWebSearch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "web search enabled without tools capability",
|
||||||
|
webSearchEnabled: true,
|
||||||
|
hasToolsCap: false,
|
||||||
|
model: "llama3.3",
|
||||||
|
wantNone: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "web search disabled with tools capability",
|
||||||
|
webSearchEnabled: false,
|
||||||
|
hasToolsCap: true,
|
||||||
|
model: "qwen3",
|
||||||
|
wantNone: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "web search disabled without tools capability",
|
||||||
|
webSearchEnabled: false,
|
||||||
|
hasToolsCap: false,
|
||||||
|
model: "llama3.3",
|
||||||
|
wantNone: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Replicate the decision logic from chat() handler
|
||||||
|
gotBrowser := false
|
||||||
|
gotWebSearch := false
|
||||||
|
|
||||||
|
if tt.webSearchEnabled && tt.hasToolsCap {
|
||||||
|
if supportsBrowserTools(tt.model) {
|
||||||
|
gotBrowser = true
|
||||||
|
} else {
|
||||||
|
gotWebSearch = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.wantBrowser && !gotBrowser {
|
||||||
|
t.Error("expected browser tools to be registered")
|
||||||
|
}
|
||||||
|
if tt.wantWebSearch && !gotWebSearch {
|
||||||
|
t.Error("expected web search tools to be registered")
|
||||||
|
}
|
||||||
|
if tt.wantNone && (gotBrowser || gotWebSearch) {
|
||||||
|
t.Error("expected no tools to be registered")
|
||||||
|
}
|
||||||
|
if !tt.wantBrowser && gotBrowser {
|
||||||
|
t.Error("unexpected browser tools registered")
|
||||||
|
}
|
||||||
|
if !tt.wantWebSearch && gotWebSearch {
|
||||||
|
t.Error("unexpected web search tools registered")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettingsToggleAutoUpdateOff_CancelsDownload(t *testing.T) {
|
||||||
|
testStore := &store.Store{
|
||||||
|
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
|
||||||
|
}
|
||||||
|
defer testStore.Close()
|
||||||
|
|
||||||
|
// Start with auto-update enabled
|
||||||
|
settings, err := testStore.Settings()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
settings.AutoUpdateEnabled = true
|
||||||
|
if err := testStore.SetSettings(settings); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
upd := &updater.Updater{Store: &store.Store{
|
||||||
|
DBPath: filepath.Join(t.TempDir(), "db2.sqlite"),
|
||||||
|
}}
|
||||||
|
defer upd.Store.Close()
|
||||||
|
|
||||||
|
// We can't easily mock CancelOngoingDownload, but we can verify
|
||||||
|
// the full settings handler flow works without error
|
||||||
|
server := &Server{
|
||||||
|
Store: testStore,
|
||||||
|
Restart: func() {},
|
||||||
|
Updater: upd,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disable auto-update via settings API
|
||||||
|
settings.AutoUpdateEnabled = false
|
||||||
|
body, err := json.Marshal(settings)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
if err := server.settings(rr, req); err != nil {
|
||||||
|
t.Fatalf("settings() error = %v", err)
|
||||||
|
}
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("settings() status = %d, want %d", rr.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify settings were saved with auto-update disabled
|
||||||
|
saved, err := testStore.Settings()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if saved.AutoUpdateEnabled {
|
||||||
|
t.Fatal("expected AutoUpdateEnabled to be false after toggle off")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettingsToggleAutoUpdateOn_WithPendingUpdate_ShowsNotification(t *testing.T) {
|
||||||
|
testStore := &store.Store{
|
||||||
|
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
|
||||||
|
}
|
||||||
|
defer testStore.Close()
|
||||||
|
|
||||||
|
// Start with auto-update disabled
|
||||||
|
settings, err := testStore.Settings()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
settings.AutoUpdateEnabled = false
|
||||||
|
if err := testStore.SetSettings(settings); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate that an update was previously downloaded
|
||||||
|
oldVal := updater.UpdateDownloaded
|
||||||
|
updater.UpdateDownloaded = true
|
||||||
|
defer func() { updater.UpdateDownloaded = oldVal }()
|
||||||
|
|
||||||
|
var notificationCalled atomic.Bool
|
||||||
|
server := &Server{
|
||||||
|
Store: testStore,
|
||||||
|
Restart: func() {},
|
||||||
|
UpdateAvailableFunc: func() {
|
||||||
|
notificationCalled.Store(true)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-enable auto-update via settings API
|
||||||
|
settings.AutoUpdateEnabled = true
|
||||||
|
body, err := json.Marshal(settings)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
if err := server.settings(rr, req); err != nil {
|
||||||
|
t.Fatalf("settings() error = %v", err)
|
||||||
|
}
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("settings() status = %d, want %d", rr.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !notificationCalled.Load() {
|
||||||
|
t.Fatal("expected UpdateAvailableFunc to be called when re-enabling with a downloaded update")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettingsToggleAutoUpdateOn_NoPendingUpdate_TriggersCheck(t *testing.T) {
|
||||||
|
testStore := &store.Store{
|
||||||
|
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
|
||||||
|
}
|
||||||
|
defer testStore.Close()
|
||||||
|
|
||||||
|
// Start with auto-update disabled
|
||||||
|
settings, err := testStore.Settings()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
settings.AutoUpdateEnabled = false
|
||||||
|
if err := testStore.SetSettings(settings); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure no pending update - clear both the downloaded flag and the stage dir
|
||||||
|
oldVal := updater.UpdateDownloaded
|
||||||
|
updater.UpdateDownloaded = false
|
||||||
|
defer func() { updater.UpdateDownloaded = oldVal }()
|
||||||
|
|
||||||
|
oldStageDir := updater.UpdateStageDir
|
||||||
|
updater.UpdateStageDir = t.TempDir() // empty dir means IsUpdatePending() returns false
|
||||||
|
defer func() { updater.UpdateStageDir = oldStageDir }()
|
||||||
|
|
||||||
|
upd := &updater.Updater{Store: &store.Store{
|
||||||
|
DBPath: filepath.Join(t.TempDir(), "db2.sqlite"),
|
||||||
|
}}
|
||||||
|
defer upd.Store.Close()
|
||||||
|
|
||||||
|
// Initialize the checkNow channel by starting (and immediately stopping) the checker
|
||||||
|
// so TriggerImmediateCheck doesn't panic on nil channel
|
||||||
|
ctx, cancel := context.WithCancel(t.Context())
|
||||||
|
upd.StartBackgroundUpdaterChecker(ctx, func(string) error { return nil })
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var notificationCalled atomic.Bool
|
||||||
|
server := &Server{
|
||||||
|
Store: testStore,
|
||||||
|
Restart: func() {},
|
||||||
|
Updater: upd,
|
||||||
|
UpdateAvailableFunc: func() {
|
||||||
|
notificationCalled.Store(true)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-enable auto-update via settings API
|
||||||
|
settings.AutoUpdateEnabled = true
|
||||||
|
body, err := json.Marshal(settings)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
if err := server.settings(rr, req); err != nil {
|
||||||
|
t.Fatalf("settings() error = %v", err)
|
||||||
|
}
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("settings() status = %d, want %d", rr.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAvailableFunc should NOT be called since there's no pending update
|
||||||
|
if notificationCalled.Load() {
|
||||||
|
t.Fatal("UpdateAvailableFunc should not be called when there is no pending update")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/app/store"
|
"github.com/ollama/ollama/app/store"
|
||||||
@@ -58,7 +59,8 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
|
|||||||
query := requestURL.Query()
|
query := requestURL.Query()
|
||||||
query.Add("os", runtime.GOOS)
|
query.Add("os", runtime.GOOS)
|
||||||
query.Add("arch", runtime.GOARCH)
|
query.Add("arch", runtime.GOARCH)
|
||||||
query.Add("version", version.Version)
|
currentVersion := version.Version
|
||||||
|
query.Add("version", currentVersion)
|
||||||
query.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
query.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
||||||
|
|
||||||
// The original macOS app used to use the device ID
|
// The original macOS app used to use the device ID
|
||||||
@@ -131,15 +133,27 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
|
func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
|
||||||
|
// Create a cancellable context for this download
|
||||||
|
downloadCtx, cancel := context.WithCancel(ctx)
|
||||||
|
u.cancelDownloadLock.Lock()
|
||||||
|
u.cancelDownload = cancel
|
||||||
|
u.cancelDownloadLock.Unlock()
|
||||||
|
defer func() {
|
||||||
|
u.cancelDownloadLock.Lock()
|
||||||
|
u.cancelDownload = nil
|
||||||
|
u.cancelDownloadLock.Unlock()
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
// Do a head first to check etag info
|
// Do a head first to check etag info
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodHead, updateResp.UpdateURL, nil)
|
req, err := http.NewRequestWithContext(downloadCtx, http.MethodHead, updateResp.UpdateURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// In case of slow downloads, continue the update check in the background
|
// In case of slow downloads, continue the update check in the background
|
||||||
bgctx, cancel := context.WithCancel(ctx)
|
bgctx, bgcancel := context.WithCancel(downloadCtx)
|
||||||
defer cancel()
|
defer bgcancel()
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -176,6 +190,7 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
|
|||||||
_, err = os.Stat(stageFilename)
|
_, err = os.Stat(stageFilename)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
slog.Info("update already downloaded", "bundle", stageFilename)
|
slog.Info("update already downloaded", "bundle", stageFilename)
|
||||||
|
UpdateDownloaded = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -245,32 +260,84 @@ func cleanupOldDownloads(stageDir string) {
|
|||||||
|
|
||||||
type Updater struct {
|
type Updater struct {
|
||||||
Store *store.Store
|
Store *store.Store
|
||||||
|
cancelDownload context.CancelFunc
|
||||||
|
cancelDownloadLock sync.Mutex
|
||||||
|
checkNow chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CancelOngoingDownload cancels any currently running download
|
||||||
|
func (u *Updater) CancelOngoingDownload() {
|
||||||
|
u.cancelDownloadLock.Lock()
|
||||||
|
defer u.cancelDownloadLock.Unlock()
|
||||||
|
if u.cancelDownload != nil {
|
||||||
|
slog.Info("cancelling ongoing update download")
|
||||||
|
u.cancelDownload()
|
||||||
|
u.cancelDownload = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TriggerImmediateCheck signals the background checker to check for updates immediately
|
||||||
|
func (u *Updater) TriggerImmediateCheck() {
|
||||||
|
if u.checkNow != nil {
|
||||||
|
select {
|
||||||
|
case u.checkNow <- struct{}{}:
|
||||||
|
default:
|
||||||
|
// Check already pending, no need to queue another
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
|
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
|
||||||
|
u.checkNow = make(chan struct{}, 1)
|
||||||
|
u.checkNow <- struct{}{} // Trigger first check after initial delay
|
||||||
go func() {
|
go func() {
|
||||||
// Don't blast an update message immediately after startup
|
// Don't blast an update message immediately after startup
|
||||||
time.Sleep(UpdateCheckInitialDelay)
|
time.Sleep(UpdateCheckInitialDelay)
|
||||||
slog.Info("beginning update checker", "interval", UpdateCheckInterval)
|
slog.Info("beginning update checker", "interval", UpdateCheckInterval)
|
||||||
|
ticker := time.NewTicker(UpdateCheckInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
available, resp := u.checkForUpdate(ctx)
|
|
||||||
if available {
|
|
||||||
err := u.DownloadNewRelease(ctx, resp)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error(fmt.Sprintf("failed to download new release: %s", err))
|
|
||||||
} else {
|
|
||||||
err = cb(resp.UpdateVersion)
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn(fmt.Sprintf("failed to register update available with tray: %s", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
slog.Debug("stopping background update checker")
|
slog.Debug("stopping background update checker")
|
||||||
return
|
return
|
||||||
default:
|
case <-u.checkNow:
|
||||||
time.Sleep(UpdateCheckInterval)
|
// Immediate check triggered
|
||||||
|
case <-ticker.C:
|
||||||
|
// Regular interval check
|
||||||
|
}
|
||||||
|
|
||||||
|
// Always check for updates
|
||||||
|
available, resp := u.checkForUpdate(ctx)
|
||||||
|
if !available {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update is available - check if auto-update is enabled for downloading
|
||||||
|
settings, err := u.Store.Settings()
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to load settings", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !settings.AutoUpdateEnabled {
|
||||||
|
// Auto-update disabled - don't download, just log
|
||||||
|
slog.Debug("update available but auto-update disabled", "version", resp.UpdateVersion)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auto-update is enabled - download
|
||||||
|
err = u.DownloadNewRelease(ctx, resp)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to download new release", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Download successful - show tray notification
|
||||||
|
err = cb(resp.UpdateVersion)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to register update available with tray", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"path/filepath"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -33,7 +35,7 @@ func TestIsNewReleaseAvailable(t *testing.T) {
|
|||||||
defer server.Close()
|
defer server.Close()
|
||||||
slog.Debug("server", "url", server.URL)
|
slog.Debug("server", "url", server.URL)
|
||||||
|
|
||||||
updater := &Updater{Store: &store.Store{}}
|
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
|
||||||
defer updater.Store.Close() // Ensure database is closed
|
defer updater.Store.Close() // Ensure database is closed
|
||||||
UpdateCheckURLBase = server.URL + "/update.json"
|
UpdateCheckURLBase = server.URL + "/update.json"
|
||||||
updatePresent, resp := updater.checkForUpdate(t.Context())
|
updatePresent, resp := updater.checkForUpdate(t.Context())
|
||||||
@@ -84,8 +86,18 @@ func TestBackgoundChecker(t *testing.T) {
|
|||||||
defer server.Close()
|
defer server.Close()
|
||||||
UpdateCheckURLBase = server.URL + "/update.json"
|
UpdateCheckURLBase = server.URL + "/update.json"
|
||||||
|
|
||||||
updater := &Updater{Store: &store.Store{}}
|
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
|
||||||
defer updater.Store.Close() // Ensure database is closed
|
defer updater.Store.Close()
|
||||||
|
|
||||||
|
settings, err := updater.Store.Settings()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
settings.AutoUpdateEnabled = true
|
||||||
|
if err := updater.Store.SetSettings(settings); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
||||||
select {
|
select {
|
||||||
case <-stallTimer.C:
|
case <-stallTimer.C:
|
||||||
@@ -99,3 +111,267 @@ func TestBackgoundChecker(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAutoUpdateDisabledSkipsDownload(t *testing.T) {
|
||||||
|
UpdateStageDir = t.TempDir()
|
||||||
|
var downloadAttempted atomic.Bool
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(t.Context())
|
||||||
|
defer cancel()
|
||||||
|
UpdateCheckInitialDelay = 5 * time.Millisecond
|
||||||
|
UpdateCheckInterval = 5 * time.Millisecond
|
||||||
|
VerifyDownload = func() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var server *httptest.Server
|
||||||
|
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/update.json" {
|
||||||
|
w.Write([]byte(
|
||||||
|
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
|
||||||
|
server.URL+"/9.9.9/"+Installer)))
|
||||||
|
} else if r.URL.Path == "/9.9.9/"+Installer {
|
||||||
|
downloadAttempted.Store(true)
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
zw := zip.NewWriter(buf)
|
||||||
|
zw.Close()
|
||||||
|
io.Copy(w, buf)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
UpdateCheckURLBase = server.URL + "/update.json"
|
||||||
|
|
||||||
|
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
|
||||||
|
defer updater.Store.Close()
|
||||||
|
|
||||||
|
// Ensure auto-update is disabled
|
||||||
|
settings, err := updater.Store.Settings()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
settings.AutoUpdateEnabled = false
|
||||||
|
if err := updater.Store.SetSettings(settings); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cb := func(ver string) error {
|
||||||
|
t.Fatal("callback should not be called when auto-update is disabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
||||||
|
|
||||||
|
// Wait enough time for multiple check cycles
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
close(done)
|
||||||
|
|
||||||
|
if downloadAttempted.Load() {
|
||||||
|
t.Fatal("download should not be attempted when auto-update is disabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAutoUpdateReenabledDownloadsUpdate(t *testing.T) {
|
||||||
|
UpdateStageDir = t.TempDir()
|
||||||
|
var downloadAttempted atomic.Bool
|
||||||
|
callbackCalled := make(chan struct{}, 1)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(t.Context())
|
||||||
|
defer cancel()
|
||||||
|
UpdateCheckInitialDelay = 5 * time.Millisecond
|
||||||
|
UpdateCheckInterval = 5 * time.Millisecond
|
||||||
|
VerifyDownload = func() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var server *httptest.Server
|
||||||
|
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/update.json" {
|
||||||
|
w.Write([]byte(
|
||||||
|
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
|
||||||
|
server.URL+"/9.9.9/"+Installer)))
|
||||||
|
} else if r.URL.Path == "/9.9.9/"+Installer {
|
||||||
|
downloadAttempted.Store(true)
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
zw := zip.NewWriter(buf)
|
||||||
|
zw.Close()
|
||||||
|
io.Copy(w, buf)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
UpdateCheckURLBase = server.URL + "/update.json"
|
||||||
|
|
||||||
|
upd := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
|
||||||
|
defer upd.Store.Close()
|
||||||
|
|
||||||
|
// Start with auto-update disabled
|
||||||
|
settings, err := upd.Store.Settings()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
settings.AutoUpdateEnabled = false
|
||||||
|
if err := upd.Store.SetSettings(settings); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cb := func(ver string) error {
|
||||||
|
select {
|
||||||
|
case callbackCalled <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
upd.StartBackgroundUpdaterChecker(ctx, cb)
|
||||||
|
|
||||||
|
// Wait for a few cycles with auto-update disabled - no download should happen
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
if downloadAttempted.Load() {
|
||||||
|
t.Fatal("download should not happen while auto-update is disabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-enable auto-update
|
||||||
|
settings.AutoUpdateEnabled = true
|
||||||
|
if err := upd.Store.SetSettings(settings); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the checker to pick it up and download
|
||||||
|
select {
|
||||||
|
case <-callbackCalled:
|
||||||
|
// Success: download happened and callback was called after re-enabling
|
||||||
|
if !downloadAttempted.Load() {
|
||||||
|
t.Fatal("expected download to be attempted after re-enabling")
|
||||||
|
}
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("expected download and callback after re-enabling auto-update")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCancelOngoingDownload(t *testing.T) {
|
||||||
|
UpdateStageDir = t.TempDir()
|
||||||
|
downloadStarted := make(chan struct{})
|
||||||
|
downloadCancelled := make(chan struct{})
|
||||||
|
|
||||||
|
ctx := t.Context()
|
||||||
|
VerifyDownload = func() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var server *httptest.Server
|
||||||
|
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/update.json" {
|
||||||
|
w.Write([]byte(
|
||||||
|
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
|
||||||
|
server.URL+"/9.9.9/"+Installer)))
|
||||||
|
} else if r.URL.Path == "/9.9.9/"+Installer {
|
||||||
|
if r.Method == http.MethodHead {
|
||||||
|
w.Header().Set("Content-Length", "1000000")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Signal that download has started
|
||||||
|
close(downloadStarted)
|
||||||
|
// Wait for cancellation or timeout
|
||||||
|
select {
|
||||||
|
case <-r.Context().Done():
|
||||||
|
close(downloadCancelled)
|
||||||
|
return
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Error("download was not cancelled in time")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
UpdateCheckURLBase = server.URL + "/update.json"
|
||||||
|
|
||||||
|
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
|
||||||
|
defer updater.Store.Close()
|
||||||
|
|
||||||
|
_, resp := updater.checkForUpdate(ctx)
|
||||||
|
|
||||||
|
// Start download in goroutine
|
||||||
|
go func() {
|
||||||
|
_ = updater.DownloadNewRelease(ctx, resp)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for download to start
|
||||||
|
select {
|
||||||
|
case <-downloadStarted:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("download did not start in time")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cancel the download
|
||||||
|
updater.CancelOngoingDownload()
|
||||||
|
|
||||||
|
// Verify cancellation was received
|
||||||
|
select {
|
||||||
|
case <-downloadCancelled:
|
||||||
|
// Success
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("download cancellation was not received by server")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTriggerImmediateCheck(t *testing.T) {
|
||||||
|
UpdateStageDir = t.TempDir()
|
||||||
|
checkCount := atomic.Int32{}
|
||||||
|
checkDone := make(chan struct{}, 10)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(t.Context())
|
||||||
|
defer cancel()
|
||||||
|
// Set a very long interval so only TriggerImmediateCheck causes checks
|
||||||
|
UpdateCheckInitialDelay = 1 * time.Millisecond
|
||||||
|
UpdateCheckInterval = 1 * time.Hour
|
||||||
|
VerifyDownload = func() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/update.json" {
|
||||||
|
checkCount.Add(1)
|
||||||
|
select {
|
||||||
|
case checkDone <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
// Return no update available
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
UpdateCheckURLBase = server.URL + "/update.json"
|
||||||
|
|
||||||
|
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
|
||||||
|
defer updater.Store.Close()
|
||||||
|
|
||||||
|
cb := func(ver string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
||||||
|
|
||||||
|
// Wait for the initial check that fires after the initial delay
|
||||||
|
select {
|
||||||
|
case <-checkDone:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("initial check did not happen")
|
||||||
|
}
|
||||||
|
|
||||||
|
initialCount := checkCount.Load()
|
||||||
|
|
||||||
|
// Trigger immediate check
|
||||||
|
updater.TriggerImmediateCheck()
|
||||||
|
|
||||||
|
// Wait for the triggered check
|
||||||
|
select {
|
||||||
|
case <-checkDone:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("triggered check did not happen")
|
||||||
|
}
|
||||||
|
|
||||||
|
finalCount := checkCount.Load()
|
||||||
|
if finalCount <= initialCount {
|
||||||
|
t.Fatalf("TriggerImmediateCheck did not cause additional check: initial=%d, final=%d", initialCount, finalCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -369,25 +369,6 @@ func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// func (t *winTray) hideMenuItem(menuItemId, parentId uint32) error {
|
|
||||||
// const ERROR_SUCCESS syscall.Errno = 0
|
|
||||||
|
|
||||||
// t.muMenus.RLock()
|
|
||||||
// menu := uintptr(t.menus[parentId])
|
|
||||||
// t.muMenus.RUnlock()
|
|
||||||
// res, _, err := pRemoveMenu.Call(
|
|
||||||
// menu,
|
|
||||||
// uintptr(menuItemId),
|
|
||||||
// MF_BYCOMMAND,
|
|
||||||
// )
|
|
||||||
// if res == 0 && err.(syscall.Errno) != ERROR_SUCCESS {
|
|
||||||
// return err
|
|
||||||
// }
|
|
||||||
// t.delFromVisibleItems(parentId, menuItemId)
|
|
||||||
|
|
||||||
// return nil
|
|
||||||
// }
|
|
||||||
|
|
||||||
func (t *winTray) showMenu() error {
|
func (t *winTray) showMenu() error {
|
||||||
p := point{}
|
p := point{}
|
||||||
boolRet, _, err := pGetCursorPos.Call(uintptr(unsafe.Pointer(&p)))
|
boolRet, _, err := pGetCursorPos.Call(uintptr(unsafe.Pointer(&p)))
|
||||||
|
|||||||
@@ -51,7 +51,6 @@ const (
|
|||||||
IMAGE_ICON = 1 // Loads an icon
|
IMAGE_ICON = 1 // Loads an icon
|
||||||
LR_DEFAULTSIZE = 0x00000040 // Loads default-size icon for windows(SM_CXICON x SM_CYICON) if cx, cy are set to zero
|
LR_DEFAULTSIZE = 0x00000040 // Loads default-size icon for windows(SM_CXICON x SM_CYICON) if cx, cy are set to zero
|
||||||
LR_LOADFROMFILE = 0x00000010 // Loads the stand-alone image from the file
|
LR_LOADFROMFILE = 0x00000010 // Loads the stand-alone image from the file
|
||||||
MF_BYCOMMAND = 0x00000000
|
|
||||||
MFS_DISABLED = 0x00000003
|
MFS_DISABLED = 0x00000003
|
||||||
MFT_SEPARATOR = 0x00000800
|
MFT_SEPARATOR = 0x00000800
|
||||||
MFT_STRING = 0x00000000
|
MFT_STRING = 0x00000000
|
||||||
|
|||||||
@@ -1,27 +1,31 @@
|
|||||||
Ollama Benchmark Tool
|
Ollama Benchmark Tool
|
||||||
---------------------
|
---------------------
|
||||||
|
|
||||||
A Go-based command-line tool for benchmarking Ollama models with configurable parameters and multiple output formats.
|
A Go-based command-line tool for benchmarking Ollama models with configurable parameters, warmup phases, TTFT tracking, VRAM monitoring, and benchstat/CSV output.
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
* Benchmark multiple models in a single run
|
* Benchmark multiple models in a single run
|
||||||
* Support for both text and image prompts
|
* Support for both text and image prompts
|
||||||
* Configurable generation parameters (temperature, max tokens, seed, etc.)
|
* Configurable generation parameters (temperature, max tokens, seed, etc.)
|
||||||
* Supports benchstat and CSV output formats
|
* Warmup phase before timed epochs to stabilize measurements
|
||||||
* Detailed performance metrics (prefill, generate, load, total durations)
|
* Time-to-first-token (TTFT) tracking per epoch
|
||||||
|
* Model metadata display (parameter size, quantization level, family)
|
||||||
|
* VRAM and CPU memory usage tracking via running process info
|
||||||
|
* Controlled prompt token length for reproducible benchmarks
|
||||||
|
* Benchstat and CSV output formats
|
||||||
|
|
||||||
## Building from Source
|
## Building from Source
|
||||||
|
|
||||||
```
|
```
|
||||||
go build -o ollama-bench bench.go
|
go build -o ollama-bench ./cmd/bench
|
||||||
./ollama-bench -model gpt-oss:20b -epochs 6 -format csv
|
./ollama-bench -model gemma3 -epochs 6 -format csv
|
||||||
```
|
```
|
||||||
|
|
||||||
Using Go Run (without building)
|
Using Go Run (without building)
|
||||||
|
|
||||||
```
|
```
|
||||||
go run bench.go -model gpt-oss:20b -epochs 3
|
go run ./cmd/bench -model gemma3 -epochs 3
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
@@ -45,10 +49,16 @@ benchstat -col /name gemma.bench
|
|||||||
./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
|
./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Controlled Prompt Length
|
||||||
|
|
||||||
|
```
|
||||||
|
./ollama-bench -model gemma3 -epochs 6 -prompt-tokens 512
|
||||||
|
```
|
||||||
|
|
||||||
### Advanced Example
|
### Advanced Example
|
||||||
|
|
||||||
```
|
```
|
||||||
./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv
|
./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -warmup 2 -format csv -output results.csv
|
||||||
```
|
```
|
||||||
|
|
||||||
## Command Line Options
|
## Command Line Options
|
||||||
@@ -56,41 +66,48 @@ benchstat -col /name gemma.bench
|
|||||||
| Option | Description | Default |
|
| Option | Description | Default |
|
||||||
|----------|-------------|---------|
|
|----------|-------------|---------|
|
||||||
| -model | Comma-separated list of models to benchmark | (required) |
|
| -model | Comma-separated list of models to benchmark | (required) |
|
||||||
| -epochs | Number of iterations per model | 1 |
|
| -epochs | Number of iterations per model | 6 |
|
||||||
| -max-tokens | Maximum tokens for model response | 0 (unlimited) |
|
| -max-tokens | Maximum tokens for model response | 200 |
|
||||||
| -temperature | Temperature parameter | 0.0 |
|
| -temperature | Temperature parameter | 0.0 |
|
||||||
| -seed | Random seed | 0 (random) |
|
| -seed | Random seed | 0 (random) |
|
||||||
| -timeout | Timeout in seconds | 300 |
|
| -timeout | Timeout in seconds | 300 |
|
||||||
| -p | Prompt text | "Write a long story." |
|
| -p | Prompt text | (default story prompt) |
|
||||||
| -image | Image file to include in prompt | |
|
| -image | Image file to include in prompt | |
|
||||||
| -k | Keep-alive duration in seconds | 0 |
|
| -k | Keep-alive duration in seconds | 0 |
|
||||||
| -format | Output format (benchstat, csv) | benchstat |
|
| -format | Output format (benchstat, csv) | benchstat |
|
||||||
| -output | Output file for results | "" (stdout) |
|
| -output | Output file for results | "" (stdout) |
|
||||||
|
| -warmup | Number of warmup requests before timing | 1 |
|
||||||
|
| -prompt-tokens | Generate prompt targeting ~N tokens (0 = use -p) | 0 |
|
||||||
| -v | Verbose mode | false |
|
| -v | Verbose mode | false |
|
||||||
| -debug | Show debug information | false |
|
| -debug | Show debug information | false |
|
||||||
|
|
||||||
## Output Formats
|
## Output Formats
|
||||||
|
|
||||||
### Markdown Format
|
### Benchstat Format (default)
|
||||||
|
|
||||||
The default markdown format is suitable for copying and pasting into a GitHub issue and will look like:
|
Compatible with Go's benchstat tool for statistical analysis. Uses one value/unit pair per line, standard `ns/op` for timing metrics, and `ns/token` for throughput. Each epoch produces one set of lines -- benchstat aggregates across repeated runs to compute statistics.
|
||||||
```
|
|
||||||
Model | Step | Count | Duration | nsPerToken | tokensPerSec |
|
|
||||||
|-------|------|-------|----------|------------|--------------|
|
|
||||||
| gpt-oss:20b | prefill | 124 | 30.006458ms | 241987.56 | 4132.44 |
|
|
||||||
| gpt-oss:20b | generate | 200 | 2.646843954s | 13234219.77 | 75.56 |
|
|
||||||
| gpt-oss:20b | load | 1 | 121.674208ms | - | - |
|
|
||||||
| gpt-oss:20b | total | 1 | 2.861047625s | - | - |
|
|
||||||
```
|
|
||||||
|
|
||||||
### Benchstat Format
|
|
||||||
|
|
||||||
Compatible with Go's benchstat tool for statistical analysis:
|
|
||||||
|
|
||||||
```
|
```
|
||||||
BenchmarkModel/name=gpt-oss:20b/step=prefill 128 78125.00 ns/token 12800.00 token/sec
|
# Model: gemma3 | Params: 4.3B | Quant: Q4_K_M | Family: gemma3 | Size: 4080218931 | VRAM: 4080218931
|
||||||
BenchmarkModel/name=gpt-oss:20b/step=generate 512 19531.25 ns/token 51200.00 token/sec
|
BenchmarkModel/name=gemma3/step=prefill 1 78125.00 ns/token 12800.00 token/sec
|
||||||
BenchmarkModel/name=gpt-oss:20b/step=load 1 1500000000 ns/request
|
BenchmarkModel/name=gemma3/step=generate 1 19531.25 ns/token 51200.00 token/sec
|
||||||
|
BenchmarkModel/name=gemma3/step=ttft 1 45123000 ns/op
|
||||||
|
BenchmarkModel/name=gemma3/step=load 1 1500000000 ns/op
|
||||||
|
BenchmarkModel/name=gemma3/step=total 1 2861047625 ns/op
|
||||||
|
```
|
||||||
|
|
||||||
|
Use with benchstat:
|
||||||
|
```
|
||||||
|
./ollama-bench -model gemma3 -epochs 6 > gemma3.bench
|
||||||
|
benchstat -col /step gemma3.bench
|
||||||
|
```
|
||||||
|
|
||||||
|
Compare two runs:
|
||||||
|
```
|
||||||
|
./ollama-bench -model gemma3 -epochs 6 > before.bench
|
||||||
|
# ... make changes ...
|
||||||
|
./ollama-bench -model gemma3 -epochs 6 > after.bench
|
||||||
|
benchstat before.bench after.bench
|
||||||
```
|
```
|
||||||
|
|
||||||
### CSV Format
|
### CSV Format
|
||||||
@@ -99,17 +116,28 @@ Machine-readable comma-separated values:
|
|||||||
|
|
||||||
```
|
```
|
||||||
NAME,STEP,COUNT,NS_PER_COUNT,TOKEN_PER_SEC
|
NAME,STEP,COUNT,NS_PER_COUNT,TOKEN_PER_SEC
|
||||||
gpt-oss:20b,prefill,128,78125.00,12800.00
|
# Model: gemma3 | Params: 4.3B | Quant: Q4_K_M | Family: gemma3 | Size: 4080218931 | VRAM: 4080218931
|
||||||
gpt-oss:20b,generate,512,19531.25,51200.00
|
gemma3,prefill,128,78125.00,12800.00
|
||||||
gpt-oss:20b,load,1,1500000000,0
|
gemma3,generate,512,19531.25,51200.00
|
||||||
|
gemma3,ttft,1,45123000,0
|
||||||
|
gemma3,load,1,1500000000,0
|
||||||
|
gemma3,total,1,2861047625,0
|
||||||
```
|
```
|
||||||
|
|
||||||
## Metrics Explained
|
## Metrics Explained
|
||||||
|
|
||||||
The tool reports four types of metrics for each model:
|
The tool reports the following metrics for each epoch:
|
||||||
|
|
||||||
* prefill: Time spent processing the prompt
|
* **prefill**: Time spent processing the prompt (ns/token)
|
||||||
* generate: Time spent generating the response
|
* **generate**: Time spent generating the response (ns/token)
|
||||||
* load: Model loading time (one-time cost)
|
* **ttft**: Time to first token -- latency from request start to first response content
|
||||||
* total: Total request duration
|
* **load**: Model loading time (one-time cost)
|
||||||
|
* **total**: Total request duration
|
||||||
|
|
||||||
|
Additionally, the model info comment line (displayed once per model before epochs) includes:
|
||||||
|
|
||||||
|
* **Params**: Model parameter count (e.g., 4.3B)
|
||||||
|
* **Quant**: Quantization level (e.g., Q4_K_M)
|
||||||
|
* **Family**: Model family (e.g., gemma3)
|
||||||
|
* **Size**: Total model memory in bytes
|
||||||
|
* **VRAM**: GPU memory used by the loaded model (when Size > VRAM, the difference is CPU spill)
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ type flagOptions struct {
|
|||||||
outputFile *string
|
outputFile *string
|
||||||
debug *bool
|
debug *bool
|
||||||
verbose *bool
|
verbose *bool
|
||||||
|
warmup *int
|
||||||
|
promptTokens *int
|
||||||
}
|
}
|
||||||
|
|
||||||
type Metrics struct {
|
type Metrics struct {
|
||||||
@@ -39,48 +41,169 @@ type Metrics struct {
|
|||||||
Duration time.Duration
|
Duration time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
var once sync.Once
|
type ModelInfo struct {
|
||||||
|
Name string
|
||||||
|
ParameterSize string
|
||||||
|
QuantizationLevel string
|
||||||
|
Family string
|
||||||
|
SizeBytes int64
|
||||||
|
VRAMBytes int64
|
||||||
|
}
|
||||||
|
|
||||||
const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.`
|
const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.`
|
||||||
|
|
||||||
|
// Word list for generating prompts targeting a specific token count.
|
||||||
|
var promptWordList = []string{
|
||||||
|
"the", "quick", "brown", "fox", "jumps", "over", "lazy", "dog",
|
||||||
|
"a", "bright", "sunny", "day", "in", "the", "meadow", "where",
|
||||||
|
"flowers", "bloom", "and", "birds", "sing", "their", "morning",
|
||||||
|
"songs", "while", "gentle", "breeze", "carries", "sweet", "scent",
|
||||||
|
"of", "pine", "trees", "across", "rolling", "hills", "toward",
|
||||||
|
"distant", "mountains", "covered", "with", "fresh", "snow",
|
||||||
|
"beneath", "clear", "blue", "sky", "children", "play", "near",
|
||||||
|
"old", "stone", "bridge", "that", "crosses", "winding", "river",
|
||||||
|
}
|
||||||
|
|
||||||
|
func generatePromptForTokenCount(targetTokens int, epoch int) string {
|
||||||
|
// ~1.3 tokens per word heuristic
|
||||||
|
targetWords := int(float64(targetTokens) / 1.3)
|
||||||
|
if targetWords < 1 {
|
||||||
|
targetWords = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vary the starting offset by epoch to defeat KV cache prefix matching
|
||||||
|
offset := epoch * 7 // stride by a prime to get good distribution
|
||||||
|
n := len(promptWordList)
|
||||||
|
words := make([]string, targetWords)
|
||||||
|
for i := range words {
|
||||||
|
words[i] = promptWordList[((i+offset)%n+n)%n]
|
||||||
|
}
|
||||||
|
return strings.Join(words, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildGenerateRequest(model string, fOpt flagOptions, imgData api.ImageData, epoch int) *api.GenerateRequest {
|
||||||
|
options := make(map[string]interface{})
|
||||||
|
if *fOpt.maxTokens > 0 {
|
||||||
|
options["num_predict"] = *fOpt.maxTokens
|
||||||
|
}
|
||||||
|
options["temperature"] = *fOpt.temperature
|
||||||
|
if fOpt.seed != nil && *fOpt.seed > 0 {
|
||||||
|
options["seed"] = *fOpt.seed
|
||||||
|
}
|
||||||
|
|
||||||
|
var keepAliveDuration *api.Duration
|
||||||
|
if *fOpt.keepAlive > 0 {
|
||||||
|
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
|
||||||
|
keepAliveDuration = &duration
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt := *fOpt.prompt
|
||||||
|
if *fOpt.promptTokens > 0 {
|
||||||
|
prompt = generatePromptForTokenCount(*fOpt.promptTokens, epoch)
|
||||||
|
} else {
|
||||||
|
// Vary the prompt per epoch to defeat KV cache prefix matching
|
||||||
|
prompt = fmt.Sprintf("[%d] %s", epoch, prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: model,
|
||||||
|
Prompt: prompt,
|
||||||
|
Raw: true,
|
||||||
|
Options: options,
|
||||||
|
KeepAlive: keepAliveDuration,
|
||||||
|
}
|
||||||
|
|
||||||
|
if imgData != nil {
|
||||||
|
req.Images = []api.ImageData{imgData}
|
||||||
|
}
|
||||||
|
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchModelInfo(ctx context.Context, client *api.Client, model string) ModelInfo {
|
||||||
|
info := ModelInfo{Name: model}
|
||||||
|
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "WARNING: Could not fetch model info for '%s': %v\n", model, err)
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
info.ParameterSize = resp.Details.ParameterSize
|
||||||
|
info.QuantizationLevel = resp.Details.QuantizationLevel
|
||||||
|
info.Family = resp.Details.Family
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchMemoryUsage(ctx context.Context, client *api.Client, model string) (size, vram int64) {
|
||||||
|
resp, err := client.ListRunning(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
||||||
|
fmt.Fprintf(os.Stderr, "WARNING: Could not fetch memory usage: %v\n", err)
|
||||||
|
}
|
||||||
|
return 0, 0
|
||||||
|
}
|
||||||
|
for _, m := range resp.Models {
|
||||||
|
if m.Name == model || m.Model == model {
|
||||||
|
return m.Size, m.SizeVRAM
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Try prefix match (model names may include :latest or tags)
|
||||||
|
for _, m := range resp.Models {
|
||||||
|
if strings.HasPrefix(m.Name, model) || strings.HasPrefix(m.Model, model) {
|
||||||
|
return m.Size, m.SizeVRAM
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func outputFormatHeader(w io.Writer, format string, verbose bool) {
|
||||||
|
switch format {
|
||||||
|
case "benchstat":
|
||||||
|
if verbose {
|
||||||
|
fmt.Fprintf(w, "goos: %s\n", runtime.GOOS)
|
||||||
|
fmt.Fprintf(w, "goarch: %s\n", runtime.GOARCH)
|
||||||
|
}
|
||||||
|
case "csv":
|
||||||
|
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
|
||||||
|
fmt.Fprintln(w, strings.Join(headings, ","))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func outputModelInfo(w io.Writer, format string, info ModelInfo) {
|
||||||
|
params := cmp.Or(info.ParameterSize, "unknown")
|
||||||
|
quant := cmp.Or(info.QuantizationLevel, "unknown")
|
||||||
|
family := cmp.Or(info.Family, "unknown")
|
||||||
|
|
||||||
|
memStr := ""
|
||||||
|
if info.SizeBytes > 0 {
|
||||||
|
memStr = fmt.Sprintf(" | Size: %d | VRAM: %d", info.SizeBytes, info.VRAMBytes)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(w, "# Model: %s | Params: %s | Quant: %s | Family: %s%s\n",
|
||||||
|
info.Name, params, quant, family, memStr)
|
||||||
|
}
|
||||||
|
|
||||||
func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) {
|
func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) {
|
||||||
switch format {
|
switch format {
|
||||||
case "benchstat":
|
case "benchstat":
|
||||||
if verbose {
|
|
||||||
printHeader := func() {
|
|
||||||
fmt.Fprintf(w, "sysname: %s\n", runtime.GOOS)
|
|
||||||
fmt.Fprintf(w, "machine: %s\n", runtime.GOARCH)
|
|
||||||
}
|
|
||||||
once.Do(printHeader)
|
|
||||||
}
|
|
||||||
for _, m := range metrics {
|
for _, m := range metrics {
|
||||||
if m.Step == "generate" || m.Step == "prefill" {
|
if m.Step == "generate" || m.Step == "prefill" {
|
||||||
if m.Count > 0 {
|
if m.Count > 0 {
|
||||||
nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
||||||
tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
||||||
|
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 %.2f ns/token %.2f token/sec\n",
|
||||||
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d %.2f ns/token %.2f token/sec\n",
|
m.Model, m.Step, nsPerToken, tokensPerSec)
|
||||||
m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
|
|
||||||
} else {
|
} else {
|
||||||
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d 0 ns/token 0 token/sec\n",
|
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 0 ns/token 0 token/sec\n",
|
||||||
m.Model, m.Step, m.Count)
|
m.Model, m.Step)
|
||||||
}
|
}
|
||||||
|
} else if m.Step == "ttft" {
|
||||||
|
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=ttft 1 %d ns/op\n",
|
||||||
|
m.Model, m.Duration.Nanoseconds())
|
||||||
} else {
|
} else {
|
||||||
var suffix string
|
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 %d ns/op\n",
|
||||||
if m.Step == "load" {
|
m.Model, m.Step, m.Duration.Nanoseconds())
|
||||||
suffix = "/step=load"
|
|
||||||
}
|
|
||||||
fmt.Fprintf(w, "BenchmarkModel/name=%s%s 1 %d ns/request\n",
|
|
||||||
m.Model, suffix, m.Duration.Nanoseconds())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "csv":
|
case "csv":
|
||||||
printHeader := func() {
|
|
||||||
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
|
|
||||||
fmt.Fprintln(w, strings.Join(headings, ","))
|
|
||||||
}
|
|
||||||
once.Do(printHeader)
|
|
||||||
|
|
||||||
for _, m := range metrics {
|
for _, m := range metrics {
|
||||||
if m.Step == "generate" || m.Step == "prefill" {
|
if m.Step == "generate" || m.Step == "prefill" {
|
||||||
var nsPerToken float64
|
var nsPerToken float64
|
||||||
@@ -94,39 +217,14 @@ func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool)
|
|||||||
fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds())
|
fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "markdown":
|
|
||||||
printHeader := func() {
|
|
||||||
fmt.Fprintln(w, "| Model | Step | Count | Duration | nsPerToken | tokensPerSec |")
|
|
||||||
fmt.Fprintln(w, "|-------|------|-------|----------|------------|--------------|")
|
|
||||||
}
|
|
||||||
once.Do(printHeader)
|
|
||||||
|
|
||||||
for _, m := range metrics {
|
|
||||||
var nsPerToken, tokensPerSec float64
|
|
||||||
var nsPerTokenStr, tokensPerSecStr string
|
|
||||||
|
|
||||||
if m.Step == "generate" || m.Step == "prefill" {
|
|
||||||
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
|
||||||
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
|
||||||
nsPerTokenStr = fmt.Sprintf("%.2f", nsPerToken)
|
|
||||||
tokensPerSecStr = fmt.Sprintf("%.2f", tokensPerSec)
|
|
||||||
} else {
|
|
||||||
nsPerTokenStr = "-"
|
|
||||||
tokensPerSecStr = "-"
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Fprintf(w, "| %s | %s | %d | %v | %s | %s |\n",
|
|
||||||
m.Model, m.Step, m.Count, m.Duration, nsPerTokenStr, tokensPerSecStr)
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format)
|
fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkChat(fOpt flagOptions) error {
|
func BenchmarkModel(fOpt flagOptions) error {
|
||||||
models := strings.Split(*fOpt.models, ",")
|
models := strings.Split(*fOpt.models, ",")
|
||||||
|
|
||||||
// todo - add multi-image support
|
|
||||||
var imgData api.ImageData
|
var imgData api.ImageData
|
||||||
var err error
|
var err error
|
||||||
if *fOpt.imageFile != "" {
|
if *fOpt.imageFile != "" {
|
||||||
@@ -158,54 +256,83 @@ func BenchmarkChat(fOpt flagOptions) error {
|
|||||||
out = f
|
out = f
|
||||||
}
|
}
|
||||||
|
|
||||||
|
outputFormatHeader(out, *fOpt.format, *fOpt.verbose)
|
||||||
|
|
||||||
|
// Log prompt-tokens info in debug mode
|
||||||
|
if *fOpt.debug && *fOpt.promptTokens > 0 {
|
||||||
|
prompt := generatePromptForTokenCount(*fOpt.promptTokens, 0)
|
||||||
|
wordCount := len(strings.Fields(prompt))
|
||||||
|
fmt.Fprintf(os.Stderr, "Generated prompt targeting ~%d tokens (%d words, varied per epoch)\n", *fOpt.promptTokens, wordCount)
|
||||||
|
}
|
||||||
|
|
||||||
for _, model := range models {
|
for _, model := range models {
|
||||||
for range *fOpt.epochs {
|
// Fetch model info
|
||||||
options := make(map[string]interface{})
|
infoCtx, infoCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
if *fOpt.maxTokens > 0 {
|
info := fetchModelInfo(infoCtx, client, model)
|
||||||
options["num_predict"] = *fOpt.maxTokens
|
infoCancel()
|
||||||
|
|
||||||
|
// Warmup phase (uses negative epoch numbers to avoid colliding with timed epochs)
|
||||||
|
for i := range *fOpt.warmup {
|
||||||
|
req := buildGenerateRequest(model, fOpt, imgData, -(i + 1))
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
|
||||||
|
|
||||||
|
err = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "WARNING: Warmup %d/%d for %s failed: %v\n", i+1, *fOpt.warmup, model, err)
|
||||||
|
} else if *fOpt.debug {
|
||||||
|
fmt.Fprintf(os.Stderr, "Warmup %d/%d for %s complete\n", i+1, *fOpt.warmup, model)
|
||||||
}
|
}
|
||||||
options["temperature"] = *fOpt.temperature
|
|
||||||
if fOpt.seed != nil && *fOpt.seed > 0 {
|
|
||||||
options["seed"] = *fOpt.seed
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var keepAliveDuration *api.Duration
|
// Fetch memory usage once after warmup (model is loaded and stable)
|
||||||
if *fOpt.keepAlive > 0 {
|
memCtx, memCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
|
info.SizeBytes, info.VRAMBytes = fetchMemoryUsage(memCtx, client, model)
|
||||||
keepAliveDuration = &duration
|
memCancel()
|
||||||
}
|
|
||||||
|
|
||||||
req := &api.ChatRequest{
|
outputModelInfo(out, *fOpt.format, info)
|
||||||
Model: model,
|
|
||||||
Messages: []api.Message{
|
|
||||||
{
|
|
||||||
Role: "user",
|
|
||||||
Content: *fOpt.prompt,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Options: options,
|
|
||||||
KeepAlive: keepAliveDuration,
|
|
||||||
}
|
|
||||||
|
|
||||||
if imgData != nil {
|
|
||||||
req.Messages[0].Images = []api.ImageData{imgData}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// Timed epoch loop
|
||||||
|
shortCount := 0
|
||||||
|
for epoch := range *fOpt.epochs {
|
||||||
var responseMetrics *api.Metrics
|
var responseMetrics *api.Metrics
|
||||||
|
var ttft time.Duration
|
||||||
|
short := false
|
||||||
|
|
||||||
|
// Retry loop: if the model hits a stop token before max-tokens,
|
||||||
|
// retry with a different prompt (up to maxRetries times).
|
||||||
|
const maxRetries = 3
|
||||||
|
for attempt := range maxRetries + 1 {
|
||||||
|
responseMetrics = nil
|
||||||
|
ttft = 0
|
||||||
|
var ttftOnce sync.Once
|
||||||
|
|
||||||
|
req := buildGenerateRequest(model, fOpt, imgData, epoch+attempt*1000)
|
||||||
|
requestStart := time.Now()
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
err = client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
err = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||||
if *fOpt.debug {
|
if *fOpt.debug {
|
||||||
fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Message.Thinking, resp.Message.Content))
|
fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Thinking, resp.Response))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Capture TTFT on first content
|
||||||
|
ttftOnce.Do(func() {
|
||||||
|
if resp.Response != "" || resp.Thinking != "" {
|
||||||
|
ttft = time.Since(requestStart)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
if resp.Done {
|
if resp.Done {
|
||||||
responseMetrics = &resp.Metrics
|
responseMetrics = &resp.Metrics
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
cancel()
|
||||||
|
|
||||||
if *fOpt.debug {
|
if *fOpt.debug {
|
||||||
fmt.Fprintln(os.Stderr)
|
fmt.Fprintln(os.Stderr)
|
||||||
@@ -213,18 +340,42 @@ func BenchmarkChat(fOpt flagOptions) error {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if ctx.Err() == context.DeadlineExceeded {
|
if ctx.Err() == context.DeadlineExceeded {
|
||||||
fmt.Fprintf(os.Stderr, "ERROR: Chat request timed out with model '%s' after %vs\n", model, 1)
|
fmt.Fprintf(os.Stderr, "ERROR: Request timed out with model '%s' after %vs\n", model, *fOpt.timeout)
|
||||||
continue
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "ERROR: Couldn't generate with model '%s': %v\n", model, err)
|
||||||
}
|
}
|
||||||
fmt.Fprintf(os.Stderr, "ERROR: Couldn't chat with model '%s': %v\n", model, err)
|
break
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if responseMetrics == nil {
|
if responseMetrics == nil {
|
||||||
fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model)
|
fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the response was shorter than requested
|
||||||
|
short = *fOpt.maxTokens > 0 && responseMetrics.EvalCount < *fOpt.maxTokens
|
||||||
|
if !short || attempt == maxRetries {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if *fOpt.debug {
|
||||||
|
fmt.Fprintf(os.Stderr, "Short response (%d/%d tokens), retrying with different prompt (attempt %d/%d)\n",
|
||||||
|
responseMetrics.EvalCount, *fOpt.maxTokens, attempt+1, maxRetries)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil || responseMetrics == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if short {
|
||||||
|
shortCount++
|
||||||
|
if *fOpt.debug {
|
||||||
|
fmt.Fprintf(os.Stderr, "WARNING: Short response (%d/%d tokens) after %d retries for epoch %d\n",
|
||||||
|
responseMetrics.EvalCount, *fOpt.maxTokens, maxRetries, epoch+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
metrics := []Metrics{
|
metrics := []Metrics{
|
||||||
{
|
{
|
||||||
Model: model,
|
Model: model,
|
||||||
@@ -238,6 +389,12 @@ func BenchmarkChat(fOpt flagOptions) error {
|
|||||||
Count: responseMetrics.EvalCount,
|
Count: responseMetrics.EvalCount,
|
||||||
Duration: responseMetrics.EvalDuration,
|
Duration: responseMetrics.EvalDuration,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Model: model,
|
||||||
|
Step: "ttft",
|
||||||
|
Count: 1,
|
||||||
|
Duration: ttft,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Model: model,
|
Model: model,
|
||||||
Step: "load",
|
Step: "load",
|
||||||
@@ -254,15 +411,42 @@ func BenchmarkChat(fOpt flagOptions) error {
|
|||||||
|
|
||||||
OutputMetrics(out, *fOpt.format, metrics, *fOpt.verbose)
|
OutputMetrics(out, *fOpt.format, metrics, *fOpt.verbose)
|
||||||
|
|
||||||
|
if *fOpt.debug && *fOpt.promptTokens > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "Generated prompt targeting ~%d tokens (actual: %d)\n",
|
||||||
|
*fOpt.promptTokens, responseMetrics.PromptEvalCount)
|
||||||
|
}
|
||||||
|
|
||||||
if *fOpt.keepAlive > 0 {
|
if *fOpt.keepAlive > 0 {
|
||||||
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
|
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if shortCount > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "WARNING: %d/%d epochs for '%s' had short responses (<%d tokens). Generation metrics may be unreliable.\n",
|
||||||
|
shortCount, *fOpt.epochs, model, *fOpt.maxTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unload model before moving to the next one
|
||||||
|
unloadModel(client, model, *fOpt.timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func unloadModel(client *api.Client, model string, timeout int) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
zero := api.Duration{Duration: 0}
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: model,
|
||||||
|
KeepAlive: &zero,
|
||||||
|
}
|
||||||
|
_ = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func readImage(filePath string) (api.ImageData, error) {
|
func readImage(filePath string) (api.ImageData, error) {
|
||||||
file, err := os.Open(filePath)
|
file, err := os.Open(filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -289,10 +473,12 @@ func main() {
|
|||||||
prompt: flag.String("p", DefaultPrompt, "Prompt to use"),
|
prompt: flag.String("p", DefaultPrompt, "Prompt to use"),
|
||||||
imageFile: flag.String("image", "", "Filename for an image to include"),
|
imageFile: flag.String("image", "", "Filename for an image to include"),
|
||||||
keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"),
|
keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"),
|
||||||
format: flag.String("format", "markdown", "Output format [benchstat|csv] (default benchstat)"),
|
format: flag.String("format", "benchstat", "Output format [benchstat|csv]"),
|
||||||
outputFile: flag.String("output", "", "Output file for results (stdout if empty)"),
|
outputFile: flag.String("output", "", "Output file for results (stdout if empty)"),
|
||||||
verbose: flag.Bool("v", false, "Show system information"),
|
verbose: flag.Bool("v", false, "Show system information"),
|
||||||
debug: flag.Bool("debug", false, "Show debug information"),
|
debug: flag.Bool("debug", false, "Show debug information"),
|
||||||
|
warmup: flag.Int("warmup", 1, "Number of warmup requests before timing"),
|
||||||
|
promptTokens: flag.Int("prompt-tokens", 0, "Generate prompt targeting ~N tokens (0 = use -p prompt)"),
|
||||||
}
|
}
|
||||||
|
|
||||||
flag.Usage = func() {
|
flag.Usage = func() {
|
||||||
@@ -302,11 +488,12 @@ func main() {
|
|||||||
fmt.Fprintf(os.Stderr, "Options:\n")
|
fmt.Fprintf(os.Stderr, "Options:\n")
|
||||||
flag.PrintDefaults()
|
flag.PrintDefaults()
|
||||||
fmt.Fprintf(os.Stderr, "\nExamples:\n")
|
fmt.Fprintf(os.Stderr, "\nExamples:\n")
|
||||||
fmt.Fprintf(os.Stderr, " bench -model gpt-oss:20b -epochs 3 -temperature 0.7\n")
|
fmt.Fprintf(os.Stderr, " bench -model gemma3,llama3 -epochs 6\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " bench -model gemma3 -epochs 6 -prompt-tokens 512 -format csv\n")
|
||||||
}
|
}
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
if !slices.Contains([]string{"markdown", "benchstat", "csv"}, *fOpt.format) {
|
if !slices.Contains([]string{"benchstat", "csv"}, *fOpt.format) {
|
||||||
fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format)
|
fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
@@ -317,5 +504,5 @@ func main() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
BenchmarkChat(fOpt)
|
BenchmarkModel(fOpt)
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
391
cmd/cmd.go
391
cmd/cmd.go
@@ -11,6 +11,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -38,9 +39,12 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/cmd/config"
|
"github.com/ollama/ollama/cmd/config"
|
||||||
|
"github.com/ollama/ollama/cmd/launch"
|
||||||
"github.com/ollama/ollama/cmd/tui"
|
"github.com/ollama/ollama/cmd/tui"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
|
"github.com/ollama/ollama/internal/modelref"
|
||||||
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/progress"
|
"github.com/ollama/ollama/progress"
|
||||||
"github.com/ollama/ollama/readline"
|
"github.com/ollama/ollama/readline"
|
||||||
@@ -57,36 +61,42 @@ import (
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
|
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
|
||||||
config.DefaultSingleSelector = func(title string, items []config.ModelItem, current string) (string, error) {
|
launch.DefaultSingleSelector = func(title string, items []launch.ModelItem, current string) (string, error) {
|
||||||
|
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
|
||||||
|
return "", fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode")
|
||||||
|
}
|
||||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||||
result, err := tui.SelectSingle(title, tuiItems, current)
|
result, err := tui.SelectSingle(title, tuiItems, current)
|
||||||
if errors.Is(err, tui.ErrCancelled) {
|
if errors.Is(err, tui.ErrCancelled) {
|
||||||
return "", config.ErrCancelled
|
return "", launch.ErrCancelled
|
||||||
}
|
}
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
config.DefaultMultiSelector = func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
|
launch.DefaultMultiSelector = func(title string, items []launch.ModelItem, preChecked []string) ([]string, error) {
|
||||||
|
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
|
||||||
|
return nil, fmt.Errorf("model selection requires an interactive terminal; use --model to run in headless mode")
|
||||||
|
}
|
||||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||||
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
|
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
|
||||||
if errors.Is(err, tui.ErrCancelled) {
|
if errors.Is(err, tui.ErrCancelled) {
|
||||||
return nil, config.ErrCancelled
|
return nil, launch.ErrCancelled
|
||||||
}
|
}
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
config.DefaultSignIn = func(modelName, signInURL string) (string, error) {
|
launch.DefaultSignIn = func(modelName, signInURL string) (string, error) {
|
||||||
userName, err := tui.RunSignIn(modelName, signInURL)
|
userName, err := tui.RunSignIn(modelName, signInURL)
|
||||||
if errors.Is(err, tui.ErrCancelled) {
|
if errors.Is(err, tui.ErrCancelled) {
|
||||||
return "", config.ErrCancelled
|
return "", launch.ErrCancelled
|
||||||
}
|
}
|
||||||
return userName, err
|
return userName, err
|
||||||
}
|
}
|
||||||
|
|
||||||
config.DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
launch.DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||||
ok, err := tui.RunConfirm(prompt)
|
ok, err := tui.RunConfirm(prompt)
|
||||||
if errors.Is(err, tui.ErrCancelled) {
|
if errors.Is(err, tui.ErrCancelled) {
|
||||||
return false, config.ErrCancelled
|
return false, launch.ErrCancelled
|
||||||
}
|
}
|
||||||
return ok, err
|
return ok, err
|
||||||
}
|
}
|
||||||
@@ -131,6 +141,17 @@ func getModelfileName(cmd *cobra.Command) (string, error) {
|
|||||||
return absName, nil
|
return absName, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isLocalhost returns true if the configured Ollama host is a loopback or unspecified address.
|
||||||
|
func isLocalhost() bool {
|
||||||
|
host := envconfig.Host()
|
||||||
|
h, _, _ := net.SplitHostPort(host.Host)
|
||||||
|
if h == "localhost" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
ip := net.ParseIP(h)
|
||||||
|
return ip != nil && (ip.IsLoopback() || ip.IsUnspecified())
|
||||||
|
}
|
||||||
|
|
||||||
func CreateHandler(cmd *cobra.Command, args []string) error {
|
func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||||
p := progress.NewProgress(os.Stderr)
|
p := progress.NewProgress(os.Stderr)
|
||||||
defer p.Stop()
|
defer p.Stop()
|
||||||
@@ -145,6 +166,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
// Check for --experimental flag for safetensors model creation
|
// Check for --experimental flag for safetensors model creation
|
||||||
experimental, _ := cmd.Flags().GetBool("experimental")
|
experimental, _ := cmd.Flags().GetBool("experimental")
|
||||||
if experimental {
|
if experimental {
|
||||||
|
if !isLocalhost() {
|
||||||
|
return errors.New("remote safetensor model creation not yet supported")
|
||||||
|
}
|
||||||
// Get Modelfile content - either from -f flag or default to "FROM ."
|
// Get Modelfile content - either from -f flag or default to "FROM ."
|
||||||
var reader io.Reader
|
var reader io.Reader
|
||||||
filename, err := getModelfileName(cmd)
|
filename, err := getModelfileName(cmd)
|
||||||
@@ -168,29 +192,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return fmt.Errorf("failed to parse Modelfile: %w", err)
|
return fmt.Errorf("failed to parse Modelfile: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract FROM path and configuration
|
modelDir, mfConfig, err := xcreateclient.ConfigFromModelfile(modelfile)
|
||||||
var modelDir string
|
if err != nil {
|
||||||
mfConfig := &xcreateclient.ModelfileConfig{}
|
return err
|
||||||
|
|
||||||
for _, cmd := range modelfile.Commands {
|
|
||||||
switch cmd.Name {
|
|
||||||
case "model":
|
|
||||||
modelDir = cmd.Args
|
|
||||||
case "template":
|
|
||||||
mfConfig.Template = cmd.Args
|
|
||||||
case "system":
|
|
||||||
mfConfig.System = cmd.Args
|
|
||||||
case "license":
|
|
||||||
mfConfig.License = cmd.Args
|
|
||||||
case "parser":
|
|
||||||
mfConfig.Parser = cmd.Args
|
|
||||||
case "renderer":
|
|
||||||
mfConfig.Renderer = cmd.Args
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if modelDir == "" {
|
|
||||||
modelDir = "."
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve relative paths based on Modelfile location
|
// Resolve relative paths based on Modelfile location
|
||||||
@@ -214,6 +218,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
if filename == "" {
|
if filename == "" {
|
||||||
// No Modelfile found - check if current directory is an image gen model
|
// No Modelfile found - check if current directory is an image gen model
|
||||||
if create.IsTensorModelDir(".") {
|
if create.IsTensorModelDir(".") {
|
||||||
|
if !isLocalhost() {
|
||||||
|
return errors.New("remote safetensor model creation not yet supported")
|
||||||
|
}
|
||||||
quantize, _ := cmd.Flags().GetString("quantize")
|
quantize, _ := cmd.Flags().GetString("quantize")
|
||||||
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
||||||
ModelName: modelName,
|
ModelName: modelName,
|
||||||
@@ -406,12 +413,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
requestedCloud := modelref.HasExplicitCloudSource(opts.Model)
|
||||||
|
|
||||||
if info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model}); err != nil {
|
if info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model}); err != nil {
|
||||||
return err
|
return err
|
||||||
} else if info.RemoteHost != "" {
|
} else if info.RemoteHost != "" || requestedCloud {
|
||||||
// Cloud model, no need to load/unload
|
// Cloud model, no need to load/unload
|
||||||
|
|
||||||
isCloud := strings.HasPrefix(info.RemoteHost, "https://ollama.com")
|
isCloud := requestedCloud || strings.HasPrefix(info.RemoteHost, "https://ollama.com")
|
||||||
|
|
||||||
// Check if user is signed in for ollama.com cloud models
|
// Check if user is signed in for ollama.com cloud models
|
||||||
if isCloud {
|
if isCloud {
|
||||||
@@ -422,10 +431,14 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
|||||||
|
|
||||||
if opts.ShowConnect {
|
if opts.ShowConnect {
|
||||||
p.StopAndClear()
|
p.StopAndClear()
|
||||||
|
remoteModel := info.RemoteModel
|
||||||
|
if remoteModel == "" {
|
||||||
|
remoteModel = opts.Model
|
||||||
|
}
|
||||||
if isCloud {
|
if isCloud {
|
||||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
|
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", remoteModel)
|
||||||
} else {
|
} else {
|
||||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
|
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", remoteModel, info.RemoteHost)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -497,6 +510,64 @@ func generateEmbedding(cmd *cobra.Command, modelName, input string, keepAlive *a
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(parthsareen): consolidate with TUI signin flow
|
||||||
|
func handleCloudAuthorizationError(err error) bool {
|
||||||
|
var authErr api.AuthorizationError
|
||||||
|
if errors.As(err, &authErr) && authErr.StatusCode == http.StatusUnauthorized {
|
||||||
|
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
|
||||||
|
if authErr.SigninURL != "" {
|
||||||
|
fmt.Printf(ConnectInstructions, authErr.SigninURL)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TEMP(drifkin): To match legacy `ollama run some-model:cloud` behavior, we
|
||||||
|
// best-effort pull cloud stub files for any explicit cloud source models.
|
||||||
|
// Remove this once `/api/tags` is cloud-aware.
|
||||||
|
func ensureCloudStub(ctx context.Context, client *api.Client, modelName string) {
|
||||||
|
if !modelref.HasExplicitCloudSource(modelName) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizedName, _, err := modelref.NormalizePullName(modelName)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to normalize pull name", "model", modelName, "error", err, "normalizedName", normalizedName)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
listResp, err := client.List(ctx)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to list models", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasListedModelName(listResp.Models, modelName) || hasListedModelName(listResp.Models, normalizedName) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logutil.Trace("pulling cloud stub", "model", modelName, "normalizedName", normalizedName)
|
||||||
|
err = client.Pull(ctx, &api.PullRequest{
|
||||||
|
Model: normalizedName,
|
||||||
|
}, func(api.ProgressResponse) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to pull cloud stub", "model", modelName, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasListedModelName(models []api.ListModelResponse, name string) bool {
|
||||||
|
for _, m := range models {
|
||||||
|
if strings.EqualFold(m.Name, name) || strings.EqualFold(m.Model, name) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||||
interactive := true
|
interactive := true
|
||||||
|
|
||||||
@@ -585,17 +656,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
opts.WordWrap = !nowrap
|
opts.WordWrap = !nowrap
|
||||||
|
|
||||||
useImagegen := false
|
|
||||||
if cmd.Flags().Lookup("imagegen") != nil {
|
|
||||||
useImagegen, err = cmd.Flags().GetBool("imagegen")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if useImagegen {
|
|
||||||
opts.Options["use_imagegen_runner"] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fill out the rest of the options based on information about the
|
// Fill out the rest of the options based on information about the
|
||||||
// model.
|
// model.
|
||||||
client, err := api.ClientFromEnvironment()
|
client, err := api.ClientFromEnvironment()
|
||||||
@@ -604,12 +664,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
name := args[0]
|
name := args[0]
|
||||||
|
requestedCloud := modelref.HasExplicitCloudSource(name)
|
||||||
|
|
||||||
info, err := func() (*api.ShowResponse, error) {
|
info, err := func() (*api.ShowResponse, error) {
|
||||||
showReq := &api.ShowRequest{Name: name}
|
showReq := &api.ShowRequest{Name: name}
|
||||||
info, err := client.Show(cmd.Context(), showReq)
|
info, err := client.Show(cmd.Context(), showReq)
|
||||||
var se api.StatusError
|
var se api.StatusError
|
||||||
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
|
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
|
||||||
|
if requestedCloud {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if err := PullHandler(cmd, []string{name}); err != nil {
|
if err := PullHandler(cmd, []string{name}); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -618,9 +682,14 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return info, err
|
return info, err
|
||||||
}()
|
}()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if handleCloudAuthorizationError(err) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ensureCloudStub(cmd.Context(), client, name)
|
||||||
|
|
||||||
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkFlag.Changed)
|
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkFlag.Changed)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -712,7 +781,13 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
return generateInteractive(cmd, opts)
|
return generateInteractive(cmd, opts)
|
||||||
}
|
}
|
||||||
return generate(cmd, opts)
|
if err := generate(cmd, opts); err != nil {
|
||||||
|
if handleCloudAuthorizationError(err) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SigninHandler(cmd *cobra.Command, args []string) error {
|
func SigninHandler(cmd *cobra.Command, args []string) error {
|
||||||
@@ -1892,6 +1967,24 @@ func ensureServerRunning(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func launchInteractiveModel(cmd *cobra.Command, modelName string) error {
|
||||||
|
opts := runOptions{
|
||||||
|
Model: modelName,
|
||||||
|
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||||
|
Options: map[string]any{},
|
||||||
|
ShowConnect: true,
|
||||||
|
}
|
||||||
|
// loadOrUnloadModel is cloud-safe here: remote/cloud models skip local preload
|
||||||
|
// and only validate auth/connectivity before interactive chat starts.
|
||||||
|
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||||
|
return fmt.Errorf("error loading model: %w", err)
|
||||||
|
}
|
||||||
|
if err := generateInteractive(cmd, opts); err != nil {
|
||||||
|
return fmt.Errorf("error running model: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// runInteractiveTUI runs the main interactive TUI menu.
|
// runInteractiveTUI runs the main interactive TUI menu.
|
||||||
func runInteractiveTUI(cmd *cobra.Command) {
|
func runInteractiveTUI(cmd *cobra.Command) {
|
||||||
// Ensure the server is running before showing the TUI
|
// Ensure the server is running before showing the TUI
|
||||||
@@ -1900,171 +1993,81 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Selector adapters for tui
|
deps := launcherDeps{
|
||||||
singleSelector := func(title string, items []config.ModelItem, current string) (string, error) {
|
buildState: launch.BuildLauncherState,
|
||||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
runMenu: tui.RunMenu,
|
||||||
result, err := tui.SelectSingle(title, tuiItems, current)
|
resolveRunModel: launch.ResolveRunModel,
|
||||||
if errors.Is(err, tui.ErrCancelled) {
|
launchIntegration: launch.LaunchIntegration,
|
||||||
return "", config.ErrCancelled
|
runModel: launchInteractiveModel,
|
||||||
}
|
|
||||||
return result, err
|
|
||||||
}
|
|
||||||
|
|
||||||
multiSelector := func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
|
|
||||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
|
||||||
result, err := tui.SelectMultiple(title, tuiItems, preChecked)
|
|
||||||
if errors.Is(err, tui.ErrCancelled) {
|
|
||||||
return nil, config.ErrCancelled
|
|
||||||
}
|
|
||||||
return result, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
result, err := tui.Run()
|
continueLoop, err := runInteractiveTUIStep(cmd, deps)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||||
|
}
|
||||||
|
if !continueLoop {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
runModel := func(modelName string) {
|
|
||||||
client, err := api.ClientFromEnvironment()
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := config.ShowOrPull(cmd.Context(), client, modelName); err != nil {
|
|
||||||
if errors.Is(err, config.ErrCancelled) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_ = config.SetLastModel(modelName)
|
|
||||||
opts := runOptions{
|
|
||||||
Model: modelName,
|
|
||||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
|
||||||
Options: map[string]any{},
|
|
||||||
ShowConnect: true,
|
|
||||||
}
|
|
||||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error loading model: %v\n", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := generateInteractive(cmd, opts); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error running model: %v\n", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
launchIntegration := func(name string) bool {
|
type launcherDeps struct {
|
||||||
// If not configured or model no longer exists, prompt for model selection
|
buildState func(context.Context) (*launch.LauncherState, error)
|
||||||
configuredModel := config.IntegrationModel(name)
|
runMenu func(*launch.LauncherState) (tui.TUIAction, error)
|
||||||
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) || config.IsCloudModelDisabled(cmd.Context(), configuredModel) {
|
resolveRunModel func(context.Context, launch.RunModelRequest) (string, error)
|
||||||
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), name, singleSelector, multiSelector)
|
launchIntegration func(context.Context, launch.IntegrationLaunchRequest) error
|
||||||
if errors.Is(err, config.ErrCancelled) {
|
runModel func(*cobra.Command, string) error
|
||||||
return false // Return to main menu
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", name, err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := config.LaunchIntegration(name); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", name, err)
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch result.Selection {
|
func runInteractiveTUIStep(cmd *cobra.Command, deps launcherDeps) (bool, error) {
|
||||||
case tui.SelectionNone:
|
state, err := deps.buildState(cmd.Context())
|
||||||
// User quit
|
if err != nil {
|
||||||
return
|
return false, fmt.Errorf("build launcher state: %w", err)
|
||||||
case tui.SelectionRunModel:
|
}
|
||||||
_ = config.SetLastSelection("run")
|
|
||||||
if modelName := config.LastModel(); modelName != "" && !config.IsCloudModelDisabled(cmd.Context(), modelName) {
|
action, err := deps.runMenu(state)
|
||||||
runModel(modelName)
|
if err != nil {
|
||||||
} else {
|
return false, fmt.Errorf("run launcher menu: %w", err)
|
||||||
modelName, err := config.SelectModelWithSelector(cmd.Context(), singleSelector)
|
}
|
||||||
if errors.Is(err, config.ErrCancelled) {
|
|
||||||
continue // Return to main menu
|
return runLauncherAction(cmd, action, deps)
|
||||||
|
}
|
||||||
|
|
||||||
|
func saveLauncherSelection(action tui.TUIAction) {
|
||||||
|
// Best effort only: this affects menu recall, not launch correctness.
|
||||||
|
_ = config.SetLastSelection(action.LastSelection())
|
||||||
|
}
|
||||||
|
|
||||||
|
func runLauncherAction(cmd *cobra.Command, action tui.TUIAction, deps launcherDeps) (bool, error) {
|
||||||
|
switch action.Kind {
|
||||||
|
case tui.TUIActionNone:
|
||||||
|
return false, nil
|
||||||
|
case tui.TUIActionRunModel:
|
||||||
|
saveLauncherSelection(action)
|
||||||
|
modelName, err := deps.resolveRunModel(cmd.Context(), action.RunModelRequest())
|
||||||
|
if errors.Is(err, launch.ErrCancelled) {
|
||||||
|
return true, nil
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
|
return true, fmt.Errorf("selecting model: %w", err)
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
runModel(modelName)
|
if err := deps.runModel(cmd, modelName); err != nil {
|
||||||
|
return true, err
|
||||||
}
|
}
|
||||||
case tui.SelectionChangeRunModel:
|
return true, nil
|
||||||
_ = config.SetLastSelection("run")
|
case tui.TUIActionLaunchIntegration:
|
||||||
// Use model from modal if selected, otherwise show picker
|
saveLauncherSelection(action)
|
||||||
modelName := result.Model
|
err := deps.launchIntegration(cmd.Context(), action.IntegrationLaunchRequest())
|
||||||
if modelName == "" {
|
if errors.Is(err, launch.ErrCancelled) {
|
||||||
var err error
|
return true, nil
|
||||||
modelName, err = config.SelectModelWithSelector(cmd.Context(), singleSelector)
|
|
||||||
if errors.Is(err, config.ErrCancelled) {
|
|
||||||
continue // Return to main menu
|
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
|
return true, fmt.Errorf("launching %s: %w", action.Integration, err)
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if config.IsCloudModelDisabled(cmd.Context(), modelName) {
|
|
||||||
continue // Return to main menu
|
|
||||||
}
|
|
||||||
runModel(modelName)
|
|
||||||
case tui.SelectionIntegration:
|
|
||||||
_ = config.SetLastSelection(result.Integration)
|
|
||||||
if !launchIntegration(result.Integration) {
|
|
||||||
continue // Return to main menu
|
|
||||||
}
|
|
||||||
case tui.SelectionChangeIntegration:
|
|
||||||
_ = config.SetLastSelection(result.Integration)
|
|
||||||
if len(result.Models) > 0 {
|
|
||||||
// Filter out cloud-disabled models
|
|
||||||
var filtered []string
|
|
||||||
for _, m := range result.Models {
|
|
||||||
if !config.IsCloudModelDisabled(cmd.Context(), m) {
|
|
||||||
filtered = append(filtered, m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(filtered) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
result.Models = filtered
|
|
||||||
// Multi-select from modal (Editor integrations)
|
|
||||||
if err := config.SaveAndEditIntegration(result.Integration, result.Models); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := config.LaunchIntegrationWithModel(result.Integration, result.Models[0]); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
|
||||||
}
|
|
||||||
} else if result.Model != "" {
|
|
||||||
if config.IsCloudModelDisabled(cmd.Context(), result.Model) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Single-select from modal - save and launch
|
|
||||||
if err := config.SaveIntegration(result.Integration, []string{result.Model}); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error saving config: %v\n", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := config.LaunchIntegrationWithModel(result.Integration, result.Model); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
err := config.ConfigureIntegrationWithSelectors(cmd.Context(), result.Integration, singleSelector, multiSelector)
|
|
||||||
if errors.Is(err, config.ErrCancelled) {
|
|
||||||
continue // Return to main menu
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := config.LaunchIntegration(result.Integration); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
return true, nil
|
||||||
|
default:
|
||||||
|
return false, fmt.Errorf("unknown launcher action: %d", action.Kind)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2334,7 +2337,7 @@ func NewCLI() *cobra.Command {
|
|||||||
copyCmd,
|
copyCmd,
|
||||||
deleteCmd,
|
deleteCmd,
|
||||||
runnerCmd,
|
runnerCmd,
|
||||||
config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
|
launch.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
|
||||||
)
|
)
|
||||||
|
|
||||||
return rootCmd
|
return rootCmd
|
||||||
|
|||||||
233
cmd/cmd_launcher_test.go
Normal file
233
cmd/cmd_launcher_test.go
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/cmd/config"
|
||||||
|
"github.com/ollama/ollama/cmd/launch"
|
||||||
|
"github.com/ollama/ollama/cmd/tui"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setCmdTestHome(t *testing.T, dir string) {
|
||||||
|
t.Helper()
|
||||||
|
t.Setenv("HOME", dir)
|
||||||
|
t.Setenv("USERPROFILE", dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
func unexpectedRunModelResolution(t *testing.T) func(context.Context, launch.RunModelRequest) (string, error) {
|
||||||
|
t.Helper()
|
||||||
|
return func(ctx context.Context, req launch.RunModelRequest) (string, error) {
|
||||||
|
t.Fatalf("did not expect run-model resolution: %+v", req)
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func unexpectedIntegrationLaunch(t *testing.T) func(context.Context, launch.IntegrationLaunchRequest) error {
|
||||||
|
t.Helper()
|
||||||
|
return func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||||
|
t.Fatalf("did not expect integration launch: %+v", req)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func unexpectedModelLaunch(t *testing.T) func(*cobra.Command, string) error {
|
||||||
|
t.Helper()
|
||||||
|
return func(cmd *cobra.Command, model string) error {
|
||||||
|
t.Fatalf("did not expect chat launch: %s", model)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunInteractiveTUI_RunModelActionsUseResolveRunModel(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
action tui.TUIAction
|
||||||
|
wantForce bool
|
||||||
|
wantModel string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "enter uses saved model flow",
|
||||||
|
action: tui.TUIAction{Kind: tui.TUIActionRunModel},
|
||||||
|
wantModel: "qwen3:8b",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "right forces picker",
|
||||||
|
action: tui.TUIAction{Kind: tui.TUIActionRunModel, ForceConfigure: true},
|
||||||
|
wantForce: true,
|
||||||
|
wantModel: "glm-5:cloud",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
setCmdTestHome(t, t.TempDir())
|
||||||
|
|
||||||
|
var menuCalls int
|
||||||
|
runMenu := func(state *launch.LauncherState) (tui.TUIAction, error) {
|
||||||
|
menuCalls++
|
||||||
|
if menuCalls == 1 {
|
||||||
|
return tt.action, nil
|
||||||
|
}
|
||||||
|
return tui.TUIAction{Kind: tui.TUIActionNone}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var gotReq launch.RunModelRequest
|
||||||
|
var launched string
|
||||||
|
deps := launcherDeps{
|
||||||
|
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
|
||||||
|
return &launch.LauncherState{}, nil
|
||||||
|
},
|
||||||
|
runMenu: runMenu,
|
||||||
|
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
|
||||||
|
gotReq = req
|
||||||
|
return tt.wantModel, nil
|
||||||
|
},
|
||||||
|
launchIntegration: unexpectedIntegrationLaunch(t),
|
||||||
|
runModel: func(cmd *cobra.Command, model string) error {
|
||||||
|
launched = model
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(context.Background())
|
||||||
|
for {
|
||||||
|
continueLoop, err := runInteractiveTUIStep(cmd, deps)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected step error: %v", err)
|
||||||
|
}
|
||||||
|
if !continueLoop {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotReq.ForcePicker != tt.wantForce {
|
||||||
|
t.Fatalf("expected ForcePicker=%v, got %v", tt.wantForce, gotReq.ForcePicker)
|
||||||
|
}
|
||||||
|
if launched != tt.wantModel {
|
||||||
|
t.Fatalf("expected interactive launcher to run %q, got %q", tt.wantModel, launched)
|
||||||
|
}
|
||||||
|
if got := config.LastSelection(); got != "run" {
|
||||||
|
t.Fatalf("expected last selection to be run, got %q", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunInteractiveTUI_IntegrationActionsUseLaunchIntegration(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
action tui.TUIAction
|
||||||
|
wantForce bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "enter launches integration",
|
||||||
|
action: tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "right forces configure",
|
||||||
|
action: tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude", ForceConfigure: true},
|
||||||
|
wantForce: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
setCmdTestHome(t, t.TempDir())
|
||||||
|
|
||||||
|
var menuCalls int
|
||||||
|
runMenu := func(state *launch.LauncherState) (tui.TUIAction, error) {
|
||||||
|
menuCalls++
|
||||||
|
if menuCalls == 1 {
|
||||||
|
return tt.action, nil
|
||||||
|
}
|
||||||
|
return tui.TUIAction{Kind: tui.TUIActionNone}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var gotReq launch.IntegrationLaunchRequest
|
||||||
|
deps := launcherDeps{
|
||||||
|
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
|
||||||
|
return &launch.LauncherState{}, nil
|
||||||
|
},
|
||||||
|
runMenu: runMenu,
|
||||||
|
resolveRunModel: unexpectedRunModelResolution(t),
|
||||||
|
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||||
|
gotReq = req
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
runModel: unexpectedModelLaunch(t),
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(context.Background())
|
||||||
|
for {
|
||||||
|
continueLoop, err := runInteractiveTUIStep(cmd, deps)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected step error: %v", err)
|
||||||
|
}
|
||||||
|
if !continueLoop {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotReq.Name != "claude" {
|
||||||
|
t.Fatalf("expected integration name to be passed through, got %q", gotReq.Name)
|
||||||
|
}
|
||||||
|
if gotReq.ForceConfigure != tt.wantForce {
|
||||||
|
t.Fatalf("expected ForceConfigure=%v, got %v", tt.wantForce, gotReq.ForceConfigure)
|
||||||
|
}
|
||||||
|
if got := config.LastSelection(); got != "claude" {
|
||||||
|
t.Fatalf("expected last selection to be claude, got %q", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunLauncherAction_RunModelContinuesAfterCancellation(t *testing.T) {
|
||||||
|
setCmdTestHome(t, t.TempDir())
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(context.Background())
|
||||||
|
|
||||||
|
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionRunModel}, launcherDeps{
|
||||||
|
buildState: nil,
|
||||||
|
runMenu: nil,
|
||||||
|
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
|
||||||
|
return "", launch.ErrCancelled
|
||||||
|
},
|
||||||
|
launchIntegration: unexpectedIntegrationLaunch(t),
|
||||||
|
runModel: unexpectedModelLaunch(t),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected nil error on cancellation, got %v", err)
|
||||||
|
}
|
||||||
|
if !continueLoop {
|
||||||
|
t.Fatal("expected cancellation to continue the menu loop")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunLauncherAction_IntegrationContinuesAfterCancellation(t *testing.T) {
|
||||||
|
setCmdTestHome(t, t.TempDir())
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(context.Background())
|
||||||
|
|
||||||
|
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"}, launcherDeps{
|
||||||
|
buildState: nil,
|
||||||
|
runMenu: nil,
|
||||||
|
resolveRunModel: unexpectedRunModelResolution(t),
|
||||||
|
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||||
|
return launch.ErrCancelled
|
||||||
|
},
|
||||||
|
runModel: unexpectedModelLaunch(t),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected nil error on cancellation, got %v", err)
|
||||||
|
}
|
||||||
|
if !continueLoop {
|
||||||
|
t.Fatal("expected cancellation to continue the menu loop")
|
||||||
|
}
|
||||||
|
}
|
||||||
466
cmd/cmd_test.go
466
cmd/cmd_test.go
@@ -705,6 +705,347 @@ func TestRunEmbeddingModelNoInput(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRunHandler_CloudAuthErrorOnShow_PrintsSigninMessage(t *testing.T) {
|
||||||
|
var generateCalled bool
|
||||||
|
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
if err := json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"error": "unauthorized",
|
||||||
|
"signin_url": "https://ollama.com/signin",
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||||
|
generateCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
t.Cleanup(mockServer.Close)
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(t.Context())
|
||||||
|
cmd.Flags().String("keepalive", "", "")
|
||||||
|
cmd.Flags().Bool("truncate", false, "")
|
||||||
|
cmd.Flags().Int("dimensions", 0, "")
|
||||||
|
cmd.Flags().Bool("verbose", false, "")
|
||||||
|
cmd.Flags().Bool("insecure", false, "")
|
||||||
|
cmd.Flags().Bool("nowordwrap", false, "")
|
||||||
|
cmd.Flags().String("format", "", "")
|
||||||
|
cmd.Flags().String("think", "", "")
|
||||||
|
cmd.Flags().Bool("hidethinking", false, "")
|
||||||
|
|
||||||
|
oldStdout := os.Stdout
|
||||||
|
readOut, writeOut, _ := os.Pipe()
|
||||||
|
os.Stdout = writeOut
|
||||||
|
t.Cleanup(func() { os.Stdout = oldStdout })
|
||||||
|
|
||||||
|
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||||
|
|
||||||
|
_ = writeOut.Close()
|
||||||
|
var out bytes.Buffer
|
||||||
|
_, _ = io.Copy(&out, readOut)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RunHandler returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if generateCalled {
|
||||||
|
t.Fatal("expected run to stop before /api/generate after unauthorized /api/show")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
|
||||||
|
t.Fatalf("expected sign-in guidance message, got %q", out.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(out.String(), "https://ollama.com/signin") {
|
||||||
|
t.Fatalf("expected signin_url in output, got %q", out.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunHandler_CloudAuthErrorOnGenerate_PrintsSigninMessage(t *testing.T) {
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||||
|
Capabilities: []model.Capability{model.CapabilityCompletion},
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
if err := json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"error": "unauthorized",
|
||||||
|
"signin_url": "https://ollama.com/signin",
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
t.Cleanup(mockServer.Close)
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(t.Context())
|
||||||
|
cmd.Flags().String("keepalive", "", "")
|
||||||
|
cmd.Flags().Bool("truncate", false, "")
|
||||||
|
cmd.Flags().Int("dimensions", 0, "")
|
||||||
|
cmd.Flags().Bool("verbose", false, "")
|
||||||
|
cmd.Flags().Bool("insecure", false, "")
|
||||||
|
cmd.Flags().Bool("nowordwrap", false, "")
|
||||||
|
cmd.Flags().String("format", "", "")
|
||||||
|
cmd.Flags().String("think", "", "")
|
||||||
|
cmd.Flags().Bool("hidethinking", false, "")
|
||||||
|
|
||||||
|
oldStdout := os.Stdout
|
||||||
|
readOut, writeOut, _ := os.Pipe()
|
||||||
|
os.Stdout = writeOut
|
||||||
|
t.Cleanup(func() { os.Stdout = oldStdout })
|
||||||
|
|
||||||
|
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||||
|
|
||||||
|
_ = writeOut.Close()
|
||||||
|
var out bytes.Buffer
|
||||||
|
_, _ = io.Copy(&out, readOut)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RunHandler returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(out.String(), "You need to be signed in to Ollama to run Cloud models.") {
|
||||||
|
t.Fatalf("expected sign-in guidance message, got %q", out.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(out.String(), "https://ollama.com/signin") {
|
||||||
|
t.Fatalf("expected signin_url in output, got %q", out.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunHandler_ExplicitCloudStubMissing_PullsNormalizedNameTEMP(t *testing.T) {
|
||||||
|
var pulledModel string
|
||||||
|
var generateCalled bool
|
||||||
|
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||||
|
Capabilities: []model.Capability{model.CapabilityCompletion},
|
||||||
|
RemoteModel: "gpt-oss:20b",
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/tags" && r.Method == http.MethodGet:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ListResponse{Models: nil}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/pull" && r.Method == http.MethodPost:
|
||||||
|
var req api.PullRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pulledModel = req.Model
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ProgressResponse{Status: "success"}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||||
|
generateCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
t.Cleanup(mockServer.Close)
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(t.Context())
|
||||||
|
cmd.Flags().String("keepalive", "", "")
|
||||||
|
cmd.Flags().Bool("truncate", false, "")
|
||||||
|
cmd.Flags().Int("dimensions", 0, "")
|
||||||
|
cmd.Flags().Bool("verbose", false, "")
|
||||||
|
cmd.Flags().Bool("insecure", false, "")
|
||||||
|
cmd.Flags().Bool("nowordwrap", false, "")
|
||||||
|
cmd.Flags().String("format", "", "")
|
||||||
|
cmd.Flags().String("think", "", "")
|
||||||
|
cmd.Flags().Bool("hidethinking", false, "")
|
||||||
|
|
||||||
|
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RunHandler returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pulledModel != "gpt-oss:20b-cloud" {
|
||||||
|
t.Fatalf("expected normalized pull model %q, got %q", "gpt-oss:20b-cloud", pulledModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !generateCalled {
|
||||||
|
t.Fatal("expected /api/generate to be called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunHandler_ExplicitCloudStubPresent_SkipsPullTEMP(t *testing.T) {
|
||||||
|
var pullCalled bool
|
||||||
|
var generateCalled bool
|
||||||
|
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||||
|
Capabilities: []model.Capability{model.CapabilityCompletion},
|
||||||
|
RemoteModel: "gpt-oss:20b",
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/tags" && r.Method == http.MethodGet:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ListResponse{
|
||||||
|
Models: []api.ListModelResponse{{Name: "gpt-oss:20b-cloud"}},
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/pull" && r.Method == http.MethodPost:
|
||||||
|
pullCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ProgressResponse{Status: "success"}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||||
|
generateCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
t.Cleanup(mockServer.Close)
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(t.Context())
|
||||||
|
cmd.Flags().String("keepalive", "", "")
|
||||||
|
cmd.Flags().Bool("truncate", false, "")
|
||||||
|
cmd.Flags().Int("dimensions", 0, "")
|
||||||
|
cmd.Flags().Bool("verbose", false, "")
|
||||||
|
cmd.Flags().Bool("insecure", false, "")
|
||||||
|
cmd.Flags().Bool("nowordwrap", false, "")
|
||||||
|
cmd.Flags().String("format", "", "")
|
||||||
|
cmd.Flags().String("think", "", "")
|
||||||
|
cmd.Flags().Bool("hidethinking", false, "")
|
||||||
|
|
||||||
|
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RunHandler returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pullCalled {
|
||||||
|
t.Fatal("expected /api/pull not to be called when cloud stub already exists")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !generateCalled {
|
||||||
|
t.Fatal("expected /api/generate to be called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunHandler_ExplicitCloudStubPullFailure_IsBestEffortTEMP(t *testing.T) {
|
||||||
|
var generateCalled bool
|
||||||
|
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case r.URL.Path == "/api/show" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||||
|
Capabilities: []model.Capability{model.CapabilityCompletion},
|
||||||
|
RemoteModel: "gpt-oss:20b",
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/tags" && r.Method == http.MethodGet:
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.ListResponse{Models: nil}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/pull" && r.Method == http.MethodPost:
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
if err := json.NewEncoder(w).Encode(map[string]string{"error": "pull failed"}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case r.URL.Path == "/api/generate" && r.Method == http.MethodPost:
|
||||||
|
generateCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(api.GenerateResponse{Done: true}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
t.Cleanup(mockServer.Close)
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(t.Context())
|
||||||
|
cmd.Flags().String("keepalive", "", "")
|
||||||
|
cmd.Flags().Bool("truncate", false, "")
|
||||||
|
cmd.Flags().Int("dimensions", 0, "")
|
||||||
|
cmd.Flags().Bool("verbose", false, "")
|
||||||
|
cmd.Flags().Bool("insecure", false, "")
|
||||||
|
cmd.Flags().Bool("nowordwrap", false, "")
|
||||||
|
cmd.Flags().String("format", "", "")
|
||||||
|
cmd.Flags().String("think", "", "")
|
||||||
|
cmd.Flags().Bool("hidethinking", false, "")
|
||||||
|
|
||||||
|
err := RunHandler(cmd, []string{"gpt-oss:20b:cloud", "hi"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RunHandler returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !generateCalled {
|
||||||
|
t.Fatal("expected /api/generate to be called despite pull failure")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetModelfileName(t *testing.T) {
|
func TestGetModelfileName(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -1212,6 +1553,20 @@ func TestNewCreateRequest(t *testing.T) {
|
|||||||
Model: "newmodel",
|
Model: "newmodel",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"explicit cloud model preserves source when parent lacks it",
|
||||||
|
"newmodel",
|
||||||
|
runOptions{
|
||||||
|
Model: "qwen3.5:cloud",
|
||||||
|
ParentModel: "qwen3.5",
|
||||||
|
Messages: []api.Message{},
|
||||||
|
WordWrap: true,
|
||||||
|
},
|
||||||
|
&api.CreateRequest{
|
||||||
|
From: "qwen3.5:cloud",
|
||||||
|
Model: "newmodel",
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"parent model as filepath test",
|
"parent model as filepath test",
|
||||||
"newmodel",
|
"newmodel",
|
||||||
@@ -1664,30 +2019,80 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
|
|||||||
func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
model string
|
||||||
|
showStatus int
|
||||||
remoteHost string
|
remoteHost string
|
||||||
|
remoteModel string
|
||||||
whoamiStatus int
|
whoamiStatus int
|
||||||
whoamiResp any
|
whoamiResp any
|
||||||
|
expectWhoami bool
|
||||||
expectedError string
|
expectedError string
|
||||||
|
expectAuthError bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "ollama.com cloud model - user signed in",
|
name: "ollama.com cloud model - user signed in",
|
||||||
|
model: "test-cloud-model",
|
||||||
remoteHost: "https://ollama.com",
|
remoteHost: "https://ollama.com",
|
||||||
|
remoteModel: "test-model",
|
||||||
whoamiStatus: http.StatusOK,
|
whoamiStatus: http.StatusOK,
|
||||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||||
|
expectWhoami: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "ollama.com cloud model - user not signed in",
|
name: "ollama.com cloud model - user not signed in",
|
||||||
|
model: "test-cloud-model",
|
||||||
remoteHost: "https://ollama.com",
|
remoteHost: "https://ollama.com",
|
||||||
|
remoteModel: "test-model",
|
||||||
whoamiStatus: http.StatusUnauthorized,
|
whoamiStatus: http.StatusUnauthorized,
|
||||||
whoamiResp: map[string]string{
|
whoamiResp: map[string]string{
|
||||||
"error": "unauthorized",
|
"error": "unauthorized",
|
||||||
"signin_url": "https://ollama.com/signin",
|
"signin_url": "https://ollama.com/signin",
|
||||||
},
|
},
|
||||||
|
expectWhoami: true,
|
||||||
expectedError: "unauthorized",
|
expectedError: "unauthorized",
|
||||||
|
expectAuthError: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "non-ollama.com remote - no auth check",
|
name: "non-ollama.com remote - no auth check",
|
||||||
|
model: "test-cloud-model",
|
||||||
remoteHost: "https://other-remote.com",
|
remoteHost: "https://other-remote.com",
|
||||||
|
remoteModel: "test-model",
|
||||||
|
whoamiStatus: http.StatusUnauthorized, // should not be called
|
||||||
|
whoamiResp: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit :cloud model - auth check without remote metadata",
|
||||||
|
model: "kimi-k2.5:cloud",
|
||||||
|
remoteHost: "",
|
||||||
|
remoteModel: "",
|
||||||
|
whoamiStatus: http.StatusOK,
|
||||||
|
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||||
|
expectWhoami: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit :cloud model without local stub returns not found by default",
|
||||||
|
model: "minimax-m2.5:cloud",
|
||||||
|
showStatus: http.StatusNotFound,
|
||||||
|
whoamiStatus: http.StatusOK,
|
||||||
|
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||||
|
expectedError: "not found",
|
||||||
|
expectWhoami: false,
|
||||||
|
expectAuthError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit -cloud model - auth check without remote metadata",
|
||||||
|
model: "kimi-k2.5:latest-cloud",
|
||||||
|
remoteHost: "",
|
||||||
|
remoteModel: "",
|
||||||
|
whoamiStatus: http.StatusOK,
|
||||||
|
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||||
|
expectWhoami: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dash cloud-like name without explicit source does not require auth",
|
||||||
|
model: "test-cloud-model",
|
||||||
|
remoteHost: "",
|
||||||
|
remoteModel: "",
|
||||||
whoamiStatus: http.StatusUnauthorized, // should not be called
|
whoamiStatus: http.StatusUnauthorized, // should not be called
|
||||||
whoamiResp: nil,
|
whoamiResp: nil,
|
||||||
},
|
},
|
||||||
@@ -1699,10 +2104,15 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
|||||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
switch r.URL.Path {
|
switch r.URL.Path {
|
||||||
case "/api/show":
|
case "/api/show":
|
||||||
|
if tt.showStatus != 0 && tt.showStatus != http.StatusOK {
|
||||||
|
w.WriteHeader(tt.showStatus)
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]string{"error": "not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||||
RemoteHost: tt.remoteHost,
|
RemoteHost: tt.remoteHost,
|
||||||
RemoteModel: "test-model",
|
RemoteModel: tt.remoteModel,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
@@ -1715,6 +2125,8 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
|||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case "/api/generate":
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
default:
|
default:
|
||||||
http.NotFound(w, r)
|
http.NotFound(w, r)
|
||||||
}
|
}
|
||||||
@@ -1727,31 +2139,30 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
|||||||
cmd.SetContext(t.Context())
|
cmd.SetContext(t.Context())
|
||||||
|
|
||||||
opts := &runOptions{
|
opts := &runOptions{
|
||||||
Model: "test-cloud-model",
|
Model: tt.model,
|
||||||
ShowConnect: false,
|
ShowConnect: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := loadOrUnloadModel(cmd, opts)
|
err := loadOrUnloadModel(cmd, opts)
|
||||||
|
|
||||||
if strings.HasPrefix(tt.remoteHost, "https://ollama.com") {
|
if whoamiCalled != tt.expectWhoami {
|
||||||
if !whoamiCalled {
|
t.Errorf("whoami called = %v, want %v", whoamiCalled, tt.expectWhoami)
|
||||||
t.Error("expected whoami to be called for ollama.com cloud model")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if whoamiCalled {
|
|
||||||
t.Error("whoami should not be called for non-ollama.com remote")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if tt.expectedError != "" {
|
if tt.expectedError != "" {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("expected error containing %q, got nil", tt.expectedError)
|
t.Errorf("expected error containing %q, got nil", tt.expectedError)
|
||||||
} else {
|
} else {
|
||||||
|
if !tt.expectAuthError && !strings.Contains(strings.ToLower(err.Error()), strings.ToLower(tt.expectedError)) {
|
||||||
|
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
|
||||||
|
}
|
||||||
|
if tt.expectAuthError {
|
||||||
var authErr api.AuthorizationError
|
var authErr api.AuthorizationError
|
||||||
if !errors.As(err, &authErr) {
|
if !errors.As(err, &authErr) {
|
||||||
t.Errorf("expected AuthorizationError, got %T: %v", err, err)
|
t.Errorf("expected AuthorizationError, got %T: %v", err, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("expected no error, got %v", err)
|
t.Errorf("expected no error, got %v", err)
|
||||||
@@ -1760,3 +2171,38 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsLocalhost(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
host string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"default empty", "", true},
|
||||||
|
{"localhost no port", "localhost", true},
|
||||||
|
{"localhost with port", "localhost:11435", true},
|
||||||
|
{"127.0.0.1 no port", "127.0.0.1", true},
|
||||||
|
{"127.0.0.1 with port", "127.0.0.1:11434", true},
|
||||||
|
{"0.0.0.0 no port", "0.0.0.0", true},
|
||||||
|
{"0.0.0.0 with port", "0.0.0.0:11434", true},
|
||||||
|
{"::1 no port", "::1", true},
|
||||||
|
{"[::1] with port", "[::1]:11434", true},
|
||||||
|
{"loopback with scheme", "http://localhost:11434", true},
|
||||||
|
{"remote hostname", "example.com", false},
|
||||||
|
{"remote hostname with port", "example.com:11434", false},
|
||||||
|
{"remote IP", "192.168.1.1", false},
|
||||||
|
{"remote IP with port", "192.168.1.1:11434", false},
|
||||||
|
{"remote with scheme", "http://example.com:11434", false},
|
||||||
|
{"https remote", "https://example.com:443", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_HOST", tt.host)
|
||||||
|
got := isLocalhost()
|
||||||
|
if got != tt.expected {
|
||||||
|
t.Errorf("isLocalhost() with OLLAMA_HOST=%q = %v, want %v", tt.host, got, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,192 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Claude implements Runner and AliasConfigurer for Claude Code integration
|
|
||||||
type Claude struct{}
|
|
||||||
|
|
||||||
// Compile-time check that Claude implements AliasConfigurer
|
|
||||||
var _ AliasConfigurer = (*Claude)(nil)
|
|
||||||
|
|
||||||
func (c *Claude) String() string { return "Claude Code" }
|
|
||||||
|
|
||||||
func (c *Claude) args(model string, extra []string) []string {
|
|
||||||
var args []string
|
|
||||||
if model != "" {
|
|
||||||
args = append(args, "--model", model)
|
|
||||||
}
|
|
||||||
args = append(args, extra...)
|
|
||||||
return args
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Claude) findPath() (string, error) {
|
|
||||||
if p, err := exec.LookPath("claude"); err == nil {
|
|
||||||
return p, nil
|
|
||||||
}
|
|
||||||
home, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
name := "claude"
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
name = "claude.exe"
|
|
||||||
}
|
|
||||||
fallback := filepath.Join(home, ".claude", "local", name)
|
|
||||||
if _, err := os.Stat(fallback); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return fallback, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Claude) Run(model string, args []string) error {
|
|
||||||
claudePath, err := c.findPath()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command(claudePath, c.args(model, args)...)
|
|
||||||
cmd.Stdin = os.Stdin
|
|
||||||
cmd.Stdout = os.Stdout
|
|
||||||
cmd.Stderr = os.Stderr
|
|
||||||
|
|
||||||
env := append(os.Environ(),
|
|
||||||
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
|
|
||||||
"ANTHROPIC_API_KEY=",
|
|
||||||
"ANTHROPIC_AUTH_TOKEN=ollama",
|
|
||||||
)
|
|
||||||
|
|
||||||
env = append(env, c.modelEnvVars(model)...)
|
|
||||||
|
|
||||||
cmd.Env = env
|
|
||||||
return cmd.Run()
|
|
||||||
}
|
|
||||||
|
|
||||||
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
|
|
||||||
func (c *Claude) modelEnvVars(model string) []string {
|
|
||||||
primary := model
|
|
||||||
fast := model
|
|
||||||
if cfg, err := loadIntegration("claude"); err == nil && cfg.Aliases != nil {
|
|
||||||
if p := cfg.Aliases["primary"]; p != "" {
|
|
||||||
primary = p
|
|
||||||
}
|
|
||||||
if f := cfg.Aliases["fast"]; f != "" {
|
|
||||||
fast = f
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return []string{
|
|
||||||
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + primary,
|
|
||||||
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + primary,
|
|
||||||
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + fast,
|
|
||||||
"CLAUDE_CODE_SUBAGENT_MODEL=" + primary,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConfigureAliases sets up model aliases for Claude Code.
|
|
||||||
// model: the model to use (if empty, user will be prompted to select)
|
|
||||||
// aliases: existing alias configuration to preserve/update
|
|
||||||
// Cloud-only: subagent routing (fast model) is gated to cloud models only until
|
|
||||||
// there is a better strategy for prompt caching on local models.
|
|
||||||
func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAliases map[string]string, force bool) (map[string]string, bool, error) {
|
|
||||||
aliases := make(map[string]string)
|
|
||||||
for k, v := range existingAliases {
|
|
||||||
aliases[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
if model != "" {
|
|
||||||
aliases["primary"] = model
|
|
||||||
}
|
|
||||||
|
|
||||||
if !force && aliases["primary"] != "" {
|
|
||||||
client, _ := api.ClientFromEnvironment()
|
|
||||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
|
||||||
if isCloudModel(ctx, client, aliases["fast"]) {
|
|
||||||
return aliases, false, nil
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
delete(aliases, "fast")
|
|
||||||
return aliases, false, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
items, existingModels, cloudModels, client, err := listModels(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
|
|
||||||
|
|
||||||
if aliases["primary"] == "" || force {
|
|
||||||
primary, err := DefaultSingleSelector("Select model:", items, aliases["primary"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, false, err
|
|
||||||
}
|
|
||||||
if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil {
|
|
||||||
return nil, false, err
|
|
||||||
}
|
|
||||||
if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil {
|
|
||||||
return nil, false, err
|
|
||||||
}
|
|
||||||
aliases["primary"] = primary
|
|
||||||
}
|
|
||||||
|
|
||||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
|
||||||
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
|
|
||||||
aliases["fast"] = aliases["primary"]
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
delete(aliases, "fast")
|
|
||||||
}
|
|
||||||
|
|
||||||
return aliases, true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetAliases syncs the configured aliases to the Ollama server using prefix matching.
|
|
||||||
// Cloud-only: for local models (fast is empty), we delete any existing aliases to
|
|
||||||
// prevent stale routing to a previous cloud model.
|
|
||||||
func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error {
|
|
||||||
client, err := api.ClientFromEnvironment()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
prefixes := []string{"claude-sonnet-", "claude-haiku-"}
|
|
||||||
|
|
||||||
if aliases["fast"] == "" {
|
|
||||||
for _, prefix := range prefixes {
|
|
||||||
_ = client.DeleteAliasExperimental(ctx, &api.AliasDeleteRequest{Alias: prefix})
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
prefixAliases := map[string]string{
|
|
||||||
"claude-sonnet-": aliases["primary"],
|
|
||||||
"claude-haiku-": aliases["fast"],
|
|
||||||
}
|
|
||||||
|
|
||||||
var errs []string
|
|
||||||
for prefix, target := range prefixAliases {
|
|
||||||
req := &api.AliasRequest{
|
|
||||||
Alias: prefix,
|
|
||||||
Target: target,
|
|
||||||
PrefixMatching: true,
|
|
||||||
}
|
|
||||||
if err := client.SetAliasExperimental(ctx, req); err != nil {
|
|
||||||
errs = append(errs, prefix)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(errs) > 0 {
|
|
||||||
return fmt.Errorf("failed to set aliases: %v", errs)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -3,7 +3,6 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -11,14 +10,18 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
type integration struct {
|
type integration struct {
|
||||||
Models []string `json:"models"`
|
Models []string `json:"models"`
|
||||||
Aliases map[string]string `json:"aliases,omitempty"`
|
Aliases map[string]string `json:"aliases,omitempty"`
|
||||||
|
Onboarded bool `json:"onboarded,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IntegrationConfig is the persisted config for one integration.
|
||||||
|
type IntegrationConfig = integration
|
||||||
|
|
||||||
type config struct {
|
type config struct {
|
||||||
Integrations map[string]*integration `json:"integrations"`
|
Integrations map[string]*integration `json:"integrations"`
|
||||||
LastModel string `json:"last_model,omitempty"`
|
LastModel string `json:"last_model,omitempty"`
|
||||||
@@ -123,7 +126,7 @@ func save(cfg *config) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return writeWithBackup(path, data)
|
return fileutil.WriteWithBackup(path, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func SaveIntegration(appName string, models []string) error {
|
func SaveIntegration(appName string, models []string) error {
|
||||||
@@ -139,34 +142,54 @@ func SaveIntegration(appName string, models []string) error {
|
|||||||
key := strings.ToLower(appName)
|
key := strings.ToLower(appName)
|
||||||
existing := cfg.Integrations[key]
|
existing := cfg.Integrations[key]
|
||||||
var aliases map[string]string
|
var aliases map[string]string
|
||||||
if existing != nil && existing.Aliases != nil {
|
var onboarded bool
|
||||||
|
if existing != nil {
|
||||||
aliases = existing.Aliases
|
aliases = existing.Aliases
|
||||||
|
onboarded = existing.Onboarded
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg.Integrations[key] = &integration{
|
cfg.Integrations[key] = &integration{
|
||||||
Models: models,
|
Models: models,
|
||||||
Aliases: aliases,
|
Aliases: aliases,
|
||||||
|
Onboarded: onboarded,
|
||||||
}
|
}
|
||||||
|
|
||||||
return save(cfg)
|
return save(cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarkIntegrationOnboarded marks an integration as onboarded in Ollama's config.
|
||||||
|
func MarkIntegrationOnboarded(appName string) error {
|
||||||
|
cfg, err := load()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
key := strings.ToLower(appName)
|
||||||
|
existing := cfg.Integrations[key]
|
||||||
|
if existing == nil {
|
||||||
|
existing = &integration{}
|
||||||
|
}
|
||||||
|
existing.Onboarded = true
|
||||||
|
cfg.Integrations[key] = existing
|
||||||
|
return save(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
|
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
|
||||||
func IntegrationModel(appName string) string {
|
func IntegrationModel(appName string) string {
|
||||||
ic, err := loadIntegration(appName)
|
integrationConfig, err := LoadIntegration(appName)
|
||||||
if err != nil || len(ic.Models) == 0 {
|
if err != nil || len(integrationConfig.Models) == 0 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
return ic.Models[0]
|
return integrationConfig.Models[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
// IntegrationModels returns all configured models for an integration, or nil.
|
// IntegrationModels returns all configured models for an integration, or nil.
|
||||||
func IntegrationModels(appName string) []string {
|
func IntegrationModels(appName string) []string {
|
||||||
ic, err := loadIntegration(appName)
|
integrationConfig, err := LoadIntegration(appName)
|
||||||
if err != nil || len(ic.Models) == 0 {
|
if err != nil || len(integrationConfig.Models) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return ic.Models
|
return integrationConfig.Models
|
||||||
}
|
}
|
||||||
|
|
||||||
// LastModel returns the last model that was run, or empty string if none.
|
// LastModel returns the last model that was run, or empty string if none.
|
||||||
@@ -207,42 +230,23 @@ func SetLastSelection(selection string) error {
|
|||||||
return save(cfg)
|
return save(cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModelExists checks if a model exists on the Ollama server.
|
// LoadIntegration returns the saved config for one integration.
|
||||||
func ModelExists(ctx context.Context, name string) bool {
|
func LoadIntegration(appName string) (*integration, error) {
|
||||||
if name == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
client, err := api.ClientFromEnvironment()
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
models, err := client.List(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for _, m := range models.Models {
|
|
||||||
if m.Name == name || strings.HasPrefix(m.Name, name+":") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadIntegration(appName string) (*integration, error) {
|
|
||||||
cfg, err := load()
|
cfg, err := load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ic, ok := cfg.Integrations[strings.ToLower(appName)]
|
integrationConfig, ok := cfg.Integrations[strings.ToLower(appName)]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, os.ErrNotExist
|
return nil, os.ErrNotExist
|
||||||
}
|
}
|
||||||
|
|
||||||
return ic, nil
|
return integrationConfig, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func saveAliases(appName string, aliases map[string]string) error {
|
// SaveAliases replaces the saved aliases for one integration.
|
||||||
|
func SaveAliases(appName string, aliases map[string]string) error {
|
||||||
if appName == "" {
|
if appName == "" {
|
||||||
return errors.New("app name cannot be empty")
|
return errors.New("app name cannot be empty")
|
||||||
}
|
}
|
||||||
@@ -272,8 +276,8 @@ func listIntegrations() ([]integration, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
result := make([]integration, 0, len(cfg.Integrations))
|
result := make([]integration, 0, len(cfg.Integrations))
|
||||||
for _, ic := range cfg.Integrations {
|
for _, integrationConfig := range cfg.Integrations {
|
||||||
result = append(result, *ic)
|
result = append(result, *integrationConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -45,12 +44,12 @@ func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
|
|||||||
"primary": "cloud-model",
|
"primary": "cloud-model",
|
||||||
"fast": "cloud-model",
|
"fast": "cloud-model",
|
||||||
}
|
}
|
||||||
if err := saveAliases("claude", initial); err != nil {
|
if err := SaveAliases("claude", initial); err != nil {
|
||||||
t.Fatalf("failed to save initial aliases: %v", err)
|
t.Fatalf("failed to save initial aliases: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify both are saved
|
// Verify both are saved
|
||||||
loaded, err := loadIntegration("claude")
|
loaded, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load: %v", err)
|
t.Fatalf("failed to load: %v", err)
|
||||||
}
|
}
|
||||||
@@ -63,12 +62,12 @@ func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
|
|||||||
"primary": "local-model",
|
"primary": "local-model",
|
||||||
// fast intentionally missing
|
// fast intentionally missing
|
||||||
}
|
}
|
||||||
if err := saveAliases("claude", updated); err != nil {
|
if err := SaveAliases("claude", updated); err != nil {
|
||||||
t.Fatalf("failed to save updated aliases: %v", err)
|
t.Fatalf("failed to save updated aliases: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify fast is GONE (not merged/preserved)
|
// Verify fast is GONE (not merged/preserved)
|
||||||
loaded, err = loadIntegration("claude")
|
loaded, err = LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load after update: %v", err)
|
t.Fatalf("failed to load after update: %v", err)
|
||||||
}
|
}
|
||||||
@@ -91,12 +90,12 @@ func TestSaveAliases_PreservesModels(t *testing.T) {
|
|||||||
|
|
||||||
// Then update aliases
|
// Then update aliases
|
||||||
aliases := map[string]string{"primary": "new-model"}
|
aliases := map[string]string{"primary": "new-model"}
|
||||||
if err := saveAliases("claude", aliases); err != nil {
|
if err := SaveAliases("claude", aliases); err != nil {
|
||||||
t.Fatalf("failed to save aliases: %v", err)
|
t.Fatalf("failed to save aliases: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify models are preserved
|
// Verify models are preserved
|
||||||
loaded, err := loadIntegration("claude")
|
loaded, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load: %v", err)
|
t.Fatalf("failed to load: %v", err)
|
||||||
}
|
}
|
||||||
@@ -111,16 +110,16 @@ func TestSaveAliases_EmptyMap(t *testing.T) {
|
|||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
// Save with aliases
|
// Save with aliases
|
||||||
if err := saveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
|
if err := SaveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
|
||||||
t.Fatalf("failed to save: %v", err)
|
t.Fatalf("failed to save: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save empty map
|
// Save empty map
|
||||||
if err := saveAliases("claude", map[string]string{}); err != nil {
|
if err := SaveAliases("claude", map[string]string{}); err != nil {
|
||||||
t.Fatalf("failed to save empty: %v", err)
|
t.Fatalf("failed to save empty: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, err := loadIntegration("claude")
|
loaded, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load: %v", err)
|
t.Fatalf("failed to load: %v", err)
|
||||||
}
|
}
|
||||||
@@ -135,16 +134,16 @@ func TestSaveAliases_NilMap(t *testing.T) {
|
|||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
// Save with aliases first
|
// Save with aliases first
|
||||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||||
t.Fatalf("failed to save: %v", err)
|
t.Fatalf("failed to save: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save nil map - should clear aliases
|
// Save nil map - should clear aliases
|
||||||
if err := saveAliases("claude", nil); err != nil {
|
if err := SaveAliases("claude", nil); err != nil {
|
||||||
t.Fatalf("failed to save nil: %v", err)
|
t.Fatalf("failed to save nil: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, err := loadIntegration("claude")
|
loaded, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load: %v", err)
|
t.Fatalf("failed to load: %v", err)
|
||||||
}
|
}
|
||||||
@@ -155,7 +154,7 @@ func TestSaveAliases_NilMap(t *testing.T) {
|
|||||||
|
|
||||||
// TestSaveAliases_EmptyAppName returns error
|
// TestSaveAliases_EmptyAppName returns error
|
||||||
func TestSaveAliases_EmptyAppName(t *testing.T) {
|
func TestSaveAliases_EmptyAppName(t *testing.T) {
|
||||||
err := saveAliases("", map[string]string{"primary": "model"})
|
err := SaveAliases("", map[string]string{"primary": "model"})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error for empty app name")
|
t.Error("expected error for empty app name")
|
||||||
}
|
}
|
||||||
@@ -165,12 +164,12 @@ func TestSaveAliases_CaseInsensitive(t *testing.T) {
|
|||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
if err := saveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
|
if err := SaveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
|
||||||
t.Fatalf("failed to save: %v", err)
|
t.Fatalf("failed to save: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load with different case
|
// Load with different case
|
||||||
loaded, err := loadIntegration("claude")
|
loaded, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load: %v", err)
|
t.Fatalf("failed to load: %v", err)
|
||||||
}
|
}
|
||||||
@@ -179,11 +178,11 @@ func TestSaveAliases_CaseInsensitive(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update with different case
|
// Update with different case
|
||||||
if err := saveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
|
if err := SaveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
|
||||||
t.Fatalf("failed to update: %v", err)
|
t.Fatalf("failed to update: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, err = loadIntegration("claude")
|
loaded, err = LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load after update: %v", err)
|
t.Fatalf("failed to load after update: %v", err)
|
||||||
}
|
}
|
||||||
@@ -198,11 +197,11 @@ func TestSaveAliases_CreatesIntegration(t *testing.T) {
|
|||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
// Save aliases for non-existent integration
|
// Save aliases for non-existent integration
|
||||||
if err := saveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
|
if err := SaveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
|
||||||
t.Fatalf("failed to save: %v", err)
|
t.Fatalf("failed to save: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, err := loadIntegration("newintegration")
|
loaded, err := LoadIntegration("newintegration")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load: %v", err)
|
t.Fatalf("failed to load: %v", err)
|
||||||
}
|
}
|
||||||
@@ -371,12 +370,12 @@ func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) {
|
|||||||
t.Fatal("server should succeed")
|
t.Fatal("server should succeed")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||||
t.Fatalf("saveAliases failed: %v", err)
|
t.Fatalf("saveAliases failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify it was actually saved
|
// Verify it was actually saved
|
||||||
loaded, err := loadIntegration("claude")
|
loaded, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load: %v", err)
|
t.Fatalf("failed to load: %v", err)
|
||||||
}
|
}
|
||||||
@@ -408,7 +407,7 @@ func TestConfigFile_PreservesUnknownFields(t *testing.T) {
|
|||||||
os.WriteFile(configPath, []byte(initialConfig), 0o644)
|
os.WriteFile(configPath, []byte(initialConfig), 0o644)
|
||||||
|
|
||||||
// Update aliases
|
// Update aliases
|
||||||
if err := saveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
|
if err := SaveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
|
||||||
t.Fatalf("failed to save: %v", err)
|
t.Fatalf("failed to save: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -440,11 +439,6 @@ func containsHelper(s, substr string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClaudeImplementsAliasConfigurer(t *testing.T) {
|
|
||||||
c := &Claude{}
|
|
||||||
var _ AliasConfigurer = c // Compile-time check
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestModelNameEdgeCases(t *testing.T) {
|
func TestModelNameEdgeCases(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -464,11 +458,11 @@ func TestModelNameEdgeCases(t *testing.T) {
|
|||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
aliases := map[string]string{"primary": tc.model}
|
aliases := map[string]string{"primary": tc.model}
|
||||||
if err := saveAliases("claude", aliases); err != nil {
|
if err := SaveAliases("claude", aliases); err != nil {
|
||||||
t.Fatalf("failed to save model %q: %v", tc.model, err)
|
t.Fatalf("failed to save model %q: %v", tc.model, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, err := loadIntegration("claude")
|
loaded, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to load: %v", err)
|
t.Fatalf("failed to load: %v", err)
|
||||||
}
|
}
|
||||||
@@ -485,7 +479,7 @@ func TestSwitchingScenarios(t *testing.T) {
|
|||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
// Initial cloud config
|
// Initial cloud config
|
||||||
if err := saveAliases("claude", map[string]string{
|
if err := SaveAliases("claude", map[string]string{
|
||||||
"primary": "cloud-model",
|
"primary": "cloud-model",
|
||||||
"fast": "cloud-model",
|
"fast": "cloud-model",
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
@@ -493,13 +487,13 @@ func TestSwitchingScenarios(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Switch to local (no fast)
|
// Switch to local (no fast)
|
||||||
if err := saveAliases("claude", map[string]string{
|
if err := SaveAliases("claude", map[string]string{
|
||||||
"primary": "local-model",
|
"primary": "local-model",
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, _ := loadIntegration("claude")
|
loaded, _ := LoadIntegration("claude")
|
||||||
if loaded.Aliases["fast"] != "" {
|
if loaded.Aliases["fast"] != "" {
|
||||||
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
|
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
|
||||||
}
|
}
|
||||||
@@ -513,21 +507,21 @@ func TestSwitchingScenarios(t *testing.T) {
|
|||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
// Initial local config
|
// Initial local config
|
||||||
if err := saveAliases("claude", map[string]string{
|
if err := SaveAliases("claude", map[string]string{
|
||||||
"primary": "local-model",
|
"primary": "local-model",
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Switch to cloud (with fast)
|
// Switch to cloud (with fast)
|
||||||
if err := saveAliases("claude", map[string]string{
|
if err := SaveAliases("claude", map[string]string{
|
||||||
"primary": "cloud-model",
|
"primary": "cloud-model",
|
||||||
"fast": "cloud-model",
|
"fast": "cloud-model",
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, _ := loadIntegration("claude")
|
loaded, _ := LoadIntegration("claude")
|
||||||
if loaded.Aliases["fast"] != "cloud-model" {
|
if loaded.Aliases["fast"] != "cloud-model" {
|
||||||
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
|
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
|
||||||
}
|
}
|
||||||
@@ -538,7 +532,7 @@ func TestSwitchingScenarios(t *testing.T) {
|
|||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
// Initial cloud config
|
// Initial cloud config
|
||||||
if err := saveAliases("claude", map[string]string{
|
if err := SaveAliases("claude", map[string]string{
|
||||||
"primary": "cloud-model-1",
|
"primary": "cloud-model-1",
|
||||||
"fast": "cloud-model-1",
|
"fast": "cloud-model-1",
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
@@ -546,14 +540,14 @@ func TestSwitchingScenarios(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Switch to different cloud
|
// Switch to different cloud
|
||||||
if err := saveAliases("claude", map[string]string{
|
if err := SaveAliases("claude", map[string]string{
|
||||||
"primary": "cloud-model-2",
|
"primary": "cloud-model-2",
|
||||||
"fast": "cloud-model-2",
|
"fast": "cloud-model-2",
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, _ := loadIntegration("claude")
|
loaded, _ := LoadIntegration("claude")
|
||||||
if loaded.Aliases["primary"] != "cloud-model-2" {
|
if loaded.Aliases["primary"] != "cloud-model-2" {
|
||||||
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
|
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
|
||||||
}
|
}
|
||||||
@@ -563,43 +557,13 @@ func TestSwitchingScenarios(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestToolCapabilityFiltering(t *testing.T) {
|
|
||||||
t.Run("all models checked for tool capability", func(t *testing.T) {
|
|
||||||
// Both cloud and local models are checked for tool capability via Show API
|
|
||||||
// Only models with "tools" in capabilities are included
|
|
||||||
m := modelInfo{Name: "tool-model", Remote: false, ToolCapable: true}
|
|
||||||
if !m.ToolCapable {
|
|
||||||
t.Error("tool capable model should be marked as such")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("modelInfo includes ToolCapable field", func(t *testing.T) {
|
|
||||||
m := modelInfo{Name: "test", Remote: true, ToolCapable: true}
|
|
||||||
if !m.ToolCapable {
|
|
||||||
t.Error("ToolCapable field should be accessible")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsCloudModel_RequiresClient(t *testing.T) {
|
|
||||||
t.Run("nil client always returns false", func(t *testing.T) {
|
|
||||||
// isCloudModel now only uses Show API, no suffix detection
|
|
||||||
if isCloudModel(context.Background(), nil, "model:cloud") {
|
|
||||||
t.Error("nil client should return false regardless of suffix")
|
|
||||||
}
|
|
||||||
if isCloudModel(context.Background(), nil, "local-model") {
|
|
||||||
t.Error("nil client should return false")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||||
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
|
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
// Save aliases with one model
|
// Save aliases with one model
|
||||||
if err := saveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
|
if err := SaveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -608,7 +572,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, _ := loadIntegration("claude")
|
loaded, _ := LoadIntegration("claude")
|
||||||
if loaded.Aliases["primary"] != loaded.Models[0] {
|
if loaded.Aliases["primary"] != loaded.Models[0] {
|
||||||
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
|
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
|
||||||
}
|
}
|
||||||
@@ -622,11 +586,11 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
|||||||
if err := SaveIntegration("claude", []string{"old-model"}); err != nil {
|
if err := SaveIntegration("claude", []string{"old-model"}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
|
if err := SaveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, _ := loadIntegration("claude")
|
loaded, _ := LoadIntegration("claude")
|
||||||
|
|
||||||
// They should be different (this is the bug state)
|
// They should be different (this is the bug state)
|
||||||
if loaded.Models[0] == loaded.Aliases["primary"] {
|
if loaded.Models[0] == loaded.Aliases["primary"] {
|
||||||
@@ -638,7 +602,7 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, _ = loadIntegration("claude")
|
loaded, _ = LoadIntegration("claude")
|
||||||
if loaded.Models[0] != loaded.Aliases["primary"] {
|
if loaded.Models[0] != loaded.Aliases["primary"] {
|
||||||
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
|
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
|
||||||
loaded.Models[0], loaded.Aliases["primary"])
|
loaded.Models[0], loaded.Aliases["primary"])
|
||||||
@@ -653,20 +617,20 @@ func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
|||||||
if err := SaveIntegration("claude", []string{"initial-model"}); err != nil {
|
if err := SaveIntegration("claude", []string{"initial-model"}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
|
if err := SaveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update aliases AND models together
|
// Update aliases AND models together
|
||||||
newAliases := map[string]string{"primary": "updated-model"}
|
newAliases := map[string]string{"primary": "updated-model"}
|
||||||
if err := saveAliases("claude", newAliases); err != nil {
|
if err := SaveAliases("claude", newAliases); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := SaveIntegration("claude", []string{newAliases["primary"]}); err != nil {
|
if err := SaveIntegration("claude", []string{newAliases["primary"]}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded, _ := loadIntegration("claude")
|
loaded, _ := LoadIntegration("claude")
|
||||||
if loaded.Models[0] != "updated-model" {
|
if loaded.Models[0] != "updated-model" {
|
||||||
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
|
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,17 +10,10 @@ import (
|
|||||||
// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests
|
// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests
|
||||||
func setTestHome(t *testing.T, dir string) {
|
func setTestHome(t *testing.T, dir string) {
|
||||||
t.Setenv("HOME", dir)
|
t.Setenv("HOME", dir)
|
||||||
|
t.Setenv("TMPDIR", dir)
|
||||||
t.Setenv("USERPROFILE", dir)
|
t.Setenv("USERPROFILE", dir)
|
||||||
}
|
}
|
||||||
|
|
||||||
// editorPaths is a test helper that safely calls Paths if the runner implements Editor
|
|
||||||
func editorPaths(r Runner) []string {
|
|
||||||
if editor, ok := r.(Editor); ok {
|
|
||||||
return editor.Paths()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntegrationConfig(t *testing.T) {
|
func TestIntegrationConfig(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
@@ -31,7 +24,7 @@ func TestIntegrationConfig(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := loadIntegration("claude")
|
config, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -55,11 +48,11 @@ func TestIntegrationConfig(t *testing.T) {
|
|||||||
"primary": "llama3.2:70b",
|
"primary": "llama3.2:70b",
|
||||||
"fast": "llama3.2:8b",
|
"fast": "llama3.2:8b",
|
||||||
}
|
}
|
||||||
if err := saveAliases("claude", aliases); err != nil {
|
if err := SaveAliases("claude", aliases); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := loadIntegration("claude")
|
config, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -77,14 +70,14 @@ func TestIntegrationConfig(t *testing.T) {
|
|||||||
if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
|
if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
|
if err := SaveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := SaveIntegration("claude", []string{"model-b"}); err != nil {
|
if err := SaveIntegration("claude", []string{"model-b"}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
config, err := loadIntegration("claude")
|
config, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -96,7 +89,7 @@ func TestIntegrationConfig(t *testing.T) {
|
|||||||
t.Run("defaultModel returns first model", func(t *testing.T) {
|
t.Run("defaultModel returns first model", func(t *testing.T) {
|
||||||
SaveIntegration("codex", []string{"model-a", "model-b"})
|
SaveIntegration("codex", []string{"model-a", "model-b"})
|
||||||
|
|
||||||
config, _ := loadIntegration("codex")
|
config, _ := LoadIntegration("codex")
|
||||||
defaultModel := ""
|
defaultModel := ""
|
||||||
if len(config.Models) > 0 {
|
if len(config.Models) > 0 {
|
||||||
defaultModel = config.Models[0]
|
defaultModel = config.Models[0]
|
||||||
@@ -120,7 +113,7 @@ func TestIntegrationConfig(t *testing.T) {
|
|||||||
t.Run("app name is case-insensitive", func(t *testing.T) {
|
t.Run("app name is case-insensitive", func(t *testing.T) {
|
||||||
SaveIntegration("Claude", []string{"model-x"})
|
SaveIntegration("Claude", []string{"model-x"})
|
||||||
|
|
||||||
config, err := loadIntegration("claude")
|
config, err := LoadIntegration("claude")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -137,8 +130,8 @@ func TestIntegrationConfig(t *testing.T) {
|
|||||||
SaveIntegration("app1", []string{"model-1"})
|
SaveIntegration("app1", []string{"model-1"})
|
||||||
SaveIntegration("app2", []string{"model-2"})
|
SaveIntegration("app2", []string{"model-2"})
|
||||||
|
|
||||||
config1, _ := loadIntegration("app1")
|
config1, _ := LoadIntegration("app1")
|
||||||
config2, _ := loadIntegration("app2")
|
config2, _ := LoadIntegration("app2")
|
||||||
|
|
||||||
defaultModel1 := ""
|
defaultModel1 := ""
|
||||||
if len(config1.Models) > 0 {
|
if len(config1.Models) > 0 {
|
||||||
@@ -185,64 +178,6 @@ func TestListIntegrations(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEditorPaths(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
t.Run("returns empty for claude (no Editor)", func(t *testing.T) {
|
|
||||||
r := integrations["claude"]
|
|
||||||
paths := editorPaths(r)
|
|
||||||
if len(paths) != 0 {
|
|
||||||
t.Errorf("expected no paths for claude, got %v", paths)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns empty for codex (no Editor)", func(t *testing.T) {
|
|
||||||
r := integrations["codex"]
|
|
||||||
paths := editorPaths(r)
|
|
||||||
if len(paths) != 0 {
|
|
||||||
t.Errorf("expected no paths for codex, got %v", paths)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns empty for droid when no config exists", func(t *testing.T) {
|
|
||||||
r := integrations["droid"]
|
|
||||||
paths := editorPaths(r)
|
|
||||||
if len(paths) != 0 {
|
|
||||||
t.Errorf("expected no paths, got %v", paths)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns path for droid when config exists", func(t *testing.T) {
|
|
||||||
settingsDir, _ := os.UserHomeDir()
|
|
||||||
settingsDir = filepath.Join(settingsDir, ".factory")
|
|
||||||
os.MkdirAll(settingsDir, 0o755)
|
|
||||||
os.WriteFile(filepath.Join(settingsDir, "settings.json"), []byte(`{}`), 0o644)
|
|
||||||
|
|
||||||
r := integrations["droid"]
|
|
||||||
paths := editorPaths(r)
|
|
||||||
if len(paths) != 1 {
|
|
||||||
t.Errorf("expected 1 path, got %d", len(paths))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns paths for opencode when configs exist", func(t *testing.T) {
|
|
||||||
home, _ := os.UserHomeDir()
|
|
||||||
configDir := filepath.Join(home, ".config", "opencode")
|
|
||||||
stateDir := filepath.Join(home, ".local", "state", "opencode")
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.MkdirAll(stateDir, 0o755)
|
|
||||||
os.WriteFile(filepath.Join(configDir, "opencode.json"), []byte(`{}`), 0o644)
|
|
||||||
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644)
|
|
||||||
|
|
||||||
r := integrations["opencode"]
|
|
||||||
paths := editorPaths(r)
|
|
||||||
if len(paths) != 2 {
|
|
||||||
t.Errorf("expected 2 paths, got %d: %v", len(paths), paths)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
@@ -251,7 +186,7 @@ func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
|||||||
os.MkdirAll(dir, 0o755)
|
os.MkdirAll(dir, 0o755)
|
||||||
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
|
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
|
||||||
|
|
||||||
_, err := loadIntegration("test")
|
_, err := LoadIntegration("test")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error for nonexistent integration in corrupted file")
|
t.Error("expected error for nonexistent integration in corrupted file")
|
||||||
}
|
}
|
||||||
@@ -265,7 +200,7 @@ func TestSaveIntegration_NilModels(t *testing.T) {
|
|||||||
t.Fatalf("saveIntegration with nil models failed: %v", err)
|
t.Fatalf("saveIntegration with nil models failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := loadIntegration("test")
|
config, err := LoadIntegration("test")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("loadIntegration failed: %v", err)
|
t.Fatalf("loadIntegration failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -294,7 +229,7 @@ func TestLoadIntegration_NonexistentIntegration(t *testing.T) {
|
|||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
_, err := loadIntegration("nonexistent")
|
_, err := LoadIntegration("nonexistent")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error for nonexistent integration, got nil")
|
t.Error("expected error for nonexistent integration, got nil")
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,264 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Openclaw struct{}
|
|
||||||
|
|
||||||
func (c *Openclaw) String() string { return "OpenClaw" }
|
|
||||||
|
|
||||||
func (c *Openclaw) Run(model string, args []string) error {
|
|
||||||
bin := "openclaw"
|
|
||||||
if _, err := exec.LookPath(bin); err != nil {
|
|
||||||
bin = "clawdbot"
|
|
||||||
if _, err := exec.LookPath(bin); err != nil {
|
|
||||||
return fmt.Errorf("openclaw is not installed, install from https://docs.openclaw.ai")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
models := []string{model}
|
|
||||||
if config, err := loadIntegration("openclaw"); err == nil && len(config.Models) > 0 {
|
|
||||||
models = config.Models
|
|
||||||
} else if config, err := loadIntegration("clawdbot"); err == nil && len(config.Models) > 0 {
|
|
||||||
models = config.Models
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
models, err = resolveEditorModels("openclaw", models, func() ([]string, error) {
|
|
||||||
return selectModels(context.Background(), "openclaw", "")
|
|
||||||
})
|
|
||||||
if errors.Is(err, errCancelled) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := c.Edit(models); err != nil {
|
|
||||||
return fmt.Errorf("setup failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !c.onboarded() {
|
|
||||||
// Onboarding not completed: run it (model already set via Edit)
|
|
||||||
// Use "ollama" as gateway token for simple local access
|
|
||||||
cmd := exec.Command(bin, "onboard",
|
|
||||||
"--auth-choice", "skip",
|
|
||||||
"--gateway-token", "ollama",
|
|
||||||
)
|
|
||||||
cmd.Stdin = os.Stdin
|
|
||||||
cmd.Stdout = os.Stdout
|
|
||||||
cmd.Stderr = os.Stderr
|
|
||||||
return cmd.Run()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Onboarding completed: run gateway
|
|
||||||
cmd := exec.Command(bin, append([]string{"gateway"}, args...)...)
|
|
||||||
cmd.Stdin = os.Stdin
|
|
||||||
|
|
||||||
// Capture output to detect "already running" message
|
|
||||||
var outputBuf bytes.Buffer
|
|
||||||
cmd.Stdout = io.MultiWriter(os.Stdout, &outputBuf)
|
|
||||||
cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf)
|
|
||||||
|
|
||||||
err = cmd.Run()
|
|
||||||
if err != nil && strings.Contains(outputBuf.String(), "Gateway already running") {
|
|
||||||
fmt.Fprintf(os.Stderr, "%sOpenClaw has been configured with Ollama. Gateway is already running.%s\n", ansiGreen, ansiReset)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// onboarded checks if OpenClaw onboarding wizard was completed
|
|
||||||
// by looking for the wizard.lastRunAt marker in the config
|
|
||||||
func (c *Openclaw) onboarded() bool {
|
|
||||||
home, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
|
|
||||||
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
|
||||||
|
|
||||||
config := make(map[string]any)
|
|
||||||
if data, err := os.ReadFile(configPath); err == nil {
|
|
||||||
_ = json.Unmarshal(data, &config)
|
|
||||||
} else if data, err := os.ReadFile(legacyPath); err == nil {
|
|
||||||
_ = json.Unmarshal(data, &config)
|
|
||||||
} else {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for wizard.lastRunAt marker (set when onboarding completes)
|
|
||||||
wizard, _ := config["wizard"].(map[string]any)
|
|
||||||
if wizard == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
lastRunAt, _ := wizard["lastRunAt"].(string)
|
|
||||||
return lastRunAt != ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Openclaw) Paths() []string {
|
|
||||||
home, _ := os.UserHomeDir()
|
|
||||||
p := filepath.Join(home, ".openclaw", "openclaw.json")
|
|
||||||
if _, err := os.Stat(p); err == nil {
|
|
||||||
return []string{p}
|
|
||||||
}
|
|
||||||
legacy := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
|
||||||
if _, err := os.Stat(legacy); err == nil {
|
|
||||||
return []string{legacy}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Openclaw) Edit(models []string) error {
|
|
||||||
if len(models) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
home, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
|
|
||||||
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
|
||||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read into map[string]any to preserve unknown fields
|
|
||||||
config := make(map[string]any)
|
|
||||||
if data, err := os.ReadFile(configPath); err == nil {
|
|
||||||
_ = json.Unmarshal(data, &config)
|
|
||||||
} else if data, err := os.ReadFile(legacyPath); err == nil {
|
|
||||||
_ = json.Unmarshal(data, &config)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Navigate/create: models.providers.ollama (preserving other providers)
|
|
||||||
modelsSection, _ := config["models"].(map[string]any)
|
|
||||||
if modelsSection == nil {
|
|
||||||
modelsSection = make(map[string]any)
|
|
||||||
}
|
|
||||||
providers, _ := modelsSection["providers"].(map[string]any)
|
|
||||||
if providers == nil {
|
|
||||||
providers = make(map[string]any)
|
|
||||||
}
|
|
||||||
ollama, _ := providers["ollama"].(map[string]any)
|
|
||||||
if ollama == nil {
|
|
||||||
ollama = make(map[string]any)
|
|
||||||
}
|
|
||||||
|
|
||||||
ollama["baseUrl"] = envconfig.Host().String() + "/v1"
|
|
||||||
// needed to register provider
|
|
||||||
ollama["apiKey"] = "ollama-local"
|
|
||||||
// TODO(parthsareen): potentially move to responses
|
|
||||||
ollama["api"] = "openai-completions"
|
|
||||||
|
|
||||||
// Build map of existing models to preserve user customizations
|
|
||||||
existingModels, _ := ollama["models"].([]any)
|
|
||||||
existingByID := make(map[string]map[string]any)
|
|
||||||
for _, m := range existingModels {
|
|
||||||
if entry, ok := m.(map[string]any); ok {
|
|
||||||
if id, ok := entry["id"].(string); ok {
|
|
||||||
existingByID[id] = entry
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var newModels []any
|
|
||||||
for _, model := range models {
|
|
||||||
entry := map[string]any{
|
|
||||||
"id": model,
|
|
||||||
"name": model,
|
|
||||||
"reasoning": false,
|
|
||||||
"input": []any{"text"},
|
|
||||||
"cost": map[string]any{
|
|
||||||
"input": 0,
|
|
||||||
"output": 0,
|
|
||||||
"cacheRead": 0,
|
|
||||||
"cacheWrite": 0,
|
|
||||||
},
|
|
||||||
// TODO(parthsareen): get these values from API
|
|
||||||
"contextWindow": 131072,
|
|
||||||
"maxTokens": 16384,
|
|
||||||
}
|
|
||||||
// Merge existing fields (user customizations)
|
|
||||||
if existing, ok := existingByID[model]; ok {
|
|
||||||
for k, v := range existing {
|
|
||||||
if _, isNew := entry[k]; !isNew {
|
|
||||||
entry[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
newModels = append(newModels, entry)
|
|
||||||
}
|
|
||||||
ollama["models"] = newModels
|
|
||||||
|
|
||||||
providers["ollama"] = ollama
|
|
||||||
modelsSection["providers"] = providers
|
|
||||||
config["models"] = modelsSection
|
|
||||||
|
|
||||||
// Update agents.defaults.model.primary (preserving other agent settings)
|
|
||||||
agents, _ := config["agents"].(map[string]any)
|
|
||||||
if agents == nil {
|
|
||||||
agents = make(map[string]any)
|
|
||||||
}
|
|
||||||
defaults, _ := agents["defaults"].(map[string]any)
|
|
||||||
if defaults == nil {
|
|
||||||
defaults = make(map[string]any)
|
|
||||||
}
|
|
||||||
modelConfig, _ := defaults["model"].(map[string]any)
|
|
||||||
if modelConfig == nil {
|
|
||||||
modelConfig = make(map[string]any)
|
|
||||||
}
|
|
||||||
modelConfig["primary"] = "ollama/" + models[0]
|
|
||||||
defaults["model"] = modelConfig
|
|
||||||
agents["defaults"] = defaults
|
|
||||||
config["agents"] = agents
|
|
||||||
|
|
||||||
data, err := json.MarshalIndent(config, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return writeWithBackup(configPath, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Openclaw) Models() []string {
|
|
||||||
home, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
config, err := readJSONFile(filepath.Join(home, ".openclaw", "openclaw.json"))
|
|
||||||
if err != nil {
|
|
||||||
config, err = readJSONFile(filepath.Join(home, ".clawdbot", "clawdbot.json"))
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
modelsSection, _ := config["models"].(map[string]any)
|
|
||||||
providers, _ := modelsSection["providers"].(map[string]any)
|
|
||||||
ollama, _ := providers["ollama"].(map[string]any)
|
|
||||||
modelList, _ := ollama["models"].([]any)
|
|
||||||
|
|
||||||
var result []string
|
|
||||||
for _, m := range modelList {
|
|
||||||
if entry, ok := m.(map[string]any); ok {
|
|
||||||
if id, ok := entry["id"].(string); ok {
|
|
||||||
result = append(result, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
@@ -1,878 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestOpenclawIntegration(t *testing.T) {
|
|
||||||
c := &Openclaw{}
|
|
||||||
|
|
||||||
t.Run("String", func(t *testing.T) {
|
|
||||||
if got := c.String(); got != "OpenClaw" {
|
|
||||||
t.Errorf("String() = %q, want %q", got, "OpenClaw")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("implements Runner", func(t *testing.T) {
|
|
||||||
var _ Runner = c
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("implements Editor", func(t *testing.T) {
|
|
||||||
var _ Editor = c
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenclawEdit(t *testing.T) {
|
|
||||||
c := &Openclaw{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
|
||||||
configPath := filepath.Join(configDir, "openclaw.json")
|
|
||||||
|
|
||||||
cleanup := func() { os.RemoveAll(configDir) }
|
|
||||||
|
|
||||||
t.Run("fresh install", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
assertOpenclawModelExists(t, configPath, "llama3.2")
|
|
||||||
assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("multiple models - first is primary", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
assertOpenclawModelExists(t, configPath, "llama3.2")
|
|
||||||
assertOpenclawModelExists(t, configPath, "mistral")
|
|
||||||
assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("preserve other providers", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"anthropic":{"apiKey":"xxx"}}}}`), 0o644)
|
|
||||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
models := cfg["models"].(map[string]any)
|
|
||||||
providers := models["providers"].(map[string]any)
|
|
||||||
if providers["anthropic"] == nil {
|
|
||||||
t.Error("anthropic provider was removed")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("preserve top-level keys", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{"theme":"dark","mcp":{"servers":{}}}`), 0o644)
|
|
||||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
if cfg["theme"] != "dark" {
|
|
||||||
t.Error("theme was removed")
|
|
||||||
}
|
|
||||||
if cfg["mcp"] == nil {
|
|
||||||
t.Error("mcp was removed")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("preserve user customizations on models", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
c.Edit([]string{"llama3.2"})
|
|
||||||
|
|
||||||
// User adds custom field
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
models := cfg["models"].(map[string]any)
|
|
||||||
providers := models["providers"].(map[string]any)
|
|
||||||
ollama := providers["ollama"].(map[string]any)
|
|
||||||
modelList := ollama["models"].([]any)
|
|
||||||
entry := modelList[0].(map[string]any)
|
|
||||||
entry["customField"] = "user-value"
|
|
||||||
configData, _ := json.MarshalIndent(cfg, "", " ")
|
|
||||||
os.WriteFile(configPath, configData, 0o644)
|
|
||||||
|
|
||||||
// Re-run Edit
|
|
||||||
c.Edit([]string{"llama3.2"})
|
|
||||||
|
|
||||||
data, _ = os.ReadFile(configPath)
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
models = cfg["models"].(map[string]any)
|
|
||||||
providers = models["providers"].(map[string]any)
|
|
||||||
ollama = providers["ollama"].(map[string]any)
|
|
||||||
modelList = ollama["models"].([]any)
|
|
||||||
entry = modelList[0].(map[string]any)
|
|
||||||
if entry["customField"] != "user-value" {
|
|
||||||
t.Error("custom field was lost")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("edit replaces models list", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
c.Edit([]string{"llama3.2", "mistral"})
|
|
||||||
c.Edit([]string{"llama3.2"})
|
|
||||||
|
|
||||||
assertOpenclawModelExists(t, configPath, "llama3.2")
|
|
||||||
assertOpenclawModelNotExists(t, configPath, "mistral")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("empty models is no-op", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
original := `{"existing":"data"}`
|
|
||||||
os.WriteFile(configPath, []byte(original), 0o644)
|
|
||||||
|
|
||||||
c.Edit([]string{})
|
|
||||||
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
if string(data) != original {
|
|
||||||
t.Error("empty models should not modify file")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("corrupted JSON treated as empty", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
|
|
||||||
|
|
||||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
var cfg map[string]any
|
|
||||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
|
||||||
t.Error("result should be valid JSON")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("wrong type models section", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{"models":"not a map"}`), 0o644)
|
|
||||||
|
|
||||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
assertOpenclawModelExists(t, configPath, "llama3.2")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenclawModels(t *testing.T) {
|
|
||||||
c := &Openclaw{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
t.Run("no config returns nil", func(t *testing.T) {
|
|
||||||
if models := c.Models(); len(models) > 0 {
|
|
||||||
t.Errorf("expected nil/empty, got %v", models)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns all ollama models", func(t *testing.T) {
|
|
||||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
|
|
||||||
"models":{"providers":{"ollama":{"models":[
|
|
||||||
{"id":"llama3.2"},
|
|
||||||
{"id":"mistral"}
|
|
||||||
]}}}
|
|
||||||
}`), 0o644)
|
|
||||||
|
|
||||||
models := c.Models()
|
|
||||||
if len(models) != 2 {
|
|
||||||
t.Errorf("expected 2 models, got %v", models)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper functions
|
|
||||||
func assertOpenclawModelExists(t *testing.T, path, model string) {
|
|
||||||
t.Helper()
|
|
||||||
data, _ := os.ReadFile(path)
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
models := cfg["models"].(map[string]any)
|
|
||||||
providers := models["providers"].(map[string]any)
|
|
||||||
ollama := providers["ollama"].(map[string]any)
|
|
||||||
modelList := ollama["models"].([]any)
|
|
||||||
for _, m := range modelList {
|
|
||||||
if entry, ok := m.(map[string]any); ok {
|
|
||||||
if entry["id"] == model {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
t.Errorf("model %s not found", model)
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertOpenclawModelNotExists(t *testing.T, path, model string) {
|
|
||||||
t.Helper()
|
|
||||||
data, _ := os.ReadFile(path)
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
models, _ := cfg["models"].(map[string]any)
|
|
||||||
providers, _ := models["providers"].(map[string]any)
|
|
||||||
ollama, _ := providers["ollama"].(map[string]any)
|
|
||||||
modelList, _ := ollama["models"].([]any)
|
|
||||||
for _, m := range modelList {
|
|
||||||
if entry, ok := m.(map[string]any); ok {
|
|
||||||
if entry["id"] == model {
|
|
||||||
t.Errorf("model %s should not exist", model)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertOpenclawPrimaryModel(t *testing.T, path, expected string) {
|
|
||||||
t.Helper()
|
|
||||||
data, _ := os.ReadFile(path)
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
agents := cfg["agents"].(map[string]any)
|
|
||||||
defaults := agents["defaults"].(map[string]any)
|
|
||||||
model := defaults["model"].(map[string]any)
|
|
||||||
if model["primary"] != expected {
|
|
||||||
t.Errorf("primary model = %v, want %v", model["primary"], expected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenclawPaths(t *testing.T) {
|
|
||||||
c := &Openclaw{}
|
|
||||||
|
|
||||||
t.Run("returns path when config exists", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{}`), 0o644)
|
|
||||||
|
|
||||||
paths := c.Paths()
|
|
||||||
if len(paths) != 1 {
|
|
||||||
t.Errorf("expected 1 path, got %d", len(paths))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns nil when config missing", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
if paths := c.Paths(); paths != nil {
|
|
||||||
t.Errorf("expected nil, got %v", paths)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenclawModelsEdgeCases(t *testing.T) {
|
|
||||||
c := &Openclaw{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
|
||||||
configPath := filepath.Join(configDir, "openclaw.json")
|
|
||||||
cleanup := func() { os.RemoveAll(configDir) }
|
|
||||||
|
|
||||||
t.Run("corrupted JSON returns nil", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
|
|
||||||
if models := c.Models(); models != nil {
|
|
||||||
t.Errorf("expected nil, got %v", models)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("wrong type at models level", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{"models":"string"}`), 0o644)
|
|
||||||
if models := c.Models(); models != nil {
|
|
||||||
t.Errorf("expected nil, got %v", models)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("wrong type at providers level", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{"models":{"providers":"string"}}`), 0o644)
|
|
||||||
if models := c.Models(); models != nil {
|
|
||||||
t.Errorf("expected nil, got %v", models)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("wrong type at ollama level", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":"string"}}}`), 0o644)
|
|
||||||
if models := c.Models(); models != nil {
|
|
||||||
t.Errorf("expected nil, got %v", models)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("model entry missing id", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"name":"test"}]}}}}`), 0o644)
|
|
||||||
if len(c.Models()) != 0 {
|
|
||||||
t.Error("expected empty for missing id")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("model id is not string", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"id":123}]}}}}`), 0o644)
|
|
||||||
if len(c.Models()) != 0 {
|
|
||||||
t.Error("expected empty for non-string id")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenclawEditSchemaFields(t *testing.T) {
|
|
||||||
c := &Openclaw{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
configPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
|
|
||||||
|
|
||||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
models := cfg["models"].(map[string]any)
|
|
||||||
providers := models["providers"].(map[string]any)
|
|
||||||
ollama := providers["ollama"].(map[string]any)
|
|
||||||
modelList := ollama["models"].([]any)
|
|
||||||
entry := modelList[0].(map[string]any)
|
|
||||||
|
|
||||||
// Verify required schema fields
|
|
||||||
if entry["reasoning"] != false {
|
|
||||||
t.Error("reasoning should be false")
|
|
||||||
}
|
|
||||||
if entry["input"] == nil {
|
|
||||||
t.Error("input should be set")
|
|
||||||
}
|
|
||||||
if entry["contextWindow"] == nil {
|
|
||||||
t.Error("contextWindow should be set")
|
|
||||||
}
|
|
||||||
if entry["maxTokens"] == nil {
|
|
||||||
t.Error("maxTokens should be set")
|
|
||||||
}
|
|
||||||
cost := entry["cost"].(map[string]any)
|
|
||||||
if cost["cacheRead"] == nil {
|
|
||||||
t.Error("cost.cacheRead should be set")
|
|
||||||
}
|
|
||||||
if cost["cacheWrite"] == nil {
|
|
||||||
t.Error("cost.cacheWrite should be set")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenclawEditModelNames(t *testing.T) {
|
|
||||||
c := &Openclaw{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
configPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
|
|
||||||
cleanup := func() { os.RemoveAll(filepath.Join(tmpDir, ".openclaw")) }
|
|
||||||
|
|
||||||
t.Run("model with colon tag", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
if err := c.Edit([]string{"llama3.2:70b"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
assertOpenclawModelExists(t, configPath, "llama3.2:70b")
|
|
||||||
assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2:70b")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("model with slash", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
if err := c.Edit([]string{"library/model:tag"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
assertOpenclawModelExists(t, configPath, "library/model:tag")
|
|
||||||
assertOpenclawPrimaryModel(t, configPath, "ollama/library/model:tag")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("model with hyphen", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
if err := c.Edit([]string{"test-model"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
assertOpenclawModelExists(t, configPath, "test-model")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenclawEditAgentsPreservation(t *testing.T) {
|
|
||||||
c := &Openclaw{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
|
||||||
configPath := filepath.Join(configDir, "openclaw.json")
|
|
||||||
cleanup := func() { os.RemoveAll(configDir) }
|
|
||||||
|
|
||||||
t.Run("preserve other agent defaults", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{"agents":{"defaults":{"model":{"primary":"old"},"temperature":0.7}}}`), 0o644)
|
|
||||||
|
|
||||||
c.Edit([]string{"llama3.2"})
|
|
||||||
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
agents := cfg["agents"].(map[string]any)
|
|
||||||
defaults := agents["defaults"].(map[string]any)
|
|
||||||
if defaults["temperature"] != 0.7 {
|
|
||||||
t.Error("temperature setting was lost")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("preserve other agents besides defaults", func(t *testing.T) {
|
|
||||||
cleanup()
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(`{"agents":{"defaults":{},"custom-agent":{"foo":"bar"}}}`), 0o644)
|
|
||||||
|
|
||||||
c.Edit([]string{"llama3.2"})
|
|
||||||
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
agents := cfg["agents"].(map[string]any)
|
|
||||||
if agents["custom-agent"] == nil {
|
|
||||||
t.Error("custom-agent was lost")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
const testOpenclawFixture = `{
|
|
||||||
"theme": "dark",
|
|
||||||
"mcp": {"servers": {"custom": {"enabled": true}}},
|
|
||||||
"models": {
|
|
||||||
"providers": {
|
|
||||||
"anthropic": {"apiKey": "xxx"},
|
|
||||||
"ollama": {
|
|
||||||
"baseUrl": "http://127.0.0.1:11434/v1",
|
|
||||||
"models": [{"id": "old-model", "customField": "preserved"}]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"agents": {
|
|
||||||
"defaults": {"model": {"primary": "old"}, "temperature": 0.7},
|
|
||||||
"custom-agent": {"foo": "bar"}
|
|
||||||
}
|
|
||||||
}`
|
|
||||||
|
|
||||||
func TestOpenclawEdit_RoundTrip(t *testing.T) {
|
|
||||||
c := &Openclaw{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
|
||||||
configPath := filepath.Join(configDir, "openclaw.json")
|
|
||||||
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
|
|
||||||
|
|
||||||
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
|
|
||||||
// Verify top-level preserved
|
|
||||||
if cfg["theme"] != "dark" {
|
|
||||||
t.Error("theme not preserved")
|
|
||||||
}
|
|
||||||
mcp := cfg["mcp"].(map[string]any)
|
|
||||||
servers := mcp["servers"].(map[string]any)
|
|
||||||
if servers["custom"] == nil {
|
|
||||||
t.Error("mcp.servers.custom not preserved")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify other providers preserved
|
|
||||||
models := cfg["models"].(map[string]any)
|
|
||||||
providers := models["providers"].(map[string]any)
|
|
||||||
if providers["anthropic"] == nil {
|
|
||||||
t.Error("anthropic provider not preserved")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify agents preserved
|
|
||||||
agents := cfg["agents"].(map[string]any)
|
|
||||||
if agents["custom-agent"] == nil {
|
|
||||||
t.Error("custom-agent not preserved")
|
|
||||||
}
|
|
||||||
defaults := agents["defaults"].(map[string]any)
|
|
||||||
if defaults["temperature"] != 0.7 {
|
|
||||||
t.Error("temperature not preserved")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenclawEdit_Idempotent(t *testing.T) {
|
|
||||||
c := &Openclaw{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
|
||||||
configPath := filepath.Join(configDir, "openclaw.json")
|
|
||||||
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
|
|
||||||
|
|
||||||
c.Edit([]string{"llama3.2", "mistral"})
|
|
||||||
firstData, _ := os.ReadFile(configPath)
|
|
||||||
|
|
||||||
c.Edit([]string{"llama3.2", "mistral"})
|
|
||||||
secondData, _ := os.ReadFile(configPath)
|
|
||||||
|
|
||||||
if string(firstData) != string(secondData) {
|
|
||||||
t.Error("repeated edits with same models produced different results")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenclawEdit_MultipleConsecutiveEdits(t *testing.T) {
|
|
||||||
c := &Openclaw{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
|
||||||
configPath := filepath.Join(configDir, "openclaw.json")
|
|
||||||
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
|
|
||||||
|
|
||||||
for i := range 10 {
|
|
||||||
models := []string{"model-a", "model-b"}
|
|
||||||
if i%2 == 0 {
|
|
||||||
models = []string{"model-x", "model-y", "model-z"}
|
|
||||||
}
|
|
||||||
if err := c.Edit(models); err != nil {
|
|
||||||
t.Fatalf("edit %d failed: %v", i, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
data, _ := os.ReadFile(configPath)
|
|
||||||
var cfg map[string]any
|
|
||||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
|
||||||
t.Fatalf("file is not valid JSON after multiple edits: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg["theme"] != "dark" {
|
|
||||||
t.Error("theme lost after multiple edits")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenclawEdit_BackupCreated(t *testing.T) {
|
|
||||||
c := &Openclaw{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
|
||||||
configPath := filepath.Join(configDir, "openclaw.json")
|
|
||||||
backupDir := filepath.Join(os.TempDir(), "ollama-backups")
|
|
||||||
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
uniqueMarker := fmt.Sprintf("test-marker-%d", os.Getpid())
|
|
||||||
original := fmt.Sprintf(`{"theme": "%s"}`, uniqueMarker)
|
|
||||||
os.WriteFile(configPath, []byte(original), 0o644)
|
|
||||||
|
|
||||||
if err := c.Edit([]string{"model-a"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
backups, _ := filepath.Glob(filepath.Join(backupDir, "openclaw.json.*"))
|
|
||||||
foundBackup := false
|
|
||||||
for _, backup := range backups {
|
|
||||||
data, _ := os.ReadFile(backup)
|
|
||||||
if string(data) == original {
|
|
||||||
foundBackup = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !foundBackup {
|
|
||||||
t.Error("backup with original content not found")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenclawClawdbotAlias(t *testing.T) {
|
|
||||||
for _, alias := range []string{"clawdbot", "moltbot"} {
|
|
||||||
t.Run(alias+" alias resolves to Openclaw runner", func(t *testing.T) {
|
|
||||||
r, ok := integrations[alias]
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("%s not found in integrations", alias)
|
|
||||||
}
|
|
||||||
if _, ok := r.(*Openclaw); !ok {
|
|
||||||
t.Errorf("%s integration is %T, want *Openclaw", alias, r)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run(alias+" is hidden from selector", func(t *testing.T) {
|
|
||||||
if !integrationAliases[alias] {
|
|
||||||
t.Errorf("%s should be in integrationAliases", alias)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenclawLegacyPaths(t *testing.T) {
|
|
||||||
c := &Openclaw{}
|
|
||||||
|
|
||||||
t.Run("falls back to legacy clawdbot path", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
|
||||||
os.MkdirAll(legacyDir, 0o755)
|
|
||||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{}`), 0o644)
|
|
||||||
|
|
||||||
paths := c.Paths()
|
|
||||||
if len(paths) != 1 {
|
|
||||||
t.Fatalf("expected 1 path, got %d", len(paths))
|
|
||||||
}
|
|
||||||
if paths[0] != filepath.Join(legacyDir, "clawdbot.json") {
|
|
||||||
t.Errorf("expected legacy path, got %s", paths[0])
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("prefers new path over legacy", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
newDir := filepath.Join(tmpDir, ".openclaw")
|
|
||||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
|
||||||
os.MkdirAll(newDir, 0o755)
|
|
||||||
os.MkdirAll(legacyDir, 0o755)
|
|
||||||
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{}`), 0o644)
|
|
||||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{}`), 0o644)
|
|
||||||
|
|
||||||
paths := c.Paths()
|
|
||||||
if len(paths) != 1 {
|
|
||||||
t.Fatalf("expected 1 path, got %d", len(paths))
|
|
||||||
}
|
|
||||||
if paths[0] != filepath.Join(newDir, "openclaw.json") {
|
|
||||||
t.Errorf("expected new path, got %s", paths[0])
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Models reads from legacy path", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
|
||||||
os.MkdirAll(legacyDir, 0o755)
|
|
||||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{
|
|
||||||
"models":{"providers":{"ollama":{"models":[{"id":"llama3.2"}]}}}
|
|
||||||
}`), 0o644)
|
|
||||||
|
|
||||||
models := c.Models()
|
|
||||||
if len(models) != 1 || models[0] != "llama3.2" {
|
|
||||||
t.Errorf("expected [llama3.2], got %v", models)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Models prefers new path over legacy", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
newDir := filepath.Join(tmpDir, ".openclaw")
|
|
||||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
|
||||||
os.MkdirAll(newDir, 0o755)
|
|
||||||
os.MkdirAll(legacyDir, 0o755)
|
|
||||||
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{
|
|
||||||
"models":{"providers":{"ollama":{"models":[{"id":"new-model"}]}}}
|
|
||||||
}`), 0o644)
|
|
||||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{
|
|
||||||
"models":{"providers":{"ollama":{"models":[{"id":"legacy-model"}]}}}
|
|
||||||
}`), 0o644)
|
|
||||||
|
|
||||||
models := c.Models()
|
|
||||||
if len(models) != 1 || models[0] != "new-model" {
|
|
||||||
t.Errorf("expected [new-model], got %v", models)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Edit reads new path over legacy when both exist", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
newDir := filepath.Join(tmpDir, ".openclaw")
|
|
||||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
|
||||||
os.MkdirAll(newDir, 0o755)
|
|
||||||
os.MkdirAll(legacyDir, 0o755)
|
|
||||||
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{"theme":"new"}`), 0o644)
|
|
||||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"theme":"legacy"}`), 0o644)
|
|
||||||
|
|
||||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
data, _ := os.ReadFile(filepath.Join(newDir, "openclaw.json"))
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
if cfg["theme"] != "new" {
|
|
||||||
t.Errorf("expected theme from new config, got %v", cfg["theme"])
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Edit migrates from legacy config", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
|
||||||
os.MkdirAll(legacyDir, 0o755)
|
|
||||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"theme":"dark"}`), 0o644)
|
|
||||||
|
|
||||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should write to new path
|
|
||||||
newPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
|
|
||||||
data, err := os.ReadFile(newPath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("expected new config file to be created")
|
|
||||||
}
|
|
||||||
var cfg map[string]any
|
|
||||||
json.Unmarshal(data, &cfg)
|
|
||||||
if cfg["theme"] != "dark" {
|
|
||||||
t.Error("legacy theme setting was not migrated")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenclawEdit_CreatesDirectoryIfMissing(t *testing.T) {
|
|
||||||
c := &Openclaw{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
|
||||||
|
|
||||||
if _, err := os.Stat(configDir); !os.IsNotExist(err) {
|
|
||||||
t.Fatal("directory should not exist before test")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.Edit([]string{"model-a"}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := os.Stat(configDir); os.IsNotExist(err) {
|
|
||||||
t.Fatal("directory was not created")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenclawOnboarded(t *testing.T) {
|
|
||||||
c := &Openclaw{}
|
|
||||||
|
|
||||||
t.Run("returns false when no config exists", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
if c.onboarded() {
|
|
||||||
t.Error("expected false when no config exists")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns false when config exists but no wizard section", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"theme":"dark"}`), 0o644)
|
|
||||||
|
|
||||||
if c.onboarded() {
|
|
||||||
t.Error("expected false when no wizard section")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns false when wizard section exists but no lastRunAt", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{}}`), 0o644)
|
|
||||||
|
|
||||||
if c.onboarded() {
|
|
||||||
t.Error("expected false when wizard.lastRunAt is missing")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns false when wizard.lastRunAt is empty string", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{"lastRunAt":""}}`), 0o644)
|
|
||||||
|
|
||||||
if c.onboarded() {
|
|
||||||
t.Error("expected false when wizard.lastRunAt is empty")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns true when wizard.lastRunAt is set", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
|
|
||||||
|
|
||||||
if !c.onboarded() {
|
|
||||||
t.Error("expected true when wizard.lastRunAt is set")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("checks legacy clawdbot path", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
|
||||||
os.MkdirAll(legacyDir, 0o755)
|
|
||||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
|
|
||||||
|
|
||||||
if !c.onboarded() {
|
|
||||||
t.Error("expected true when legacy config has wizard.lastRunAt")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("prefers new path over legacy", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
newDir := filepath.Join(tmpDir, ".openclaw")
|
|
||||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
|
||||||
os.MkdirAll(newDir, 0o755)
|
|
||||||
os.MkdirAll(legacyDir, 0o755)
|
|
||||||
// New path has no wizard marker
|
|
||||||
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{}`), 0o644)
|
|
||||||
// Legacy has wizard marker
|
|
||||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
|
|
||||||
|
|
||||||
if c.onboarded() {
|
|
||||||
t.Error("expected false - should prefer new path which has no wizard marker")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("handles corrupted JSON gracefully", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{corrupted`), 0o644)
|
|
||||||
|
|
||||||
if c.onboarded() {
|
|
||||||
t.Error("expected false for corrupted JSON")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("handles wrong type for wizard section", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
|
||||||
os.MkdirAll(configDir, 0o755)
|
|
||||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":"not a map"}`), 0o644)
|
|
||||||
|
|
||||||
if c.onboarded() {
|
|
||||||
t.Error("expected false when wizard is wrong type")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,58 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"golang.org/x/term"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ANSI escape sequences for terminal formatting.
|
|
||||||
const (
|
|
||||||
ansiBold = "\033[1m"
|
|
||||||
ansiReset = "\033[0m"
|
|
||||||
ansiGray = "\033[37m"
|
|
||||||
ansiGreen = "\033[32m"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ErrCancelled is returned when the user cancels a selection.
|
|
||||||
var ErrCancelled = errors.New("cancelled")
|
|
||||||
|
|
||||||
// errCancelled is kept as an alias for backward compatibility within the package.
|
|
||||||
var errCancelled = ErrCancelled
|
|
||||||
|
|
||||||
// DefaultConfirmPrompt provides a TUI-based confirmation prompt.
|
|
||||||
// When set, confirmPrompt delegates to it instead of using raw terminal I/O.
|
|
||||||
var DefaultConfirmPrompt func(prompt string) (bool, error)
|
|
||||||
|
|
||||||
func confirmPrompt(prompt string) (bool, error) {
|
|
||||||
if DefaultConfirmPrompt != nil {
|
|
||||||
return DefaultConfirmPrompt(prompt)
|
|
||||||
}
|
|
||||||
|
|
||||||
fd := int(os.Stdin.Fd())
|
|
||||||
oldState, err := term.MakeRaw(fd)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
defer term.Restore(fd, oldState)
|
|
||||||
|
|
||||||
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
|
|
||||||
|
|
||||||
buf := make([]byte, 1)
|
|
||||||
for {
|
|
||||||
if _, err := os.Stdin.Read(buf); err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch buf[0] {
|
|
||||||
case 'Y', 'y', 13:
|
|
||||||
fmt.Fprintf(os.Stderr, "yes\r\n")
|
|
||||||
return true, nil
|
|
||||||
case 'N', 'n', 27, 3:
|
|
||||||
fmt.Fprintf(os.Stderr, "no\r\n")
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestErrCancelled(t *testing.T) {
|
|
||||||
t.Run("NotNil", func(t *testing.T) {
|
|
||||||
if errCancelled == nil {
|
|
||||||
t.Error("errCancelled should not be nil")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Message", func(t *testing.T) {
|
|
||||||
if errCancelled.Error() != "cancelled" {
|
|
||||||
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/internal/modelref"
|
||||||
"github.com/ollama/ollama/readline"
|
"github.com/ollama/ollama/readline"
|
||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
@@ -540,6 +541,13 @@ func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
|
|||||||
parentModel = ""
|
parentModel = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Preserve explicit cloud intent for sessions started with `:cloud`.
|
||||||
|
// Cloud model metadata can return a source-less parent_model (for example
|
||||||
|
// "qwen3.5"), which would otherwise make `/save` create a local derivative.
|
||||||
|
if modelref.HasExplicitCloudSource(opts.Model) && !modelref.HasExplicitCloudSource(parentModel) {
|
||||||
|
parentModel = ""
|
||||||
|
}
|
||||||
|
|
||||||
req := &api.CreateRequest{
|
req := &api.CreateRequest{
|
||||||
Model: name,
|
Model: name,
|
||||||
From: cmp.Or(parentModel, opts.Model),
|
From: cmp.Or(parentModel, opts.Model),
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
package config
|
// Package fileutil provides small shared helpers for reading JSON files
|
||||||
|
// and writing config files with backup-on-overwrite semantics.
|
||||||
|
package fileutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@@ -9,7 +11,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func readJSONFile(path string) (map[string]any, error) {
|
// ReadJSON reads a JSON object file into a generic map.
|
||||||
|
func ReadJSON(path string) (map[string]any, error) {
|
||||||
data, err := os.ReadFile(path)
|
data, err := os.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -33,12 +36,13 @@ func copyFile(src, dst string) error {
|
|||||||
return os.WriteFile(dst, data, info.Mode().Perm())
|
return os.WriteFile(dst, data, info.Mode().Perm())
|
||||||
}
|
}
|
||||||
|
|
||||||
func backupDir() string {
|
// BackupDir returns the shared backup directory used before overwriting files.
|
||||||
|
func BackupDir() string {
|
||||||
return filepath.Join(os.TempDir(), "ollama-backups")
|
return filepath.Join(os.TempDir(), "ollama-backups")
|
||||||
}
|
}
|
||||||
|
|
||||||
func backupToTmp(srcPath string) (string, error) {
|
func backupToTmp(srcPath string) (string, error) {
|
||||||
dir := backupDir()
|
dir := BackupDir()
|
||||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -50,8 +54,8 @@ func backupToTmp(srcPath string) (string, error) {
|
|||||||
return backupPath, nil
|
return backupPath, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeWithBackup writes data to path via temp file + rename, backing up any existing file first
|
// WriteWithBackup writes data to path via temp file + rename, backing up any existing file first.
|
||||||
func writeWithBackup(path string, data []byte) error {
|
func WriteWithBackup(path string, data []byte) error {
|
||||||
var backupPath string
|
var backupPath string
|
||||||
// backup must be created before any writes to the target file
|
// backup must be created before any writes to the target file
|
||||||
if existingContent, err := os.ReadFile(path); err == nil {
|
if existingContent, err := os.ReadFile(path); err == nil {
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package fileutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -9,6 +9,21 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
tmpRoot, err := os.MkdirTemp("", "fileutil-test-*")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.Setenv("TMPDIR", tmpRoot); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
code := m.Run()
|
||||||
|
_ = os.RemoveAll(tmpRoot)
|
||||||
|
os.Exit(code)
|
||||||
|
}
|
||||||
|
|
||||||
func mustMarshal(t *testing.T, v any) []byte {
|
func mustMarshal(t *testing.T, v any) []byte {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
data, err := json.MarshalIndent(v, "", " ")
|
data, err := json.MarshalIndent(v, "", " ")
|
||||||
@@ -18,14 +33,19 @@ func mustMarshal(t *testing.T, v any) []byte {
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isolatedTempDir(t *testing.T) string {
|
||||||
|
t.Helper()
|
||||||
|
return t.TempDir()
|
||||||
|
}
|
||||||
|
|
||||||
func TestWriteWithBackup(t *testing.T) {
|
func TestWriteWithBackup(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
|
|
||||||
t.Run("creates file", func(t *testing.T) {
|
t.Run("creates file", func(t *testing.T) {
|
||||||
path := filepath.Join(tmpDir, "new.json")
|
path := filepath.Join(tmpDir, "new.json")
|
||||||
data := mustMarshal(t, map[string]string{"key": "value"})
|
data := mustMarshal(t, map[string]string{"key": "value"})
|
||||||
|
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -43,17 +63,17 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("creates backup in /tmp/ollama-backups", func(t *testing.T) {
|
t.Run("creates backup in the temp backup directory", func(t *testing.T) {
|
||||||
path := filepath.Join(tmpDir, "backup.json")
|
path := filepath.Join(tmpDir, "backup.json")
|
||||||
|
|
||||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||||
|
|
||||||
data := mustMarshal(t, map[string]bool{"updated": true})
|
data := mustMarshal(t, map[string]bool{"updated": true})
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
entries, err := os.ReadDir(backupDir())
|
entries, err := os.ReadDir(BackupDir())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("backup directory not created")
|
t.Fatal("backup directory not created")
|
||||||
}
|
}
|
||||||
@@ -63,7 +83,7 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
if filepath.Ext(entry.Name()) != ".json" {
|
if filepath.Ext(entry.Name()) != ".json" {
|
||||||
name := entry.Name()
|
name := entry.Name()
|
||||||
if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." {
|
if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." {
|
||||||
backupPath := filepath.Join(backupDir(), name)
|
backupPath := filepath.Join(BackupDir(), name)
|
||||||
backup, err := os.ReadFile(backupPath)
|
backup, err := os.ReadFile(backupPath)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
var backupData map[string]bool
|
var backupData map[string]bool
|
||||||
@@ -79,7 +99,7 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !foundBackup {
|
if !foundBackup {
|
||||||
t.Error("backup file not created in /tmp/ollama-backups")
|
t.Error("backup file not created in backup directory")
|
||||||
}
|
}
|
||||||
|
|
||||||
current, _ := os.ReadFile(path)
|
current, _ := os.ReadFile(path)
|
||||||
@@ -94,11 +114,11 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
path := filepath.Join(tmpDir, "nobak.json")
|
path := filepath.Join(tmpDir, "nobak.json")
|
||||||
|
|
||||||
data := mustMarshal(t, map[string]string{"new": "file"})
|
data := mustMarshal(t, map[string]string{"new": "file"})
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
entries, _ := os.ReadDir(backupDir())
|
entries, _ := os.ReadDir(BackupDir())
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." {
|
if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." {
|
||||||
t.Error("backup should not exist for new file")
|
t.Error("backup should not exist for new file")
|
||||||
@@ -111,11 +131,11 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
|
|
||||||
data := mustMarshal(t, map[string]string{"key": "value"})
|
data := mustMarshal(t, map[string]string{"key": "value"})
|
||||||
|
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
entries1, _ := os.ReadDir(backupDir())
|
entries1, _ := os.ReadDir(BackupDir())
|
||||||
countBefore := 0
|
countBefore := 0
|
||||||
for _, e := range entries1 {
|
for _, e := range entries1 {
|
||||||
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
||||||
@@ -123,11 +143,11 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
entries2, _ := os.ReadDir(backupDir())
|
entries2, _ := os.ReadDir(BackupDir())
|
||||||
countAfter := 0
|
countAfter := 0
|
||||||
for _, e := range entries2 {
|
for _, e := range entries2 {
|
||||||
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
||||||
@@ -145,11 +165,11 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
|
|
||||||
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
|
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
|
||||||
data := mustMarshal(t, map[string]int{"v": 2})
|
data := mustMarshal(t, map[string]int{"v": 2})
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
entries, _ := os.ReadDir(backupDir())
|
entries, _ := os.ReadDir(BackupDir())
|
||||||
var found bool
|
var found bool
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
name := entry.Name()
|
name := entry.Name()
|
||||||
@@ -161,7 +181,7 @@ func TestWriteWithBackup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
found = true
|
found = true
|
||||||
os.Remove(filepath.Join(backupDir(), name))
|
os.Remove(filepath.Join(BackupDir(), name))
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -180,7 +200,7 @@ func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
|
|||||||
t.Skip("permission tests unreliable on Windows")
|
t.Skip("permission tests unreliable on Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
path := filepath.Join(tmpDir, "config.json")
|
path := filepath.Join(tmpDir, "config.json")
|
||||||
|
|
||||||
// Create original file
|
// Create original file
|
||||||
@@ -188,13 +208,13 @@ func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
|
|||||||
os.WriteFile(path, originalContent, 0o644)
|
os.WriteFile(path, originalContent, 0o644)
|
||||||
|
|
||||||
// Make backup directory read-only to force backup failure
|
// Make backup directory read-only to force backup failure
|
||||||
backupDir := backupDir()
|
backupDir := BackupDir()
|
||||||
os.MkdirAll(backupDir, 0o755)
|
os.MkdirAll(backupDir, 0o755)
|
||||||
os.Chmod(backupDir, 0o444) // Read-only
|
os.Chmod(backupDir, 0o444) // Read-only
|
||||||
defer os.Chmod(backupDir, 0o755)
|
defer os.Chmod(backupDir, 0o755)
|
||||||
|
|
||||||
newContent := []byte(`{"updated": true}`)
|
newContent := []byte(`{"updated": true}`)
|
||||||
err := writeWithBackup(path, newContent)
|
err := WriteWithBackup(path, newContent)
|
||||||
|
|
||||||
// Should fail because backup couldn't be created
|
// Should fail because backup couldn't be created
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -215,7 +235,7 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
|
|||||||
t.Skip("permission tests unreliable on Windows")
|
t.Skip("permission tests unreliable on Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
|
|
||||||
// Create a read-only directory
|
// Create a read-only directory
|
||||||
readOnlyDir := filepath.Join(tmpDir, "readonly")
|
readOnlyDir := filepath.Join(tmpDir, "readonly")
|
||||||
@@ -224,7 +244,7 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
|
|||||||
defer os.Chmod(readOnlyDir, 0o755)
|
defer os.Chmod(readOnlyDir, 0o755)
|
||||||
|
|
||||||
path := filepath.Join(readOnlyDir, "config.json")
|
path := filepath.Join(readOnlyDir, "config.json")
|
||||||
err := writeWithBackup(path, []byte(`{"test": true}`))
|
err := WriteWithBackup(path, []byte(`{"test": true}`))
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected permission error, got nil")
|
t.Error("expected permission error, got nil")
|
||||||
@@ -234,10 +254,10 @@ func TestWriteWithBackup_PermissionDenied(t *testing.T) {
|
|||||||
// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist.
|
// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist.
|
||||||
// writeWithBackup doesn't create directories - caller is responsible.
|
// writeWithBackup doesn't create directories - caller is responsible.
|
||||||
func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) {
|
func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json")
|
path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json")
|
||||||
|
|
||||||
err := writeWithBackup(path, []byte(`{"test": true}`))
|
err := WriteWithBackup(path, []byte(`{"test": true}`))
|
||||||
|
|
||||||
// Should fail because directory doesn't exist
|
// Should fail because directory doesn't exist
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -252,7 +272,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
|
|||||||
t.Skip("symlink tests may require admin on Windows")
|
t.Skip("symlink tests may require admin on Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
realFile := filepath.Join(tmpDir, "real.json")
|
realFile := filepath.Join(tmpDir, "real.json")
|
||||||
symlink := filepath.Join(tmpDir, "link.json")
|
symlink := filepath.Join(tmpDir, "link.json")
|
||||||
|
|
||||||
@@ -261,7 +281,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
|
|||||||
os.Symlink(realFile, symlink)
|
os.Symlink(realFile, symlink)
|
||||||
|
|
||||||
// Write through symlink
|
// Write through symlink
|
||||||
err := writeWithBackup(symlink, []byte(`{"v": 2}`))
|
err := WriteWithBackup(symlink, []byte(`{"v": 2}`))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("writeWithBackup through symlink failed: %v", err)
|
t.Fatalf("writeWithBackup through symlink failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -276,7 +296,7 @@ func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
|
|||||||
// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters.
|
// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters.
|
||||||
// User may have config files with unusual names.
|
// User may have config files with unusual names.
|
||||||
func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) {
|
func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
|
|
||||||
// File with spaces and special chars
|
// File with spaces and special chars
|
||||||
path := filepath.Join(tmpDir, "my config (backup).json")
|
path := filepath.Join(tmpDir, "my config (backup).json")
|
||||||
@@ -305,7 +325,7 @@ func TestCopyFile_PreservesPermissions(t *testing.T) {
|
|||||||
t.Skip("permission preservation tests unreliable on Windows")
|
t.Skip("permission preservation tests unreliable on Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
src := filepath.Join(tmpDir, "src.json")
|
src := filepath.Join(tmpDir, "src.json")
|
||||||
dst := filepath.Join(tmpDir, "dst.json")
|
dst := filepath.Join(tmpDir, "dst.json")
|
||||||
|
|
||||||
@@ -327,7 +347,7 @@ func TestCopyFile_PreservesPermissions(t *testing.T) {
|
|||||||
|
|
||||||
// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist.
|
// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist.
|
||||||
func TestCopyFile_SourceNotFound(t *testing.T) {
|
func TestCopyFile_SourceNotFound(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
src := filepath.Join(tmpDir, "nonexistent.json")
|
src := filepath.Join(tmpDir, "nonexistent.json")
|
||||||
dst := filepath.Join(tmpDir, "dst.json")
|
dst := filepath.Join(tmpDir, "dst.json")
|
||||||
|
|
||||||
@@ -339,11 +359,11 @@ func TestCopyFile_SourceNotFound(t *testing.T) {
|
|||||||
|
|
||||||
// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory.
|
// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory.
|
||||||
func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
|
func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
dirPath := filepath.Join(tmpDir, "actualdir")
|
dirPath := filepath.Join(tmpDir, "actualdir")
|
||||||
os.MkdirAll(dirPath, 0o755)
|
os.MkdirAll(dirPath, 0o755)
|
||||||
|
|
||||||
err := writeWithBackup(dirPath, []byte(`{"test": true}`))
|
err := WriteWithBackup(dirPath, []byte(`{"test": true}`))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error when target is a directory, got nil")
|
t.Error("expected error when target is a directory, got nil")
|
||||||
}
|
}
|
||||||
@@ -351,10 +371,10 @@ func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
|
|||||||
|
|
||||||
// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly.
|
// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly.
|
||||||
func TestWriteWithBackup_EmptyData(t *testing.T) {
|
func TestWriteWithBackup_EmptyData(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
path := filepath.Join(tmpDir, "empty.json")
|
path := filepath.Join(tmpDir, "empty.json")
|
||||||
|
|
||||||
err := writeWithBackup(path, []byte{})
|
err := WriteWithBackup(path, []byte{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("writeWithBackup with empty data failed: %v", err)
|
t.Fatalf("writeWithBackup with empty data failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -375,7 +395,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
|
|||||||
t.Skip("permission tests unreliable on Windows")
|
t.Skip("permission tests unreliable on Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
path := filepath.Join(tmpDir, "unreadable.json")
|
path := filepath.Join(tmpDir, "unreadable.json")
|
||||||
|
|
||||||
// Create file and make it unreadable
|
// Create file and make it unreadable
|
||||||
@@ -384,7 +404,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
|
|||||||
defer os.Chmod(path, 0o644)
|
defer os.Chmod(path, 0o644)
|
||||||
|
|
||||||
// Should fail because we can't read the file to compare/backup
|
// Should fail because we can't read the file to compare/backup
|
||||||
err := writeWithBackup(path, []byte(`{"updated": true}`))
|
err := WriteWithBackup(path, []byte(`{"updated": true}`))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error when file is unreadable, got nil")
|
t.Error("expected error when file is unreadable, got nil")
|
||||||
}
|
}
|
||||||
@@ -393,7 +413,7 @@ func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
|
|||||||
// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes
|
// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes
|
||||||
// within the same second (timestamp collision scenario).
|
// within the same second (timestamp collision scenario).
|
||||||
func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
path := filepath.Join(tmpDir, "rapid.json")
|
path := filepath.Join(tmpDir, "rapid.json")
|
||||||
|
|
||||||
// Create initial file
|
// Create initial file
|
||||||
@@ -402,7 +422,7 @@ func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
|||||||
// Rapid successive writes
|
// Rapid successive writes
|
||||||
for i := 1; i <= 3; i++ {
|
for i := 1; i <= 3; i++ {
|
||||||
data := []byte(fmt.Sprintf(`{"v": %d}`, i))
|
data := []byte(fmt.Sprintf(`{"v": %d}`, i))
|
||||||
if err := writeWithBackup(path, data); err != nil {
|
if err := WriteWithBackup(path, data); err != nil {
|
||||||
t.Fatalf("write %d failed: %v", i, err)
|
t.Fatalf("write %d failed: %v", i, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -414,7 +434,7 @@ func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Verify at least one backup exists
|
// Verify at least one backup exists
|
||||||
entries, _ := os.ReadDir(backupDir())
|
entries, _ := os.ReadDir(BackupDir())
|
||||||
var backupCount int
|
var backupCount int
|
||||||
for _, e := range entries {
|
for _, e := range entries {
|
||||||
if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." {
|
if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." {
|
||||||
@@ -432,8 +452,9 @@ func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
|
|||||||
t.Skip("test modifies system temp directory")
|
t.Skip("test modifies system temp directory")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tmpDir := isolatedTempDir(t)
|
||||||
// Create a file at the backup directory path
|
// Create a file at the backup directory path
|
||||||
backupPath := backupDir()
|
backupPath := BackupDir()
|
||||||
// Clean up any existing directory first
|
// Clean up any existing directory first
|
||||||
os.RemoveAll(backupPath)
|
os.RemoveAll(backupPath)
|
||||||
// Create a file instead of directory
|
// Create a file instead of directory
|
||||||
@@ -443,11 +464,10 @@ func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
|
|||||||
os.MkdirAll(backupPath, 0o755)
|
os.MkdirAll(backupPath, 0o755)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
path := filepath.Join(tmpDir, "test.json")
|
path := filepath.Join(tmpDir, "test.json")
|
||||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||||
|
|
||||||
err := writeWithBackup(path, []byte(`{"updated": true}`))
|
err := WriteWithBackup(path, []byte(`{"updated": true}`))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error when backup dir is a file, got nil")
|
t.Error("expected error when backup dir is a file, got nil")
|
||||||
}
|
}
|
||||||
@@ -459,7 +479,7 @@ func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
|
|||||||
t.Skip("permission tests unreliable on Windows")
|
t.Skip("permission tests unreliable on Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := isolatedTempDir(t)
|
||||||
|
|
||||||
// Count existing temp files
|
// Count existing temp files
|
||||||
countTempFiles := func() int {
|
countTempFiles := func() int {
|
||||||
@@ -493,7 +513,7 @@ func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
|
|||||||
badPath := filepath.Join(tmpDir, "isdir")
|
badPath := filepath.Join(tmpDir, "isdir")
|
||||||
os.MkdirAll(badPath, 0o755)
|
os.MkdirAll(badPath, 0o755)
|
||||||
|
|
||||||
_ = writeWithBackup(badPath, []byte(`{"test": true}`))
|
_ = WriteWithBackup(badPath, []byte(`{"test": true}`))
|
||||||
|
|
||||||
after := countTempFiles()
|
after := countTempFiles()
|
||||||
if after > before {
|
if after > before {
|
||||||
86
cmd/launch/claude.go
Normal file
86
cmd/launch/claude.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Claude implements Runner for Claude Code integration.
|
||||||
|
type Claude struct{}
|
||||||
|
|
||||||
|
func (c *Claude) String() string { return "Claude Code" }
|
||||||
|
|
||||||
|
func (c *Claude) args(model string, extra []string) []string {
|
||||||
|
var args []string
|
||||||
|
if model != "" {
|
||||||
|
args = append(args, "--model", model)
|
||||||
|
}
|
||||||
|
args = append(args, extra...)
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Claude) findPath() (string, error) {
|
||||||
|
if p, err := exec.LookPath("claude"); err == nil {
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
name := "claude"
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
name = "claude.exe"
|
||||||
|
}
|
||||||
|
fallback := filepath.Join(home, ".claude", "local", name)
|
||||||
|
if _, err := os.Stat(fallback); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return fallback, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Claude) Run(model string, args []string) error {
|
||||||
|
claudePath, err := c.findPath()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command(claudePath, c.args(model, args)...)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
|
||||||
|
env := append(os.Environ(),
|
||||||
|
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
|
||||||
|
"ANTHROPIC_API_KEY=",
|
||||||
|
"ANTHROPIC_AUTH_TOKEN=ollama",
|
||||||
|
)
|
||||||
|
|
||||||
|
env = append(env, c.modelEnvVars(model)...)
|
||||||
|
|
||||||
|
cmd.Env = env
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
|
||||||
|
func (c *Claude) modelEnvVars(model string) []string {
|
||||||
|
env := []string{
|
||||||
|
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + model,
|
||||||
|
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + model,
|
||||||
|
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + model,
|
||||||
|
"CLAUDE_CODE_SUBAGENT_MODEL=" + model,
|
||||||
|
}
|
||||||
|
|
||||||
|
if isCloudModelName(model) {
|
||||||
|
if l, ok := lookupCloudModelLimit(model); ok {
|
||||||
|
env = append(env, "CLAUDE_CODE_AUTO_COMPACT_WINDOW="+strconv.Itoa(l.Context))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return env
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
@@ -117,10 +117,7 @@ func TestClaudeModelEnvVars(t *testing.T) {
|
|||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("falls back to model param when no aliases saved", func(t *testing.T) {
|
t.Run("maps all Claude model env vars to the provided model", func(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
got := envMap(c.modelEnvVars("llama3.2"))
|
got := envMap(c.modelEnvVars("llama3.2"))
|
||||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2" {
|
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2" {
|
||||||
t.Errorf("OPUS = %q, want llama3.2", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
t.Errorf("OPUS = %q, want llama3.2", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||||
@@ -134,65 +131,41 @@ func TestClaudeModelEnvVars(t *testing.T) {
|
|||||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2" {
|
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2" {
|
||||||
t.Errorf("SUBAGENT = %q, want llama3.2", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
t.Errorf("SUBAGENT = %q, want llama3.2", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||||
}
|
}
|
||||||
})
|
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
|
||||||
|
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty for local models", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
|
||||||
t.Run("uses primary alias for opus sonnet and subagent", func(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
SaveIntegration("claude", []string{"qwen3:8b"})
|
|
||||||
saveAliases("claude", map[string]string{"primary": "qwen3:8b"})
|
|
||||||
|
|
||||||
got := envMap(c.modelEnvVars("qwen3:8b"))
|
|
||||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "qwen3:8b" {
|
|
||||||
t.Errorf("OPUS = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
|
||||||
}
|
|
||||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "qwen3:8b" {
|
|
||||||
t.Errorf("SONNET = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
|
||||||
}
|
|
||||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "qwen3:8b" {
|
|
||||||
t.Errorf("HAIKU = %q, want qwen3:8b (no fast alias)", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
|
||||||
}
|
|
||||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "qwen3:8b" {
|
|
||||||
t.Errorf("SUBAGENT = %q, want qwen3:8b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("uses fast alias for haiku", func(t *testing.T) {
|
t.Run("supports empty model", func(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
got := envMap(c.modelEnvVars(""))
|
||||||
setTestHome(t, tmpDir)
|
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "" {
|
||||||
|
t.Errorf("OPUS = %q, want empty", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||||
SaveIntegration("claude", []string{"llama3.2:70b"})
|
|
||||||
saveAliases("claude", map[string]string{
|
|
||||||
"primary": "llama3.2:70b",
|
|
||||||
"fast": "llama3.2:8b",
|
|
||||||
})
|
|
||||||
|
|
||||||
got := envMap(c.modelEnvVars("llama3.2:70b"))
|
|
||||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2:70b" {
|
|
||||||
t.Errorf("OPUS = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
|
||||||
}
|
}
|
||||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2:70b" {
|
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "" {
|
||||||
t.Errorf("SONNET = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
t.Errorf("SONNET = %q, want empty", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||||
}
|
}
|
||||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2:8b" {
|
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "" {
|
||||||
t.Errorf("HAIKU = %q, want llama3.2:8b", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
t.Errorf("HAIKU = %q, want empty", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||||
}
|
}
|
||||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2:70b" {
|
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "" {
|
||||||
t.Errorf("SUBAGENT = %q, want llama3.2:70b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
t.Errorf("SUBAGENT = %q, want empty", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||||
|
}
|
||||||
|
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
|
||||||
|
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("alias primary overrides model param", func(t *testing.T) {
|
t.Run("sets auto compact window for known cloud models", func(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
got := envMap(c.modelEnvVars("glm-5:cloud"))
|
||||||
setTestHome(t, tmpDir)
|
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "202752" {
|
||||||
|
t.Errorf("AUTO_COMPACT_WINDOW = %q, want 202752", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
SaveIntegration("claude", []string{"saved-model"})
|
t.Run("does not set auto compact window for unknown cloud models", func(t *testing.T) {
|
||||||
saveAliases("claude", map[string]string{"primary": "saved-model"})
|
got := envMap(c.modelEnvVars("unknown-model:cloud"))
|
||||||
|
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
|
||||||
got := envMap(c.modelEnvVars("different-model"))
|
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
|
||||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "saved-model" {
|
|
||||||
t.Errorf("OPUS = %q, want saved-model", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -1,14 +1,13 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -22,24 +21,6 @@ func (c *Cline) Run(model string, args []string) error {
|
|||||||
return fmt.Errorf("cline is not installed, install with: npm install -g cline")
|
return fmt.Errorf("cline is not installed, install with: npm install -g cline")
|
||||||
}
|
}
|
||||||
|
|
||||||
models := []string{model}
|
|
||||||
if config, err := loadIntegration("cline"); err == nil && len(config.Models) > 0 {
|
|
||||||
models = config.Models
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
models, err = resolveEditorModels("cline", models, func() ([]string, error) {
|
|
||||||
return selectModels(context.Background(), "cline", "")
|
|
||||||
})
|
|
||||||
if errors.Is(err, errCancelled) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := c.Edit(models); err != nil {
|
|
||||||
return fmt.Errorf("setup failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command("cline", args...)
|
cmd := exec.Command("cline", args...)
|
||||||
cmd.Stdin = os.Stdin
|
cmd.Stdin = os.Stdin
|
||||||
cmd.Stdout = os.Stdout
|
cmd.Stdout = os.Stdout
|
||||||
@@ -97,7 +78,7 @@ func (c *Cline) Edit(models []string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return writeWithBackup(configPath, data)
|
return fileutil.WriteWithBackup(configPath, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Cline) Models() []string {
|
func (c *Cline) Models() []string {
|
||||||
@@ -106,7 +87,7 @@ func (c *Cline) Models() []string {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := readJSONFile(filepath.Join(home, ".cline", "data", "globalState.json"))
|
config, err := fileutil.ReadJSON(filepath.Join(home, ".cline", "data", "globalState.json"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"os/exec"
|
"os/exec"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
"golang.org/x/mod/semver"
|
"golang.org/x/mod/semver"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -32,6 +33,10 @@ func (c *Codex) Run(model string, args []string) error {
|
|||||||
cmd.Stdin = os.Stdin
|
cmd.Stdin = os.Stdin
|
||||||
cmd.Stdout = os.Stdout
|
cmd.Stdout = os.Stdout
|
||||||
cmd.Stderr = os.Stderr
|
cmd.Stderr = os.Stderr
|
||||||
|
cmd.Env = append(os.Environ(),
|
||||||
|
"OPENAI_BASE_URL="+envconfig.Host().String()+"/v1/",
|
||||||
|
"OPENAI_API_KEY=ollama",
|
||||||
|
)
|
||||||
return cmd.Run()
|
return cmd.Run()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"slices"
|
"slices"
|
||||||
@@ -16,7 +16,7 @@ func TestCodexArgs(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{"with model", "llama3.2", nil, []string{"--oss", "-m", "llama3.2"}},
|
{"with model", "llama3.2", nil, []string{"--oss", "-m", "llama3.2"}},
|
||||||
{"empty model", "", nil, []string{"--oss"}},
|
{"empty model", "", nil, []string{"--oss"}},
|
||||||
{"with model and profile", "qwen3-coder", []string{"-p", "myprofile"}, []string{"--oss", "-m", "qwen3-coder", "-p", "myprofile"}},
|
{"with model and profile", "qwen3.5", []string{"-p", "myprofile"}, []string{"--oss", "-m", "qwen3.5", "-p", "myprofile"}},
|
||||||
{"with sandbox flag", "llama3.2", []string{"--sandbox", "workspace-write"}, []string{"--oss", "-m", "llama3.2", "--sandbox", "workspace-write"}},
|
{"with sandbox flag", "llama3.2", []string{"--sandbox", "workspace-write"}, []string{"--oss", "-m", "llama3.2", "--sandbox", "workspace-write"}},
|
||||||
}
|
}
|
||||||
|
|
||||||
598
cmd/launch/command_test.go
Normal file
598
cmd/launch/command_test.go
Normal file
@@ -0,0 +1,598 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/ollama/ollama/cmd/config"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
func captureStderr(t *testing.T, fn func()) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
oldStderr := os.Stderr
|
||||||
|
r, w, err := os.Pipe()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create stderr pipe: %v", err)
|
||||||
|
}
|
||||||
|
os.Stderr = w
|
||||||
|
defer func() {
|
||||||
|
os.Stderr = oldStderr
|
||||||
|
}()
|
||||||
|
|
||||||
|
done := make(chan string, 1)
|
||||||
|
go func() {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
_, _ = io.Copy(&buf, r)
|
||||||
|
done <- buf.String()
|
||||||
|
}()
|
||||||
|
|
||||||
|
fn()
|
||||||
|
|
||||||
|
_ = w.Close()
|
||||||
|
return <-done
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmd(t *testing.T) {
|
||||||
|
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
mockTUI := func(cmd *cobra.Command) {}
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
|
||||||
|
t.Run("command structure", func(t *testing.T) {
|
||||||
|
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
|
||||||
|
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]")
|
||||||
|
}
|
||||||
|
if cmd.Short == "" {
|
||||||
|
t.Error("Short description should not be empty")
|
||||||
|
}
|
||||||
|
if cmd.Long == "" {
|
||||||
|
t.Error("Long description should not be empty")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("flags exist", func(t *testing.T) {
|
||||||
|
if cmd.Flags().Lookup("model") == nil {
|
||||||
|
t.Error("--model flag should exist")
|
||||||
|
}
|
||||||
|
if cmd.Flags().Lookup("config") == nil {
|
||||||
|
t.Error("--config flag should exist")
|
||||||
|
}
|
||||||
|
if cmd.Flags().Lookup("yes") == nil {
|
||||||
|
t.Error("--yes flag should exist")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("PreRunE is set", func(t *testing.T) {
|
||||||
|
if cmd.PreRunE == nil {
|
||||||
|
t.Error("PreRunE should be set to checkServerHeartbeat")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdTUICallback(t *testing.T) {
|
||||||
|
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("no args calls TUI", func(t *testing.T) {
|
||||||
|
tuiCalled := false
|
||||||
|
mockTUI := func(cmd *cobra.Command) {
|
||||||
|
tuiCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
cmd.SetArgs([]string{})
|
||||||
|
_ = cmd.Execute()
|
||||||
|
|
||||||
|
if !tuiCalled {
|
||||||
|
t.Error("TUI callback should be called when no args provided")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("integration arg bypasses TUI", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.NotFoundHandler())
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
tuiCalled := false
|
||||||
|
mockTUI := func(cmd *cobra.Command) {
|
||||||
|
tuiCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
cmd.SetArgs([]string{"claude"})
|
||||||
|
_ = cmd.Execute()
|
||||||
|
|
||||||
|
if tuiCalled {
|
||||||
|
t.Error("TUI callback should NOT be called when integration arg provided")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("--model flag without integration returns error", func(t *testing.T) {
|
||||||
|
tuiCalled := false
|
||||||
|
mockTUI := func(cmd *cobra.Command) {
|
||||||
|
tuiCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
cmd.SetArgs([]string{"--model", "test-model"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected --model without an integration to fail")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "require an integration name") {
|
||||||
|
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if tuiCalled {
|
||||||
|
t.Error("TUI callback should NOT be called when --model is provided without an integration")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("--config flag without integration returns error", func(t *testing.T) {
|
||||||
|
tuiCalled := false
|
||||||
|
mockTUI := func(cmd *cobra.Command) {
|
||||||
|
tuiCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
cmd.SetArgs([]string{"--config"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected --config without an integration to fail")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "require an integration name") {
|
||||||
|
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if tuiCalled {
|
||||||
|
t.Error("TUI callback should NOT be called when --config is provided without an integration")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("--yes flag without integration returns error", func(t *testing.T) {
|
||||||
|
tuiCalled := false
|
||||||
|
mockTUI := func(cmd *cobra.Command) {
|
||||||
|
tuiCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
cmd.SetArgs([]string{"--yes"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected --yes without an integration to fail")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "require an integration name") {
|
||||||
|
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if tuiCalled {
|
||||||
|
t.Error("TUI callback should NOT be called when --yes is provided without an integration")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("extra args without integration return error", func(t *testing.T) {
|
||||||
|
tuiCalled := false
|
||||||
|
mockTUI := func(cmd *cobra.Command) {
|
||||||
|
tuiCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||||
|
cmd.SetArgs([]string{"--model", "test-model", "--", "--sandbox", "workspace-write"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected flags and extra args without an integration to fail")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "require an integration name") {
|
||||||
|
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if tuiCalled {
|
||||||
|
t.Error("TUI callback should NOT be called when flags or extra args are provided without an integration")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdNilHeartbeat(t *testing.T) {
|
||||||
|
cmd := LaunchCmd(nil, nil)
|
||||||
|
if cmd == nil {
|
||||||
|
t.Fatal("LaunchCmd returned nil")
|
||||||
|
}
|
||||||
|
if cmd.PreRunE != nil {
|
||||||
|
t.Log("Note: PreRunE is set even when nil is passed (acceptable)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdModelFlagFiltersDisabledCloudFromSavedConfig(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
if err := config.SaveIntegration("stubeditor", []string{"glm-5:cloud"}); err != nil {
|
||||||
|
t.Fatalf("failed to seed saved config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/status":
|
||||||
|
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprintf(w, `{"model":"llama3.2"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherEditorRunner{}
|
||||||
|
restore := OverrideIntegration("stubeditor", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
|
||||||
|
if err := cmd.Execute(); err != nil {
|
||||||
|
t.Fatalf("launch command failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
saved, err := config.LoadIntegration("stubeditor")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to reload integration config: %v", err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
|
||||||
|
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
|
||||||
|
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "llama3.2" {
|
||||||
|
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdModelFlagClearsDisabledCloudOverride(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/status":
|
||||||
|
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||||
|
case "/api/tags":
|
||||||
|
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherSingleRunner{}
|
||||||
|
restore := OverrideIntegration("stubapp", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
oldSelector := DefaultSingleSelector
|
||||||
|
defer func() { DefaultSingleSelector = oldSelector }()
|
||||||
|
|
||||||
|
var selectorCalls int
|
||||||
|
var gotCurrent string
|
||||||
|
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||||
|
selectorCalls++
|
||||||
|
gotCurrent = current
|
||||||
|
return "llama3.2", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubapp", "--model", "glm-5:cloud"})
|
||||||
|
stderr := captureStderr(t, func() {
|
||||||
|
if err := cmd.Execute(); err != nil {
|
||||||
|
t.Fatalf("launch command failed: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if selectorCalls != 1 {
|
||||||
|
t.Fatalf("expected disabled cloud override to fall back to selector, got %d calls", selectorCalls)
|
||||||
|
}
|
||||||
|
if gotCurrent != "" {
|
||||||
|
t.Fatalf("expected disabled override to be cleared before selection, got current %q", gotCurrent)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "llama3.2" {
|
||||||
|
t.Fatalf("expected launch to run with replacement local model, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
if !strings.Contains(stderr, "Warning: ignoring --model glm-5:cloud because cloud is disabled") {
|
||||||
|
t.Fatalf("expected disabled-cloud warning, got stderr: %q", stderr)
|
||||||
|
}
|
||||||
|
|
||||||
|
saved, err := config.LoadIntegration("stubapp")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to reload integration config: %v", err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
|
||||||
|
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdYes_AutoConfirmsLaunchPromptPath(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
withInteractiveSession(t, false)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||||
|
case "/api/status":
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprint(w, `{"error":"not found"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherEditorRunner{paths: []string{"/tmp/stubeditor.json"}}
|
||||||
|
restore := OverrideIntegration("stubeditor", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||||
|
t.Fatalf("unexpected prompt with --yes: %q", prompt)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2", "--yes"})
|
||||||
|
if err := cmd.Execute(); err != nil {
|
||||||
|
t.Fatalf("launch command with --yes failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
|
||||||
|
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "llama3.2" {
|
||||||
|
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdHeadlessWithYes_AutoPullsMissingLocalModel(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
withInteractiveSession(t, false)
|
||||||
|
|
||||||
|
var pullCalled bool
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprint(w, `{"error":"model not found"}`)
|
||||||
|
case "/api/pull":
|
||||||
|
pullCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprint(w, `{"status":"success"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherSingleRunner{}
|
||||||
|
restore := OverrideIntegration("stubapp", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||||
|
t.Fatalf("unexpected prompt with --yes in headless autopull path: %q", prompt)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubapp", "--model", "missing-model", "--yes"})
|
||||||
|
if err := cmd.Execute(); err != nil {
|
||||||
|
t.Fatalf("launch command with --yes failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !pullCalled {
|
||||||
|
t.Fatal("expected missing local model to be auto-pulled with --yes in headless mode")
|
||||||
|
}
|
||||||
|
if stub.ranModel != "missing-model" {
|
||||||
|
t.Fatalf("expected launch to run with pulled model, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdHeadlessWithoutYes_ReturnsActionableConfirmError(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
withInteractiveSession(t, false)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||||
|
case "/api/status":
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprint(w, `{"error":"not found"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherEditorRunner{paths: []string{"/tmp/stubeditor.json"}}
|
||||||
|
restore := OverrideIntegration("stubeditor", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||||
|
t.Fatalf("unexpected prompt in headless non-yes mode: %q", prompt)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected launch command to fail without --yes in headless mode")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "re-run with --yes") {
|
||||||
|
t.Fatalf("expected actionable --yes guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if len(stub.edited) != 0 {
|
||||||
|
t.Fatalf("expected no editor writes when confirmation is blocked, got %v", stub.edited)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "" {
|
||||||
|
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdIntegrationArgPromptsForModelWithSavedSelection(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
if err := config.SaveIntegration("stubapp", []string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatalf("failed to seed saved config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/tags":
|
||||||
|
fmt.Fprint(w, `{"models":[{"name":"llama3.2"},{"name":"qwen3:8b"}]}`)
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model":"qwen3:8b"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherSingleRunner{}
|
||||||
|
restore := OverrideIntegration("stubapp", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
oldSelector := DefaultSingleSelector
|
||||||
|
defer func() { DefaultSingleSelector = oldSelector }()
|
||||||
|
|
||||||
|
var gotCurrent string
|
||||||
|
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||||
|
gotCurrent = current
|
||||||
|
return "qwen3:8b", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubapp"})
|
||||||
|
if err := cmd.Execute(); err != nil {
|
||||||
|
t.Fatalf("launch command failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotCurrent != "llama3.2" {
|
||||||
|
t.Fatalf("expected selector current model to be saved model llama3.2, got %q", gotCurrent)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "qwen3:8b" {
|
||||||
|
t.Fatalf("expected launch to run selected model qwen3:8b, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
saved, err := config.LoadIntegration("stubapp")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to reload integration config: %v", err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff([]string{"qwen3:8b"}, saved.Models); diff != "" {
|
||||||
|
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdHeadlessYes_IntegrationRequiresModelEvenWhenSaved(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
withInteractiveSession(t, false)
|
||||||
|
|
||||||
|
if err := config.SaveIntegration("stubapp", []string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatalf("failed to seed saved config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherSingleRunner{}
|
||||||
|
restore := OverrideIntegration("stubapp", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
oldSelector := DefaultSingleSelector
|
||||||
|
defer func() { DefaultSingleSelector = oldSelector }()
|
||||||
|
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||||
|
t.Fatal("selector should not be called for headless --yes saved-model launch")
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubapp", "--yes"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected launch command to fail when --yes is used headlessly without --model")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "requires --model <model>") {
|
||||||
|
t.Fatalf("expected actionable --model guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "" {
|
||||||
|
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchCmdHeadlessYes_IntegrationWithoutSavedModelReturnsError(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
withInteractiveSession(t, false)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
stub := &launcherSingleRunner{}
|
||||||
|
restore := OverrideIntegration("stubapp", stub)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
oldSelector := DefaultSingleSelector
|
||||||
|
defer func() { DefaultSingleSelector = oldSelector }()
|
||||||
|
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||||
|
t.Fatal("selector should not be called for headless --yes without saved model")
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||||
|
cmd.SetArgs([]string{"stubapp", "--yes"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected launch command to fail when --yes is used headlessly without --model")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "requires --model <model>") {
|
||||||
|
t.Fatalf("expected actionable --model guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if stub.ranModel != "" {
|
||||||
|
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,16 +1,14 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -47,25 +45,6 @@ func (d *Droid) Run(model string, args []string) error {
|
|||||||
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
|
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call Edit() to ensure config is up-to-date before launch
|
|
||||||
models := []string{model}
|
|
||||||
if config, err := loadIntegration("droid"); err == nil && len(config.Models) > 0 {
|
|
||||||
models = config.Models
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
models, err = resolveEditorModels("droid", models, func() ([]string, error) {
|
|
||||||
return selectModels(context.Background(), "droid", "")
|
|
||||||
})
|
|
||||||
if errors.Is(err, errCancelled) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := d.Edit(models); err != nil {
|
|
||||||
return fmt.Errorf("setup failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command("droid", args...)
|
cmd := exec.Command("droid", args...)
|
||||||
cmd.Stdin = os.Stdin
|
cmd.Stdin = os.Stdin
|
||||||
cmd.Stdout = os.Stdout
|
cmd.Stdout = os.Stdout
|
||||||
@@ -111,6 +90,16 @@ func (d *Droid) Edit(models []string) error {
|
|||||||
json.Unmarshal(data, &settings) // ignore error, zero values are fine
|
json.Unmarshal(data, &settings) // ignore error, zero values are fine
|
||||||
}
|
}
|
||||||
|
|
||||||
|
settingsMap = updateDroidSettings(settingsMap, settings, models)
|
||||||
|
|
||||||
|
data, err := json.MarshalIndent(settingsMap, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return fileutil.WriteWithBackup(settingsPath, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateDroidSettings(settingsMap map[string]any, settings droidSettings, models []string) map[string]any {
|
||||||
// Keep only non-Ollama models from the raw map (preserves extra fields)
|
// Keep only non-Ollama models from the raw map (preserves extra fields)
|
||||||
// Rebuild Ollama models
|
// Rebuild Ollama models
|
||||||
var nonOllamaModels []any
|
var nonOllamaModels []any
|
||||||
@@ -125,13 +114,12 @@ func (d *Droid) Edit(models []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
|
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
|
||||||
client, _ := api.ClientFromEnvironment()
|
|
||||||
|
|
||||||
var newModels []any
|
var newModels []any
|
||||||
var defaultModelID string
|
var defaultModelID string
|
||||||
for i, model := range models {
|
for i, model := range models {
|
||||||
maxOutput := 64000
|
maxOutput := 64000
|
||||||
if isCloudModel(context.Background(), client, model) {
|
if isCloudModelName(model) {
|
||||||
if l, ok := lookupCloudModelLimit(model); ok {
|
if l, ok := lookupCloudModelLimit(model); ok {
|
||||||
maxOutput = l.Output
|
maxOutput = l.Output
|
||||||
}
|
}
|
||||||
@@ -167,12 +155,7 @@ func (d *Droid) Edit(models []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
settingsMap["sessionDefaultSettings"] = sessionSettings
|
settingsMap["sessionDefaultSettings"] = sessionSettings
|
||||||
|
return settingsMap
|
||||||
data, err := json.MarshalIndent(settingsMap, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return writeWithBackup(settingsPath, data)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Droid) Models() []string {
|
func (d *Droid) Models() []string {
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -6,6 +6,8 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDroidIntegration(t *testing.T) {
|
func TestDroidIntegration(t *testing.T) {
|
||||||
@@ -362,7 +364,7 @@ func TestDroidEdit_DuplicateModels(t *testing.T) {
|
|||||||
t.Fatalf("Edit with duplicates failed: %v", err)
|
t.Fatalf("Edit with duplicates failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := readJSONFile(settingsPath)
|
settings, err := fileutil.ReadJSON(settingsPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("readJSONFile failed: %v", err)
|
t.Fatalf("readJSONFile failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -392,7 +394,7 @@ func TestDroidEdit_MalformedModelEntry(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Malformed entries (non-object) are dropped - only valid model objects are preserved
|
// Malformed entries (non-object) are dropped - only valid model objects are preserved
|
||||||
settings, _ := readJSONFile(settingsPath)
|
settings, _ := fileutil.ReadJSON(settingsPath)
|
||||||
customModels, _ := settings["customModels"].([]any)
|
customModels, _ := settings["customModels"].([]any)
|
||||||
|
|
||||||
// Should have: 1 new Ollama model only (malformed entries dropped)
|
// Should have: 1 new Ollama model only (malformed entries dropped)
|
||||||
@@ -419,7 +421,7 @@ func TestDroidEdit_WrongTypeSessionSettings(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Should create proper sessionDefaultSettings
|
// Should create proper sessionDefaultSettings
|
||||||
settings, _ := readJSONFile(settingsPath)
|
settings, _ := fileutil.ReadJSON(settingsPath)
|
||||||
session, ok := settings["sessionDefaultSettings"].(map[string]any)
|
session, ok := settings["sessionDefaultSettings"].(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("sessionDefaultSettings should be map after setup, got %T", settings["sessionDefaultSettings"])
|
t.Fatalf("sessionDefaultSettings should be map after setup, got %T", settings["sessionDefaultSettings"])
|
||||||
@@ -1008,34 +1010,34 @@ func TestDroidEdit_ModelNamesWithSpecialCharacters(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDroidEdit_MissingCustomModelsKey(t *testing.T) {
|
func TestDroidEdit_MissingCustomModelsKey(t *testing.T) {
|
||||||
d := &Droid{}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
setTestHome(t, tmpDir)
|
|
||||||
|
|
||||||
settingsDir := filepath.Join(tmpDir, ".factory")
|
|
||||||
settingsPath := filepath.Join(settingsDir, "settings.json")
|
|
||||||
|
|
||||||
os.MkdirAll(settingsDir, 0o755)
|
|
||||||
|
|
||||||
// No customModels key at all
|
// No customModels key at all
|
||||||
original := `{
|
original := `{
|
||||||
"diffMode": "github",
|
"diffMode": "github",
|
||||||
"sessionDefaultSettings": {"autonomyMode": "auto-high"}
|
"sessionDefaultSettings": {"autonomyMode": "auto-high"}
|
||||||
}`
|
}`
|
||||||
os.WriteFile(settingsPath, []byte(original), 0o644)
|
|
||||||
|
|
||||||
if err := d.Edit([]string{"model-a"}); err != nil {
|
var settingsStruct droidSettings
|
||||||
|
var settings map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(original), &settings); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(original), &settingsStruct); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
data, _ := os.ReadFile(settingsPath)
|
settings = updateDroidSettings(settings, settingsStruct, []string{"model-a"})
|
||||||
var settings map[string]any
|
|
||||||
json.Unmarshal(data, &settings)
|
|
||||||
|
|
||||||
// Original fields preserved
|
// Original fields preserved
|
||||||
if settings["diffMode"] != "github" {
|
if settings["diffMode"] != "github" {
|
||||||
t.Error("diffMode not preserved")
|
t.Error("diffMode not preserved")
|
||||||
}
|
}
|
||||||
|
session, ok := settings["sessionDefaultSettings"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("sessionDefaultSettings not preserved")
|
||||||
|
}
|
||||||
|
if session["autonomyMode"] != "auto-high" {
|
||||||
|
t.Error("sessionDefaultSettings.autonomyMode not preserved")
|
||||||
|
}
|
||||||
|
|
||||||
// customModels created
|
// customModels created
|
||||||
models, ok := settings["customModels"].([]any)
|
models, ok := settings["customModels"].([]any)
|
||||||
@@ -1276,25 +1278,17 @@ func TestDroidEdit_LocalModelDefaultMaxOutput(t *testing.T) {
|
|||||||
|
|
||||||
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
|
func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
|
||||||
// Verify that every cloud model in cloudModelLimits has a valid output
|
// Verify that every cloud model in cloudModelLimits has a valid output
|
||||||
// value that would be used for maxOutputTokens when isCloudModel returns true.
|
// value that would be used for maxOutputTokens when the selected model uses
|
||||||
// :cloud suffix stripping must also work since that's how users specify them.
|
// the explicit :cloud source tag.
|
||||||
for name, expected := range cloudModelLimits {
|
for name, expected := range cloudModelLimits {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
l, ok := lookupCloudModelLimit(name)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("lookupCloudModelLimit(%q) returned false", name)
|
|
||||||
}
|
|
||||||
if l.Output != expected.Output {
|
|
||||||
t.Errorf("output = %d, want %d", l.Output, expected.Output)
|
|
||||||
}
|
|
||||||
// Also verify :cloud suffix lookup
|
|
||||||
cloudName := name + ":cloud"
|
cloudName := name + ":cloud"
|
||||||
l2, ok := lookupCloudModelLimit(cloudName)
|
l, ok := lookupCloudModelLimit(cloudName)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName)
|
t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName)
|
||||||
}
|
}
|
||||||
if l2.Output != expected.Output {
|
if l.Output != expected.Output {
|
||||||
t.Errorf(":cloud output = %d, want %d", l2.Output, expected.Output)
|
t.Errorf("output = %d, want %d", l.Output, expected.Output)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
842
cmd/launch/launch.go
Normal file
842
cmd/launch/launch.go
Normal file
@@ -0,0 +1,842 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/cmd/config"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"golang.org/x/term"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LauncherState is the launch-owned snapshot used to render the root launcher menu.
|
||||||
|
type LauncherState struct {
|
||||||
|
LastSelection string
|
||||||
|
RunModel string
|
||||||
|
RunModelUsable bool
|
||||||
|
Integrations map[string]LauncherIntegrationState
|
||||||
|
}
|
||||||
|
|
||||||
|
// LauncherIntegrationState is the launch-owned status for one launcher integration.
|
||||||
|
type LauncherIntegrationState struct {
|
||||||
|
Name string
|
||||||
|
DisplayName string
|
||||||
|
Description string
|
||||||
|
Installed bool
|
||||||
|
AutoInstallable bool
|
||||||
|
Selectable bool
|
||||||
|
Changeable bool
|
||||||
|
CurrentModel string
|
||||||
|
ModelUsable bool
|
||||||
|
InstallHint string
|
||||||
|
Editor bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunModelRequest controls how the root launcher resolves the chat model.
|
||||||
|
type RunModelRequest struct {
|
||||||
|
ForcePicker bool
|
||||||
|
Policy *LaunchPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
// LaunchConfirmMode controls confirmation behavior across launch flows.
|
||||||
|
type LaunchConfirmMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// LaunchConfirmPrompt prompts the user for confirmation.
|
||||||
|
LaunchConfirmPrompt LaunchConfirmMode = iota
|
||||||
|
// LaunchConfirmAutoApprove skips prompts and treats confirmation as accepted.
|
||||||
|
LaunchConfirmAutoApprove
|
||||||
|
// LaunchConfirmRequireYes rejects confirmation requests with a --yes hint.
|
||||||
|
LaunchConfirmRequireYes
|
||||||
|
)
|
||||||
|
|
||||||
|
// LaunchMissingModelMode controls local missing-model handling in launch flows.
|
||||||
|
type LaunchMissingModelMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// LaunchMissingModelPromptToPull prompts to pull a missing local model.
|
||||||
|
LaunchMissingModelPromptToPull LaunchMissingModelMode = iota
|
||||||
|
// LaunchMissingModelAutoPull pulls a missing local model without prompting.
|
||||||
|
LaunchMissingModelAutoPull
|
||||||
|
// LaunchMissingModelFail fails immediately when a local model is missing.
|
||||||
|
LaunchMissingModelFail
|
||||||
|
)
|
||||||
|
|
||||||
|
// LaunchPolicy controls launch behavior that may vary by caller context.
|
||||||
|
type LaunchPolicy struct {
|
||||||
|
Confirm LaunchConfirmMode
|
||||||
|
MissingModel LaunchMissingModelMode
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultLaunchPolicy(interactive bool, yes bool) LaunchPolicy {
|
||||||
|
policy := LaunchPolicy{
|
||||||
|
Confirm: LaunchConfirmPrompt,
|
||||||
|
MissingModel: LaunchMissingModelPromptToPull,
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case yes:
|
||||||
|
// if yes flag is set, auto approve and auto pull
|
||||||
|
policy.Confirm = LaunchConfirmAutoApprove
|
||||||
|
policy.MissingModel = LaunchMissingModelAutoPull
|
||||||
|
case !interactive:
|
||||||
|
// otherwise make sure to stop when needed
|
||||||
|
policy.Confirm = LaunchConfirmRequireYes
|
||||||
|
policy.MissingModel = LaunchMissingModelFail
|
||||||
|
}
|
||||||
|
return policy
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p LaunchPolicy) confirmPolicy() launchConfirmPolicy {
|
||||||
|
switch p.Confirm {
|
||||||
|
case LaunchConfirmAutoApprove:
|
||||||
|
return launchConfirmPolicy{yes: true}
|
||||||
|
case LaunchConfirmRequireYes:
|
||||||
|
return launchConfirmPolicy{requireYesMessage: true}
|
||||||
|
default:
|
||||||
|
return launchConfirmPolicy{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p LaunchPolicy) missingModelPolicy() missingModelPolicy {
|
||||||
|
switch p.MissingModel {
|
||||||
|
case LaunchMissingModelAutoPull:
|
||||||
|
return missingModelAutoPull
|
||||||
|
case LaunchMissingModelFail:
|
||||||
|
return missingModelFail
|
||||||
|
default:
|
||||||
|
return missingModelPromptPull
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntegrationLaunchRequest controls the canonical integration launcher flow.
|
||||||
|
type IntegrationLaunchRequest struct {
|
||||||
|
Name string
|
||||||
|
ModelOverride string
|
||||||
|
ForceConfigure bool
|
||||||
|
ConfigureOnly bool
|
||||||
|
ExtraArgs []string
|
||||||
|
Policy *LaunchPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
var isInteractiveSession = func() bool {
|
||||||
|
return term.IsTerminal(int(os.Stdin.Fd())) && term.IsTerminal(int(os.Stdout.Fd()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Runner executes a model with an integration.
|
||||||
|
type Runner interface {
|
||||||
|
Run(model string, args []string) error
|
||||||
|
String() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Editor can edit config files for integrations that support model configuration.
|
||||||
|
type Editor interface {
|
||||||
|
Paths() []string
|
||||||
|
Edit(models []string) error
|
||||||
|
Models() []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type modelInfo struct {
|
||||||
|
Name string
|
||||||
|
Remote bool
|
||||||
|
ToolCapable bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelInfo re-exports launcher model inventory details for callers.
|
||||||
|
type ModelInfo = modelInfo
|
||||||
|
|
||||||
|
// ModelItem represents a model for selection UIs.
|
||||||
|
type ModelItem struct {
|
||||||
|
Name string
|
||||||
|
Description string
|
||||||
|
Recommended bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// LaunchCmd returns the cobra command for launching integrations.
|
||||||
|
// The runTUI callback is called when the root launcher UI should be shown.
|
||||||
|
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error, runTUI func(cmd *cobra.Command)) *cobra.Command {
|
||||||
|
var modelFlag string
|
||||||
|
var configFlag bool
|
||||||
|
var yesFlag bool
|
||||||
|
|
||||||
|
cmd := &cobra.Command{
|
||||||
|
Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
|
||||||
|
Short: "Launch the Ollama menu or an integration",
|
||||||
|
Long: `Launch the Ollama interactive menu, or directly launch a specific integration.
|
||||||
|
|
||||||
|
Without arguments, this is equivalent to running 'ollama' directly.
|
||||||
|
Flags and extra arguments require an integration name.
|
||||||
|
|
||||||
|
Supported integrations:
|
||||||
|
claude Claude Code
|
||||||
|
cline Cline
|
||||||
|
codex Codex
|
||||||
|
droid Droid
|
||||||
|
opencode OpenCode
|
||||||
|
openclaw OpenClaw (aliases: clawdbot, moltbot)
|
||||||
|
pi Pi
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
ollama launch
|
||||||
|
ollama launch claude
|
||||||
|
ollama launch claude --model <model>
|
||||||
|
ollama launch droid --config (does not auto-launch)
|
||||||
|
ollama launch codex -- -p myprofile (pass extra args to integration)
|
||||||
|
ollama launch codex -- --sandbox workspace-write`,
|
||||||
|
Args: cobra.ArbitraryArgs,
|
||||||
|
PreRunE: checkServerHeartbeat,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
policy := defaultLaunchPolicy(isInteractiveSession(), yesFlag)
|
||||||
|
// reset when done to make sure state doens't leak between launches
|
||||||
|
restoreConfirmPolicy := withLaunchConfirmPolicy(policy.confirmPolicy())
|
||||||
|
defer restoreConfirmPolicy()
|
||||||
|
|
||||||
|
var name string
|
||||||
|
var passArgs []string
|
||||||
|
dashIdx := cmd.ArgsLenAtDash()
|
||||||
|
|
||||||
|
if dashIdx == -1 {
|
||||||
|
if len(args) > 1 {
|
||||||
|
return fmt.Errorf("unexpected arguments: %v\nUse '--' to pass extra arguments to the integration", args[1:])
|
||||||
|
}
|
||||||
|
if len(args) == 1 {
|
||||||
|
name = args[0]
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if dashIdx > 1 {
|
||||||
|
return fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx)
|
||||||
|
}
|
||||||
|
if dashIdx == 1 {
|
||||||
|
name = args[0]
|
||||||
|
}
|
||||||
|
passArgs = args[dashIdx:]
|
||||||
|
}
|
||||||
|
|
||||||
|
if name == "" {
|
||||||
|
if cmd.Flags().Changed("model") || cmd.Flags().Changed("config") || cmd.Flags().Changed("yes") || len(passArgs) > 0 {
|
||||||
|
return fmt.Errorf("flags and extra args require an integration name, for example: 'ollama launch claude --model qwen3.5'")
|
||||||
|
}
|
||||||
|
runTUI(cmd)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelFlag != "" && isCloudModelName(modelFlag) {
|
||||||
|
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||||
|
if disabled, _ := cloudStatusDisabled(cmd.Context(), client); disabled {
|
||||||
|
fmt.Fprintf(os.Stderr, "Warning: ignoring --model %s because cloud is disabled\n", modelFlag)
|
||||||
|
modelFlag = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
headlessYes := yesFlag && !isInteractiveSession()
|
||||||
|
err := LaunchIntegration(cmd.Context(), IntegrationLaunchRequest{
|
||||||
|
Name: name,
|
||||||
|
ModelOverride: modelFlag,
|
||||||
|
ForceConfigure: configFlag || (modelFlag == "" && !headlessYes),
|
||||||
|
ConfigureOnly: configFlag,
|
||||||
|
ExtraArgs: passArgs,
|
||||||
|
Policy: &policy,
|
||||||
|
})
|
||||||
|
if errors.Is(err, ErrCancelled) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Flags().StringVar(&modelFlag, "model", "", "Model to use")
|
||||||
|
cmd.Flags().BoolVar(&configFlag, "config", false, "Configure without launching")
|
||||||
|
cmd.Flags().BoolVarP(&yesFlag, "yes", "y", false, "Automatically answer yes to confirmation prompts")
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
type launcherClient struct {
|
||||||
|
apiClient *api.Client
|
||||||
|
modelInventory []ModelInfo
|
||||||
|
inventoryLoaded bool
|
||||||
|
policy LaunchPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLauncherClient(policy LaunchPolicy) (*launcherClient, error) {
|
||||||
|
apiClient, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &launcherClient{
|
||||||
|
apiClient: apiClient,
|
||||||
|
policy: policy,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildLauncherState returns the launch-owned root launcher menu snapshot.
|
||||||
|
func BuildLauncherState(ctx context.Context) (*LauncherState, error) {
|
||||||
|
launchClient, err := newLauncherClient(defaultLaunchPolicy(isInteractiveSession(), false))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return launchClient.buildLauncherState(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveRunModel returns the model that should be used for interactive chat.
|
||||||
|
func ResolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
|
||||||
|
// Called by the launcher TUI "Run a model" action (cmd/runLauncherAction),
|
||||||
|
// which resolves models separately from LaunchIntegration. Callers can pass
|
||||||
|
// Policy directly; otherwise we fall back to ambient --yes/session defaults.
|
||||||
|
policy := defaultLaunchPolicy(isInteractiveSession(), currentLaunchConfirmPolicy.yes)
|
||||||
|
if req.Policy != nil {
|
||||||
|
policy = *req.Policy
|
||||||
|
}
|
||||||
|
|
||||||
|
launchClient, err := newLauncherClient(policy)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return launchClient.resolveRunModel(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LaunchIntegration runs the canonical launcher flow for one integration.
|
||||||
|
func LaunchIntegration(ctx context.Context, req IntegrationLaunchRequest) error {
|
||||||
|
name, runner, err := LookupIntegration(req.Name)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !req.ConfigureOnly {
|
||||||
|
if err := EnsureIntegrationInstalled(name, runner); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var policy LaunchPolicy
|
||||||
|
// TUI does not set a policy, whereas ollama launch <app> does as it can have flags which change the behavior
|
||||||
|
if req.Policy == nil {
|
||||||
|
policy = defaultLaunchPolicy(isInteractiveSession(), false)
|
||||||
|
} else {
|
||||||
|
policy = *req.Policy
|
||||||
|
}
|
||||||
|
|
||||||
|
launchClient, err := newLauncherClient(policy)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
saved, _ := loadStoredIntegrationConfig(name)
|
||||||
|
// In headless --yes mode we cannot prompt, so require an explicit --model.
|
||||||
|
if policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() && req.ModelOverride == "" {
|
||||||
|
return fmt.Errorf("headless --yes launch for %s requires --model <model>", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if editor, ok := runner.(Editor); ok {
|
||||||
|
return launchClient.launchEditorIntegration(ctx, name, runner, editor, saved, req)
|
||||||
|
}
|
||||||
|
return launchClient.launchSingleIntegration(ctx, name, runner, saved, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) buildLauncherState(ctx context.Context) (*LauncherState, error) {
|
||||||
|
_ = c.loadModelInventoryOnce(ctx)
|
||||||
|
|
||||||
|
state := &LauncherState{
|
||||||
|
LastSelection: config.LastSelection(),
|
||||||
|
RunModel: config.LastModel(),
|
||||||
|
Integrations: make(map[string]LauncherIntegrationState),
|
||||||
|
}
|
||||||
|
runModelUsable, err := c.savedModelUsable(ctx, state.RunModel)
|
||||||
|
if err != nil {
|
||||||
|
runModelUsable = false
|
||||||
|
}
|
||||||
|
state.RunModelUsable = runModelUsable
|
||||||
|
|
||||||
|
for _, info := range ListIntegrationInfos() {
|
||||||
|
integrationState, err := c.buildLauncherIntegrationState(ctx, info)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
state.Integrations[info.Name] = integrationState
|
||||||
|
}
|
||||||
|
|
||||||
|
return state, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) buildLauncherIntegrationState(ctx context.Context, info IntegrationInfo) (LauncherIntegrationState, error) {
|
||||||
|
integration, err := integrationFor(info.Name)
|
||||||
|
if err != nil {
|
||||||
|
return LauncherIntegrationState{}, err
|
||||||
|
}
|
||||||
|
currentModel, usable, err := c.launcherModelState(ctx, info.Name, integration.editor)
|
||||||
|
if err != nil {
|
||||||
|
return LauncherIntegrationState{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return LauncherIntegrationState{
|
||||||
|
Name: info.Name,
|
||||||
|
DisplayName: info.DisplayName,
|
||||||
|
Description: info.Description,
|
||||||
|
Installed: integration.installed,
|
||||||
|
AutoInstallable: integration.autoInstallable,
|
||||||
|
Selectable: integration.installed || integration.autoInstallable,
|
||||||
|
Changeable: integration.installed || integration.autoInstallable,
|
||||||
|
CurrentModel: currentModel,
|
||||||
|
ModelUsable: usable,
|
||||||
|
InstallHint: integration.installHint,
|
||||||
|
Editor: integration.editor,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) launcherModelState(ctx context.Context, name string, isEditor bool) (string, bool, error) {
|
||||||
|
cfg, loadErr := loadStoredIntegrationConfig(name)
|
||||||
|
hasModels := loadErr == nil && len(cfg.Models) > 0
|
||||||
|
if !hasModels {
|
||||||
|
return "", false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if isEditor {
|
||||||
|
filtered := c.filterDisabledCloudModels(ctx, cfg.Models)
|
||||||
|
if len(filtered) > 0 {
|
||||||
|
return filtered[0], true, nil
|
||||||
|
}
|
||||||
|
return cfg.Models[0], false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
model := cfg.Models[0]
|
||||||
|
usable, usableErr := c.savedModelUsable(ctx, model)
|
||||||
|
return model, usableErr == nil && usable, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
|
||||||
|
current := config.LastModel()
|
||||||
|
if !req.ForcePicker && current != "" && c.policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() {
|
||||||
|
if err := c.ensureModelsReady(ctx, []string{current}); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "Headless mode: auto-selected last used model %q\n", current)
|
||||||
|
if err := config.SetLastModel(current); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return current, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !req.ForcePicker {
|
||||||
|
usable, err := c.savedModelUsable(ctx, current)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if usable {
|
||||||
|
if err := c.ensureModelsReady(ctx, []string{current}); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := config.SetLastModel(current); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return current, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
model, err := c.selectSingleModelWithSelector(ctx, "Select model to run:", current, DefaultSingleSelector)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := config.SetLastModel(model); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return model, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) launchSingleIntegration(ctx context.Context, name string, runner Runner, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
|
||||||
|
current := primaryModelFromConfig(saved)
|
||||||
|
target := req.ModelOverride
|
||||||
|
needsConfigure := req.ForceConfigure
|
||||||
|
|
||||||
|
if target == "" {
|
||||||
|
target = current
|
||||||
|
usable, err := c.savedModelUsable(ctx, target)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !usable {
|
||||||
|
needsConfigure = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if needsConfigure {
|
||||||
|
selected, err := c.selectSingleModelWithSelector(ctx, fmt.Sprintf("Select model for %s:", runner), target, DefaultSingleSelector)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
target = selected
|
||||||
|
} else if err := c.ensureModelsReady(ctx, []string{target}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if target == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := config.SaveIntegration(name, []string{target}); err != nil {
|
||||||
|
return fmt.Errorf("failed to save: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return launchAfterConfiguration(name, runner, target, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) launchEditorIntegration(ctx context.Context, name string, runner Runner, editor Editor, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
|
||||||
|
models, needsConfigure := c.resolveEditorLaunchModels(ctx, saved, req)
|
||||||
|
|
||||||
|
if needsConfigure {
|
||||||
|
selected, err := c.selectMultiModelsForIntegration(ctx, runner, models)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
models = selected
|
||||||
|
} else if err := c.ensureModelsReady(ctx, models); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if needsConfigure || req.ModelOverride != "" {
|
||||||
|
if err := prepareEditorIntegration(name, runner, editor, models); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return launchAfterConfiguration(name, runner, models[0], req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) selectSingleModelWithSelector(ctx context.Context, title, current string, selector SingleSelector) (string, error) {
|
||||||
|
if selector == nil {
|
||||||
|
return "", fmt.Errorf("no selector configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
items, _, err := c.loadSelectableModels(ctx, nil, current, "no models available, run 'ollama pull <model>' first")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
selected, err := selector(title, items, current)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := c.ensureModelsReady(ctx, []string{selected}); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return selected, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) selectMultiModelsForIntegration(ctx context.Context, runner Runner, preChecked []string) ([]string, error) {
|
||||||
|
if DefaultMultiSelector == nil {
|
||||||
|
return nil, fmt.Errorf("no selector configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
current := firstModel(preChecked)
|
||||||
|
|
||||||
|
items, orderedChecked, err := c.loadSelectableModels(ctx, preChecked, current, "no models available")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(preChecked) > 0 {
|
||||||
|
// Keep list order stable in multi-select even when there are existing checks.
|
||||||
|
// checked/default state still comes from orderedChecked.
|
||||||
|
stableItems, _, stableErr := c.loadSelectableModels(ctx, nil, current, "no models available")
|
||||||
|
if stableErr != nil {
|
||||||
|
return nil, stableErr
|
||||||
|
}
|
||||||
|
items = stableItems
|
||||||
|
}
|
||||||
|
|
||||||
|
selected, err := DefaultMultiSelector(fmt.Sprintf("Select models for %s:", runner), items, orderedChecked)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := c.ensureModelsReady(ctx, selected); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return selected, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) loadSelectableModels(ctx context.Context, preChecked []string, current, emptyMessage string) ([]ModelItem, []string, error) {
|
||||||
|
if err := c.loadModelInventoryOnce(ctx); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||||
|
items, orderedChecked, _, _ := buildModelList(c.modelInventory, preChecked, current)
|
||||||
|
if cloudDisabled {
|
||||||
|
items = filterCloudItems(items)
|
||||||
|
orderedChecked = c.filterDisabledCloudModels(ctx, orderedChecked)
|
||||||
|
}
|
||||||
|
if len(items) == 0 {
|
||||||
|
return nil, nil, errors.New(emptyMessage)
|
||||||
|
}
|
||||||
|
return items, orderedChecked, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) ensureModelsReady(ctx context.Context, models []string) error {
|
||||||
|
var deduped []string
|
||||||
|
seen := make(map[string]bool, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
if model == "" || seen[model] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[model] = true
|
||||||
|
deduped = append(deduped, model)
|
||||||
|
}
|
||||||
|
models = deduped
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cloudModels := make(map[string]bool, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
isCloudModel := isCloudModelName(model)
|
||||||
|
if isCloudModel {
|
||||||
|
cloudModels[model] = true
|
||||||
|
}
|
||||||
|
if err := showOrPullWithPolicy(ctx, c.apiClient, model, c.policy.missingModelPolicy(), isCloudModel); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ensureAuth(ctx, c.apiClient, cloudModels, models)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) resolveEditorLaunchModels(ctx context.Context, saved *config.IntegrationConfig, req IntegrationLaunchRequest) ([]string, bool) {
|
||||||
|
if req.ForceConfigure {
|
||||||
|
return editorPreCheckedModels(saved, req.ModelOverride), true
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.ModelOverride != "" {
|
||||||
|
models := append([]string{req.ModelOverride}, additionalSavedModels(saved, req.ModelOverride)...)
|
||||||
|
models = c.filterDisabledCloudModels(ctx, models)
|
||||||
|
return models, len(models) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if saved == nil || len(saved.Models) == 0 {
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
|
||||||
|
models := c.filterDisabledCloudModels(ctx, saved.Models)
|
||||||
|
return models, len(models) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) filterDisabledCloudModels(ctx context.Context, models []string) []string {
|
||||||
|
// if connection cannot be established or there is a 404, cloud models will continue to be displayed
|
||||||
|
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||||
|
if !cloudDisabled {
|
||||||
|
return append([]string(nil), models...)
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := make([]string, 0, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
if !isCloudModelName(model) {
|
||||||
|
filtered = append(filtered, model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) savedModelUsable(ctx context.Context, name string) (bool, error) {
|
||||||
|
if err := c.loadModelInventoryOnce(ctx); err != nil {
|
||||||
|
return c.showBasedModelUsable(ctx, name)
|
||||||
|
}
|
||||||
|
return c.singleModelUsable(ctx, name), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) showBasedModelUsable(ctx context.Context, name string) (bool, error) {
|
||||||
|
if name == "" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := c.apiClient.Show(ctx, &api.ShowRequest{Model: name})
|
||||||
|
if err != nil {
|
||||||
|
var statusErr api.StatusError
|
||||||
|
if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if isCloudModelName(name) || info.RemoteModel != "" {
|
||||||
|
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||||
|
|
||||||
|
return !cloudDisabled, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) singleModelUsable(ctx context.Context, name string) bool {
|
||||||
|
if name == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if isCloudModelName(name) {
|
||||||
|
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||||
|
return !cloudDisabled
|
||||||
|
}
|
||||||
|
return c.hasLocalModel(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) hasLocalModel(name string) bool {
|
||||||
|
for _, model := range c.modelInventory {
|
||||||
|
if model.Remote {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if model.Name == name || strings.HasPrefix(model.Name, name+":") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) loadModelInventoryOnce(ctx context.Context) error {
|
||||||
|
if c.inventoryLoaded {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.apiClient.List(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.modelInventory = c.modelInventory[:0]
|
||||||
|
for _, model := range resp.Models {
|
||||||
|
c.modelInventory = append(c.modelInventory, ModelInfo{
|
||||||
|
Name: model.Name,
|
||||||
|
Remote: model.RemoteModel != "",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||||
|
if cloudDisabled {
|
||||||
|
c.modelInventory = filterCloudModels(c.modelInventory)
|
||||||
|
}
|
||||||
|
c.inventoryLoaded = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runIntegration(runner Runner, modelName string, args []string) error {
|
||||||
|
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", runner, modelName)
|
||||||
|
return runner.Run(modelName, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
func launchAfterConfiguration(name string, runner Runner, model string, req IntegrationLaunchRequest) error {
|
||||||
|
if req.ConfigureOnly {
|
||||||
|
launch, err := ConfirmPrompt(fmt.Sprintf("Launch %s now?", runner))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !launch {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := EnsureIntegrationInstalled(name, runner); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return runIntegration(runner, model, req.ExtraArgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadStoredIntegrationConfig(name string) (*config.IntegrationConfig, error) {
|
||||||
|
cfg, err := config.LoadIntegration(name)
|
||||||
|
if err == nil {
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
if !errors.Is(err, os.ErrNotExist) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
spec, specErr := LookupIntegrationSpec(name)
|
||||||
|
if specErr != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, alias := range spec.Aliases {
|
||||||
|
legacy, legacyErr := config.LoadIntegration(alias)
|
||||||
|
if legacyErr == nil {
|
||||||
|
migrateLegacyIntegrationConfig(spec.Name, legacy)
|
||||||
|
if migrated, migratedErr := config.LoadIntegration(spec.Name); migratedErr == nil {
|
||||||
|
return migrated, nil
|
||||||
|
}
|
||||||
|
return legacy, nil
|
||||||
|
}
|
||||||
|
if legacyErr != nil && !errors.Is(legacyErr, os.ErrNotExist) {
|
||||||
|
return nil, legacyErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func migrateLegacyIntegrationConfig(canonical string, legacy *config.IntegrationConfig) {
|
||||||
|
if legacy == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = config.SaveIntegration(canonical, append([]string(nil), legacy.Models...))
|
||||||
|
if len(legacy.Aliases) > 0 {
|
||||||
|
_ = config.SaveAliases(canonical, cloneAliases(legacy.Aliases))
|
||||||
|
}
|
||||||
|
if legacy.Onboarded {
|
||||||
|
_ = config.MarkIntegrationOnboarded(canonical)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func primaryModelFromConfig(cfg *config.IntegrationConfig) string {
|
||||||
|
if cfg == nil || len(cfg.Models) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return cfg.Models[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneAliases(aliases map[string]string) map[string]string {
|
||||||
|
if len(aliases) == 0 {
|
||||||
|
return make(map[string]string)
|
||||||
|
}
|
||||||
|
|
||||||
|
cloned := make(map[string]string, len(aliases))
|
||||||
|
for key, value := range aliases {
|
||||||
|
cloned[key] = value
|
||||||
|
}
|
||||||
|
return cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
func singleModelPrechecked(current string) []string {
|
||||||
|
if current == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return []string{current}
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstModel(models []string) string {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return models[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func editorPreCheckedModels(saved *config.IntegrationConfig, override string) []string {
|
||||||
|
if override == "" {
|
||||||
|
if saved == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return append([]string(nil), saved.Models...)
|
||||||
|
}
|
||||||
|
return append([]string{override}, additionalSavedModels(saved, override)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func additionalSavedModels(saved *config.IntegrationConfig, exclude string) []string {
|
||||||
|
if saved == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var models []string
|
||||||
|
for _, model := range saved.Models {
|
||||||
|
if model != exclude {
|
||||||
|
models = append(models, model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return models
|
||||||
|
}
|
||||||
1498
cmd/launch/launch_test.go
Normal file
1498
cmd/launch/launch_test.go
Normal file
File diff suppressed because it is too large
Load Diff
490
cmd/launch/models.go
Normal file
490
cmd/launch/models.go
Normal file
@@ -0,0 +1,490 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"runtime"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/cmd/config"
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
|
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||||
|
"github.com/ollama/ollama/internal/modelref"
|
||||||
|
"github.com/ollama/ollama/progress"
|
||||||
|
)
|
||||||
|
|
||||||
|
var recommendedModels = []ModelItem{
|
||||||
|
{Name: "kimi-k2.5:cloud", Description: "Multimodal reasoning with subagents", Recommended: true},
|
||||||
|
{Name: "qwen3.5:cloud", Description: "Reasoning, coding, and agentic tool use with vision", Recommended: true},
|
||||||
|
{Name: "glm-5:cloud", Description: "Reasoning and code generation", Recommended: true},
|
||||||
|
{Name: "minimax-m2.5:cloud", Description: "Fast, efficient coding and real-world productivity", Recommended: true},
|
||||||
|
{Name: "glm-4.7-flash", Description: "Reasoning and code generation locally", Recommended: true},
|
||||||
|
{Name: "qwen3.5", Description: "Reasoning, coding, and visual understanding locally", Recommended: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
var recommendedVRAM = map[string]string{
|
||||||
|
"glm-4.7-flash": "~25GB",
|
||||||
|
"qwen3.5": "~11GB",
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloudModelLimit holds context and output token limits for a cloud model.
|
||||||
|
type cloudModelLimit struct {
|
||||||
|
Context int
|
||||||
|
Output int
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloudModelLimits maps cloud model base names to their token limits.
|
||||||
|
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
|
||||||
|
var cloudModelLimits = map[string]cloudModelLimit{
|
||||||
|
"minimax-m2.5": {Context: 204_800, Output: 128_000},
|
||||||
|
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
|
||||||
|
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
|
||||||
|
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
|
||||||
|
"glm-4.6": {Context: 202_752, Output: 131_072},
|
||||||
|
"glm-4.7": {Context: 202_752, Output: 131_072},
|
||||||
|
"glm-5": {Context: 202_752, Output: 131_072},
|
||||||
|
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
|
||||||
|
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
|
||||||
|
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
|
||||||
|
"kimi-k2.5": {Context: 262_144, Output: 262_144},
|
||||||
|
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
|
||||||
|
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
|
||||||
|
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
|
||||||
|
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
|
||||||
|
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
|
||||||
|
"qwen3.5": {Context: 262_144, Output: 32_768},
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupCloudModelLimit returns the token limits for a cloud model.
|
||||||
|
// It normalizes explicit cloud source suffixes before checking the shared limit map.
|
||||||
|
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
||||||
|
base, stripped := modelref.StripCloudSourceTag(name)
|
||||||
|
if stripped {
|
||||||
|
if l, ok := cloudModelLimits[base]; ok {
|
||||||
|
return l, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cloudModelLimit{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// missingModelPolicy controls how model-not-found errors should be handled.
|
||||||
|
type missingModelPolicy int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// missingModelPromptPull prompts the user to download missing local models.
|
||||||
|
missingModelPromptPull missingModelPolicy = iota
|
||||||
|
// missingModelAutoPull downloads missing local models without prompting.
|
||||||
|
missingModelAutoPull
|
||||||
|
// missingModelFail returns an error for missing local models without prompting.
|
||||||
|
missingModelFail
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenBrowser opens the URL in the user's browser.
|
||||||
|
func OpenBrowser(url string) {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
_ = exec.Command("open", url).Start()
|
||||||
|
case "linux":
|
||||||
|
_ = exec.Command("xdg-open", url).Start()
|
||||||
|
case "windows":
|
||||||
|
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureAuth ensures the user is signed in before cloud-backed models run.
|
||||||
|
func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error {
|
||||||
|
var selectedCloudModels []string
|
||||||
|
for _, m := range selected {
|
||||||
|
if cloudModels[m] {
|
||||||
|
selectedCloudModels = append(selectedCloudModels, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(selectedCloudModels) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
|
||||||
|
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := client.Whoami(ctx)
|
||||||
|
if err == nil && user != nil && user.Name != "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var aErr api.AuthorizationError
|
||||||
|
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
modelList := strings.Join(selectedCloudModels, ", ")
|
||||||
|
|
||||||
|
if DefaultSignIn != nil {
|
||||||
|
_, err := DefaultSignIn(modelList, aErr.SigninURL)
|
||||||
|
if errors.Is(err, ErrCancelled) {
|
||||||
|
return ErrCancelled
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%s requires sign in", modelList)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
yes, err := ConfirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||||
|
if errors.Is(err, ErrCancelled) {
|
||||||
|
return ErrCancelled
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !yes {
|
||||||
|
return ErrCancelled
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||||
|
OpenBrowser(aErr.SigninURL)
|
||||||
|
|
||||||
|
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||||
|
frame := 0
|
||||||
|
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||||
|
|
||||||
|
ticker := time.NewTicker(200 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||||
|
return ctx.Err()
|
||||||
|
case <-ticker.C:
|
||||||
|
frame++
|
||||||
|
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||||
|
|
||||||
|
if frame%10 == 0 {
|
||||||
|
u, err := client.Whoami(ctx)
|
||||||
|
if err == nil && u != nil && u.Name != "" {
|
||||||
|
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// showOrPullWithPolicy checks if a model exists and applies the provided missing-model policy.
|
||||||
|
func showOrPullWithPolicy(ctx context.Context, client *api.Client, model string, policy missingModelPolicy, isCloudModel bool) error {
|
||||||
|
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
var statusErr api.StatusError
|
||||||
|
if !errors.As(err, &statusErr) || statusErr.StatusCode != http.StatusNotFound {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isCloudModel {
|
||||||
|
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
|
||||||
|
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
|
||||||
|
}
|
||||||
|
return fmt.Errorf("model %q not found", model)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch policy {
|
||||||
|
case missingModelAutoPull:
|
||||||
|
return pullMissingModel(ctx, client, model)
|
||||||
|
case missingModelFail:
|
||||||
|
return fmt.Errorf("model %q not found; run 'ollama pull %s' first, or use --yes to auto-pull", model, model)
|
||||||
|
default:
|
||||||
|
return confirmAndPull(ctx, client, model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func confirmAndPull(ctx context.Context, client *api.Client, model string) error {
|
||||||
|
if ok, err := ConfirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
|
||||||
|
return err
|
||||||
|
} else if !ok {
|
||||||
|
return errCancelled
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "\n")
|
||||||
|
return pullMissingModel(ctx, client, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func pullMissingModel(ctx context.Context, client *api.Client, model string) error {
|
||||||
|
if err := pullModel(ctx, client, model, false); err != nil {
|
||||||
|
return fmt.Errorf("failed to pull %s: %w", model, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepareEditorIntegration persists models and applies editor-managed config files.
|
||||||
|
func prepareEditorIntegration(name string, runner Runner, editor Editor, models []string) error {
|
||||||
|
if ok, err := confirmEditorEdit(runner, editor); err != nil {
|
||||||
|
return err
|
||||||
|
} else if !ok {
|
||||||
|
return errCancelled
|
||||||
|
}
|
||||||
|
if err := editor.Edit(models); err != nil {
|
||||||
|
return fmt.Errorf("setup failed: %w", err)
|
||||||
|
}
|
||||||
|
if err := config.SaveIntegration(name, models); err != nil {
|
||||||
|
return fmt.Errorf("failed to save: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func confirmEditorEdit(runner Runner, editor Editor) (bool, error) {
|
||||||
|
paths := editor.Paths()
|
||||||
|
if len(paths) == 0 {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", runner)
|
||||||
|
for _, path := range paths {
|
||||||
|
fmt.Fprintf(os.Stderr, " %s\n", path)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", fileutil.BackupDir())
|
||||||
|
|
||||||
|
return ConfirmPrompt("Proceed?")
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildModelList merges existing models with recommendations for selection UIs.
|
||||||
|
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
|
||||||
|
existingModels = make(map[string]bool)
|
||||||
|
cloudModels = make(map[string]bool)
|
||||||
|
recommended := make(map[string]bool)
|
||||||
|
var hasLocalModel, hasCloudModel bool
|
||||||
|
|
||||||
|
recDesc := make(map[string]string)
|
||||||
|
for _, rec := range recommendedModels {
|
||||||
|
recommended[rec.Name] = true
|
||||||
|
recDesc[rec.Name] = rec.Description
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range existing {
|
||||||
|
existingModels[m.Name] = true
|
||||||
|
if m.Remote {
|
||||||
|
cloudModels[m.Name] = true
|
||||||
|
hasCloudModel = true
|
||||||
|
} else {
|
||||||
|
hasLocalModel = true
|
||||||
|
}
|
||||||
|
displayName := strings.TrimSuffix(m.Name, ":latest")
|
||||||
|
existingModels[displayName] = true
|
||||||
|
item := ModelItem{Name: displayName, Recommended: recommended[displayName], Description: recDesc[displayName]}
|
||||||
|
items = append(items, item)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rec := range recommendedModels {
|
||||||
|
if existingModels[rec.Name] || existingModels[rec.Name+":latest"] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
items = append(items, rec)
|
||||||
|
if isCloudModelName(rec.Name) {
|
||||||
|
cloudModels[rec.Name] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
checked := make(map[string]bool, len(preChecked))
|
||||||
|
for _, n := range preChecked {
|
||||||
|
checked[n] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if current != "" {
|
||||||
|
matchedCurrent := false
|
||||||
|
for _, item := range items {
|
||||||
|
if item.Name == current {
|
||||||
|
current = item.Name
|
||||||
|
matchedCurrent = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !matchedCurrent {
|
||||||
|
for _, item := range items {
|
||||||
|
if strings.HasPrefix(item.Name, current+":") {
|
||||||
|
current = item.Name
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if checked[current] {
|
||||||
|
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
|
||||||
|
}
|
||||||
|
|
||||||
|
notInstalled := make(map[string]bool)
|
||||||
|
for i := range items {
|
||||||
|
if !existingModels[items[i].Name] && !cloudModels[items[i].Name] {
|
||||||
|
notInstalled[items[i].Name] = true
|
||||||
|
var parts []string
|
||||||
|
if items[i].Description != "" {
|
||||||
|
parts = append(parts, items[i].Description)
|
||||||
|
}
|
||||||
|
if vram := recommendedVRAM[items[i].Name]; vram != "" {
|
||||||
|
parts = append(parts, vram)
|
||||||
|
}
|
||||||
|
parts = append(parts, "(not downloaded)")
|
||||||
|
items[i].Description = strings.Join(parts, ", ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
recRank := make(map[string]int)
|
||||||
|
for i, rec := range recommendedModels {
|
||||||
|
recRank[rec.Name] = i + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
onlyLocal := hasLocalModel && !hasCloudModel
|
||||||
|
|
||||||
|
if hasLocalModel || hasCloudModel {
|
||||||
|
slices.SortStableFunc(items, func(a, b ModelItem) int {
|
||||||
|
ac, bc := checked[a.Name], checked[b.Name]
|
||||||
|
aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
|
||||||
|
aRec, bRec := recRank[a.Name] > 0, recRank[b.Name] > 0
|
||||||
|
aCloud, bCloud := cloudModels[a.Name], cloudModels[b.Name]
|
||||||
|
|
||||||
|
if ac != bc {
|
||||||
|
if ac {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if aRec != bRec {
|
||||||
|
if aRec {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if aRec && bRec {
|
||||||
|
if aCloud != bCloud {
|
||||||
|
if onlyLocal {
|
||||||
|
if aCloud {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if aCloud {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return recRank[a.Name] - recRank[b.Name]
|
||||||
|
}
|
||||||
|
if aNew != bNew {
|
||||||
|
if aNew {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return items, preChecked, existingModels, cloudModels
|
||||||
|
}
|
||||||
|
|
||||||
|
// isCloudModelName reports whether the model name has an explicit cloud source.
|
||||||
|
func isCloudModelName(name string) bool {
|
||||||
|
return modelref.HasExplicitCloudSource(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterCloudModels drops remote-only models from the given inventory.
|
||||||
|
func filterCloudModels(existing []modelInfo) []modelInfo {
|
||||||
|
filtered := existing[:0]
|
||||||
|
for _, m := range existing {
|
||||||
|
if !m.Remote {
|
||||||
|
filtered = append(filtered, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterCloudItems removes cloud models from selection items.
|
||||||
|
func filterCloudItems(items []ModelItem) []ModelItem {
|
||||||
|
filtered := items[:0]
|
||||||
|
for _, item := range items {
|
||||||
|
if !isCloudModelName(item.Name) {
|
||||||
|
filtered = append(filtered, item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
|
||||||
|
if client == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return resp.RemoteModel != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloudStatusDisabled returns whether cloud usage is currently disabled.
|
||||||
|
func cloudStatusDisabled(ctx context.Context, client *api.Client) (disabled bool, known bool) {
|
||||||
|
status, err := client.CloudStatusExperimental(ctx)
|
||||||
|
if err != nil {
|
||||||
|
var statusErr api.StatusError
|
||||||
|
if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
return status.Cloud.Disabled, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(parthsareen): this duplicates the pull progress UI in cmd.PullHandler.
|
||||||
|
// Move the shared pull rendering to a small utility once the package boundary settles.
|
||||||
|
func pullModel(ctx context.Context, client *api.Client, model string, insecure bool) error {
|
||||||
|
p := progress.NewProgress(os.Stderr)
|
||||||
|
defer p.Stop()
|
||||||
|
|
||||||
|
bars := make(map[string]*progress.Bar)
|
||||||
|
var status string
|
||||||
|
var spinner *progress.Spinner
|
||||||
|
|
||||||
|
fn := func(resp api.ProgressResponse) error {
|
||||||
|
if resp.Digest != "" {
|
||||||
|
if resp.Completed == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if spinner != nil {
|
||||||
|
spinner.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
bar, ok := bars[resp.Digest]
|
||||||
|
if !ok {
|
||||||
|
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
if isDigest {
|
||||||
|
name = name[:min(12, len(name))]
|
||||||
|
}
|
||||||
|
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
|
||||||
|
bars[resp.Digest] = bar
|
||||||
|
p.Add(resp.Digest, bar)
|
||||||
|
}
|
||||||
|
|
||||||
|
bar.Set(resp.Completed)
|
||||||
|
} else if status != resp.Status {
|
||||||
|
if spinner != nil {
|
||||||
|
spinner.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
status = resp.Status
|
||||||
|
spinner = progress.NewSpinner(status)
|
||||||
|
p.Add(status, spinner)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
request := api.PullRequest{Name: model, Insecure: insecure}
|
||||||
|
return client.Pull(ctx, &request, fn)
|
||||||
|
}
|
||||||
890
cmd/launch/openclaw.go
Normal file
890
cmd/launch/openclaw.go
Normal file
@@ -0,0 +1,890 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/mod/semver"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultGatewayPort = 18789
|
||||||
|
|
||||||
|
// Bound model capability probing so launch/config cannot hang on slow/unreachable API calls.
|
||||||
|
var openclawModelShowTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
// openclawFreshInstall is set to true when ensureOpenclawInstalled performs an install
|
||||||
|
var openclawFreshInstall bool
|
||||||
|
|
||||||
|
type Openclaw struct{}
|
||||||
|
|
||||||
|
func (c *Openclaw) String() string { return "OpenClaw" }
|
||||||
|
|
||||||
|
func (c *Openclaw) Run(model string, args []string) error {
|
||||||
|
bin, err := ensureOpenclawInstalled()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
firstLaunch := !c.onboarded()
|
||||||
|
|
||||||
|
if firstLaunch {
|
||||||
|
fmt.Fprintf(os.Stderr, "\n%sSecurity%s\n\n", ansiBold, ansiReset)
|
||||||
|
fmt.Fprintf(os.Stderr, " OpenClaw can read files and run actions when tools are enabled.\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " A bad prompt can trick it into doing unsafe things.\n\n")
|
||||||
|
fmt.Fprintf(os.Stderr, "%s Learn more: https://docs.openclaw.ai/gateway/security%s\n\n", ansiGray, ansiReset)
|
||||||
|
|
||||||
|
ok, err := ConfirmPrompt("I understand the risks. Continue?")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the latest version is installed before onboarding so we get
|
||||||
|
// the newest wizard flags (e.g. --auth-choice ollama).
|
||||||
|
if !openclawFreshInstall {
|
||||||
|
update := exec.Command(bin, "update")
|
||||||
|
update.Stdout = os.Stdout
|
||||||
|
update.Stderr = os.Stderr
|
||||||
|
_ = update.Run() // best-effort; continue even if update fails
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "\n%sSetting up OpenClaw with Ollama...%s\n", ansiGreen, ansiReset)
|
||||||
|
fmt.Fprintf(os.Stderr, "%s Model: %s%s\n\n", ansiGray, model, ansiReset)
|
||||||
|
|
||||||
|
onboardArgs := []string{
|
||||||
|
"onboard",
|
||||||
|
"--non-interactive",
|
||||||
|
"--accept-risk",
|
||||||
|
"--auth-choice", "ollama",
|
||||||
|
"--custom-base-url", envconfig.Host().String(),
|
||||||
|
"--custom-model-id", model,
|
||||||
|
"--skip-channels",
|
||||||
|
"--skip-skills",
|
||||||
|
}
|
||||||
|
if canInstallDaemon() {
|
||||||
|
onboardArgs = append(onboardArgs, "--install-daemon")
|
||||||
|
}
|
||||||
|
cmd := exec.Command(bin, onboardArgs...)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
return windowsHint(fmt.Errorf("openclaw onboarding failed: %w\n\nTry running: openclaw onboard", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
patchDeviceScopes()
|
||||||
|
}
|
||||||
|
|
||||||
|
if ensureWebSearchPlugin() {
|
||||||
|
registerWebSearchPlugin()
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "\n%sStarting your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
|
||||||
|
|
||||||
|
// When extra args are passed through, run exactly what the user asked for
|
||||||
|
// after setup and skip the built-in gateway+TUI convenience flow.
|
||||||
|
if len(args) > 0 {
|
||||||
|
cmd := exec.Command(bin, args...)
|
||||||
|
cmd.Env = openclawEnv()
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
return windowsHint(err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
token, port := c.gatewayInfo()
|
||||||
|
addr := fmt.Sprintf("localhost:%d", port)
|
||||||
|
|
||||||
|
// If the gateway is already running (e.g. via the daemon), restart it
|
||||||
|
// so it picks up any config changes (model, provider, etc.).
|
||||||
|
if portOpen(addr) {
|
||||||
|
restart := exec.Command(bin, "daemon", "restart")
|
||||||
|
restart.Env = openclawEnv()
|
||||||
|
if err := restart.Run(); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s Warning: daemon restart failed: %v%s\n", ansiYellow, err, ansiReset)
|
||||||
|
}
|
||||||
|
if !waitForPort(addr, 10*time.Second) {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s Warning: gateway did not come back after restart%s\n", ansiYellow, ansiReset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the gateway isn't running, start it as a background child process.
|
||||||
|
if !portOpen(addr) {
|
||||||
|
gw := exec.Command(bin, "gateway", "run", "--force")
|
||||||
|
gw.Env = openclawEnv()
|
||||||
|
if err := gw.Start(); err != nil {
|
||||||
|
return windowsHint(fmt.Errorf("failed to start gateway: %w", err))
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if gw.Process != nil {
|
||||||
|
_ = gw.Process.Kill()
|
||||||
|
_ = gw.Wait()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "%sStarting gateway...%s\n", ansiGray, ansiReset)
|
||||||
|
if !waitForPort(addr, 30*time.Second) {
|
||||||
|
return windowsHint(fmt.Errorf("gateway did not start on %s", addr))
|
||||||
|
}
|
||||||
|
|
||||||
|
printOpenclawReady(bin, token, port, firstLaunch)
|
||||||
|
|
||||||
|
tuiArgs := []string{"tui"}
|
||||||
|
if firstLaunch {
|
||||||
|
tuiArgs = append(tuiArgs, "--message", "Wake up, my friend!")
|
||||||
|
}
|
||||||
|
tui := exec.Command(bin, tuiArgs...)
|
||||||
|
tui.Env = openclawEnv()
|
||||||
|
tui.Stdin = os.Stdin
|
||||||
|
tui.Stdout = os.Stdout
|
||||||
|
tui.Stderr = os.Stderr
|
||||||
|
if err := tui.Run(); err != nil {
|
||||||
|
return windowsHint(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// gatewayInfo reads the gateway auth token and port from the OpenClaw config.
|
||||||
|
func (c *Openclaw) gatewayInfo() (token string, port int) {
|
||||||
|
port = defaultGatewayPort
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", port
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, path := range []string{
|
||||||
|
filepath.Join(home, ".openclaw", "openclaw.json"),
|
||||||
|
filepath.Join(home, ".clawdbot", "clawdbot.json"),
|
||||||
|
} {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var config map[string]any
|
||||||
|
if json.Unmarshal(data, &config) != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
gw, _ := config["gateway"].(map[string]any)
|
||||||
|
if p, ok := gw["port"].(float64); ok && p > 0 {
|
||||||
|
port = int(p)
|
||||||
|
}
|
||||||
|
auth, _ := gw["auth"].(map[string]any)
|
||||||
|
if t, _ := auth["token"].(string); t != "" {
|
||||||
|
token = t
|
||||||
|
}
|
||||||
|
return token, port
|
||||||
|
}
|
||||||
|
return "", port
|
||||||
|
}
|
||||||
|
|
||||||
|
func printOpenclawReady(bin, token string, port int, firstLaunch bool) {
|
||||||
|
u := fmt.Sprintf("http://localhost:%d", port)
|
||||||
|
if token != "" {
|
||||||
|
u += "/#token=" + url.QueryEscape(token)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "\n%s✓ OpenClaw is running%s\n\n", ansiGreen, ansiReset)
|
||||||
|
fmt.Fprintf(os.Stderr, " Open the Web UI:\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " %s\n\n", hyperlink(u, u))
|
||||||
|
|
||||||
|
if firstLaunch {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s Quick start:%s\n", ansiBold, ansiReset)
|
||||||
|
fmt.Fprintf(os.Stderr, "%s /help see all commands%s\n", ansiGray, ansiReset)
|
||||||
|
fmt.Fprintf(os.Stderr, "%s %s configure --section channels connect WhatsApp, Telegram, etc.%s\n", ansiGray, bin, ansiReset)
|
||||||
|
fmt.Fprintf(os.Stderr, "%s %s skills browse and install skills%s\n\n", ansiGray, bin, ansiReset)
|
||||||
|
fmt.Fprintf(os.Stderr, "%s The OpenClaw gateway is running in the background.%s\n", ansiYellow, ansiReset)
|
||||||
|
fmt.Fprintf(os.Stderr, "%s Stop it with: %s gateway stop%s\n\n", ansiYellow, bin, ansiReset)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "%sTip: connect WhatsApp, Telegram, and more with: %s configure --section channels%s\n", ansiGray, bin, ansiReset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// openclawEnv returns the current environment with provider API keys cleared
|
||||||
|
// so openclaw only uses the Ollama gateway, not keys from the user's shell.
|
||||||
|
func openclawEnv() []string {
|
||||||
|
clear := map[string]bool{
|
||||||
|
"ANTHROPIC_API_KEY": true,
|
||||||
|
"ANTHROPIC_OAUTH_TOKEN": true,
|
||||||
|
"OPENAI_API_KEY": true,
|
||||||
|
"GEMINI_API_KEY": true,
|
||||||
|
"MISTRAL_API_KEY": true,
|
||||||
|
"GROQ_API_KEY": true,
|
||||||
|
"XAI_API_KEY": true,
|
||||||
|
"OPENROUTER_API_KEY": true,
|
||||||
|
}
|
||||||
|
var env []string
|
||||||
|
for _, e := range os.Environ() {
|
||||||
|
key, _, _ := strings.Cut(e, "=")
|
||||||
|
if !clear[key] {
|
||||||
|
env = append(env, e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return env
|
||||||
|
}
|
||||||
|
|
||||||
|
// portOpen checks if a TCP port is currently accepting connections.
|
||||||
|
func portOpen(addr string) bool {
|
||||||
|
conn, err := net.DialTimeout("tcp", addr, 500*time.Millisecond)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
conn.Close()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForPort(addr string, timeout time.Duration) bool {
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
conn, err := net.DialTimeout("tcp", addr, 500*time.Millisecond)
|
||||||
|
if err == nil {
|
||||||
|
conn.Close()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
time.Sleep(250 * time.Millisecond)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func windowsHint(err error) error {
|
||||||
|
if runtime.GOOS != "windows" {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return fmt.Errorf("%w\n\n"+
|
||||||
|
"OpenClaw runs best on WSL2.\n"+
|
||||||
|
"Quick setup: wsl --install\n"+
|
||||||
|
"Guide: https://docs.openclaw.ai/windows", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// onboarded checks if OpenClaw onboarding wizard was completed
|
||||||
|
// by looking for the wizard.lastRunAt marker in the config
|
||||||
|
func (c *Openclaw) onboarded() bool {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||||
|
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||||
|
|
||||||
|
config := make(map[string]any)
|
||||||
|
if data, err := os.ReadFile(configPath); err == nil {
|
||||||
|
_ = json.Unmarshal(data, &config)
|
||||||
|
} else if data, err := os.ReadFile(legacyPath); err == nil {
|
||||||
|
_ = json.Unmarshal(data, &config)
|
||||||
|
} else {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for wizard.lastRunAt marker (set when onboarding completes)
|
||||||
|
wizard, _ := config["wizard"].(map[string]any)
|
||||||
|
if wizard == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
lastRunAt, _ := wizard["lastRunAt"].(string)
|
||||||
|
return lastRunAt != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// patchDeviceScopes upgrades the local CLI device's paired scopes to include
|
||||||
|
// operator.admin. Only patches the local device, not remote ones.
|
||||||
|
// Best-effort: silently returns on any error.
|
||||||
|
func patchDeviceScopes() {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceID := readLocalDeviceID(home)
|
||||||
|
if deviceID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
path := filepath.Join(home, ".openclaw", "devices", "paired.json")
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var devices map[string]map[string]any
|
||||||
|
if err := json.Unmarshal(data, &devices); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dev, ok := devices[deviceID]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
required := []string{
|
||||||
|
"operator.read",
|
||||||
|
"operator.admin",
|
||||||
|
"operator.approvals",
|
||||||
|
"operator.pairing",
|
||||||
|
}
|
||||||
|
|
||||||
|
changed := patchScopes(dev, "scopes", required)
|
||||||
|
if tokens, ok := dev["tokens"].(map[string]any); ok {
|
||||||
|
for _, tok := range tokens {
|
||||||
|
if tokenMap, ok := tok.(map[string]any); ok {
|
||||||
|
if patchScopes(tokenMap, "scopes", required) {
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !changed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := json.MarshalIndent(devices, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = os.WriteFile(path, out, 0o600)
|
||||||
|
}
|
||||||
|
|
||||||
|
// readLocalDeviceID reads the local device ID from openclaw's identity file.
|
||||||
|
func readLocalDeviceID(home string) string {
|
||||||
|
data, err := os.ReadFile(filepath.Join(home, ".openclaw", "identity", "device-auth.json"))
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var auth map[string]any
|
||||||
|
if err := json.Unmarshal(data, &auth); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
id, _ := auth["deviceId"].(string)
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
// patchScopes ensures obj[key] contains all required scopes. Returns true if
|
||||||
|
// any scopes were added.
|
||||||
|
func patchScopes(obj map[string]any, key string, required []string) bool {
|
||||||
|
existing, _ := obj[key].([]any)
|
||||||
|
have := make(map[string]bool, len(existing))
|
||||||
|
for _, s := range existing {
|
||||||
|
if str, ok := s.(string); ok {
|
||||||
|
have[str] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
added := false
|
||||||
|
for _, s := range required {
|
||||||
|
if !have[s] {
|
||||||
|
existing = append(existing, s)
|
||||||
|
added = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if added {
|
||||||
|
obj[key] = existing
|
||||||
|
}
|
||||||
|
return added
|
||||||
|
}
|
||||||
|
|
||||||
|
// canInstallDaemon reports whether the openclaw daemon can be installed as a
|
||||||
|
// background service. Returns false on Linux when systemd is absent (e.g.
|
||||||
|
// containers) so that --install-daemon is omitted and the gateway is started
|
||||||
|
// as a foreground child process instead. Returns true in all other cases.
|
||||||
|
func canInstallDaemon() bool {
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// /run/systemd/system exists as a directory when systemd is the init system.
|
||||||
|
// This is absent in most containers.
|
||||||
|
fi, err := os.Stat("/run/systemd/system")
|
||||||
|
if err != nil || !fi.IsDir() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Even when systemd is the init system, user services require a user
|
||||||
|
// manager instance. XDG_RUNTIME_DIR being set is a prerequisite.
|
||||||
|
return os.Getenv("XDG_RUNTIME_DIR") != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureOpenclawInstalled() (string, error) {
|
||||||
|
if _, err := exec.LookPath("openclaw"); err == nil {
|
||||||
|
return "openclaw", nil
|
||||||
|
}
|
||||||
|
if _, err := exec.LookPath("clawdbot"); err == nil {
|
||||||
|
return "clawdbot", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := exec.LookPath("npm"); err != nil {
|
||||||
|
return "", fmt.Errorf("openclaw is not installed and npm was not found\n\n" +
|
||||||
|
"Install Node.js first:\n" +
|
||||||
|
" https://nodejs.org/\n\n" +
|
||||||
|
"Then rerun:\n" +
|
||||||
|
" ollama launch\n" +
|
||||||
|
"and select OpenClaw")
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := ConfirmPrompt("OpenClaw is not installed. Install with npm?")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("openclaw installation cancelled")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "\nInstalling OpenClaw...\n")
|
||||||
|
cmd := exec.Command("npm", "install", "-g", "openclaw@latest")
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to install openclaw: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := exec.LookPath("openclaw"); err != nil {
|
||||||
|
return "", fmt.Errorf("openclaw was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "%sOpenClaw installed successfully%s\n\n", ansiGreen, ansiReset)
|
||||||
|
openclawFreshInstall = true
|
||||||
|
return "openclaw", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Openclaw) Paths() []string {
|
||||||
|
home, _ := os.UserHomeDir()
|
||||||
|
p := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||||
|
if _, err := os.Stat(p); err == nil {
|
||||||
|
return []string{p}
|
||||||
|
}
|
||||||
|
legacy := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||||
|
if _, err := os.Stat(legacy); err == nil {
|
||||||
|
return []string{legacy}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Openclaw) Edit(models []string) error {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||||
|
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||||
|
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read into map[string]any to preserve unknown fields
|
||||||
|
config := make(map[string]any)
|
||||||
|
if data, err := os.ReadFile(configPath); err == nil {
|
||||||
|
_ = json.Unmarshal(data, &config)
|
||||||
|
} else if data, err := os.ReadFile(legacyPath); err == nil {
|
||||||
|
_ = json.Unmarshal(data, &config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Navigate/create: models.providers.ollama (preserving other providers)
|
||||||
|
modelsSection, _ := config["models"].(map[string]any)
|
||||||
|
if modelsSection == nil {
|
||||||
|
modelsSection = make(map[string]any)
|
||||||
|
}
|
||||||
|
providers, _ := modelsSection["providers"].(map[string]any)
|
||||||
|
if providers == nil {
|
||||||
|
providers = make(map[string]any)
|
||||||
|
}
|
||||||
|
ollama, _ := providers["ollama"].(map[string]any)
|
||||||
|
if ollama == nil {
|
||||||
|
ollama = make(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
|
ollama["baseUrl"] = envconfig.Host().String()
|
||||||
|
// needed to register provider
|
||||||
|
ollama["apiKey"] = "ollama-local"
|
||||||
|
ollama["api"] = "ollama"
|
||||||
|
|
||||||
|
// Build map of existing models to preserve user customizations
|
||||||
|
existingModels, _ := ollama["models"].([]any)
|
||||||
|
existingByID := make(map[string]map[string]any)
|
||||||
|
for _, m := range existingModels {
|
||||||
|
if entry, ok := m.(map[string]any); ok {
|
||||||
|
if id, ok := entry["id"].(string); ok {
|
||||||
|
existingByID[id] = entry
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
client, _ := api.ClientFromEnvironment()
|
||||||
|
|
||||||
|
var newModels []any
|
||||||
|
for _, m := range models {
|
||||||
|
entry, _ := openclawModelConfig(context.Background(), client, m)
|
||||||
|
// Merge existing fields (user customizations)
|
||||||
|
if existing, ok := existingByID[m]; ok {
|
||||||
|
for k, v := range existing {
|
||||||
|
if _, isNew := entry[k]; !isNew {
|
||||||
|
entry[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
newModels = append(newModels, entry)
|
||||||
|
}
|
||||||
|
ollama["models"] = newModels
|
||||||
|
|
||||||
|
providers["ollama"] = ollama
|
||||||
|
modelsSection["providers"] = providers
|
||||||
|
config["models"] = modelsSection
|
||||||
|
|
||||||
|
// Update agents.defaults.model.primary (preserving other agent settings)
|
||||||
|
agents, _ := config["agents"].(map[string]any)
|
||||||
|
if agents == nil {
|
||||||
|
agents = make(map[string]any)
|
||||||
|
}
|
||||||
|
defaults, _ := agents["defaults"].(map[string]any)
|
||||||
|
if defaults == nil {
|
||||||
|
defaults = make(map[string]any)
|
||||||
|
}
|
||||||
|
modelConfig, _ := defaults["model"].(map[string]any)
|
||||||
|
if modelConfig == nil {
|
||||||
|
modelConfig = make(map[string]any)
|
||||||
|
}
|
||||||
|
modelConfig["primary"] = "ollama/" + models[0]
|
||||||
|
defaults["model"] = modelConfig
|
||||||
|
agents["defaults"] = defaults
|
||||||
|
config["agents"] = agents
|
||||||
|
|
||||||
|
data, err := json.MarshalIndent(config, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := fileutil.WriteWithBackup(configPath, data); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear any per-session model overrides so the new primary takes effect
|
||||||
|
// immediately rather than being shadowed by a cached modelOverride.
|
||||||
|
clearSessionModelOverride(models[0])
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// clearSessionModelOverride removes per-session model overrides from the main
|
||||||
|
// agent session so the global primary model takes effect on the next TUI launch.
|
||||||
|
func clearSessionModelOverride(primary string) {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
path := filepath.Join(home, ".openclaw", "agents", "main", "sessions", "sessions.json")
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var sessions map[string]map[string]any
|
||||||
|
if json.Unmarshal(data, &sessions) != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
changed := false
|
||||||
|
for _, sess := range sessions {
|
||||||
|
if override, _ := sess["modelOverride"].(string); override != "" && override != primary {
|
||||||
|
delete(sess, "modelOverride")
|
||||||
|
delete(sess, "providerOverride")
|
||||||
|
sess["model"] = primary
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !changed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out, err := json.MarshalIndent(sessions, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = os.WriteFile(path, out, 0o600)
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
webSearchNpmPackage = "@ollama/openclaw-web-search"
|
||||||
|
webSearchMinVersion = "0.2.1"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ensureWebSearchPlugin installs the openclaw-web-search extension into the
|
||||||
|
// user-level extensions directory (~/.openclaw/extensions/) if it isn't already
|
||||||
|
// present, or re-installs if the installed version is older than webSearchMinVersion.
|
||||||
|
// Returns true if the extension is available.
|
||||||
|
func ensureWebSearchPlugin() bool {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
pluginDir := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
|
||||||
|
if webSearchPluginUpToDate(pluginDir) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
npmBin, err := exec.LookPath("npm")
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.MkdirAll(pluginDir, 0o755); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Download the tarball via `npm pack`, extract it flat into the plugin dir.
|
||||||
|
pack := exec.Command(npmBin, "pack", webSearchNpmPackage, "--pack-destination", pluginDir)
|
||||||
|
out, err := pack.Output()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s Warning: could not download web search plugin: %v%s\n", ansiYellow, err, ansiReset)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
tgzName := strings.TrimSpace(string(out))
|
||||||
|
tgzPath := filepath.Join(pluginDir, tgzName)
|
||||||
|
defer os.Remove(tgzPath)
|
||||||
|
|
||||||
|
tar := exec.Command("tar", "xzf", tgzPath, "--strip-components=1", "-C", pluginDir)
|
||||||
|
if err := tar.Run(); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s Warning: could not extract web search plugin: %v%s\n", ansiYellow, err, ansiReset)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "%s ✓ Installed web search plugin%s\n", ansiGreen, ansiReset)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// webSearchPluginUpToDate returns true if the plugin is installed and its
|
||||||
|
// package.json version is >= webSearchMinVersion.
|
||||||
|
func webSearchPluginUpToDate(pluginDir string) bool {
|
||||||
|
data, err := os.ReadFile(filepath.Join(pluginDir, "package.json"))
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
var pkg struct {
|
||||||
|
Version string `json:"version"`
|
||||||
|
}
|
||||||
|
if json.Unmarshal(data, &pkg) != nil || pkg.Version == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return !versionLessThan(pkg.Version, webSearchMinVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
// versionLessThan compares two semver version strings (major.minor.patch).
|
||||||
|
// Inputs may omit the "v" prefix; it is added automatically for semver.Compare.
|
||||||
|
func versionLessThan(a, b string) bool {
|
||||||
|
if !strings.HasPrefix(a, "v") {
|
||||||
|
a = "v" + a
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(b, "v") {
|
||||||
|
b = "v" + b
|
||||||
|
}
|
||||||
|
return semver.Compare(a, b) < 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerWebSearchPlugin adds plugins.entries.openclaw-web-search to the OpenClaw
|
||||||
|
// config so the gateway activates it on next start. Best-effort; silently returns
|
||||||
|
// on any error.
|
||||||
|
func registerWebSearchPlugin() {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var config map[string]any
|
||||||
|
if json.Unmarshal(data, &config) != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
plugins, _ := config["plugins"].(map[string]any)
|
||||||
|
if plugins == nil {
|
||||||
|
plugins = make(map[string]any)
|
||||||
|
}
|
||||||
|
entries, _ := plugins["entries"].(map[string]any)
|
||||||
|
if entries == nil {
|
||||||
|
entries = make(map[string]any)
|
||||||
|
}
|
||||||
|
entries["openclaw-web-search"] = map[string]any{"enabled": true}
|
||||||
|
plugins["entries"] = entries
|
||||||
|
|
||||||
|
// Pin trust so the gateway doesn't warn about untracked plugins.
|
||||||
|
allow, _ := plugins["allow"].([]any)
|
||||||
|
hasAllow := false
|
||||||
|
for _, v := range allow {
|
||||||
|
if s, ok := v.(string); ok && s == "openclaw-web-search" {
|
||||||
|
hasAllow = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasAllow {
|
||||||
|
allow = append(allow, "openclaw-web-search")
|
||||||
|
}
|
||||||
|
plugins["allow"] = allow
|
||||||
|
|
||||||
|
// Record install provenance so the loader can verify the plugin origin.
|
||||||
|
installs, _ := plugins["installs"].(map[string]any)
|
||||||
|
if installs == nil {
|
||||||
|
installs = make(map[string]any)
|
||||||
|
}
|
||||||
|
pluginDir := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
|
||||||
|
installs["openclaw-web-search"] = map[string]any{
|
||||||
|
"source": "npm",
|
||||||
|
"spec": webSearchNpmPackage,
|
||||||
|
"installPath": pluginDir,
|
||||||
|
}
|
||||||
|
plugins["installs"] = installs
|
||||||
|
|
||||||
|
config["plugins"] = plugins
|
||||||
|
|
||||||
|
// Add plugin tools to tools.alsoAllow so they survive the coding profile's
|
||||||
|
// policy pipeline (which has an explicit allow list of core tools only).
|
||||||
|
tools, _ := config["tools"].(map[string]any)
|
||||||
|
if tools == nil {
|
||||||
|
tools = make(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
|
alsoAllow, _ := tools["alsoAllow"].([]any)
|
||||||
|
needed := []string{"ollama_web_search", "ollama_web_fetch"}
|
||||||
|
have := make(map[string]bool, len(alsoAllow))
|
||||||
|
for _, v := range alsoAllow {
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
have[s] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, name := range needed {
|
||||||
|
if !have[name] {
|
||||||
|
alsoAllow = append(alsoAllow, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tools["alsoAllow"] = alsoAllow
|
||||||
|
|
||||||
|
// Disable built-in web search/fetch since our plugin replaces them.
|
||||||
|
web, _ := tools["web"].(map[string]any)
|
||||||
|
if web == nil {
|
||||||
|
web = make(map[string]any)
|
||||||
|
}
|
||||||
|
web["search"] = map[string]any{"enabled": false}
|
||||||
|
web["fetch"] = map[string]any{"enabled": false}
|
||||||
|
tools["web"] = web
|
||||||
|
config["tools"] = tools
|
||||||
|
|
||||||
|
out, err := json.MarshalIndent(config, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = os.WriteFile(configPath, out, 0o600)
|
||||||
|
}
|
||||||
|
|
||||||
|
// openclawModelConfig builds an OpenClaw model config entry with capability detection.
|
||||||
|
// The second return value indicates whether the model is a cloud (remote) model.
|
||||||
|
func openclawModelConfig(ctx context.Context, client *api.Client, modelID string) (map[string]any, bool) {
|
||||||
|
entry := map[string]any{
|
||||||
|
"id": modelID,
|
||||||
|
"name": modelID,
|
||||||
|
"input": []any{"text"},
|
||||||
|
"cost": map[string]any{
|
||||||
|
"input": 0,
|
||||||
|
"output": 0,
|
||||||
|
"cacheRead": 0,
|
||||||
|
"cacheWrite": 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if client == nil {
|
||||||
|
return entry, false
|
||||||
|
}
|
||||||
|
|
||||||
|
showCtx := ctx
|
||||||
|
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
showCtx, cancel = context.WithTimeout(ctx, openclawModelShowTimeout)
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Show(showCtx, &api.ShowRequest{Model: modelID})
|
||||||
|
if err != nil {
|
||||||
|
return entry, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set input types based on vision capability
|
||||||
|
if slices.Contains(resp.Capabilities, model.CapabilityVision) {
|
||||||
|
entry["input"] = []any{"text", "image"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set reasoning based on thinking capability
|
||||||
|
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||||
|
entry["reasoning"] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cloud models: use hardcoded limits for context/output tokens.
|
||||||
|
// Capability detection above still applies (vision, thinking).
|
||||||
|
if resp.RemoteModel != "" {
|
||||||
|
if l, ok := lookupCloudModelLimit(modelID); ok {
|
||||||
|
entry["contextWindow"] = l.Context
|
||||||
|
entry["maxTokens"] = l.Output
|
||||||
|
}
|
||||||
|
return entry, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract context window from ModelInfo (local models only)
|
||||||
|
for key, val := range resp.ModelInfo {
|
||||||
|
if strings.HasSuffix(key, ".context_length") {
|
||||||
|
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
|
||||||
|
entry["contextWindow"] = int(ctxLen)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return entry, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Openclaw) Models() []string {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := fileutil.ReadJSON(filepath.Join(home, ".openclaw", "openclaw.json"))
|
||||||
|
if err != nil {
|
||||||
|
config, err = fileutil.ReadJSON(filepath.Join(home, ".clawdbot", "clawdbot.json"))
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
modelsSection, _ := config["models"].(map[string]any)
|
||||||
|
providers, _ := modelsSection["providers"].(map[string]any)
|
||||||
|
ollama, _ := providers["ollama"].(map[string]any)
|
||||||
|
modelList, _ := ollama["models"].([]any)
|
||||||
|
|
||||||
|
var result []string
|
||||||
|
for _, m := range modelList {
|
||||||
|
if entry, ok := m.(map[string]any); ok {
|
||||||
|
if id, ok := entry["id"].(string); ok {
|
||||||
|
result = append(result, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
1770
cmd/launch/openclaw_test.go
Normal file
1770
cmd/launch/openclaw_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,7 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"maps"
|
"maps"
|
||||||
"os"
|
"os"
|
||||||
@@ -12,34 +10,13 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OpenCode implements Runner and Editor for OpenCode integration
|
// OpenCode implements Runner and Editor for OpenCode integration
|
||||||
type OpenCode struct{}
|
type OpenCode struct{}
|
||||||
|
|
||||||
// cloudModelLimit holds context and output token limits for a cloud model.
|
|
||||||
type cloudModelLimit struct {
|
|
||||||
Context int
|
|
||||||
Output int
|
|
||||||
}
|
|
||||||
|
|
||||||
// lookupCloudModelLimit returns the token limits for a cloud model.
|
|
||||||
// It tries the exact name first, then strips the ":cloud" suffix.
|
|
||||||
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
|
||||||
if l, ok := cloudModelLimits[name]; ok {
|
|
||||||
return l, true
|
|
||||||
}
|
|
||||||
base := strings.TrimSuffix(name, ":cloud")
|
|
||||||
if base != name {
|
|
||||||
if l, ok := cloudModelLimits[base]; ok {
|
|
||||||
return l, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return cloudModelLimit{}, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *OpenCode) String() string { return "OpenCode" }
|
func (o *OpenCode) String() string { return "OpenCode" }
|
||||||
|
|
||||||
func (o *OpenCode) Run(model string, args []string) error {
|
func (o *OpenCode) Run(model string, args []string) error {
|
||||||
@@ -47,25 +24,6 @@ func (o *OpenCode) Run(model string, args []string) error {
|
|||||||
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
|
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call Edit() to ensure config is up-to-date before launch
|
|
||||||
models := []string{model}
|
|
||||||
if config, err := loadIntegration("opencode"); err == nil && len(config.Models) > 0 {
|
|
||||||
models = config.Models
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
models, err = resolveEditorModels("opencode", models, func() ([]string, error) {
|
|
||||||
return selectModels(context.Background(), "opencode", "")
|
|
||||||
})
|
|
||||||
if errors.Is(err, errCancelled) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := o.Edit(models); err != nil {
|
|
||||||
return fmt.Errorf("setup failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command("opencode", args...)
|
cmd := exec.Command("opencode", args...)
|
||||||
cmd.Stdin = os.Stdin
|
cmd.Stdin = os.Stdin
|
||||||
cmd.Stdout = os.Stdout
|
cmd.Stdout = os.Stdout
|
||||||
@@ -122,13 +80,18 @@ func (o *OpenCode) Edit(modelList []string) error {
|
|||||||
if !ok {
|
if !ok {
|
||||||
ollama = map[string]any{
|
ollama = map[string]any{
|
||||||
"npm": "@ai-sdk/openai-compatible",
|
"npm": "@ai-sdk/openai-compatible",
|
||||||
"name": "Ollama (local)",
|
"name": "Ollama",
|
||||||
"options": map[string]any{
|
"options": map[string]any{
|
||||||
"baseURL": envconfig.Host().String() + "/v1",
|
"baseURL": envconfig.Host().String() + "/v1",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Migrate legacy provider name
|
||||||
|
if name, _ := ollama["name"].(string); name == "Ollama (local)" {
|
||||||
|
ollama["name"] = "Ollama"
|
||||||
|
}
|
||||||
|
|
||||||
models, ok := ollama["models"].(map[string]any)
|
models, ok := ollama["models"].(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
models = make(map[string]any)
|
models = make(map[string]any)
|
||||||
@@ -147,8 +110,6 @@ func (o *OpenCode) Edit(modelList []string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
client, _ := api.ClientFromEnvironment()
|
|
||||||
|
|
||||||
for _, model := range modelList {
|
for _, model := range modelList {
|
||||||
if existing, ok := models[model].(map[string]any); ok {
|
if existing, ok := models[model].(map[string]any); ok {
|
||||||
// migrate existing models without _launch marker
|
// migrate existing models without _launch marker
|
||||||
@@ -158,7 +119,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
|||||||
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if isCloudModel(context.Background(), client, model) {
|
if isCloudModelName(model) {
|
||||||
if l, ok := lookupCloudModelLimit(model); ok {
|
if l, ok := lookupCloudModelLimit(model); ok {
|
||||||
existing["limit"] = map[string]any{
|
existing["limit"] = map[string]any{
|
||||||
"context": l.Context,
|
"context": l.Context,
|
||||||
@@ -172,7 +133,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
|||||||
"name": model,
|
"name": model,
|
||||||
"_launch": true,
|
"_launch": true,
|
||||||
}
|
}
|
||||||
if isCloudModel(context.Background(), client, model) {
|
if isCloudModelName(model) {
|
||||||
if l, ok := lookupCloudModelLimit(model); ok {
|
if l, ok := lookupCloudModelLimit(model); ok {
|
||||||
entry["limit"] = map[string]any{
|
entry["limit"] = map[string]any{
|
||||||
"context": l.Context,
|
"context": l.Context,
|
||||||
@@ -191,7 +152,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := writeWithBackup(configPath, configData); err != nil {
|
if err := fileutil.WriteWithBackup(configPath, configData); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -243,7 +204,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return writeWithBackup(statePath, stateData)
|
return fileutil.WriteWithBackup(statePath, stateData)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *OpenCode) Models() []string {
|
func (o *OpenCode) Models() []string {
|
||||||
@@ -251,7 +212,7 @@ func (o *OpenCode) Models() []string {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
config, err := readJSONFile(filepath.Join(home, ".config", "opencode", "opencode.json"))
|
config, err := fileutil.ReadJSON(filepath.Join(home, ".config", "opencode", "opencode.json"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -1,8 +1,10 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -232,6 +234,44 @@ func TestOpenCodeEdit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("migrate Ollama (local) provider name", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"name":"Ollama (local)","npm":"@ai-sdk/openai-compatible","options":{"baseURL":"http://localhost:11434/v1"}}}}`), 0o644)
|
||||||
|
|
||||||
|
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := os.ReadFile(configPath)
|
||||||
|
var cfg map[string]any
|
||||||
|
json.Unmarshal(data, &cfg)
|
||||||
|
provider := cfg["provider"].(map[string]any)
|
||||||
|
ollama := provider["ollama"].(map[string]any)
|
||||||
|
if ollama["name"] != "Ollama" {
|
||||||
|
t.Errorf("provider name not migrated: got %q, want %q", ollama["name"], "Ollama")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserve custom provider name", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"name":"My Custom Ollama","npm":"@ai-sdk/openai-compatible","options":{"baseURL":"http://localhost:11434/v1"}}}}`), 0o644)
|
||||||
|
|
||||||
|
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := os.ReadFile(configPath)
|
||||||
|
var cfg map[string]any
|
||||||
|
json.Unmarshal(data, &cfg)
|
||||||
|
provider := cfg["provider"].(map[string]any)
|
||||||
|
ollama := provider["ollama"].(map[string]any)
|
||||||
|
if ollama["name"] != "My Custom Ollama" {
|
||||||
|
t.Errorf("custom provider name was changed: got %q, want %q", ollama["name"], "My Custom Ollama")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
|
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
|
||||||
cleanup()
|
cleanup()
|
||||||
os.MkdirAll(configDir, 0o755)
|
os.MkdirAll(configDir, 0o755)
|
||||||
@@ -619,6 +659,54 @@ func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenCodeEdit_BackfillsCloudModelLimitOnExistingEntry(t *testing.T) {
|
||||||
|
o := &OpenCode{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/api/show" {
|
||||||
|
fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"glm-5"}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||||
|
configPath := filepath.Join(configDir, "opencode.json")
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
os.WriteFile(configPath, []byte(`{
|
||||||
|
"provider": {
|
||||||
|
"ollama": {
|
||||||
|
"models": {
|
||||||
|
"glm-5:cloud": {
|
||||||
|
"name": "glm-5:cloud",
|
||||||
|
"_launch": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`), 0o644)
|
||||||
|
|
||||||
|
if err := o.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := readOpenCodeModel(t, configPath, "glm-5:cloud")
|
||||||
|
limit, ok := entry["limit"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("cloud model limit was not added on re-edit")
|
||||||
|
}
|
||||||
|
if limit["context"] != float64(202_752) {
|
||||||
|
t.Errorf("context = %v, want 202752", limit["context"])
|
||||||
|
}
|
||||||
|
if limit["output"] != float64(131_072) {
|
||||||
|
t.Errorf("output = %v, want 131072", limit["output"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestLookupCloudModelLimit(t *testing.T) {
|
func TestLookupCloudModelLimit(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -626,13 +714,19 @@ func TestLookupCloudModelLimit(t *testing.T) {
|
|||||||
wantContext int
|
wantContext int
|
||||||
wantOutput int
|
wantOutput int
|
||||||
}{
|
}{
|
||||||
{"glm-4.7", true, 202_752, 131_072},
|
{"glm-4.7", false, 0, 0},
|
||||||
{"glm-4.7:cloud", true, 202_752, 131_072},
|
{"glm-4.7:cloud", true, 202_752, 131_072},
|
||||||
{"kimi-k2.5", true, 262_144, 262_144},
|
{"glm-5:cloud", true, 202_752, 131_072},
|
||||||
|
{"gpt-oss:120b-cloud", true, 131_072, 131_072},
|
||||||
|
{"gpt-oss:20b-cloud", true, 131_072, 131_072},
|
||||||
|
{"kimi-k2.5", false, 0, 0},
|
||||||
{"kimi-k2.5:cloud", true, 262_144, 262_144},
|
{"kimi-k2.5:cloud", true, 262_144, 262_144},
|
||||||
{"deepseek-v3.2", true, 163_840, 65_536},
|
{"deepseek-v3.2", false, 0, 0},
|
||||||
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
|
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
|
||||||
{"qwen3-coder:480b", true, 262_144, 65_536},
|
{"qwen3.5", false, 0, 0},
|
||||||
|
{"qwen3.5:cloud", true, 262_144, 32_768},
|
||||||
|
{"qwen3-coder:480b", false, 0, 0},
|
||||||
|
{"qwen3-coder:480b:cloud", true, 262_144, 65_536},
|
||||||
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
|
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
|
||||||
{"llama3.2", false, 0, 0},
|
{"llama3.2", false, 0, 0},
|
||||||
{"unknown-model:cloud", false, 0, 0},
|
{"unknown-model:cloud", false, 0, 0},
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
@@ -26,15 +27,6 @@ func (p *Pi) Run(model string, args []string) error {
|
|||||||
return fmt.Errorf("pi is not installed, install with: npm install -g @mariozechner/pi-coding-agent")
|
return fmt.Errorf("pi is not installed, install with: npm install -g @mariozechner/pi-coding-agent")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call Edit() to ensure config is up-to-date before launch
|
|
||||||
models := []string{model}
|
|
||||||
if config, err := loadIntegration("pi"); err == nil && len(config.Models) > 0 {
|
|
||||||
models = config.Models
|
|
||||||
}
|
|
||||||
if err := p.Edit(models); err != nil {
|
|
||||||
return fmt.Errorf("setup failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command("pi", args...)
|
cmd := exec.Command("pi", args...)
|
||||||
cmd.Stdin = os.Stdin
|
cmd.Stdin = os.Stdin
|
||||||
cmd.Stdout = os.Stdout
|
cmd.Stdout = os.Stdout
|
||||||
@@ -107,7 +99,8 @@ func (p *Pi) Edit(models []string) error {
|
|||||||
|
|
||||||
// Build new models list:
|
// Build new models list:
|
||||||
// 1. Keep user-managed models (no _launch marker) - untouched
|
// 1. Keep user-managed models (no _launch marker) - untouched
|
||||||
// 2. Keep ollama-managed models (_launch marker) that are still selected
|
// 2. Keep ollama-managed models (_launch marker) that are still selected,
|
||||||
|
// except stale cloud entries that should be rebuilt below
|
||||||
// 3. Add new ollama-managed models
|
// 3. Add new ollama-managed models
|
||||||
var newModels []any
|
var newModels []any
|
||||||
for _, m := range existingModels {
|
for _, m := range existingModels {
|
||||||
@@ -117,7 +110,13 @@ func (p *Pi) Edit(models []string) error {
|
|||||||
if !isPiOllamaModel(modelObj) {
|
if !isPiOllamaModel(modelObj) {
|
||||||
newModels = append(newModels, m)
|
newModels = append(newModels, m)
|
||||||
} else if selectedSet[id] {
|
} else if selectedSet[id] {
|
||||||
// Ollama-managed and still selected - keep it
|
// Rebuild stale managed cloud entries so createConfig refreshes
|
||||||
|
// the whole entry instead of patching it in place.
|
||||||
|
if !hasContextWindow(modelObj) {
|
||||||
|
if _, ok := lookupCloudModelLimit(id); ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
newModels = append(newModels, m)
|
newModels = append(newModels, m)
|
||||||
selectedSet[id] = false
|
selectedSet[id] = false
|
||||||
}
|
}
|
||||||
@@ -142,7 +141,7 @@ func (p *Pi) Edit(models []string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := writeWithBackup(configPath, configData); err != nil {
|
if err := fileutil.WriteWithBackup(configPath, configData); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -160,7 +159,7 @@ func (p *Pi) Edit(models []string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return writeWithBackup(settingsPath, settingsData)
|
return fileutil.WriteWithBackup(settingsPath, settingsData)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Pi) Models() []string {
|
func (p *Pi) Models() []string {
|
||||||
@@ -170,7 +169,7 @@ func (p *Pi) Models() []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
configPath := filepath.Join(home, ".pi", "agent", "models.json")
|
configPath := filepath.Join(home, ".pi", "agent", "models.json")
|
||||||
config, err := readJSONFile(configPath)
|
config, err := fileutil.ReadJSON(configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -199,15 +198,38 @@ func isPiOllamaModel(cfg map[string]any) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hasContextWindow(cfg map[string]any) bool {
|
||||||
|
switch v := cfg["contextWindow"].(type) {
|
||||||
|
case float64:
|
||||||
|
return v > 0
|
||||||
|
case int:
|
||||||
|
return v > 0
|
||||||
|
case int64:
|
||||||
|
return v > 0
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// createConfig builds Pi model config with capability detection
|
// createConfig builds Pi model config with capability detection
|
||||||
func createConfig(ctx context.Context, client *api.Client, modelID string) map[string]any {
|
func createConfig(ctx context.Context, client *api.Client, modelID string) map[string]any {
|
||||||
cfg := map[string]any{
|
cfg := map[string]any{
|
||||||
"id": modelID,
|
"id": modelID,
|
||||||
"_launch": true,
|
"_launch": true,
|
||||||
}
|
}
|
||||||
|
if l, ok := lookupCloudModelLimit(modelID); ok {
|
||||||
|
cfg["contextWindow"] = l.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
applyCloudContextFallback := func() {
|
||||||
|
if l, ok := lookupCloudModelLimit(modelID); ok {
|
||||||
|
cfg["contextWindow"] = l.Context
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID})
|
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
applyCloudContextFallback()
|
||||||
return cfg
|
return cfg
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -223,15 +245,21 @@ func createConfig(ctx context.Context, client *api.Client, modelID string) map[s
|
|||||||
cfg["reasoning"] = true
|
cfg["reasoning"] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract context window from ModelInfo
|
// Extract context window from ModelInfo. For known cloud models, the
|
||||||
|
// pre-filled shared limit remains unless the server provides a positive value.
|
||||||
|
hasContextWindow := false
|
||||||
for key, val := range resp.ModelInfo {
|
for key, val := range resp.ModelInfo {
|
||||||
if strings.HasSuffix(key, ".context_length") {
|
if strings.HasSuffix(key, ".context_length") {
|
||||||
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
|
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
|
||||||
cfg["contextWindow"] = int(ctxLen)
|
cfg["contextWindow"] = int(ctxLen)
|
||||||
|
hasContextWindow = true
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if !hasContextWindow {
|
||||||
|
applyCloudContextFallback()
|
||||||
|
}
|
||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -192,6 +192,48 @@ func TestPiEdit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("rebuilds stale existing managed cloud model", func(t *testing.T) {
|
||||||
|
cleanup()
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
|
||||||
|
existingConfig := `{
|
||||||
|
"providers": {
|
||||||
|
"ollama": {
|
||||||
|
"baseUrl": "http://localhost:11434/v1",
|
||||||
|
"api": "openai-completions",
|
||||||
|
"apiKey": "ollama",
|
||||||
|
"models": [
|
||||||
|
{"id": "glm-5:cloud", "_launch": true, "legacyField": "stale"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := pi.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||||
|
t.Fatalf("Edit() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := readConfig()
|
||||||
|
providers := cfg["providers"].(map[string]any)
|
||||||
|
ollama := providers["ollama"].(map[string]any)
|
||||||
|
modelsArray := ollama["models"].([]any)
|
||||||
|
modelEntry := modelsArray[0].(map[string]any)
|
||||||
|
|
||||||
|
if modelEntry["contextWindow"] != float64(202_752) {
|
||||||
|
t.Errorf("contextWindow = %v, want 202752", modelEntry["contextWindow"])
|
||||||
|
}
|
||||||
|
input, ok := modelEntry["input"].([]any)
|
||||||
|
if !ok || len(input) != 1 || input[0] != "text" {
|
||||||
|
t.Errorf("input = %v, want [text]", modelEntry["input"])
|
||||||
|
}
|
||||||
|
if _, ok := modelEntry["legacyField"]; ok {
|
||||||
|
t.Error("legacyField should be removed when stale managed cloud entry is rebuilt")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("replaces old models with new ones", func(t *testing.T) {
|
t.Run("replaces old models with new ones", func(t *testing.T) {
|
||||||
cleanup()
|
cleanup()
|
||||||
os.MkdirAll(configDir, 0o755)
|
os.MkdirAll(configDir, 0o755)
|
||||||
@@ -798,6 +840,59 @@ func TestCreateConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("cloud model falls back to hardcoded context when show fails", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, _ := url.Parse(srv.URL)
|
||||||
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
|
cfg := createConfig(context.Background(), client, "kimi-k2.5:cloud")
|
||||||
|
|
||||||
|
if cfg["contextWindow"] != 262_144 {
|
||||||
|
t.Errorf("contextWindow = %v, want 262144", cfg["contextWindow"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cloud model falls back to hardcoded context when show omits model info", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/api/show" {
|
||||||
|
fmt.Fprintf(w, `{"capabilities":[],"model_info":{}}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, _ := url.Parse(srv.URL)
|
||||||
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
|
cfg := createConfig(context.Background(), client, "glm-5:cloud")
|
||||||
|
|
||||||
|
if cfg["contextWindow"] != 202_752 {
|
||||||
|
t.Errorf("contextWindow = %v, want 202752", cfg["contextWindow"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cloud model with dash suffix falls back to hardcoded context", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, _ := url.Parse(srv.URL)
|
||||||
|
client := api.NewClient(u, srv.Client())
|
||||||
|
|
||||||
|
cfg := createConfig(context.Background(), client, "gpt-oss:120b-cloud")
|
||||||
|
|
||||||
|
if cfg["contextWindow"] != 131_072 {
|
||||||
|
t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
|
||||||
|
}
|
||||||
|
})
|
||||||
t.Run("skips zero context length", func(t *testing.T) {
|
t.Run("skips zero context length", func(t *testing.T) {
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.URL.Path == "/api/show" {
|
if r.URL.Path == "/api/show" {
|
||||||
355
cmd/launch/registry.go
Normal file
355
cmd/launch/registry.go
Normal file
@@ -0,0 +1,355 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IntegrationInstallSpec describes how launcher should detect and guide installation.
|
||||||
|
type IntegrationInstallSpec struct {
|
||||||
|
CheckInstalled func() bool
|
||||||
|
EnsureInstalled func() error
|
||||||
|
URL string
|
||||||
|
Command []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntegrationSpec is the canonical registry entry for one integration.
|
||||||
|
type IntegrationSpec struct {
|
||||||
|
Name string
|
||||||
|
Runner Runner
|
||||||
|
Aliases []string
|
||||||
|
Hidden bool
|
||||||
|
Description string
|
||||||
|
Install IntegrationInstallSpec
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntegrationInfo contains display information about a registered integration.
|
||||||
|
type IntegrationInfo struct {
|
||||||
|
Name string
|
||||||
|
DisplayName string
|
||||||
|
Description string
|
||||||
|
}
|
||||||
|
|
||||||
|
var launcherIntegrationOrder = []string{"opencode", "droid", "pi", "cline"}
|
||||||
|
|
||||||
|
var integrationSpecs = []*IntegrationSpec{
|
||||||
|
{
|
||||||
|
Name: "claude",
|
||||||
|
Runner: &Claude{},
|
||||||
|
Description: "Anthropic's coding tool with subagents",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := (&Claude{}).findPath()
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
URL: "https://code.claude.com/docs/en/quickstart",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "cline",
|
||||||
|
Runner: &Cline{},
|
||||||
|
Description: "Autonomous coding agent with parallel execution",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := exec.LookPath("cline")
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
Command: []string{"npm", "install", "-g", "cline"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "codex",
|
||||||
|
Runner: &Codex{},
|
||||||
|
Description: "OpenAI's open-source coding agent",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := exec.LookPath("codex")
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
URL: "https://developers.openai.com/codex/cli/",
|
||||||
|
Command: []string{"npm", "install", "-g", "@openai/codex"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "droid",
|
||||||
|
Runner: &Droid{},
|
||||||
|
Description: "Factory's coding agent across terminal and IDEs",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := exec.LookPath("droid")
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
URL: "https://docs.factory.ai/cli/getting-started/quickstart",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "opencode",
|
||||||
|
Runner: &OpenCode{},
|
||||||
|
Description: "Anomaly's open-source coding agent",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := exec.LookPath("opencode")
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
URL: "https://opencode.ai",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "openclaw",
|
||||||
|
Runner: &Openclaw{},
|
||||||
|
Aliases: []string{"clawdbot", "moltbot"},
|
||||||
|
Description: "Personal AI with 100+ skills",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
if _, err := exec.LookPath("openclaw"); err == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if _, err := exec.LookPath("clawdbot"); err == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
},
|
||||||
|
EnsureInstalled: func() error {
|
||||||
|
_, err := ensureOpenclawInstalled()
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
URL: "https://docs.openclaw.ai",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "pi",
|
||||||
|
Runner: &Pi{},
|
||||||
|
Description: "Minimal AI agent toolkit with plugin support",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := exec.LookPath("pi")
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
Command: []string{"npm", "install", "-g", "@mariozechner/pi-coding-agent"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var integrationSpecsByName map[string]*IntegrationSpec
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rebuildIntegrationSpecIndexes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func hyperlink(url, text string) string {
|
||||||
|
return fmt.Sprintf("\033]8;;%s\033\\%s\033]8;;\033\\", url, text)
|
||||||
|
}
|
||||||
|
|
||||||
|
func rebuildIntegrationSpecIndexes() {
|
||||||
|
integrationSpecsByName = make(map[string]*IntegrationSpec, len(integrationSpecs))
|
||||||
|
|
||||||
|
canonical := make(map[string]bool, len(integrationSpecs))
|
||||||
|
for _, spec := range integrationSpecs {
|
||||||
|
key := strings.ToLower(spec.Name)
|
||||||
|
if key == "" {
|
||||||
|
panic("launch: integration spec missing name")
|
||||||
|
}
|
||||||
|
if canonical[key] {
|
||||||
|
panic(fmt.Sprintf("launch: duplicate integration name %q", key))
|
||||||
|
}
|
||||||
|
canonical[key] = true
|
||||||
|
integrationSpecsByName[key] = spec
|
||||||
|
}
|
||||||
|
|
||||||
|
seenAliases := make(map[string]string)
|
||||||
|
for _, spec := range integrationSpecs {
|
||||||
|
for _, alias := range spec.Aliases {
|
||||||
|
key := strings.ToLower(alias)
|
||||||
|
if key == "" {
|
||||||
|
panic(fmt.Sprintf("launch: integration %q has empty alias", spec.Name))
|
||||||
|
}
|
||||||
|
if canonical[key] {
|
||||||
|
panic(fmt.Sprintf("launch: alias %q collides with canonical integration name", key))
|
||||||
|
}
|
||||||
|
if owner, exists := seenAliases[key]; exists {
|
||||||
|
panic(fmt.Sprintf("launch: alias %q collides between %q and %q", key, owner, spec.Name))
|
||||||
|
}
|
||||||
|
seenAliases[key] = spec.Name
|
||||||
|
integrationSpecsByName[key] = spec
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
orderSeen := make(map[string]bool, len(launcherIntegrationOrder))
|
||||||
|
for _, name := range launcherIntegrationOrder {
|
||||||
|
key := strings.ToLower(name)
|
||||||
|
if orderSeen[key] {
|
||||||
|
panic(fmt.Sprintf("launch: duplicate launcher order entry %q", key))
|
||||||
|
}
|
||||||
|
orderSeen[key] = true
|
||||||
|
|
||||||
|
spec, ok := integrationSpecsByName[key]
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Sprintf("launch: unknown launcher order entry %q", key))
|
||||||
|
}
|
||||||
|
if spec.Name != key {
|
||||||
|
panic(fmt.Sprintf("launch: launcher order entry %q must use canonical name, not alias", key))
|
||||||
|
}
|
||||||
|
if spec.Hidden {
|
||||||
|
panic(fmt.Sprintf("launch: hidden integration %q cannot appear in launcher order", key))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LookupIntegrationSpec resolves either a canonical integration name or alias to its spec.
|
||||||
|
func LookupIntegrationSpec(name string) (*IntegrationSpec, error) {
|
||||||
|
spec, ok := integrationSpecsByName[strings.ToLower(name)]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unknown integration: %s", name)
|
||||||
|
}
|
||||||
|
return spec, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LookupIntegration resolves a registry name to the canonical key and runner.
|
||||||
|
func LookupIntegration(name string) (string, Runner, error) {
|
||||||
|
spec, err := LookupIntegrationSpec(name)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
return spec.Name, spec.Runner, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListVisibleIntegrationSpecs returns the canonical integrations that should appear in interactive UIs.
|
||||||
|
func ListVisibleIntegrationSpecs() []IntegrationSpec {
|
||||||
|
visible := make([]IntegrationSpec, 0, len(integrationSpecs))
|
||||||
|
for _, spec := range integrationSpecs {
|
||||||
|
if spec.Hidden {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
visible = append(visible, *spec)
|
||||||
|
}
|
||||||
|
|
||||||
|
orderRank := make(map[string]int, len(launcherIntegrationOrder))
|
||||||
|
for i, name := range launcherIntegrationOrder {
|
||||||
|
orderRank[name] = i + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
slices.SortFunc(visible, func(a, b IntegrationSpec) int {
|
||||||
|
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
|
||||||
|
if aRank > 0 && bRank > 0 {
|
||||||
|
return aRank - bRank
|
||||||
|
}
|
||||||
|
if aRank > 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if bRank > 0 {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return strings.Compare(a.Name, b.Name)
|
||||||
|
})
|
||||||
|
|
||||||
|
return visible
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListIntegrationInfos returns the registered integrations in launcher display order.
|
||||||
|
func ListIntegrationInfos() []IntegrationInfo {
|
||||||
|
visible := ListVisibleIntegrationSpecs()
|
||||||
|
infos := make([]IntegrationInfo, 0, len(visible))
|
||||||
|
for _, spec := range visible {
|
||||||
|
infos = append(infos, IntegrationInfo{
|
||||||
|
Name: spec.Name,
|
||||||
|
DisplayName: spec.Runner.String(),
|
||||||
|
Description: spec.Description,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return infos
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntegrationSelectionItems returns the sorted integration items shown by launcher selection UIs.
|
||||||
|
func IntegrationSelectionItems() ([]ModelItem, error) {
|
||||||
|
visible := ListVisibleIntegrationSpecs()
|
||||||
|
if len(visible) == 0 {
|
||||||
|
return nil, fmt.Errorf("no integrations available")
|
||||||
|
}
|
||||||
|
|
||||||
|
items := make([]ModelItem, 0, len(visible))
|
||||||
|
for _, spec := range visible {
|
||||||
|
description := spec.Runner.String()
|
||||||
|
if conn, err := loadStoredIntegrationConfig(spec.Name); err == nil && len(conn.Models) > 0 {
|
||||||
|
description = fmt.Sprintf("%s (%s)", spec.Runner.String(), conn.Models[0])
|
||||||
|
}
|
||||||
|
items = append(items, ModelItem{Name: spec.Name, Description: description})
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsIntegrationInstalled checks if an integration binary is installed.
|
||||||
|
func IsIntegrationInstalled(name string) bool {
|
||||||
|
integration, err := integrationFor(name)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Ollama couldn't find integration %q, so it'll show up as not installed.\n", name)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return integration.installed
|
||||||
|
}
|
||||||
|
|
||||||
|
// integration is resolved registry metadata used by launcher state and install checks.
|
||||||
|
// It combines immutable registry spec data with computed runtime traits.
|
||||||
|
type integration struct {
|
||||||
|
spec *IntegrationSpec
|
||||||
|
installed bool
|
||||||
|
autoInstallable bool
|
||||||
|
editor bool
|
||||||
|
installHint string
|
||||||
|
}
|
||||||
|
|
||||||
|
// integrationFor resolves an integration name into the canonical spec plus
|
||||||
|
// derived launcher/install traits used across registry and launch flows.
|
||||||
|
func integrationFor(name string) (integration, error) {
|
||||||
|
spec, err := LookupIntegrationSpec(name)
|
||||||
|
if err != nil {
|
||||||
|
return integration{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
installed := true
|
||||||
|
if spec.Install.CheckInstalled != nil {
|
||||||
|
installed = spec.Install.CheckInstalled()
|
||||||
|
}
|
||||||
|
|
||||||
|
_, editor := spec.Runner.(Editor)
|
||||||
|
hint := ""
|
||||||
|
if spec.Install.URL != "" {
|
||||||
|
hint = "Install from " + hyperlink(spec.Install.URL, spec.Install.URL)
|
||||||
|
} else if len(spec.Install.Command) > 0 {
|
||||||
|
hint = "Install with: " + strings.Join(spec.Install.Command, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
return integration{
|
||||||
|
spec: spec,
|
||||||
|
installed: installed,
|
||||||
|
autoInstallable: spec.Install.EnsureInstalled != nil,
|
||||||
|
editor: editor,
|
||||||
|
installHint: hint,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnsureIntegrationInstalled installs auto-installable integrations when missing.
|
||||||
|
func EnsureIntegrationInstalled(name string, runner Runner) error {
|
||||||
|
integration, err := integrationFor(name)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%s is not installed", runner)
|
||||||
|
}
|
||||||
|
|
||||||
|
if integration.installed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if integration.autoInstallable {
|
||||||
|
return integration.spec.Install.EnsureInstalled()
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case integration.spec.Install.URL != "":
|
||||||
|
return fmt.Errorf("%s is not installed, install from %s", integration.spec.Name, integration.spec.Install.URL)
|
||||||
|
case len(integration.spec.Install.Command) > 0:
|
||||||
|
return fmt.Errorf("%s is not installed, install with: %s", integration.spec.Name, strings.Join(integration.spec.Install.Command, " "))
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%s is not installed", runner)
|
||||||
|
}
|
||||||
|
}
|
||||||
21
cmd/launch/registry_test_helpers_test.go
Normal file
21
cmd/launch/registry_test_helpers_test.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// OverrideIntegration replaces one registry entry's runner for tests and returns a restore function.
|
||||||
|
func OverrideIntegration(name string, runner Runner) func() {
|
||||||
|
spec, err := LookupIntegrationSpec(name)
|
||||||
|
if err != nil {
|
||||||
|
key := strings.ToLower(name)
|
||||||
|
integrationSpecsByName[key] = &IntegrationSpec{Name: key, Runner: runner}
|
||||||
|
return func() {
|
||||||
|
delete(integrationSpecsByName, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
original := spec.Runner
|
||||||
|
spec.Runner = runner
|
||||||
|
return func() {
|
||||||
|
spec.Runner = original
|
||||||
|
}
|
||||||
|
}
|
||||||
68
cmd/launch/runner_exec_only_test.go
Normal file
68
cmd/launch/runner_exec_only_test.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
binary string
|
||||||
|
runner Runner
|
||||||
|
checkPath func(home string) string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "droid",
|
||||||
|
binary: "droid",
|
||||||
|
runner: &Droid{},
|
||||||
|
checkPath: func(home string) string {
|
||||||
|
return filepath.Join(home, ".factory", "settings.json")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "opencode",
|
||||||
|
binary: "opencode",
|
||||||
|
runner: &OpenCode{},
|
||||||
|
checkPath: func(home string) string {
|
||||||
|
return filepath.Join(home, ".config", "opencode", "opencode.json")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cline",
|
||||||
|
binary: "cline",
|
||||||
|
runner: &Cline{},
|
||||||
|
checkPath: func(home string) string {
|
||||||
|
return filepath.Join(home, ".cline", "data", "globalState.json")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "pi",
|
||||||
|
binary: "pi",
|
||||||
|
runner: &Pi{},
|
||||||
|
checkPath: func(home string) string {
|
||||||
|
return filepath.Join(home, ".pi", "agent", "models.json")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
home := t.TempDir()
|
||||||
|
setTestHome(t, home)
|
||||||
|
|
||||||
|
binDir := t.TempDir()
|
||||||
|
writeFakeBinary(t, binDir, tt.binary)
|
||||||
|
t.Setenv("PATH", binDir)
|
||||||
|
|
||||||
|
configPath := tt.checkPath(home)
|
||||||
|
if err := tt.runner.Run("llama3.2", nil); err != nil {
|
||||||
|
t.Fatalf("Run returned error: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(configPath); !os.IsNotExist(err) {
|
||||||
|
t.Fatalf("expected Run to leave %s untouched, got err=%v", configPath, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
103
cmd/launch/selector_hooks.go
Normal file
103
cmd/launch/selector_hooks.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"golang.org/x/term"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ANSI escape sequences for terminal formatting.
|
||||||
|
const (
|
||||||
|
ansiBold = "\033[1m"
|
||||||
|
ansiReset = "\033[0m"
|
||||||
|
ansiGray = "\033[37m"
|
||||||
|
ansiGreen = "\033[32m"
|
||||||
|
ansiYellow = "\033[33m"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrCancelled is returned when the user cancels a selection.
|
||||||
|
var ErrCancelled = errors.New("cancelled")
|
||||||
|
|
||||||
|
// errCancelled is kept as an internal alias for existing call sites.
|
||||||
|
var errCancelled = ErrCancelled
|
||||||
|
|
||||||
|
// DefaultConfirmPrompt provides a TUI-based confirmation prompt.
|
||||||
|
// When set, ConfirmPrompt delegates to it instead of using raw terminal I/O.
|
||||||
|
var DefaultConfirmPrompt func(prompt string) (bool, error)
|
||||||
|
|
||||||
|
// SingleSelector is a function type for single item selection.
|
||||||
|
// current is the name of the previously selected item to highlight; empty means no pre-selection.
|
||||||
|
type SingleSelector func(title string, items []ModelItem, current string) (string, error)
|
||||||
|
|
||||||
|
// MultiSelector is a function type for multi item selection.
|
||||||
|
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
|
||||||
|
|
||||||
|
// DefaultSingleSelector is the default single-select implementation.
|
||||||
|
var DefaultSingleSelector SingleSelector
|
||||||
|
|
||||||
|
// DefaultMultiSelector is the default multi-select implementation.
|
||||||
|
var DefaultMultiSelector MultiSelector
|
||||||
|
|
||||||
|
// DefaultSignIn provides a TUI-based sign-in flow.
|
||||||
|
// When set, ensureAuth uses it instead of plain text prompts.
|
||||||
|
// Returns the signed-in username or an error.
|
||||||
|
var DefaultSignIn func(modelName, signInURL string) (string, error)
|
||||||
|
|
||||||
|
type launchConfirmPolicy struct {
|
||||||
|
yes bool
|
||||||
|
requireYesMessage bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var currentLaunchConfirmPolicy launchConfirmPolicy
|
||||||
|
|
||||||
|
func withLaunchConfirmPolicy(policy launchConfirmPolicy) func() {
|
||||||
|
old := currentLaunchConfirmPolicy
|
||||||
|
currentLaunchConfirmPolicy = policy
|
||||||
|
return func() {
|
||||||
|
currentLaunchConfirmPolicy = old
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfirmPrompt is the shared confirmation gate for launch flows (integration
|
||||||
|
// edits, missing-model pulls, sign-in prompts, OpenClaw install/security, etc).
|
||||||
|
// Behavior is controlled by currentLaunchConfirmPolicy, typically scoped by
|
||||||
|
// withLaunchConfirmPolicy in LaunchCmd (e.g. auto-approve with --yes).
|
||||||
|
func ConfirmPrompt(prompt string) (bool, error) {
|
||||||
|
if currentLaunchConfirmPolicy.yes {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
if currentLaunchConfirmPolicy.requireYesMessage {
|
||||||
|
return false, fmt.Errorf("%s requires confirmation; re-run with --yes to continue", prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
if DefaultConfirmPrompt != nil {
|
||||||
|
return DefaultConfirmPrompt(prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
fd := int(os.Stdin.Fd())
|
||||||
|
oldState, err := term.MakeRaw(fd)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
defer term.Restore(fd, oldState)
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
|
||||||
|
|
||||||
|
buf := make([]byte, 1)
|
||||||
|
for {
|
||||||
|
if _, err := os.Stdin.Read(buf); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch buf[0] {
|
||||||
|
case 'Y', 'y', 13:
|
||||||
|
fmt.Fprintf(os.Stderr, "yes\r\n")
|
||||||
|
return true, nil
|
||||||
|
case 'N', 'n', 27, 3:
|
||||||
|
fmt.Fprintf(os.Stderr, "no\r\n")
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
76
cmd/launch/selector_test.go
Normal file
76
cmd/launch/selector_test.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestErrCancelled(t *testing.T) {
|
||||||
|
t.Run("NotNil", func(t *testing.T) {
|
||||||
|
if errCancelled == nil {
|
||||||
|
t.Error("errCancelled should not be nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Message", func(t *testing.T) {
|
||||||
|
if errCancelled.Error() != "cancelled" {
|
||||||
|
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithLaunchConfirmPolicy_ScopesAndRestores(t *testing.T) {
|
||||||
|
oldPolicy := currentLaunchConfirmPolicy
|
||||||
|
oldHook := DefaultConfirmPrompt
|
||||||
|
t.Cleanup(func() {
|
||||||
|
currentLaunchConfirmPolicy = oldPolicy
|
||||||
|
DefaultConfirmPrompt = oldHook
|
||||||
|
})
|
||||||
|
|
||||||
|
currentLaunchConfirmPolicy = launchConfirmPolicy{}
|
||||||
|
var hookCalls int
|
||||||
|
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||||
|
hookCalls++
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
restoreOuter := withLaunchConfirmPolicy(launchConfirmPolicy{requireYesMessage: true})
|
||||||
|
restoreInner := withLaunchConfirmPolicy(launchConfirmPolicy{yes: true})
|
||||||
|
|
||||||
|
ok, err := ConfirmPrompt("test prompt")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected --yes policy to allow prompt, got error: %v", err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected --yes policy to auto-accept prompt")
|
||||||
|
}
|
||||||
|
if hookCalls != 0 {
|
||||||
|
t.Fatalf("expected --yes to skip hook, got %d hook calls", hookCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
restoreInner()
|
||||||
|
|
||||||
|
_, err = ConfirmPrompt("test prompt")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected requireYesMessage policy to block prompt")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "re-run with --yes") {
|
||||||
|
t.Fatalf("expected actionable --yes error, got: %v", err)
|
||||||
|
}
|
||||||
|
if hookCalls != 0 {
|
||||||
|
t.Fatalf("expected blocking policy to skip hook, got %d hook calls", hookCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
restoreOuter()
|
||||||
|
|
||||||
|
ok, err = ConfirmPrompt("test prompt")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected restored default behavior to use hook, got error: %v", err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected hook to return true")
|
||||||
|
}
|
||||||
|
if hookCalls != 1 {
|
||||||
|
t.Fatalf("expected one hook call after restore, got %d", hookCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
82
cmd/launch/test_config_helpers_test.go
Normal file
82
cmd/launch/test_config_helpers_test.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/cmd/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
integrations map[string]Runner
|
||||||
|
integrationAliases map[string]bool
|
||||||
|
integrationOrder = launcherIntegrationOrder
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
integrations = buildTestIntegrations()
|
||||||
|
integrationAliases = buildTestIntegrationAliases()
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildTestIntegrations() map[string]Runner {
|
||||||
|
result := make(map[string]Runner, len(integrationSpecsByName))
|
||||||
|
for name, spec := range integrationSpecsByName {
|
||||||
|
result[strings.ToLower(name)] = spec.Runner
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildTestIntegrationAliases() map[string]bool {
|
||||||
|
result := make(map[string]bool)
|
||||||
|
for _, spec := range integrationSpecs {
|
||||||
|
for _, alias := range spec.Aliases {
|
||||||
|
result[strings.ToLower(alias)] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func setTestHome(t *testing.T, dir string) {
|
||||||
|
t.Helper()
|
||||||
|
setLaunchTestHome(t, dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SaveIntegration(appName string, models []string) error {
|
||||||
|
return config.SaveIntegration(appName, models)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadIntegration(appName string) (*config.IntegrationConfig, error) {
|
||||||
|
return config.LoadIntegration(appName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SaveAliases(appName string, aliases map[string]string) error {
|
||||||
|
return config.SaveAliases(appName, aliases)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LastModel() string {
|
||||||
|
return config.LastModel()
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetLastModel(model string) error {
|
||||||
|
return config.SetLastModel(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LastSelection() string {
|
||||||
|
return config.LastSelection()
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetLastSelection(selection string) error {
|
||||||
|
return config.SetLastSelection(selection)
|
||||||
|
}
|
||||||
|
|
||||||
|
func IntegrationModel(appName string) string {
|
||||||
|
return config.IntegrationModel(appName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func IntegrationModels(appName string) []string {
|
||||||
|
return config.IntegrationModels(appName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func integrationOnboarded(appName string) error {
|
||||||
|
return config.MarkIntegrationOnboarded(appName)
|
||||||
|
}
|
||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
"github.com/charmbracelet/lipgloss"
|
"github.com/charmbracelet/lipgloss"
|
||||||
"github.com/ollama/ollama/cmd/config"
|
"github.com/ollama/ollama/cmd/launch"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -64,8 +64,8 @@ type SelectItem struct {
|
|||||||
Recommended bool
|
Recommended bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConvertItems converts config.ModelItem slice to SelectItem slice.
|
// ConvertItems converts launch.ModelItem slice to SelectItem slice.
|
||||||
func ConvertItems(items []config.ModelItem) []SelectItem {
|
func ConvertItems(items []launch.ModelItem) []SelectItem {
|
||||||
out := make([]SelectItem, len(items))
|
out := make([]SelectItem, len(items))
|
||||||
for i, item := range items {
|
for i, item := range items {
|
||||||
out[i] = SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
|
out[i] = SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
|
||||||
@@ -101,6 +101,16 @@ type selectorModel struct {
|
|||||||
width int
|
width int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func selectorModelWithCurrent(title string, items []SelectItem, current string) selectorModel {
|
||||||
|
m := selectorModel{
|
||||||
|
title: title,
|
||||||
|
items: items,
|
||||||
|
cursor: cursorForCurrent(items, current),
|
||||||
|
}
|
||||||
|
m.updateScroll(m.otherStart())
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
func (m selectorModel) filteredItems() []SelectItem {
|
func (m selectorModel) filteredItems() []SelectItem {
|
||||||
if m.filter == "" {
|
if m.filter == "" {
|
||||||
return m.items
|
return m.items
|
||||||
@@ -367,13 +377,24 @@ func (m selectorModel) View() string {
|
|||||||
|
|
||||||
// cursorForCurrent returns the item index matching current, or 0 if not found.
|
// cursorForCurrent returns the item index matching current, or 0 if not found.
|
||||||
func cursorForCurrent(items []SelectItem, current string) int {
|
func cursorForCurrent(items []SelectItem, current string) int {
|
||||||
if current != "" {
|
if current == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prefer exact name matches before tag-prefix fallback so "qwen3.5" does not
|
||||||
|
// incorrectly select "qwen3.5:cloud" (and vice versa) based on list order.
|
||||||
for i, item := range items {
|
for i, item := range items {
|
||||||
if item.Name == current || strings.HasPrefix(item.Name, current+":") || strings.HasPrefix(current, item.Name+":") {
|
if item.Name == current {
|
||||||
return i
|
return i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for i, item := range items {
|
||||||
|
if strings.HasPrefix(item.Name, current+":") || strings.HasPrefix(current, item.Name+":") {
|
||||||
|
return i
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -382,11 +403,7 @@ func SelectSingle(title string, items []SelectItem, current string) (string, err
|
|||||||
return "", fmt.Errorf("no items to select from")
|
return "", fmt.Errorf("no items to select from")
|
||||||
}
|
}
|
||||||
|
|
||||||
m := selectorModel{
|
m := selectorModelWithCurrent(title, items, current)
|
||||||
title: title,
|
|
||||||
items: items,
|
|
||||||
cursor: cursorForCurrent(items, current),
|
|
||||||
}
|
|
||||||
|
|
||||||
p := tea.NewProgram(m)
|
p := tea.NewProgram(m)
|
||||||
finalModel, err := p.Run()
|
finalModel, err := p.Run()
|
||||||
@@ -415,6 +432,12 @@ type multiSelectorModel struct {
|
|||||||
cancelled bool
|
cancelled bool
|
||||||
confirmed bool
|
confirmed bool
|
||||||
width int
|
width int
|
||||||
|
|
||||||
|
// multi enables full multi-select editing mode. The zero value (false)
|
||||||
|
// shows a single-select picker where Enter adds the chosen model to
|
||||||
|
// the existing list. Tab toggles between modes.
|
||||||
|
multi bool
|
||||||
|
singleAdd string // model picked in single mode
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMultiSelectorModel(title string, items []SelectItem, preChecked []string) multiSelectorModel {
|
func newMultiSelectorModel(title string, items []SelectItem, preChecked []string) multiSelectorModel {
|
||||||
@@ -429,13 +452,23 @@ func newMultiSelectorModel(title string, items []SelectItem, preChecked []string
|
|||||||
m.itemIndex[item.Name] = i
|
m.itemIndex[item.Name] = i
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, name := range preChecked {
|
// Reverse order so preChecked[0] (the current default) ends up last
|
||||||
if idx, ok := m.itemIndex[name]; ok {
|
// in checkOrder, matching the "last checked = default" convention.
|
||||||
|
for i := len(preChecked) - 1; i >= 0; i-- {
|
||||||
|
if idx, ok := m.itemIndex[preChecked[i]]; ok {
|
||||||
m.checked[idx] = true
|
m.checked[idx] = true
|
||||||
m.checkOrder = append(m.checkOrder, idx)
|
m.checkOrder = append(m.checkOrder, idx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Position cursor on the current default model
|
||||||
|
if len(preChecked) > 0 {
|
||||||
|
if idx, ok := m.itemIndex[preChecked[0]]; ok {
|
||||||
|
m.cursor = idx
|
||||||
|
m.updateScroll(m.otherStart())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -507,6 +540,7 @@ func (m *multiSelectorModel) toggleItem() {
|
|||||||
origIdx := m.itemIndex[item.Name]
|
origIdx := m.itemIndex[item.Name]
|
||||||
|
|
||||||
if m.checked[origIdx] {
|
if m.checked[origIdx] {
|
||||||
|
wasDefault := len(m.checkOrder) > 0 && m.checkOrder[len(m.checkOrder)-1] == origIdx
|
||||||
delete(m.checked, origIdx)
|
delete(m.checked, origIdx)
|
||||||
for i, idx := range m.checkOrder {
|
for i, idx := range m.checkOrder {
|
||||||
if idx == origIdx {
|
if idx == origIdx {
|
||||||
@@ -514,6 +548,34 @@ func (m *multiSelectorModel) toggleItem() {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if wasDefault {
|
||||||
|
// When removing the default, pick the nearest checked model above it
|
||||||
|
// (or below if none above) so default fallback follows list order.
|
||||||
|
newDefault := -1
|
||||||
|
for i := origIdx - 1; i >= 0; i-- {
|
||||||
|
if m.checked[i] {
|
||||||
|
newDefault = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if newDefault == -1 {
|
||||||
|
for i := origIdx + 1; i < len(m.items); i++ {
|
||||||
|
if m.checked[i] {
|
||||||
|
newDefault = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if newDefault != -1 {
|
||||||
|
for i, idx := range m.checkOrder {
|
||||||
|
if idx == newDefault {
|
||||||
|
m.checkOrder = append(m.checkOrder[:i], m.checkOrder[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.checkOrder = append(m.checkOrder, newDefault)
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
m.checked[origIdx] = true
|
m.checked[origIdx] = true
|
||||||
m.checkOrder = append(m.checkOrder, origIdx)
|
m.checkOrder = append(m.checkOrder, origIdx)
|
||||||
@@ -546,14 +608,25 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||||||
m.cancelled = true
|
m.cancelled = true
|
||||||
return m, tea.Quit
|
return m, tea.Quit
|
||||||
|
|
||||||
|
case tea.KeyTab:
|
||||||
|
m.multi = !m.multi
|
||||||
|
|
||||||
case tea.KeyEnter:
|
case tea.KeyEnter:
|
||||||
if len(m.checkOrder) > 0 {
|
if !m.multi {
|
||||||
|
if len(filtered) > 0 && m.cursor < len(filtered) {
|
||||||
|
m.singleAdd = filtered[m.cursor].Name
|
||||||
|
m.confirmed = true
|
||||||
|
return m, tea.Quit
|
||||||
|
}
|
||||||
|
} else if len(m.checkOrder) > 0 {
|
||||||
m.confirmed = true
|
m.confirmed = true
|
||||||
return m, tea.Quit
|
return m, tea.Quit
|
||||||
}
|
}
|
||||||
|
|
||||||
case tea.KeySpace:
|
case tea.KeySpace:
|
||||||
|
if m.multi {
|
||||||
m.toggleItem()
|
m.toggleItem()
|
||||||
|
}
|
||||||
|
|
||||||
case tea.KeyUp:
|
case tea.KeyUp:
|
||||||
if m.cursor > 0 {
|
if m.cursor > 0 {
|
||||||
@@ -592,7 +665,9 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||||||
// On some terminals (e.g. Windows PowerShell), space arrives as
|
// On some terminals (e.g. Windows PowerShell), space arrives as
|
||||||
// KeyRunes instead of KeySpace. Intercept it so toggle still works.
|
// KeyRunes instead of KeySpace. Intercept it so toggle still works.
|
||||||
if len(msg.Runes) == 1 && msg.Runes[0] == ' ' {
|
if len(msg.Runes) == 1 && msg.Runes[0] == ' ' {
|
||||||
|
if m.multi {
|
||||||
m.toggleItem()
|
m.toggleItem()
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
m.filter += string(msg.Runes)
|
m.filter += string(msg.Runes)
|
||||||
m.cursor = 0
|
m.cursor = 0
|
||||||
@@ -604,6 +679,19 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m multiSelectorModel) renderSingleItem(s *strings.Builder, item SelectItem, idx int) {
|
||||||
|
if idx == m.cursor {
|
||||||
|
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
|
||||||
|
} else {
|
||||||
|
s.WriteString(selectorItemStyle.Render(item.Name))
|
||||||
|
}
|
||||||
|
s.WriteString("\n")
|
||||||
|
if item.Description != "" {
|
||||||
|
s.WriteString(selectorDescLineStyle.Render(item.Description))
|
||||||
|
s.WriteString("\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem, idx int) {
|
func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem, idx int) {
|
||||||
origIdx := m.itemIndex[item.Name]
|
origIdx := m.itemIndex[item.Name]
|
||||||
|
|
||||||
@@ -615,7 +703,7 @@ func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem,
|
|||||||
}
|
}
|
||||||
|
|
||||||
suffix := ""
|
suffix := ""
|
||||||
if len(m.checkOrder) > 0 && m.checkOrder[0] == origIdx {
|
if len(m.checkOrder) > 0 && m.checkOrder[len(m.checkOrder)-1] == origIdx {
|
||||||
suffix = " " + selectorDefaultTagStyle.Render("(default)")
|
suffix = " " + selectorDefaultTagStyle.Render("(default)")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -637,6 +725,11 @@ func (m multiSelectorModel) View() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
renderItem := m.renderSingleItem
|
||||||
|
if m.multi {
|
||||||
|
renderItem = m.renderMultiItem
|
||||||
|
}
|
||||||
|
|
||||||
var s strings.Builder
|
var s strings.Builder
|
||||||
|
|
||||||
s.WriteString(selectorTitleStyle.Render(m.title))
|
s.WriteString(selectorTitleStyle.Render(m.title))
|
||||||
@@ -661,7 +754,7 @@ func (m multiSelectorModel) View() string {
|
|||||||
if idx >= len(filtered) {
|
if idx >= len(filtered) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
m.renderMultiItem(&s, filtered[idx], idx)
|
renderItem(&s, filtered[idx], idx)
|
||||||
}
|
}
|
||||||
|
|
||||||
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
|
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
|
||||||
@@ -684,7 +777,7 @@ func (m multiSelectorModel) View() string {
|
|||||||
s.WriteString(sectionHeaderStyle.Render("Recommended"))
|
s.WriteString(sectionHeaderStyle.Render("Recommended"))
|
||||||
s.WriteString("\n")
|
s.WriteString("\n")
|
||||||
for _, idx := range recItems {
|
for _, idx := range recItems {
|
||||||
m.renderMultiItem(&s, filtered[idx], idx)
|
renderItem(&s, filtered[idx], idx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -704,7 +797,7 @@ func (m multiSelectorModel) View() string {
|
|||||||
if idx >= len(otherItems) {
|
if idx >= len(otherItems) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
m.renderMultiItem(&s, filtered[otherItems[idx]], otherItems[idx])
|
renderItem(&s, filtered[otherItems[idx]], otherItems[idx])
|
||||||
}
|
}
|
||||||
|
|
||||||
if remaining := len(otherItems) - m.scrollOffset - displayCount; remaining > 0 {
|
if remaining := len(otherItems) - m.scrollOffset - displayCount; remaining > 0 {
|
||||||
@@ -716,6 +809,9 @@ func (m multiSelectorModel) View() string {
|
|||||||
|
|
||||||
s.WriteString("\n")
|
s.WriteString("\n")
|
||||||
|
|
||||||
|
if !m.multi {
|
||||||
|
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • tab add multiple • esc cancel"))
|
||||||
|
} else {
|
||||||
count := m.selectedCount()
|
count := m.selectedCount()
|
||||||
if count == 0 {
|
if count == 0 {
|
||||||
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
|
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
|
||||||
@@ -723,8 +819,8 @@ func (m multiSelectorModel) View() string {
|
|||||||
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
|
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
|
||||||
}
|
}
|
||||||
s.WriteString("\n\n")
|
s.WriteString("\n\n")
|
||||||
|
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • tab select single • enter confirm • esc cancel"))
|
||||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • enter confirm • esc cancel"))
|
}
|
||||||
|
|
||||||
result := s.String()
|
result := s.String()
|
||||||
if m.width > 0 {
|
if m.width > 0 {
|
||||||
@@ -747,18 +843,28 @@ func SelectMultiple(title string, items []SelectItem, preChecked []string) ([]st
|
|||||||
}
|
}
|
||||||
|
|
||||||
fm := finalModel.(multiSelectorModel)
|
fm := finalModel.(multiSelectorModel)
|
||||||
if fm.cancelled {
|
if fm.cancelled || !fm.confirmed {
|
||||||
return nil, ErrCancelled
|
return nil, ErrCancelled
|
||||||
}
|
}
|
||||||
|
|
||||||
if !fm.confirmed {
|
// Single-add mode: prepend the picked model, keep existing models deduped
|
||||||
return nil, ErrCancelled
|
if fm.singleAdd != "" {
|
||||||
|
result := []string{fm.singleAdd}
|
||||||
|
for _, name := range preChecked {
|
||||||
|
if name != fm.singleAdd {
|
||||||
|
result = append(result, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multi-edit mode: last checked is default (first in result)
|
||||||
|
last := fm.checkOrder[len(fm.checkOrder)-1]
|
||||||
|
result := []string{fm.items[last].Name}
|
||||||
|
for _, idx := range fm.checkOrder {
|
||||||
|
if idx != last {
|
||||||
|
result = append(result, fm.items[idx].Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
var result []string
|
|
||||||
for _, idx := range fm.checkOrder {
|
|
||||||
result = append(result, fm.items[idx].Name)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -216,6 +216,41 @@ func TestUpdateScroll(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSelectorModelWithCurrent_ScrollsToCurrentInMoreSection(t *testing.T) {
|
||||||
|
m := selectorModelWithCurrent("Pick:", mixedItems(), "other-10")
|
||||||
|
|
||||||
|
if m.cursor != 11 {
|
||||||
|
t.Fatalf("cursor = %d, want 11", m.cursor)
|
||||||
|
}
|
||||||
|
if m.scrollOffset == 0 {
|
||||||
|
t.Fatal("scrollOffset should move to reveal current item in More section")
|
||||||
|
}
|
||||||
|
|
||||||
|
content := m.renderContent()
|
||||||
|
if !strings.Contains(content, "▸ other-10") {
|
||||||
|
t.Fatalf("expected current item to be visible and highlighted\n%s", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectorModelWithCurrent_HighlightsExactLocalWhenCloudVariantExists(t *testing.T) {
|
||||||
|
m := selectorModelWithCurrent("Pick:", []SelectItem{
|
||||||
|
{Name: "qwen3.5:cloud", Recommended: true},
|
||||||
|
{Name: "qwen3.5", Recommended: true},
|
||||||
|
}, "qwen3.5")
|
||||||
|
|
||||||
|
if m.cursor != 1 {
|
||||||
|
t.Fatalf("cursor = %d, want 1", m.cursor)
|
||||||
|
}
|
||||||
|
|
||||||
|
content := m.renderContent()
|
||||||
|
if !strings.Contains(content, "▸ qwen3.5") {
|
||||||
|
t.Fatalf("expected local qwen3.5 to be highlighted\n%s", content)
|
||||||
|
}
|
||||||
|
if strings.Contains(content, "▸ qwen3.5:cloud") {
|
||||||
|
t.Fatalf("did not expect cloud qwen3.5:cloud to be highlighted\n%s", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRenderContent_SectionHeaders(t *testing.T) {
|
func TestRenderContent_SectionHeaders(t *testing.T) {
|
||||||
m := selectorModel{
|
m := selectorModel{
|
||||||
title: "Pick:",
|
title: "Pick:",
|
||||||
@@ -418,6 +453,28 @@ func TestCursorForCurrent(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCursorForCurrent_PrefersExactLocalOverCloudPrefix(t *testing.T) {
|
||||||
|
testItems := []SelectItem{
|
||||||
|
{Name: "qwen3.5:cloud", Recommended: true},
|
||||||
|
{Name: "qwen3.5", Recommended: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := cursorForCurrent(testItems, "qwen3.5"); got != 1 {
|
||||||
|
t.Errorf("cursorForCurrent(%q) = %d, want %d", "qwen3.5", got, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCursorForCurrent_PrefersExactCloudOverLocalPrefix(t *testing.T) {
|
||||||
|
testItems := []SelectItem{
|
||||||
|
{Name: "qwen3.5", Recommended: true},
|
||||||
|
{Name: "qwen3.5:cloud", Recommended: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := cursorForCurrent(testItems, "qwen3.5:cloud"); got != 1 {
|
||||||
|
t.Errorf("cursorForCurrent(%q) = %d, want %d", "qwen3.5:cloud", got, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// --- ReorderItems ---
|
// --- ReorderItems ---
|
||||||
|
|
||||||
func TestReorderItems(t *testing.T) {
|
func TestReorderItems(t *testing.T) {
|
||||||
@@ -539,6 +596,7 @@ func TestMultiView_CursorIndicator(t *testing.T) {
|
|||||||
|
|
||||||
func TestMultiView_CheckedItemShowsX(t *testing.T) {
|
func TestMultiView_CheckedItemShowsX(t *testing.T) {
|
||||||
m := newMultiSelectorModel("Pick:", items("a", "b"), []string{"a"})
|
m := newMultiSelectorModel("Pick:", items("a", "b"), []string{"a"})
|
||||||
|
m.multi = true
|
||||||
content := m.View()
|
content := m.View()
|
||||||
|
|
||||||
if !strings.Contains(content, "[x]") {
|
if !strings.Contains(content, "[x]") {
|
||||||
@@ -550,11 +608,18 @@ func TestMultiView_CheckedItemShowsX(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestMultiView_DefaultTag(t *testing.T) {
|
func TestMultiView_DefaultTag(t *testing.T) {
|
||||||
m := newMultiSelectorModel("Pick:", items("a", "b"), []string{"a"})
|
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "b"})
|
||||||
|
m.multi = true
|
||||||
content := m.View()
|
content := m.View()
|
||||||
|
|
||||||
if !strings.Contains(content, "(default)") {
|
if !strings.Contains(content, "(default)") {
|
||||||
t.Error("first checked item should have (default) tag")
|
t.Error("should have (default) tag")
|
||||||
|
}
|
||||||
|
// preChecked[0] ("a") should be the default (last in checkOrder)
|
||||||
|
aIdx := strings.Index(content, "a")
|
||||||
|
defaultIdx := strings.Index(content, "(default)")
|
||||||
|
if defaultIdx < aIdx {
|
||||||
|
t.Error("(default) tag should appear after 'a' (the current default)")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -585,6 +650,7 @@ func TestMultiView_OverflowIndicator(t *testing.T) {
|
|||||||
|
|
||||||
func TestMultiUpdate_SpaceTogglesItem(t *testing.T) {
|
func TestMultiUpdate_SpaceTogglesItem(t *testing.T) {
|
||||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||||
|
m.multi = true
|
||||||
m.cursor = 1
|
m.cursor = 1
|
||||||
|
|
||||||
// Simulate space delivered as tea.KeySpace
|
// Simulate space delivered as tea.KeySpace
|
||||||
@@ -601,6 +667,7 @@ func TestMultiUpdate_SpaceTogglesItem(t *testing.T) {
|
|||||||
|
|
||||||
func TestMultiUpdate_SpaceRuneTogglesItem(t *testing.T) {
|
func TestMultiUpdate_SpaceRuneTogglesItem(t *testing.T) {
|
||||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||||
|
m.multi = true
|
||||||
m.cursor = 1
|
m.cursor = 1
|
||||||
|
|
||||||
// Simulate space delivered as tea.KeyRunes (Windows PowerShell behavior)
|
// Simulate space delivered as tea.KeyRunes (Windows PowerShell behavior)
|
||||||
@@ -618,6 +685,189 @@ func TestMultiUpdate_SpaceRuneTogglesItem(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- Single-add mode ---
|
||||||
|
|
||||||
|
func TestMulti_StartsInSingleMode(t *testing.T) {
|
||||||
|
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||||
|
if m.multi {
|
||||||
|
t.Error("should start in single mode (multi=false)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMulti_SingleModeNoCheckboxes(t *testing.T) {
|
||||||
|
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||||
|
content := m.View()
|
||||||
|
if strings.Contains(content, "[x]") || strings.Contains(content, "[ ]") {
|
||||||
|
t.Error("single mode should not show checkboxes")
|
||||||
|
}
|
||||||
|
if !strings.Contains(content, "▸") {
|
||||||
|
t.Error("single mode should show cursor indicator")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMulti_SingleModeEnterPicksItem(t *testing.T) {
|
||||||
|
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||||
|
m.cursor = 1
|
||||||
|
|
||||||
|
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||||
|
m = updated.(multiSelectorModel)
|
||||||
|
|
||||||
|
if m.singleAdd != "b" {
|
||||||
|
t.Errorf("enter in single mode should pick cursor item, got %q", m.singleAdd)
|
||||||
|
}
|
||||||
|
if !m.confirmed {
|
||||||
|
t.Error("should set confirmed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMulti_SingleModeSpaceIsNoop(t *testing.T) {
|
||||||
|
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||||
|
m.cursor = 0
|
||||||
|
|
||||||
|
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeySpace})
|
||||||
|
m = updated.(multiSelectorModel)
|
||||||
|
|
||||||
|
if len(m.checked) != 0 {
|
||||||
|
t.Error("space in single mode should not toggle items")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMulti_SingleModeSpaceRuneIsNoop(t *testing.T) {
|
||||||
|
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||||
|
m.cursor = 0
|
||||||
|
|
||||||
|
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{' '}})
|
||||||
|
m = updated.(multiSelectorModel)
|
||||||
|
|
||||||
|
if len(m.checked) != 0 {
|
||||||
|
t.Error("space rune in single mode should not toggle items")
|
||||||
|
}
|
||||||
|
if m.filter != "" {
|
||||||
|
t.Error("space rune in single mode should not add to filter")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMulti_TabTogglesMode(t *testing.T) {
|
||||||
|
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||||
|
|
||||||
|
if m.multi {
|
||||||
|
t.Fatal("should start in single mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyTab})
|
||||||
|
m = updated.(multiSelectorModel)
|
||||||
|
if !m.multi {
|
||||||
|
t.Error("tab should switch to multi mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, _ = m.Update(tea.KeyMsg{Type: tea.KeyTab})
|
||||||
|
m = updated.(multiSelectorModel)
|
||||||
|
if m.multi {
|
||||||
|
t.Error("tab should switch back to single mode")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMulti_SingleModeHelpText(t *testing.T) {
|
||||||
|
m := newMultiSelectorModel("Pick:", items("a"), nil)
|
||||||
|
content := m.View()
|
||||||
|
if !strings.Contains(content, "tab add multiple") {
|
||||||
|
t.Error("single mode should show 'tab add multiple' in help")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMulti_MultiModeHelpText(t *testing.T) {
|
||||||
|
m := newMultiSelectorModel("Pick:", items("a"), nil)
|
||||||
|
m.multi = true
|
||||||
|
content := m.View()
|
||||||
|
if !strings.Contains(content, "tab select single") {
|
||||||
|
t.Error("multi mode should show 'tab select single' in help")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- preChecked initialization order ---
|
||||||
|
|
||||||
|
func TestMulti_PreCheckedDefaultIsLast(t *testing.T) {
|
||||||
|
// preChecked[0] ("a") is the current default and should end up
|
||||||
|
// last in checkOrder so it gets the (default) tag.
|
||||||
|
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "b", "c"})
|
||||||
|
|
||||||
|
if len(m.checkOrder) != 3 {
|
||||||
|
t.Fatalf("expected 3 in checkOrder, got %d", len(m.checkOrder))
|
||||||
|
}
|
||||||
|
lastIdx := m.checkOrder[len(m.checkOrder)-1]
|
||||||
|
if m.items[lastIdx].Name != "a" {
|
||||||
|
t.Errorf("preChecked[0] should be last in checkOrder, got %q", m.items[lastIdx].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMulti_CursorOnDefaultModel(t *testing.T) {
|
||||||
|
// preChecked[0] ("b") is the default; cursor should start on it
|
||||||
|
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"b", "c"})
|
||||||
|
|
||||||
|
if m.cursor != 1 {
|
||||||
|
t.Errorf("cursor should be on preChecked[0] ('b') at index 1, got %d", m.cursor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Multi-mode last-checked is default ---
|
||||||
|
|
||||||
|
func TestMulti_LastCheckedIsDefault(t *testing.T) {
|
||||||
|
m := newMultiSelectorModel("Pick:", items("alpha", "beta", "gamma"), nil)
|
||||||
|
m.multi = true
|
||||||
|
|
||||||
|
// Check "alpha" then "gamma"
|
||||||
|
m.cursor = 0
|
||||||
|
m.toggleItem()
|
||||||
|
m.cursor = 2
|
||||||
|
m.toggleItem()
|
||||||
|
|
||||||
|
// Last checked ("gamma") should be at the end of checkOrder
|
||||||
|
lastIdx := m.checkOrder[len(m.checkOrder)-1]
|
||||||
|
if m.items[lastIdx].Name != "gamma" {
|
||||||
|
t.Errorf("last checked should be 'gamma', got %q", m.items[lastIdx].Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The (default) tag renders based on checkOrder[len-1]
|
||||||
|
content := m.View()
|
||||||
|
if !strings.Contains(content, "(default)") {
|
||||||
|
t.Fatal("should show (default) tag")
|
||||||
|
}
|
||||||
|
// "alpha" line should NOT have the default tag
|
||||||
|
for _, line := range strings.Split(content, "\n") {
|
||||||
|
if strings.Contains(line, "alpha") && strings.Contains(line, "(default)") {
|
||||||
|
t.Error("'alpha' (first checked) should not have (default) tag")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMulti_UncheckingDefaultFallsBackToNearestCheckedAbove(t *testing.T) {
|
||||||
|
// Default is "b", and checked models are "a", "b", "c".
|
||||||
|
// Unticking default should make "a" (the nearest checked item above) default.
|
||||||
|
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"b", "c", "a"})
|
||||||
|
m.multi = true
|
||||||
|
m.cursor = 1 // "b"
|
||||||
|
m.toggleItem()
|
||||||
|
|
||||||
|
lastIdx := m.checkOrder[len(m.checkOrder)-1]
|
||||||
|
if m.items[lastIdx].Name != "a" {
|
||||||
|
t.Fatalf("expected default to fall back to 'a', got %q", m.items[lastIdx].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMulti_UncheckingTopDefaultFallsBackToNearestCheckedBelow(t *testing.T) {
|
||||||
|
// Default is top item "a". With no checked item above, fallback should pick
|
||||||
|
// the nearest checked item below ("b").
|
||||||
|
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "c", "b"})
|
||||||
|
m.multi = true
|
||||||
|
m.cursor = 0 // "a"
|
||||||
|
m.toggleItem()
|
||||||
|
|
||||||
|
lastIdx := m.checkOrder[len(m.checkOrder)-1]
|
||||||
|
if m.items[lastIdx].Name != "b" {
|
||||||
|
t.Fatalf("expected default to fall back to 'b', got %q", m.items[lastIdx].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Key message helpers for testing
|
// Key message helpers for testing
|
||||||
|
|
||||||
type keyType = int
|
type keyType = int
|
||||||
|
|||||||
@@ -1,15 +1,24 @@
|
|||||||
package tui
|
package tui
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
"github.com/charmbracelet/lipgloss"
|
"github.com/charmbracelet/lipgloss"
|
||||||
"github.com/ollama/ollama/cmd/config"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/cmd/launch"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type signInTickMsg struct{}
|
||||||
|
|
||||||
|
type signInCheckMsg struct {
|
||||||
|
signedIn bool
|
||||||
|
userName string
|
||||||
|
}
|
||||||
|
|
||||||
type signInModel struct {
|
type signInModel struct {
|
||||||
modelName string
|
modelName string
|
||||||
signInURL string
|
signInURL string
|
||||||
@@ -104,9 +113,21 @@ func renderSignIn(modelName, signInURL string, spinner, width int) string {
|
|||||||
return lipgloss.NewStyle().PaddingLeft(2).Render(s.String())
|
return lipgloss.NewStyle().PaddingLeft(2).Render(s.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func checkSignIn() tea.Msg {
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
return signInCheckMsg{signedIn: false}
|
||||||
|
}
|
||||||
|
user, err := client.Whoami(context.Background())
|
||||||
|
if err == nil && user != nil && user.Name != "" {
|
||||||
|
return signInCheckMsg{signedIn: true, userName: user.Name}
|
||||||
|
}
|
||||||
|
return signInCheckMsg{signedIn: false}
|
||||||
|
}
|
||||||
|
|
||||||
// RunSignIn shows a bubbletea sign-in dialog and polls until the user signs in or cancels.
|
// RunSignIn shows a bubbletea sign-in dialog and polls until the user signs in or cancels.
|
||||||
func RunSignIn(modelName, signInURL string) (string, error) {
|
func RunSignIn(modelName, signInURL string) (string, error) {
|
||||||
config.OpenBrowser(signInURL)
|
launch.OpenBrowser(signInURL)
|
||||||
|
|
||||||
m := signInModel{
|
m := signInModel{
|
||||||
modelName: modelName,
|
modelName: modelName,
|
||||||
|
|||||||
777
cmd/tui/tui.go
777
cmd/tui/tui.go
@@ -1,16 +1,11 @@
|
|||||||
package tui
|
package tui
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
"github.com/charmbracelet/lipgloss"
|
"github.com/charmbracelet/lipgloss"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/cmd/launch"
|
||||||
"github.com/ollama/ollama/cmd/config"
|
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -45,7 +40,7 @@ var (
|
|||||||
type menuItem struct {
|
type menuItem struct {
|
||||||
title string
|
title string
|
||||||
description string
|
description string
|
||||||
integration string // integration name for loading model config, empty if not an integration
|
integration string
|
||||||
isRunModel bool
|
isRunModel bool
|
||||||
isOthers bool
|
isOthers bool
|
||||||
}
|
}
|
||||||
@@ -57,18 +52,12 @@ var mainMenuItems = []menuItem{
|
|||||||
isRunModel: true,
|
isRunModel: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
title: "Launch Claude Code",
|
|
||||||
description: "Agentic coding across large codebases",
|
|
||||||
integration: "claude",
|
integration: "claude",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
title: "Launch Codex",
|
|
||||||
description: "OpenAI's open-source coding agent",
|
|
||||||
integration: "codex",
|
integration: "codex",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
title: "Launch OpenClaw",
|
|
||||||
description: "Personal AI with 100+ skills",
|
|
||||||
integration: "openclaw",
|
integration: "openclaw",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -79,277 +68,106 @@ var othersMenuItem = menuItem{
|
|||||||
isOthers: true,
|
isOthers: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
// getOtherIntegrations dynamically builds the "Others" list from the integration
|
type model struct {
|
||||||
// registry, excluding any integrations already present in the pinned mainMenuItems.
|
state *launch.LauncherState
|
||||||
func getOtherIntegrations() []menuItem {
|
items []menuItem
|
||||||
pinned := map[string]bool{
|
cursor int
|
||||||
"run": true, // not an integration but in the pinned list
|
showOthers bool
|
||||||
|
width int
|
||||||
|
quitting bool
|
||||||
|
selected bool
|
||||||
|
action TUIAction
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newModel(state *launch.LauncherState) model {
|
||||||
|
m := model{
|
||||||
|
state: state,
|
||||||
|
}
|
||||||
|
m.showOthers = shouldExpandOthers(state)
|
||||||
|
m.items = buildMenuItems(state, m.showOthers)
|
||||||
|
m.cursor = initialCursor(state, m.items)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldExpandOthers(state *launch.LauncherState) bool {
|
||||||
|
if state == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, item := range otherIntegrationItems(state) {
|
||||||
|
if item.integration == state.LastSelection {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildMenuItems(state *launch.LauncherState, showOthers bool) []menuItem {
|
||||||
|
items := make([]menuItem, 0, len(mainMenuItems)+1)
|
||||||
for _, item := range mainMenuItems {
|
for _, item := range mainMenuItems {
|
||||||
if item.integration != "" {
|
if item.integration == "" {
|
||||||
pinned[item.integration] = true
|
items = append(items, item)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if integrationState, ok := state.Integrations[item.integration]; ok {
|
||||||
|
items = append(items, integrationMenuItem(integrationState))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var others []menuItem
|
if showOthers {
|
||||||
for _, info := range config.ListIntegrationInfos() {
|
items = append(items, otherIntegrationItems(state)...)
|
||||||
|
} else {
|
||||||
|
items = append(items, othersMenuItem)
|
||||||
|
}
|
||||||
|
|
||||||
|
return items
|
||||||
|
}
|
||||||
|
|
||||||
|
func integrationMenuItem(state launch.LauncherIntegrationState) menuItem {
|
||||||
|
description := state.Description
|
||||||
|
if description == "" {
|
||||||
|
description = "Open " + state.DisplayName + " integration"
|
||||||
|
}
|
||||||
|
return menuItem{
|
||||||
|
title: "Launch " + state.DisplayName,
|
||||||
|
description: description,
|
||||||
|
integration: state.Name,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func otherIntegrationItems(state *launch.LauncherState) []menuItem {
|
||||||
|
pinned := map[string]bool{
|
||||||
|
"claude": true,
|
||||||
|
"codex": true,
|
||||||
|
"openclaw": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
var items []menuItem
|
||||||
|
for _, info := range launch.ListIntegrationInfos() {
|
||||||
if pinned[info.Name] {
|
if pinned[info.Name] {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
desc := info.Description
|
integrationState, ok := state.Integrations[info.Name]
|
||||||
if desc == "" {
|
if !ok {
|
||||||
desc = "Open " + info.DisplayName + " integration"
|
|
||||||
}
|
|
||||||
others = append(others, menuItem{
|
|
||||||
title: "Launch " + info.DisplayName,
|
|
||||||
description: desc,
|
|
||||||
integration: info.Name,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return others
|
|
||||||
}
|
|
||||||
|
|
||||||
type model struct {
|
|
||||||
items []menuItem
|
|
||||||
cursor int
|
|
||||||
quitting bool
|
|
||||||
selected bool
|
|
||||||
changeModel bool
|
|
||||||
changeModels []string // multi-select result for Editor integrations
|
|
||||||
showOthers bool
|
|
||||||
availableModels map[string]bool
|
|
||||||
err error
|
|
||||||
|
|
||||||
showingModal bool
|
|
||||||
modalSelector selectorModel
|
|
||||||
modalItems []SelectItem
|
|
||||||
|
|
||||||
showingMultiModal bool
|
|
||||||
multiModalSelector multiSelectorModel
|
|
||||||
|
|
||||||
showingSignIn bool
|
|
||||||
signInURL string
|
|
||||||
signInModel string
|
|
||||||
signInSpinner int
|
|
||||||
signInFromModal bool // true if sign-in was triggered from modal (not main menu)
|
|
||||||
|
|
||||||
width int // terminal width from WindowSizeMsg
|
|
||||||
statusMsg string // temporary status message shown near help text
|
|
||||||
}
|
|
||||||
|
|
||||||
type signInTickMsg struct{}
|
|
||||||
|
|
||||||
type signInCheckMsg struct {
|
|
||||||
signedIn bool
|
|
||||||
userName string
|
|
||||||
}
|
|
||||||
|
|
||||||
type clearStatusMsg struct{}
|
|
||||||
|
|
||||||
func (m *model) modelExists(name string) bool {
|
|
||||||
if m.availableModels == nil || name == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if m.availableModels[name] {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// Check for prefix match (e.g., "llama2" matches "llama2:latest")
|
|
||||||
for modelName := range m.availableModels {
|
|
||||||
if strings.HasPrefix(modelName, name+":") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) buildModalItems() []SelectItem {
|
|
||||||
modelItems, _ := config.GetModelItems(context.Background())
|
|
||||||
return ReorderItems(ConvertItems(modelItems))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) openModelModal(currentModel string) {
|
|
||||||
m.modalItems = m.buildModalItems()
|
|
||||||
cursor := 0
|
|
||||||
if currentModel != "" {
|
|
||||||
for i, item := range m.modalItems {
|
|
||||||
if item.Name == currentModel || strings.HasPrefix(item.Name, currentModel+":") || strings.HasPrefix(currentModel, item.Name+":") {
|
|
||||||
cursor = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
m.modalSelector = selectorModel{
|
|
||||||
title: "Select model:",
|
|
||||||
items: m.modalItems,
|
|
||||||
cursor: cursor,
|
|
||||||
helpText: "↑/↓ navigate • enter select • ← back",
|
|
||||||
}
|
|
||||||
m.modalSelector.updateScroll(m.modalSelector.otherStart())
|
|
||||||
m.showingModal = true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) openMultiModelModal(integration string) {
|
|
||||||
items := m.buildModalItems()
|
|
||||||
var preChecked []string
|
|
||||||
if models := config.IntegrationModels(integration); len(models) > 0 {
|
|
||||||
preChecked = models
|
|
||||||
}
|
|
||||||
m.multiModalSelector = newMultiSelectorModel("Select models:", items, preChecked)
|
|
||||||
// Set cursor to the first pre-checked (last used) model
|
|
||||||
if len(preChecked) > 0 {
|
|
||||||
for i, item := range items {
|
|
||||||
if item.Name == preChecked[0] {
|
|
||||||
m.multiModalSelector.cursor = i
|
|
||||||
m.multiModalSelector.updateScroll(m.multiModalSelector.otherStart())
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
m.showingMultiModal = true
|
|
||||||
}
|
|
||||||
|
|
||||||
func isCloudModel(name string) bool {
|
|
||||||
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
|
|
||||||
}
|
|
||||||
|
|
||||||
func cloudStatusDisabled(client *api.Client) bool {
|
|
||||||
status, err := client.CloudStatusExperimental(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return status.Cloud.Disabled
|
|
||||||
}
|
|
||||||
|
|
||||||
func cloudModelDisabled(name string) bool {
|
|
||||||
if !isCloudModel(name) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
client, err := api.ClientFromEnvironment()
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return cloudStatusDisabled(client)
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkCloudSignIn checks if a cloud model needs sign-in.
|
|
||||||
// Returns a command to start sign-in if needed, or nil if already signed in.
|
|
||||||
func (m *model) checkCloudSignIn(modelName string, fromModal bool) tea.Cmd {
|
|
||||||
if modelName == "" || !isCloudModel(modelName) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
client, err := api.ClientFromEnvironment()
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if cloudStatusDisabled(client) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
user, err := client.Whoami(context.Background())
|
|
||||||
if err == nil && user != nil && user.Name != "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var aErr api.AuthorizationError
|
|
||||||
if errors.As(err, &aErr) && aErr.SigninURL != "" {
|
|
||||||
return m.startSignIn(modelName, aErr.SigninURL, fromModal)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// startSignIn initiates the sign-in flow for a cloud model.
|
|
||||||
// fromModal indicates if this was triggered from the model picker modal.
|
|
||||||
func (m *model) startSignIn(modelName, signInURL string, fromModal bool) tea.Cmd {
|
|
||||||
m.showingModal = false
|
|
||||||
m.showingSignIn = true
|
|
||||||
m.signInURL = signInURL
|
|
||||||
m.signInModel = modelName
|
|
||||||
m.signInSpinner = 0
|
|
||||||
m.signInFromModal = fromModal
|
|
||||||
|
|
||||||
config.OpenBrowser(signInURL)
|
|
||||||
|
|
||||||
return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
|
||||||
return signInTickMsg{}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkSignIn() tea.Msg {
|
|
||||||
client, err := api.ClientFromEnvironment()
|
|
||||||
if err != nil {
|
|
||||||
return signInCheckMsg{signedIn: false}
|
|
||||||
}
|
|
||||||
user, err := client.Whoami(context.Background())
|
|
||||||
if err == nil && user != nil && user.Name != "" {
|
|
||||||
return signInCheckMsg{signedIn: true, userName: user.Name}
|
|
||||||
}
|
|
||||||
return signInCheckMsg{signedIn: false}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) loadAvailableModels() {
|
|
||||||
m.availableModels = make(map[string]bool)
|
|
||||||
client, err := api.ClientFromEnvironment()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
models, err := client.List(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
cloudDisabled := cloudStatusDisabled(client)
|
|
||||||
for _, mdl := range models.Models {
|
|
||||||
if cloudDisabled && mdl.RemoteModel != "" {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
m.availableModels[mdl.Name] = true
|
items = append(items, integrationMenuItem(integrationState))
|
||||||
}
|
}
|
||||||
|
return items
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *model) buildItems() {
|
func initialCursor(state *launch.LauncherState, items []menuItem) int {
|
||||||
others := getOtherIntegrations()
|
if state == nil || state.LastSelection == "" {
|
||||||
m.items = make([]menuItem, 0, len(mainMenuItems)+1+len(others))
|
return 0
|
||||||
m.items = append(m.items, mainMenuItems...)
|
}
|
||||||
|
for i, item := range items {
|
||||||
if m.showOthers {
|
if state.LastSelection == "run" && item.isRunModel {
|
||||||
m.items = append(m.items, others...)
|
return i
|
||||||
} else {
|
}
|
||||||
m.items = append(m.items, othersMenuItem)
|
if item.integration == state.LastSelection {
|
||||||
|
return i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return 0
|
||||||
func isOthersIntegration(name string) bool {
|
|
||||||
for _, item := range getOtherIntegrations() {
|
|
||||||
if item.integration == name {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func initialModel() model {
|
|
||||||
m := model{
|
|
||||||
cursor: 0,
|
|
||||||
}
|
|
||||||
m.loadAvailableModels()
|
|
||||||
|
|
||||||
lastSelection := config.LastSelection()
|
|
||||||
if isOthersIntegration(lastSelection) {
|
|
||||||
m.showOthers = true
|
|
||||||
}
|
|
||||||
|
|
||||||
m.buildItems()
|
|
||||||
|
|
||||||
if lastSelection != "" {
|
|
||||||
for i, item := range m.items {
|
|
||||||
if lastSelection == "run" && item.isRunModel {
|
|
||||||
m.cursor = i
|
|
||||||
break
|
|
||||||
} else if item.integration == lastSelection {
|
|
||||||
m.cursor = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return m
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m model) Init() tea.Cmd {
|
func (m model) Init() tea.Cmd {
|
||||||
@@ -357,127 +175,11 @@ func (m model) Init() tea.Cmd {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
if wmsg, ok := msg.(tea.WindowSizeMsg); ok {
|
|
||||||
wasSet := m.width > 0
|
|
||||||
m.width = wmsg.Width
|
|
||||||
if wasSet {
|
|
||||||
return m, tea.EnterAltScreen
|
|
||||||
}
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := msg.(clearStatusMsg); ok {
|
|
||||||
m.statusMsg = ""
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.showingSignIn {
|
|
||||||
switch msg := msg.(type) {
|
switch msg := msg.(type) {
|
||||||
case tea.KeyMsg:
|
case tea.WindowSizeMsg:
|
||||||
switch msg.Type {
|
m.width = msg.Width
|
||||||
case tea.KeyCtrlC, tea.KeyEsc:
|
|
||||||
m.showingSignIn = false
|
|
||||||
if m.signInFromModal {
|
|
||||||
m.showingModal = true
|
|
||||||
}
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
case signInTickMsg:
|
|
||||||
m.signInSpinner++
|
|
||||||
// Check sign-in status every 5th tick (~1 second)
|
|
||||||
if m.signInSpinner%5 == 0 {
|
|
||||||
return m, tea.Batch(
|
|
||||||
tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
|
||||||
return signInTickMsg{}
|
|
||||||
}),
|
|
||||||
checkSignIn,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
return m, tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
|
||||||
return signInTickMsg{}
|
|
||||||
})
|
|
||||||
|
|
||||||
case signInCheckMsg:
|
|
||||||
if msg.signedIn {
|
|
||||||
if m.signInFromModal {
|
|
||||||
m.modalSelector.selected = m.signInModel
|
|
||||||
m.changeModel = true
|
|
||||||
} else {
|
|
||||||
m.selected = true
|
|
||||||
}
|
|
||||||
m.quitting = true
|
|
||||||
return m, tea.Quit
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.showingMultiModal {
|
|
||||||
switch msg := msg.(type) {
|
|
||||||
case tea.KeyMsg:
|
|
||||||
if msg.Type == tea.KeyLeft {
|
|
||||||
m.showingMultiModal = false
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
updated, cmd := m.multiModalSelector.Update(msg)
|
|
||||||
m.multiModalSelector = updated.(multiSelectorModel)
|
|
||||||
|
|
||||||
if m.multiModalSelector.cancelled {
|
|
||||||
m.showingMultiModal = false
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
if m.multiModalSelector.confirmed {
|
|
||||||
var selected []string
|
|
||||||
for _, idx := range m.multiModalSelector.checkOrder {
|
|
||||||
selected = append(selected, m.multiModalSelector.items[idx].Name)
|
|
||||||
}
|
|
||||||
if len(selected) > 0 {
|
|
||||||
m.changeModels = selected
|
|
||||||
m.changeModel = true
|
|
||||||
m.quitting = true
|
|
||||||
return m, tea.Quit
|
|
||||||
}
|
|
||||||
m.multiModalSelector.confirmed = false
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
return m, cmd
|
|
||||||
}
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.showingModal {
|
|
||||||
switch msg := msg.(type) {
|
|
||||||
case tea.KeyMsg:
|
|
||||||
switch msg.Type {
|
|
||||||
case tea.KeyCtrlC, tea.KeyEsc, tea.KeyLeft:
|
|
||||||
m.showingModal = false
|
|
||||||
return m, nil
|
return m, nil
|
||||||
|
|
||||||
case tea.KeyEnter:
|
|
||||||
filtered := m.modalSelector.filteredItems()
|
|
||||||
if len(filtered) > 0 && m.modalSelector.cursor < len(filtered) {
|
|
||||||
m.modalSelector.selected = filtered[m.modalSelector.cursor].Name
|
|
||||||
}
|
|
||||||
if m.modalSelector.selected != "" {
|
|
||||||
if cmd := m.checkCloudSignIn(m.modalSelector.selected, true); cmd != nil {
|
|
||||||
return m, cmd
|
|
||||||
}
|
|
||||||
m.changeModel = true
|
|
||||||
m.quitting = true
|
|
||||||
return m, tea.Quit
|
|
||||||
}
|
|
||||||
return m, nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
// Delegate navigation (up/down/pgup/pgdown/filter/backspace) to selectorModel
|
|
||||||
m.modalSelector.updateNavigation(msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
switch msg := msg.(type) {
|
|
||||||
case tea.KeyMsg:
|
case tea.KeyMsg:
|
||||||
switch msg.String() {
|
switch msg.String() {
|
||||||
case "ctrl+c", "q", "esc":
|
case "ctrl+c", "q", "esc":
|
||||||
@@ -488,231 +190,204 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||||||
if m.cursor > 0 {
|
if m.cursor > 0 {
|
||||||
m.cursor--
|
m.cursor--
|
||||||
}
|
}
|
||||||
// Auto-collapse "Others" when cursor moves back into pinned items
|
|
||||||
if m.showOthers && m.cursor < len(mainMenuItems) {
|
if m.showOthers && m.cursor < len(mainMenuItems) {
|
||||||
m.showOthers = false
|
m.showOthers = false
|
||||||
m.buildItems()
|
m.items = buildMenuItems(m.state, false)
|
||||||
|
m.cursor = min(m.cursor, len(m.items)-1)
|
||||||
}
|
}
|
||||||
|
return m, nil
|
||||||
|
|
||||||
case "down", "j":
|
case "down", "j":
|
||||||
if m.cursor < len(m.items)-1 {
|
if m.cursor < len(m.items)-1 {
|
||||||
m.cursor++
|
m.cursor++
|
||||||
}
|
}
|
||||||
// Auto-expand "Others..." when cursor lands on it
|
|
||||||
if m.cursor < len(m.items) && m.items[m.cursor].isOthers && !m.showOthers {
|
if m.cursor < len(m.items) && m.items[m.cursor].isOthers && !m.showOthers {
|
||||||
m.showOthers = true
|
m.showOthers = true
|
||||||
m.buildItems()
|
m.items = buildMenuItems(m.state, true)
|
||||||
// cursor now points at the first "other" integration
|
|
||||||
}
|
}
|
||||||
|
return m, nil
|
||||||
|
|
||||||
case "enter", " ":
|
case "enter", " ":
|
||||||
item := m.items[m.cursor]
|
if m.selectableItem(m.items[m.cursor]) {
|
||||||
|
|
||||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var configuredModel string
|
|
||||||
if item.isRunModel {
|
|
||||||
configuredModel = config.LastModel()
|
|
||||||
} else if item.integration != "" {
|
|
||||||
configuredModel = config.IntegrationModel(item.integration)
|
|
||||||
}
|
|
||||||
if cmd := m.checkCloudSignIn(configuredModel, false); cmd != nil {
|
|
||||||
return m, cmd
|
|
||||||
}
|
|
||||||
|
|
||||||
if configuredModel != "" && isCloudModel(configuredModel) && cloudModelDisabled(configuredModel) {
|
|
||||||
if item.integration != "" && config.IsEditorIntegration(item.integration) {
|
|
||||||
m.openMultiModelModal(item.integration)
|
|
||||||
} else {
|
|
||||||
m.openModelModal(configuredModel)
|
|
||||||
}
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
m.selected = true
|
m.selected = true
|
||||||
|
m.action = actionForMenuItem(m.items[m.cursor], false)
|
||||||
m.quitting = true
|
m.quitting = true
|
||||||
return m, tea.Quit
|
return m, tea.Quit
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
|
||||||
case "right", "l":
|
case "right", "l":
|
||||||
item := m.items[m.cursor]
|
item := m.items[m.cursor]
|
||||||
if item.integration != "" || item.isRunModel {
|
if item.isRunModel || m.changeableItem(item) {
|
||||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
|
m.selected = true
|
||||||
|
m.action = actionForMenuItem(item, true)
|
||||||
|
m.quitting = true
|
||||||
|
return m, tea.Quit
|
||||||
|
}
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
if item.integration != "" && config.IsEditorIntegration(item.integration) {
|
|
||||||
m.openMultiModelModal(item.integration)
|
|
||||||
} else {
|
|
||||||
var currentModel string
|
|
||||||
if item.isRunModel {
|
|
||||||
currentModel = config.LastModel()
|
|
||||||
} else if item.integration != "" {
|
|
||||||
currentModel = config.IntegrationModel(item.integration)
|
|
||||||
}
|
|
||||||
m.openModelModal(currentModel)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m model) selectableItem(item menuItem) bool {
|
||||||
|
if item.isRunModel {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if item.integration == "" || item.isOthers {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
state, ok := m.state.Integrations[item.integration]
|
||||||
|
return ok && state.Selectable
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m model) changeableItem(item menuItem) bool {
|
||||||
|
if item.integration == "" || item.isOthers {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
state, ok := m.state.Integrations[item.integration]
|
||||||
|
return ok && state.Changeable
|
||||||
|
}
|
||||||
|
|
||||||
func (m model) View() string {
|
func (m model) View() string {
|
||||||
if m.quitting {
|
if m.quitting {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.showingSignIn {
|
|
||||||
return m.renderSignInDialog()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.showingMultiModal {
|
|
||||||
return m.multiModalSelector.View()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.showingModal {
|
|
||||||
return m.renderModal()
|
|
||||||
}
|
|
||||||
|
|
||||||
s := selectorTitleStyle.Render("Ollama "+versionStyle.Render(version.Version)) + "\n\n"
|
s := selectorTitleStyle.Render("Ollama "+versionStyle.Render(version.Version)) + "\n\n"
|
||||||
|
|
||||||
for i, item := range m.items {
|
for i, item := range m.items {
|
||||||
|
s += m.renderMenuItem(i, item)
|
||||||
|
}
|
||||||
|
|
||||||
|
s += "\n" + selectorHelpStyle.Render("↑/↓ navigate • enter launch • → configure • esc quit")
|
||||||
|
|
||||||
|
if m.width > 0 {
|
||||||
|
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m model) renderMenuItem(index int, item menuItem) string {
|
||||||
cursor := ""
|
cursor := ""
|
||||||
style := menuItemStyle
|
style := menuItemStyle
|
||||||
isInstalled := true
|
title := item.title
|
||||||
|
description := item.description
|
||||||
|
modelSuffix := ""
|
||||||
|
|
||||||
if item.integration != "" {
|
if m.cursor == index {
|
||||||
isInstalled = config.IsIntegrationInstalled(item.integration)
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.cursor == i {
|
|
||||||
cursor = "▸ "
|
cursor = "▸ "
|
||||||
if isInstalled {
|
|
||||||
style = menuSelectedItemStyle
|
|
||||||
} else {
|
|
||||||
style = greyedSelectedStyle
|
|
||||||
}
|
}
|
||||||
} else if !isInstalled && item.integration != "" {
|
|
||||||
|
if item.isRunModel {
|
||||||
|
if m.cursor == index && m.state.RunModel != "" {
|
||||||
|
modelSuffix = " " + modelStyle.Render("("+m.state.RunModel+")")
|
||||||
|
}
|
||||||
|
if m.cursor == index {
|
||||||
|
style = menuSelectedItemStyle
|
||||||
|
}
|
||||||
|
} else if item.isOthers {
|
||||||
|
if m.cursor == index {
|
||||||
|
style = menuSelectedItemStyle
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
integrationState := m.state.Integrations[item.integration]
|
||||||
|
if !integrationState.Selectable {
|
||||||
|
if m.cursor == index {
|
||||||
|
style = greyedSelectedStyle
|
||||||
|
} else {
|
||||||
style = greyedStyle
|
style = greyedStyle
|
||||||
}
|
}
|
||||||
|
} else if m.cursor == index {
|
||||||
title := item.title
|
style = menuSelectedItemStyle
|
||||||
var modelSuffix string
|
|
||||||
if item.integration != "" {
|
|
||||||
if !isInstalled {
|
|
||||||
title += " " + notInstalledStyle.Render("(not installed)")
|
|
||||||
} else if m.cursor == i {
|
|
||||||
if mdl := config.IntegrationModel(item.integration); mdl != "" && m.modelExists(mdl) {
|
|
||||||
modelSuffix = " " + modelStyle.Render("("+mdl+")")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if item.isRunModel && m.cursor == i {
|
|
||||||
if mdl := config.LastModel(); mdl != "" && m.modelExists(mdl) {
|
|
||||||
modelSuffix = " " + modelStyle.Render("("+mdl+")")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s += style.Render(cursor+title) + modelSuffix + "\n"
|
if m.cursor == index && integrationState.CurrentModel != "" {
|
||||||
|
modelSuffix = " " + modelStyle.Render("("+integrationState.CurrentModel+")")
|
||||||
|
}
|
||||||
|
|
||||||
desc := item.description
|
if !integrationState.Installed {
|
||||||
if !isInstalled && item.integration != "" && m.cursor == i {
|
if integrationState.AutoInstallable {
|
||||||
if hint := config.IntegrationInstallHint(item.integration); hint != "" {
|
title += " " + notInstalledStyle.Render("(install)")
|
||||||
desc = hint
|
|
||||||
} else {
|
} else {
|
||||||
desc = "not installed"
|
title += " " + notInstalledStyle.Render("(not installed)")
|
||||||
|
}
|
||||||
|
if m.cursor == index {
|
||||||
|
if integrationState.AutoInstallable {
|
||||||
|
description = "Press enter to install"
|
||||||
|
} else if integrationState.InstallHint != "" {
|
||||||
|
description = integrationState.InstallHint
|
||||||
|
} else {
|
||||||
|
description = "not installed"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
s += menuDescStyle.Render(desc) + "\n\n"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.statusMsg != "" {
|
return style.Render(cursor+title) + modelSuffix + "\n" + menuDescStyle.Render(description) + "\n\n"
|
||||||
s += "\n" + lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "124", Dark: "210"}).Render(m.statusMsg) + "\n"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s += "\n" + selectorHelpStyle.Render("↑/↓ navigate • enter launch • → change model • esc quit")
|
type TUIActionKind int
|
||||||
|
|
||||||
if m.width > 0 {
|
|
||||||
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m model) renderModal() string {
|
|
||||||
modalStyle := lipgloss.NewStyle().
|
|
||||||
PaddingBottom(1).
|
|
||||||
PaddingRight(2)
|
|
||||||
|
|
||||||
s := modalStyle.Render(m.modalSelector.renderContent())
|
|
||||||
if m.width > 0 {
|
|
||||||
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m model) renderSignInDialog() string {
|
|
||||||
return renderSignIn(m.signInModel, m.signInURL, m.signInSpinner, m.width)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Selection int
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
SelectionNone Selection = iota
|
TUIActionNone TUIActionKind = iota
|
||||||
SelectionRunModel
|
TUIActionRunModel
|
||||||
SelectionChangeRunModel
|
TUIActionLaunchIntegration
|
||||||
SelectionIntegration // Generic integration selection
|
|
||||||
SelectionChangeIntegration // Generic change model for integration
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Result struct {
|
type TUIAction struct {
|
||||||
Selection Selection
|
Kind TUIActionKind
|
||||||
Integration string // integration name if applicable
|
Integration string
|
||||||
Model string // model name if selected from single-select modal
|
ForceConfigure bool
|
||||||
Models []string // models selected from multi-select modal (Editor integrations)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Run() (Result, error) {
|
func (a TUIAction) LastSelection() string {
|
||||||
m := initialModel()
|
switch a.Kind {
|
||||||
p := tea.NewProgram(m)
|
case TUIActionRunModel:
|
||||||
|
return "run"
|
||||||
|
case TUIActionLaunchIntegration:
|
||||||
|
return a.Integration
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
finalModel, err := p.Run()
|
func (a TUIAction) RunModelRequest() launch.RunModelRequest {
|
||||||
|
return launch.RunModelRequest{ForcePicker: a.ForceConfigure}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a TUIAction) IntegrationLaunchRequest() launch.IntegrationLaunchRequest {
|
||||||
|
return launch.IntegrationLaunchRequest{
|
||||||
|
Name: a.Integration,
|
||||||
|
ForceConfigure: a.ForceConfigure,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func actionForMenuItem(item menuItem, forceConfigure bool) TUIAction {
|
||||||
|
switch {
|
||||||
|
case item.isRunModel:
|
||||||
|
return TUIAction{Kind: TUIActionRunModel, ForceConfigure: forceConfigure}
|
||||||
|
case item.integration != "":
|
||||||
|
return TUIAction{Kind: TUIActionLaunchIntegration, Integration: item.integration, ForceConfigure: forceConfigure}
|
||||||
|
default:
|
||||||
|
return TUIAction{Kind: TUIActionNone}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func RunMenu(state *launch.LauncherState) (TUIAction, error) {
|
||||||
|
menu := newModel(state)
|
||||||
|
program := tea.NewProgram(menu)
|
||||||
|
|
||||||
|
finalModel, err := program.Run()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Result{Selection: SelectionNone}, fmt.Errorf("error running TUI: %w", err)
|
return TUIAction{Kind: TUIActionNone}, fmt.Errorf("error running TUI: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fm := finalModel.(model)
|
finalMenu := finalModel.(model)
|
||||||
if fm.err != nil {
|
if !finalMenu.selected {
|
||||||
return Result{Selection: SelectionNone}, fm.err
|
return TUIAction{Kind: TUIActionNone}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if !fm.selected && !fm.changeModel {
|
return finalMenu.action, nil
|
||||||
return Result{Selection: SelectionNone}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
item := fm.items[fm.cursor]
|
|
||||||
|
|
||||||
if fm.changeModel {
|
|
||||||
if item.isRunModel {
|
|
||||||
return Result{
|
|
||||||
Selection: SelectionChangeRunModel,
|
|
||||||
Model: fm.modalSelector.selected,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
return Result{
|
|
||||||
Selection: SelectionChangeIntegration,
|
|
||||||
Integration: item.integration,
|
|
||||||
Model: fm.modalSelector.selected,
|
|
||||||
Models: fm.changeModels,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if item.isRunModel {
|
|
||||||
return Result{Selection: SelectionRunModel}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return Result{
|
|
||||||
Selection: SelectionIntegration,
|
|
||||||
Integration: item.integration,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
178
cmd/tui/tui_test.go
Normal file
178
cmd/tui/tui_test.go
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
package tui
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
"github.com/ollama/ollama/cmd/launch"
|
||||||
|
)
|
||||||
|
|
||||||
|
func launcherTestState() *launch.LauncherState {
|
||||||
|
return &launch.LauncherState{
|
||||||
|
LastSelection: "run",
|
||||||
|
RunModel: "qwen3:8b",
|
||||||
|
Integrations: map[string]launch.LauncherIntegrationState{
|
||||||
|
"claude": {
|
||||||
|
Name: "claude",
|
||||||
|
DisplayName: "Claude Code",
|
||||||
|
Description: "Anthropic's coding tool with subagents",
|
||||||
|
Selectable: true,
|
||||||
|
Changeable: true,
|
||||||
|
CurrentModel: "glm-5:cloud",
|
||||||
|
},
|
||||||
|
"codex": {
|
||||||
|
Name: "codex",
|
||||||
|
DisplayName: "Codex",
|
||||||
|
Description: "OpenAI's open-source coding agent",
|
||||||
|
Selectable: true,
|
||||||
|
Changeable: true,
|
||||||
|
},
|
||||||
|
"openclaw": {
|
||||||
|
Name: "openclaw",
|
||||||
|
DisplayName: "OpenClaw",
|
||||||
|
Description: "Personal AI with 100+ skills",
|
||||||
|
Selectable: true,
|
||||||
|
Changeable: true,
|
||||||
|
AutoInstallable: true,
|
||||||
|
},
|
||||||
|
"droid": {
|
||||||
|
Name: "droid",
|
||||||
|
DisplayName: "Droid",
|
||||||
|
Description: "Factory's coding agent across terminal and IDEs",
|
||||||
|
Selectable: true,
|
||||||
|
Changeable: true,
|
||||||
|
},
|
||||||
|
"pi": {
|
||||||
|
Name: "pi",
|
||||||
|
DisplayName: "Pi",
|
||||||
|
Description: "Minimal AI agent toolkit with plugin support",
|
||||||
|
Selectable: true,
|
||||||
|
Changeable: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMenuRendersPinnedItemsAndMore(t *testing.T) {
|
||||||
|
view := newModel(launcherTestState()).View()
|
||||||
|
for _, want := range []string{"Run a model", "Launch Claude Code", "Launch Codex", "Launch OpenClaw", "More..."} {
|
||||||
|
if !strings.Contains(view, want) {
|
||||||
|
t.Fatalf("expected menu view to contain %q\n%s", want, view)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMenuExpandsOthersFromLastSelection(t *testing.T) {
|
||||||
|
state := launcherTestState()
|
||||||
|
state.LastSelection = "pi"
|
||||||
|
|
||||||
|
menu := newModel(state)
|
||||||
|
if !menu.showOthers {
|
||||||
|
t.Fatal("expected others section to expand when last selection is in the overflow list")
|
||||||
|
}
|
||||||
|
view := menu.View()
|
||||||
|
if !strings.Contains(view, "Launch Pi") {
|
||||||
|
t.Fatalf("expected expanded view to contain overflow integration\n%s", view)
|
||||||
|
}
|
||||||
|
if strings.Contains(view, "More...") {
|
||||||
|
t.Fatalf("expected expanded view to replace More... item\n%s", view)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMenuEnterOnRunSelectsRun(t *testing.T) {
|
||||||
|
menu := newModel(launcherTestState())
|
||||||
|
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||||
|
got := updated.(model)
|
||||||
|
want := TUIAction{Kind: TUIActionRunModel}
|
||||||
|
if !got.selected || got.action != want {
|
||||||
|
t.Fatalf("expected enter on run to select run action, got selected=%v action=%v", got.selected, got.action)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMenuRightOnRunSelectsChangeRun(t *testing.T) {
|
||||||
|
menu := newModel(launcherTestState())
|
||||||
|
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
|
||||||
|
got := updated.(model)
|
||||||
|
want := TUIAction{Kind: TUIActionRunModel, ForceConfigure: true}
|
||||||
|
if !got.selected || got.action != want {
|
||||||
|
t.Fatalf("expected right on run to select change-run action, got selected=%v action=%v", got.selected, got.action)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMenuEnterOnIntegrationSelectsLaunch(t *testing.T) {
|
||||||
|
menu := newModel(launcherTestState())
|
||||||
|
menu.cursor = 1
|
||||||
|
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||||
|
got := updated.(model)
|
||||||
|
want := TUIAction{Kind: TUIActionLaunchIntegration, Integration: "claude"}
|
||||||
|
if !got.selected || got.action != want {
|
||||||
|
t.Fatalf("expected enter on integration to launch, got selected=%v action=%v", got.selected, got.action)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMenuRightOnIntegrationSelectsConfigure(t *testing.T) {
|
||||||
|
menu := newModel(launcherTestState())
|
||||||
|
menu.cursor = 1
|
||||||
|
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
|
||||||
|
got := updated.(model)
|
||||||
|
want := TUIAction{Kind: TUIActionLaunchIntegration, Integration: "claude", ForceConfigure: true}
|
||||||
|
if !got.selected || got.action != want {
|
||||||
|
t.Fatalf("expected right on integration to configure, got selected=%v action=%v", got.selected, got.action)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMenuIgnoresDisabledActions(t *testing.T) {
|
||||||
|
state := launcherTestState()
|
||||||
|
claude := state.Integrations["claude"]
|
||||||
|
claude.Selectable = false
|
||||||
|
claude.Changeable = false
|
||||||
|
state.Integrations["claude"] = claude
|
||||||
|
|
||||||
|
menu := newModel(state)
|
||||||
|
menu.cursor = 1
|
||||||
|
|
||||||
|
updatedEnter, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||||
|
if updatedEnter.(model).selected {
|
||||||
|
t.Fatal("expected non-selectable integration to ignore enter")
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedRight, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
|
||||||
|
if updatedRight.(model).selected {
|
||||||
|
t.Fatal("expected non-changeable integration to ignore right")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMenuShowsCurrentModelSuffixes(t *testing.T) {
|
||||||
|
menu := newModel(launcherTestState())
|
||||||
|
runView := menu.View()
|
||||||
|
if !strings.Contains(runView, "(qwen3:8b)") {
|
||||||
|
t.Fatalf("expected run row to show current model suffix\n%s", runView)
|
||||||
|
}
|
||||||
|
|
||||||
|
menu.cursor = 1
|
||||||
|
integrationView := menu.View()
|
||||||
|
if !strings.Contains(integrationView, "(glm-5:cloud)") {
|
||||||
|
t.Fatalf("expected integration row to show current model suffix\n%s", integrationView)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMenuShowsInstallStatusAndHint(t *testing.T) {
|
||||||
|
state := launcherTestState()
|
||||||
|
codex := state.Integrations["codex"]
|
||||||
|
codex.Installed = false
|
||||||
|
codex.Selectable = false
|
||||||
|
codex.Changeable = false
|
||||||
|
codex.InstallHint = "Install from https://example.com/codex"
|
||||||
|
state.Integrations["codex"] = codex
|
||||||
|
|
||||||
|
menu := newModel(state)
|
||||||
|
menu.cursor = 2
|
||||||
|
view := menu.View()
|
||||||
|
if !strings.Contains(view, "(not installed)") {
|
||||||
|
t.Fatalf("expected not-installed marker\n%s", view)
|
||||||
|
}
|
||||||
|
if !strings.Contains(view, codex.InstallHint) {
|
||||||
|
t.Fatalf("expected install hint in description\n%s", view)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -257,10 +257,11 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
bts = sanitizeNonFiniteJSON(bts)
|
||||||
|
|
||||||
var p ModelParameters
|
var p ModelParameters
|
||||||
if err := json.Unmarshal(bts, &p); err != nil {
|
if err := json.Unmarshal(bts, &p); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, fmt.Errorf("parse config.json: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(p.Architectures) < 1 {
|
if len(p.Architectures) < 1 {
|
||||||
@@ -315,16 +316,20 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
|||||||
conv = &glm4MoeLiteModel{}
|
conv = &glm4MoeLiteModel{}
|
||||||
case "GlmOcrForConditionalGeneration":
|
case "GlmOcrForConditionalGeneration":
|
||||||
conv = &glmOcrModel{}
|
conv = &glmOcrModel{}
|
||||||
case "Lfm2ForCausalLM":
|
case "Lfm2ForCausalLM", "Lfm2MoeForCausalLM":
|
||||||
conv = &lfm2Model{}
|
conv = &lfm2Model{}
|
||||||
case "Qwen3NextForCausalLM":
|
case "Lfm2VlForConditionalGeneration":
|
||||||
|
conv = &lfm2VLTextModel{}
|
||||||
|
case "Qwen3NextForCausalLM", "Qwen3_5ForConditionalGeneration", "Qwen3_5MoeForConditionalGeneration":
|
||||||
conv = &qwen3NextModel{}
|
conv = &qwen3NextModel{}
|
||||||
|
case "NemotronHForCausalLM":
|
||||||
|
conv = &nemotronHModel{}
|
||||||
default:
|
default:
|
||||||
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(bts, conv); err != nil {
|
if err := json.Unmarshal(bts, conv); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, fmt.Errorf("parse config.json for %q: %w", p.Architectures[0], err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if t, ok := conv.(moreParser); ok {
|
if t, ok := conv.(moreParser); ok {
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package convert
|
package convert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"cmp"
|
||||||
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -13,42 +15,149 @@ type lfm2Model struct {
|
|||||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
IntermediateSize uint32 `json:"intermediate_size"`
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
BlockFFDim uint32 `json:"block_ff_dim"`
|
||||||
|
BlockMultipleOf uint32 `json:"block_multiple_of"`
|
||||||
|
BlockAutoAdjustFFDim bool `json:"block_auto_adjust_ff_dim"`
|
||||||
|
BlockFFNDimMultiplier float32 `json:"block_ffn_dim_multiplier"`
|
||||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
RopeTheta float32 `json:"rope_theta"`
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
NormEps float32 `json:"norm_eps"`
|
NormEps float32 `json:"norm_eps"`
|
||||||
ConvLCache uint32 `json:"conv_L_cache"`
|
ConvLCache uint32 `json:"conv_L_cache"`
|
||||||
|
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
|
||||||
|
NumExperts uint32 `json:"num_experts"`
|
||||||
|
NumLocalExperts uint32 `json:"num_local_experts"`
|
||||||
|
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||||
|
NumDenseLayers uint32 `json:"num_dense_layers"`
|
||||||
|
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
|
||||||
LayerTypes []string `json:"layer_types"`
|
LayerTypes []string `json:"layer_types"`
|
||||||
TieEmbedding bool `json:"tie_embedding"`
|
TieEmbedding bool `json:"tie_embedding"`
|
||||||
|
RopeParameters struct {
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
} `json:"rope_parameters"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ ModelConverter = (*lfm2Model)(nil)
|
var _ ModelConverter = (*lfm2Model)(nil)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultMaxPositionEmbeddings = uint32(128_000)
|
||||||
|
fallbackContextLength = uint32(32_768)
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *lfm2Model) isMoE() bool {
|
||||||
|
return p.ModelType == "lfm2_moe" || p.expertCount() > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *lfm2Model) ropeFreqBase() float32 {
|
||||||
|
if p.RopeTheta != 0 {
|
||||||
|
return p.RopeTheta
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.RopeParameters.RopeTheta
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *lfm2Model) expertCount() uint32 {
|
||||||
|
if p.NumLocalExperts > 0 {
|
||||||
|
return p.NumLocalExperts
|
||||||
|
}
|
||||||
|
return p.NumExperts
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *lfm2Model) feedForwardLength() uint32 {
|
||||||
|
ff := p.IntermediateSize
|
||||||
|
if p.BlockFFDim != 0 {
|
||||||
|
ff = p.BlockFFDim
|
||||||
|
}
|
||||||
|
|
||||||
|
if !p.BlockAutoAdjustFFDim || p.BlockMultipleOf == 0 {
|
||||||
|
return ff
|
||||||
|
}
|
||||||
|
|
||||||
|
ff = (2 * ff) / 3
|
||||||
|
|
||||||
|
// Keep default multiplier behavior consistent with llama.cpp conversion.
|
||||||
|
if p.BlockFFNDimMultiplier != 0 {
|
||||||
|
ff = uint32(float32(ff) * p.BlockFFNDimMultiplier)
|
||||||
|
}
|
||||||
|
|
||||||
|
m := p.BlockMultipleOf
|
||||||
|
return m * ((ff + m - 1) / m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *lfm2Model) hasKnownContextLengthFallbackSignature() bool {
|
||||||
|
return p.isMoE() &&
|
||||||
|
p.VocabSize == 65536 &&
|
||||||
|
p.HiddenSize == 2048 &&
|
||||||
|
p.NumHiddenLayers == 40 &&
|
||||||
|
p.IntermediateSize == 11776 &&
|
||||||
|
p.NumAttentionHeads == 32 &&
|
||||||
|
p.NumKeyValueHeads == 8 &&
|
||||||
|
p.NumDenseLayers == 2 &&
|
||||||
|
p.expertCount() == 64 &&
|
||||||
|
p.NumExpertsPerToken == 4 &&
|
||||||
|
p.MoEIntermediateSize == 1536
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *lfm2Model) contextLength() uint32 {
|
||||||
|
if p.MaxPositionEmbeddings == defaultMaxPositionEmbeddings && p.hasKnownContextLengthFallbackSignature() {
|
||||||
|
return fallbackContextLength
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.MaxPositionEmbeddings
|
||||||
|
}
|
||||||
|
|
||||||
func (p *lfm2Model) KV(t *Tokenizer) KV {
|
func (p *lfm2Model) KV(t *Tokenizer) KV {
|
||||||
|
architecture := "lfm2"
|
||||||
|
if p.isMoE() {
|
||||||
|
architecture = "lfm2moe"
|
||||||
|
}
|
||||||
|
|
||||||
kv := p.ModelParameters.KV(t)
|
kv := p.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "lfm2"
|
kv["general.architecture"] = architecture
|
||||||
kv["lfm2.vocab_size"] = p.VocabSize
|
kv["tokenizer.ggml.pre"] = "lfm2"
|
||||||
kv["lfm2.block_count"] = p.NumHiddenLayers
|
kv["vocab_size"] = p.VocabSize
|
||||||
kv["lfm2.embedding_length"] = p.HiddenSize
|
kv["block_count"] = p.NumHiddenLayers
|
||||||
kv["lfm2.feed_forward_length"] = p.IntermediateSize
|
kv["embedding_length"] = p.HiddenSize
|
||||||
kv["lfm2.context_length"] = p.MaxPositionEmbeddings
|
kv["feed_forward_length"] = p.feedForwardLength()
|
||||||
|
kv["context_length"] = p.contextLength()
|
||||||
|
|
||||||
// Build per-layer KV head count array based on layer_types
|
// Build per-layer KV head count array based on layer_types
|
||||||
// (0 = shortconv layer, non-zero = attention layer with that many KV heads)
|
// (0 = shortconv layer, non-zero = attention layer with that many KV heads).
|
||||||
|
//
|
||||||
|
// Dense LFM2 in HF defaults to all attention layers when layer_types is absent.
|
||||||
|
// Preserve that behavior to avoid accidentally emitting all-conv metadata.
|
||||||
kvHeadCounts := make([]uint32, p.NumHiddenLayers)
|
kvHeadCounts := make([]uint32, p.NumHiddenLayers)
|
||||||
|
if len(p.LayerTypes) == 0 {
|
||||||
|
for i := range p.NumHiddenLayers {
|
||||||
|
kvHeadCounts[i] = p.NumKeyValueHeads
|
||||||
|
}
|
||||||
|
} else {
|
||||||
for i := range p.NumHiddenLayers {
|
for i := range p.NumHiddenLayers {
|
||||||
if int(i) < len(p.LayerTypes) && p.LayerTypes[i] == "full_attention" {
|
if int(i) < len(p.LayerTypes) && p.LayerTypes[i] == "full_attention" {
|
||||||
kvHeadCounts[i] = p.NumKeyValueHeads
|
kvHeadCounts[i] = p.NumKeyValueHeads
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kv["lfm2.attention.head_count"] = p.NumAttentionHeads
|
kv["attention.head_count"] = p.NumAttentionHeads
|
||||||
kv["lfm2.attention.head_count_kv"] = kvHeadCounts
|
kv["attention.head_count_kv"] = kvHeadCounts
|
||||||
kv["lfm2.attention.key_length"] = p.HiddenSize / p.NumAttentionHeads
|
kv["attention.key_length"] = p.HiddenSize / p.NumAttentionHeads
|
||||||
kv["lfm2.attention.value_length"] = p.HiddenSize / p.NumAttentionHeads
|
kv["attention.value_length"] = p.HiddenSize / p.NumAttentionHeads
|
||||||
kv["lfm2.attention.layer_norm_rms_epsilon"] = p.NormEps
|
kv["attention.layer_norm_rms_epsilon"] = p.NormEps
|
||||||
kv["lfm2.rope.freq_base"] = p.RopeTheta
|
kv["shortconv.l_cache"] = p.ConvLCache
|
||||||
kv["lfm2.shortconv.l_cache"] = p.ConvLCache
|
|
||||||
|
if ropeFreqBase := p.ropeFreqBase(); ropeFreqBase != 0 {
|
||||||
|
kv["rope.freq_base"] = ropeFreqBase
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.isMoE() {
|
||||||
|
kv["expert_count"] = p.expertCount()
|
||||||
|
kv["expert_used_count"] = p.NumExpertsPerToken
|
||||||
|
kv["expert_feed_forward_length"] = p.MoEIntermediateSize
|
||||||
|
kv["leading_dense_block_count"] = p.NumDenseLayers
|
||||||
|
kv["expert_gating_func"] = uint32(2) // sigmoid
|
||||||
|
kv["expert_weights_scale"] = cmp.Or(p.RoutedScalingFactor, float32(1.0))
|
||||||
|
}
|
||||||
|
|
||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
@@ -56,6 +165,30 @@ func (p *lfm2Model) KV(t *Tokenizer) KV {
|
|||||||
func (p *lfm2Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
func (p *lfm2Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
var out []*ggml.Tensor
|
var out []*ggml.Tensor
|
||||||
|
|
||||||
|
if p.isMoE() {
|
||||||
|
merges := make([]merge, 0, p.NumHiddenLayers*3)
|
||||||
|
for i := range p.NumHiddenLayers {
|
||||||
|
if i < p.NumDenseLayers {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
merges = append(merges, merge{
|
||||||
|
fmt.Sprintf("blk.%d.feed_forward.experts.*.w1.weight", i),
|
||||||
|
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||||
|
}, merge{
|
||||||
|
fmt.Sprintf("blk.%d.feed_forward.experts.*.w2.weight", i),
|
||||||
|
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||||
|
}, merge{
|
||||||
|
fmt.Sprintf("blk.%d.feed_forward.experts.*.w3.weight", i),
|
||||||
|
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
merged, remaining := mergeTensors(ts, merges...)
|
||||||
|
out = append(out, merged...)
|
||||||
|
ts = remaining
|
||||||
|
}
|
||||||
|
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
shape := t.Shape()
|
shape := t.Shape()
|
||||||
|
|
||||||
@@ -80,7 +213,7 @@ func (p *lfm2Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
|||||||
func (p *lfm2Model) Replacements() []string {
|
func (p *lfm2Model) Replacements() []string {
|
||||||
return []string{
|
return []string{
|
||||||
"model.embed_tokens", "token_embd",
|
"model.embed_tokens", "token_embd",
|
||||||
"model.embedding_norm", "output_norm",
|
"model.embedding_norm", "token_embd_norm",
|
||||||
"model.layers", "blk",
|
"model.layers", "blk",
|
||||||
"operator_norm", "attn_norm",
|
"operator_norm", "attn_norm",
|
||||||
"self_attn.q_proj", "attn_q",
|
"self_attn.q_proj", "attn_q",
|
||||||
@@ -92,6 +225,8 @@ func (p *lfm2Model) Replacements() []string {
|
|||||||
"conv.conv", "shortconv.conv",
|
"conv.conv", "shortconv.conv",
|
||||||
"conv.in_proj", "shortconv.in_proj",
|
"conv.in_proj", "shortconv.in_proj",
|
||||||
"conv.out_proj", "shortconv.out_proj",
|
"conv.out_proj", "shortconv.out_proj",
|
||||||
|
"feed_forward.gate", "ffn_gate_inp",
|
||||||
|
"feed_forward.expert_bias", "exp_probs_b.bias",
|
||||||
"feed_forward.w1", "ffn_gate",
|
"feed_forward.w1", "ffn_gate",
|
||||||
"feed_forward.w2", "ffn_down",
|
"feed_forward.w2", "ffn_down",
|
||||||
"feed_forward.w3", "ffn_up",
|
"feed_forward.w3", "ffn_up",
|
||||||
|
|||||||
271
convert/convert_lfm2_test.go
Normal file
271
convert/convert_lfm2_test.go
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type lfm2StubTensor struct {
|
||||||
|
tensorBase
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLFM2StubTensor(name string, shape []uint64) *lfm2StubTensor {
|
||||||
|
return &lfm2StubTensor{
|
||||||
|
tensorBase: tensorBase{
|
||||||
|
name: name,
|
||||||
|
shape: shape,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *lfm2StubTensor) WriteTo(io.Writer) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *lfm2StubTensor) Clone() Tensor {
|
||||||
|
return &lfm2StubTensor{
|
||||||
|
tensorBase: tensorBase{
|
||||||
|
name: t.name,
|
||||||
|
shape: slices.Clone(t.shape),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLFM2MoEKV(t *testing.T) {
|
||||||
|
var p lfm2Model
|
||||||
|
p.ModelParameters.ModelType = "lfm2_moe"
|
||||||
|
p.VocabSize = 65536
|
||||||
|
p.HiddenSize = 2048
|
||||||
|
p.NumHiddenLayers = 4
|
||||||
|
p.MaxPositionEmbeddings = 128000
|
||||||
|
p.IntermediateSize = 11776
|
||||||
|
p.NumAttentionHeads = 32
|
||||||
|
p.NumKeyValueHeads = 8
|
||||||
|
p.LayerTypes = []string{"conv", "full_attention", "conv", "full_attention"}
|
||||||
|
p.NormEps = 1e-5
|
||||||
|
p.ConvLCache = 3
|
||||||
|
p.MoEIntermediateSize = 1536
|
||||||
|
p.NumExperts = 64
|
||||||
|
p.NumExpertsPerToken = 4
|
||||||
|
p.NumDenseLayers = 2
|
||||||
|
p.RopeParameters.RopeTheta = 1_000_000
|
||||||
|
|
||||||
|
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
|
||||||
|
|
||||||
|
if got, want := kv["general.architecture"], "lfm2moe"; got != want {
|
||||||
|
t.Fatalf("general.architecture = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := kv["tokenizer.ggml.pre"], "lfm2"; got != want {
|
||||||
|
t.Fatalf("tokenizer.ggml.pre = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := kv["expert_count"], uint32(64); got != want {
|
||||||
|
t.Fatalf("expert_count = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := kv["expert_used_count"], uint32(4); got != want {
|
||||||
|
t.Fatalf("expert_used_count = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := kv["expert_feed_forward_length"], uint32(1536); got != want {
|
||||||
|
t.Fatalf("expert_feed_forward_length = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := kv["leading_dense_block_count"], uint32(2); got != want {
|
||||||
|
t.Fatalf("leading_dense_block_count = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := kv["expert_gating_func"], uint32(2); got != want {
|
||||||
|
t.Fatalf("expert_gating_func = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotHeadCounts, ok := kv["attention.head_count_kv"].([]uint32)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("attention.head_count_kv has unexpected type %T", kv["attention.head_count_kv"])
|
||||||
|
}
|
||||||
|
|
||||||
|
wantHeadCounts := []uint32{0, 8, 0, 8}
|
||||||
|
if !slices.Equal(gotHeadCounts, wantHeadCounts) {
|
||||||
|
t.Fatalf("attention.head_count_kv = %v, want %v", gotHeadCounts, wantHeadCounts)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := kv["rope.freq_base"], float32(1_000_000); got != want {
|
||||||
|
t.Fatalf("rope.freq_base = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLFM2DenseKV(t *testing.T) {
|
||||||
|
p := lfm2Model{
|
||||||
|
ModelParameters: ModelParameters{ModelType: "lfm2", VocabSize: 32000},
|
||||||
|
HiddenSize: 1024,
|
||||||
|
NumHiddenLayers: 2,
|
||||||
|
MaxPositionEmbeddings: 32768,
|
||||||
|
IntermediateSize: 4096,
|
||||||
|
NumAttentionHeads: 16,
|
||||||
|
NumKeyValueHeads: 4,
|
||||||
|
LayerTypes: []string{"conv", "full_attention"},
|
||||||
|
NormEps: 1e-5,
|
||||||
|
ConvLCache: 3,
|
||||||
|
RopeTheta: 10000,
|
||||||
|
}
|
||||||
|
|
||||||
|
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
|
||||||
|
|
||||||
|
if got, want := kv["general.architecture"], "lfm2"; got != want {
|
||||||
|
t.Fatalf("general.architecture = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := kv["tokenizer.ggml.pre"], "lfm2"; got != want {
|
||||||
|
t.Fatalf("tokenizer.ggml.pre = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := kv["expert_count"]; ok {
|
||||||
|
t.Fatalf("expert_count should not be set for dense lfm2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLFM2MoETensors(t *testing.T) {
|
||||||
|
p := lfm2Model{
|
||||||
|
ModelParameters: ModelParameters{ModelType: "lfm2_moe"},
|
||||||
|
NumHiddenLayers: 4,
|
||||||
|
NumDenseLayers: 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
in := []Tensor{
|
||||||
|
newLFM2StubTensor("blk.2.feed_forward.experts.0.w1.weight", []uint64{1536, 2048}),
|
||||||
|
newLFM2StubTensor("blk.2.feed_forward.experts.1.w1.weight", []uint64{1536, 2048}),
|
||||||
|
newLFM2StubTensor("blk.2.feed_forward.experts.0.w2.weight", []uint64{2048, 1536}),
|
||||||
|
newLFM2StubTensor("blk.2.feed_forward.experts.1.w2.weight", []uint64{2048, 1536}),
|
||||||
|
newLFM2StubTensor("blk.2.feed_forward.experts.0.w3.weight", []uint64{1536, 2048}),
|
||||||
|
newLFM2StubTensor("blk.2.feed_forward.experts.1.w3.weight", []uint64{1536, 2048}),
|
||||||
|
newLFM2StubTensor("blk.0.shortconv.conv.weight", []uint64{2048, 1, 3}),
|
||||||
|
}
|
||||||
|
|
||||||
|
out := p.Tensors(in)
|
||||||
|
|
||||||
|
byName := make(map[string][]uint64, len(out))
|
||||||
|
for _, tns := range out {
|
||||||
|
byName[tns.Name] = tns.Shape
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, ok := byName["blk.2.ffn_gate_exps.weight"]; !ok {
|
||||||
|
t.Fatalf("missing merged tensor blk.2.ffn_gate_exps.weight")
|
||||||
|
} else if !slices.Equal(got, []uint64{2, 1536, 2048}) {
|
||||||
|
t.Fatalf("blk.2.ffn_gate_exps.weight shape = %v, want [2 1536 2048]", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, ok := byName["blk.2.ffn_down_exps.weight"]; !ok {
|
||||||
|
t.Fatalf("missing merged tensor blk.2.ffn_down_exps.weight")
|
||||||
|
} else if !slices.Equal(got, []uint64{2, 2048, 1536}) {
|
||||||
|
t.Fatalf("blk.2.ffn_down_exps.weight shape = %v, want [2 2048 1536]", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, ok := byName["blk.2.ffn_up_exps.weight"]; !ok {
|
||||||
|
t.Fatalf("missing merged tensor blk.2.ffn_up_exps.weight")
|
||||||
|
} else if !slices.Equal(got, []uint64{2, 1536, 2048}) {
|
||||||
|
t.Fatalf("blk.2.ffn_up_exps.weight shape = %v, want [2 1536 2048]", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, ok := byName["blk.0.shortconv.conv.weight"]; !ok {
|
||||||
|
t.Fatalf("missing shortconv tensor")
|
||||||
|
} else if !slices.Equal(got, []uint64{2048, 3}) {
|
||||||
|
t.Fatalf("blk.0.shortconv.conv.weight shape = %v, want [2048 3]", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := byName["blk.2.feed_forward.experts.0.w1.weight"]; ok {
|
||||||
|
t.Fatalf("unmerged expert tensor should not be present")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLFM2MoEReplacements(t *testing.T) {
|
||||||
|
p := lfm2Model{}
|
||||||
|
replacer := strings.NewReplacer(p.Replacements()...)
|
||||||
|
|
||||||
|
if got, want := replacer.Replace("model.layers.2.feed_forward.expert_bias"), "blk.2.exp_probs_b.bias"; got != want {
|
||||||
|
t.Fatalf("expert bias replacement = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := replacer.Replace("model.layers.2.feed_forward.gate.weight"), "blk.2.ffn_gate_inp.weight"; got != want {
|
||||||
|
t.Fatalf("gate replacement = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLFM2KVContextLengthEdgeCaseFallbackOverride(t *testing.T) {
|
||||||
|
p := lfm2Model{
|
||||||
|
ModelParameters: ModelParameters{ModelType: "lfm2_moe", VocabSize: 65536},
|
||||||
|
HiddenSize: 2048,
|
||||||
|
NumHiddenLayers: 40,
|
||||||
|
MaxPositionEmbeddings: 128000,
|
||||||
|
IntermediateSize: 11776,
|
||||||
|
NumAttentionHeads: 32,
|
||||||
|
NumKeyValueHeads: 8,
|
||||||
|
LayerTypes: make([]string, 40),
|
||||||
|
NormEps: 1e-5,
|
||||||
|
ConvLCache: 3,
|
||||||
|
MoEIntermediateSize: 1536,
|
||||||
|
NumExperts: 64,
|
||||||
|
NumExpertsPerToken: 4,
|
||||||
|
NumDenseLayers: 2,
|
||||||
|
}
|
||||||
|
for i := 0; i < len(p.LayerTypes); i++ {
|
||||||
|
p.LayerTypes[i] = "conv"
|
||||||
|
}
|
||||||
|
p.LayerTypes[2] = "full_attention"
|
||||||
|
|
||||||
|
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
|
||||||
|
|
||||||
|
if got, want := kv["context_length"], uint32(32768); got != want {
|
||||||
|
t.Fatalf("context_length = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLFM2KVContextLengthNoOverride(t *testing.T) {
|
||||||
|
p := lfm2Model{
|
||||||
|
ModelParameters: ModelParameters{ModelType: "lfm2_moe", VocabSize: 65536},
|
||||||
|
HiddenSize: 2048,
|
||||||
|
NumHiddenLayers: 39, // mismatch: should not trigger edge case
|
||||||
|
MaxPositionEmbeddings: 128000,
|
||||||
|
IntermediateSize: 11776,
|
||||||
|
NumAttentionHeads: 32,
|
||||||
|
NumKeyValueHeads: 8,
|
||||||
|
LayerTypes: []string{"conv", "full_attention"},
|
||||||
|
NormEps: 1e-5,
|
||||||
|
ConvLCache: 3,
|
||||||
|
MoEIntermediateSize: 1536,
|
||||||
|
NumExperts: 64,
|
||||||
|
NumExpertsPerToken: 4,
|
||||||
|
NumDenseLayers: 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
|
||||||
|
|
||||||
|
if got, want := kv["context_length"], uint32(128000); got != want {
|
||||||
|
t.Fatalf("context_length = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLFM2KVFeedForwardLengthAutoAdjust(t *testing.T) {
|
||||||
|
p := lfm2Model{
|
||||||
|
ModelParameters: ModelParameters{ModelType: "lfm2", VocabSize: 65536},
|
||||||
|
HiddenSize: 2048,
|
||||||
|
NumHiddenLayers: 16,
|
||||||
|
MaxPositionEmbeddings: 128000,
|
||||||
|
IntermediateSize: 12288, // should be ignored when block_ff_dim is set
|
||||||
|
BlockFFDim: 12288,
|
||||||
|
BlockAutoAdjustFFDim: true,
|
||||||
|
BlockMultipleOf: 256,
|
||||||
|
BlockFFNDimMultiplier: 1.0,
|
||||||
|
NumAttentionHeads: 32,
|
||||||
|
NumKeyValueHeads: 8,
|
||||||
|
LayerTypes: []string{"conv", "full_attention"},
|
||||||
|
NormEps: 1e-5,
|
||||||
|
ConvLCache: 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
kv := p.KV(&Tokenizer{Vocabulary: &Vocabulary{Model: "gpt2"}})
|
||||||
|
|
||||||
|
if got, want := kv["feed_forward_length"], uint32(8192); got != want {
|
||||||
|
t.Fatalf("feed_forward_length = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
417
convert/convert_lfm2_vl.go
Normal file
417
convert/convert_lfm2_vl.go
Normal file
@@ -0,0 +1,417 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
// lfm2VLTextModel converts the language model component of LFM2 VL checkpoints.
|
||||||
|
type lfm2VLTextModel struct {
|
||||||
|
TextConfig lfm2Model `json:"text_config"`
|
||||||
|
DoImageSplitting *bool `json:"do_image_splitting"`
|
||||||
|
DownsampleFactor uint32 `json:"downsample_factor"`
|
||||||
|
EncoderPatchSize uint32 `json:"encoder_patch_size"`
|
||||||
|
ImageTokenID uint32 `json:"image_token_id"`
|
||||||
|
MaxImageTokens uint32 `json:"max_image_tokens"`
|
||||||
|
MinImageTokens uint32 `json:"min_image_tokens"`
|
||||||
|
MaxTiles uint32 `json:"max_tiles"`
|
||||||
|
MinTiles uint32 `json:"min_tiles"`
|
||||||
|
TileSize uint32 `json:"tile_size"`
|
||||||
|
MaxPixelsTolerance float32 `json:"max_pixels_tolerance"`
|
||||||
|
ProjectorUseLayernorm bool `json:"projector_use_layernorm"`
|
||||||
|
ProjectorHiddenSize uint32 `json:"projector_hidden_size"`
|
||||||
|
ProjectorHiddenAct string `json:"projector_hidden_act"`
|
||||||
|
UseImageSpecialTokens *bool `json:"use_image_special_tokens"`
|
||||||
|
UseThumbnail *bool `json:"use_thumbnail"`
|
||||||
|
VisionConfig struct {
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
NumChannels uint32 `json:"num_channels"`
|
||||||
|
PatchSize uint32 `json:"patch_size"`
|
||||||
|
LayerNormEpsilon float32 `json:"layer_norm_eps"`
|
||||||
|
} `json:"vision_config"`
|
||||||
|
Processor struct {
|
||||||
|
ImageProcessor struct {
|
||||||
|
DoImageSplitting *bool `json:"do_image_splitting"`
|
||||||
|
DownsampleFactor uint32 `json:"downsample_factor"`
|
||||||
|
MaxImageTokens uint32 `json:"max_image_tokens"`
|
||||||
|
MinImageTokens uint32 `json:"min_image_tokens"`
|
||||||
|
MaxTiles uint32 `json:"max_tiles"`
|
||||||
|
MinTiles uint32 `json:"min_tiles"`
|
||||||
|
MaxPixelsTol float32 `json:"max_pixels_tolerance"`
|
||||||
|
TileSize uint32 `json:"tile_size"`
|
||||||
|
UseThumbnail *bool `json:"use_thumbnail"`
|
||||||
|
ImageMean []float32 `json:"image_mean"`
|
||||||
|
ImageStd []float32 `json:"image_std"`
|
||||||
|
Size struct {
|
||||||
|
Height uint32 `json:"height"`
|
||||||
|
Width uint32 `json:"width"`
|
||||||
|
} `json:"size"`
|
||||||
|
} `json:"image_processor"`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *lfm2VLTextModel) textModel() *lfm2Model {
|
||||||
|
return &p.TextConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *lfm2VLTextModel) specialTokenTypes() []string {
|
||||||
|
return p.textModel().specialTokenTypes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *lfm2VLTextModel) parseMore(fsys fs.FS) error {
|
||||||
|
bts, err := fs.ReadFile(fsys, "processor_config.json")
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, fs.ErrNotExist) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Unmarshal(bts, &p.Processor)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *lfm2VLTextModel) visionImageSize() uint32 {
|
||||||
|
// LFM2-VL image processor operates on 512 tiles and downsamples by factor 2
|
||||||
|
// before projection. Keep a fixed square image size compatible with position
|
||||||
|
// embeddings and the simplified runtime image pipeline.
|
||||||
|
tile := cmp.Or(
|
||||||
|
p.Processor.ImageProcessor.TileSize,
|
||||||
|
p.Processor.ImageProcessor.Size.Height,
|
||||||
|
p.Processor.ImageProcessor.Size.Width,
|
||||||
|
uint32(512),
|
||||||
|
)
|
||||||
|
downsample := cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2))
|
||||||
|
if downsample == 0 {
|
||||||
|
return tile
|
||||||
|
}
|
||||||
|
|
||||||
|
return max(uint32(1), tile/downsample)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *lfm2VLTextModel) KV(t *Tokenizer) KV {
|
||||||
|
kv := p.textModel().KV(t)
|
||||||
|
|
||||||
|
boolOr := func(defaultValue bool, values ...*bool) bool {
|
||||||
|
for _, v := range values {
|
||||||
|
if v != nil {
|
||||||
|
return *v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
|
kv["vision.block_count"] = cmp.Or(p.VisionConfig.NumHiddenLayers, uint32(27))
|
||||||
|
kv["vision.embedding_length"] = cmp.Or(p.VisionConfig.HiddenSize, uint32(1152))
|
||||||
|
kv["vision.feed_forward_length"] = cmp.Or(p.VisionConfig.IntermediateSize, uint32(4304))
|
||||||
|
kv["vision.attention.head_count"] = cmp.Or(p.VisionConfig.NumAttentionHeads, uint32(16))
|
||||||
|
kv["vision.attention.layer_norm_epsilon"] = cmp.Or(p.VisionConfig.LayerNormEpsilon, float32(1e-6))
|
||||||
|
kv["vision.patch_size"] = cmp.Or(p.VisionConfig.PatchSize, p.EncoderPatchSize, uint32(16))
|
||||||
|
kv["vision.num_channels"] = cmp.Or(p.VisionConfig.NumChannels, uint32(3))
|
||||||
|
kv["vision.image_size"] = p.visionImageSize()
|
||||||
|
kv["vision.projector.scale_factor"] = cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2))
|
||||||
|
kv["vision.projector.use_layernorm"] = p.ProjectorUseLayernorm
|
||||||
|
kv["vision.do_image_splitting"] = boolOr(true, p.DoImageSplitting, p.Processor.ImageProcessor.DoImageSplitting)
|
||||||
|
kv["vision.min_tiles"] = cmp.Or(p.MinTiles, p.Processor.ImageProcessor.MinTiles, uint32(2))
|
||||||
|
kv["vision.max_tiles"] = cmp.Or(p.MaxTiles, p.Processor.ImageProcessor.MaxTiles, uint32(10))
|
||||||
|
kv["vision.tile_size"] = cmp.Or(p.TileSize, p.Processor.ImageProcessor.TileSize, uint32(512))
|
||||||
|
kv["vision.min_image_tokens"] = cmp.Or(p.MinImageTokens, p.Processor.ImageProcessor.MinImageTokens, uint32(64))
|
||||||
|
kv["vision.max_image_tokens"] = cmp.Or(p.MaxImageTokens, p.Processor.ImageProcessor.MaxImageTokens, uint32(256))
|
||||||
|
kv["vision.max_pixels_tolerance"] = cmp.Or(p.MaxPixelsTolerance, p.Processor.ImageProcessor.MaxPixelsTol, float32(2.0))
|
||||||
|
kv["vision.use_thumbnail"] = boolOr(true, p.UseThumbnail, p.Processor.ImageProcessor.UseThumbnail)
|
||||||
|
kv["vision.use_image_special_tokens"] = boolOr(true, p.UseImageSpecialTokens)
|
||||||
|
kv["vision.image_mean"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageMean, []float32{0.5, 0.5, 0.5}))
|
||||||
|
kv["vision.image_std"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageStd, []float32{0.5, 0.5, 0.5}))
|
||||||
|
kv["vision.image_token_id"] = cmp.Or(p.ImageTokenID, uint32(396))
|
||||||
|
|
||||||
|
setVisionTokenID := func(k, token string) {
|
||||||
|
if t == nil || t.Vocabulary == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i, v := range t.Vocabulary.Tokens {
|
||||||
|
if v == token {
|
||||||
|
kv[k] = uint32(i)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
setVisionTokenID("vision.image_start_token_id", "<|image_start|>")
|
||||||
|
setVisionTokenID("vision.image_end_token_id", "<|image_end|>")
|
||||||
|
setVisionTokenID("vision.image_thumbnail_token_id", "<|img_thumbnail|>")
|
||||||
|
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *lfm2VLTextModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
|
patchSize := int(cmp.Or(p.VisionConfig.PatchSize, p.EncoderPatchSize, uint32(16)))
|
||||||
|
numChannels := int(cmp.Or(p.VisionConfig.NumChannels, uint32(3)))
|
||||||
|
|
||||||
|
for _, t := range ts {
|
||||||
|
if t.Name() == "v.patch_embd.weight" {
|
||||||
|
shape := t.Shape()
|
||||||
|
if len(shape) == 2 {
|
||||||
|
inputDim := uint64(numChannels * patchSize * patchSize)
|
||||||
|
if shape[1] == inputDim {
|
||||||
|
channels := numChannels
|
||||||
|
patch := patchSize
|
||||||
|
t.SetRepacker(func(_ string, data []float32, srcShape []uint64) ([]float32, error) {
|
||||||
|
return repackPatchEmbeddingWeight(data, srcShape, channels, patch)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out := p.textModel().Tensors(ts)
|
||||||
|
for _, t := range out {
|
||||||
|
if t.Name == "v.patch_embd.weight" && len(t.Shape) == 2 {
|
||||||
|
t.Shape = []uint64{t.Shape[0], uint64(numChannels), uint64(patchSize), uint64(patchSize)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *lfm2VLTextModel) Replacements() []string {
|
||||||
|
out := make([]string, 0, 96)
|
||||||
|
|
||||||
|
addText := func(from, to string) {
|
||||||
|
out = append(out, from, to)
|
||||||
|
if strings.HasPrefix(from, "model.") {
|
||||||
|
suffix := strings.TrimPrefix(from, "model.")
|
||||||
|
out = append(out,
|
||||||
|
"model.language_model."+suffix, to,
|
||||||
|
"model.language_model.model."+suffix, to,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
base := p.textModel().Replacements()
|
||||||
|
for i := 0; i+1 < len(base); i += 2 {
|
||||||
|
addText(base[i], base[i+1])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vision tower + multimodal projector tensors (single-file conversion).
|
||||||
|
out = append(out,
|
||||||
|
"model.vision_tower.vision_model.embeddings.patch_embedding", "v.patch_embd",
|
||||||
|
"model.vision_tower.vision_model.embeddings.position_embedding", "v.position_embd",
|
||||||
|
"model.vision_tower.vision_model.encoder.layers", "v.blk",
|
||||||
|
"model.vision_tower.vision_model.post_layernorm", "v.post_ln",
|
||||||
|
"model.multi_modal_projector.layer_norm", "mm.layer_norm",
|
||||||
|
"model.multi_modal_projector.linear_1", "mm.1",
|
||||||
|
"model.multi_modal_projector.linear_2", "mm.2",
|
||||||
|
"self_attn.q_proj", "attn_q",
|
||||||
|
"self_attn.k_proj", "attn_k",
|
||||||
|
"self_attn.v_proj", "attn_v",
|
||||||
|
"self_attn.out_proj", "attn_out",
|
||||||
|
"layer_norm1", "ln1",
|
||||||
|
"layer_norm2", "ln2",
|
||||||
|
"mlp.fc1", "ffn_up",
|
||||||
|
"mlp.fc2", "ffn_down",
|
||||||
|
)
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// lfm2VLProjectorModel converts the vision encoder + projector component of LFM2 VL checkpoints.
|
||||||
|
type lfm2VLProjectorModel struct {
|
||||||
|
ModelParameters
|
||||||
|
DownsampleFactor uint32 `json:"downsample_factor"`
|
||||||
|
ProjectorHiddenDim uint32 `json:"projector_hidden_size"`
|
||||||
|
VisionModel struct {
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
NumChannels uint32 `json:"num_channels"`
|
||||||
|
PatchSize uint32 `json:"patch_size"`
|
||||||
|
LayerNormEpsilon float32 `json:"layer_norm_eps"`
|
||||||
|
ImageSize uint32 `json:"image_size"`
|
||||||
|
} `json:"vision_config"`
|
||||||
|
Processor struct {
|
||||||
|
ImageProcessor struct {
|
||||||
|
DownsampleFactor uint32 `json:"downsample_factor"`
|
||||||
|
TileSize uint32 `json:"tile_size"`
|
||||||
|
ImageMean []float32 `json:"image_mean"`
|
||||||
|
ImageStd []float32 `json:"image_std"`
|
||||||
|
Size struct {
|
||||||
|
Height uint32 `json:"height"`
|
||||||
|
Width uint32 `json:"width"`
|
||||||
|
} `json:"size"`
|
||||||
|
} `json:"image_processor"`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ ModelConverter = (*lfm2VLTextModel)(nil)
|
||||||
|
_ ModelConverter = (*lfm2VLProjectorModel)(nil)
|
||||||
|
_ moreParser = (*lfm2VLTextModel)(nil)
|
||||||
|
_ moreParser = (*lfm2VLProjectorModel)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *lfm2VLProjectorModel) parseMore(fsys fs.FS) error {
|
||||||
|
bts, err := fs.ReadFile(fsys, "processor_config.json")
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, fs.ErrNotExist) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Unmarshal(bts, &p.Processor)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *lfm2VLProjectorModel) imageSize() uint32 {
|
||||||
|
if p.VisionModel.ImageSize > 0 {
|
||||||
|
return p.VisionModel.ImageSize
|
||||||
|
}
|
||||||
|
|
||||||
|
downsample := cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2))
|
||||||
|
baseSize := cmp.Or(
|
||||||
|
p.Processor.ImageProcessor.TileSize,
|
||||||
|
p.Processor.ImageProcessor.Size.Height,
|
||||||
|
p.Processor.ImageProcessor.Size.Width,
|
||||||
|
uint32(256),
|
||||||
|
)
|
||||||
|
if downsample == 0 {
|
||||||
|
return baseSize
|
||||||
|
}
|
||||||
|
|
||||||
|
return max(uint32(1), baseSize/downsample)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *lfm2VLProjectorModel) KV(_ *Tokenizer) KV {
|
||||||
|
kv := KV{
|
||||||
|
"general.architecture": "clip",
|
||||||
|
"general.type": "mmproj",
|
||||||
|
"general.file_type": uint32(1),
|
||||||
|
"general.quantization_version": uint32(2),
|
||||||
|
"clip.has_vision_encoder": true,
|
||||||
|
"clip.projector_type": "lfm2",
|
||||||
|
"clip.use_gelu": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
kv["clip.vision.block_count"] = cmp.Or(p.VisionModel.NumHiddenLayers, uint32(27))
|
||||||
|
kv["clip.vision.embedding_length"] = cmp.Or(p.VisionModel.HiddenSize, uint32(1152))
|
||||||
|
kv["clip.vision.feed_forward_length"] = cmp.Or(p.VisionModel.IntermediateSize, uint32(4304))
|
||||||
|
kv["clip.vision.attention.head_count"] = cmp.Or(p.VisionModel.NumAttentionHeads, uint32(16))
|
||||||
|
kv["clip.vision.attention.layer_norm_epsilon"] = cmp.Or(p.VisionModel.LayerNormEpsilon, float32(1e-6))
|
||||||
|
kv["clip.vision.patch_size"] = cmp.Or(p.VisionModel.PatchSize, uint32(16))
|
||||||
|
kv["clip.vision.image_size"] = p.imageSize()
|
||||||
|
kv["clip.vision.projection_dim"] = cmp.Or(p.ProjectorHiddenDim, uint32(2048))
|
||||||
|
kv["clip.vision.projector.scale_factor"] = cmp.Or(p.DownsampleFactor, p.Processor.ImageProcessor.DownsampleFactor, uint32(2))
|
||||||
|
kv["clip.vision.image_mean"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageMean, []float32{0.5, 0.5, 0.5}))
|
||||||
|
kv["clip.vision.image_std"] = slices.Clone(defaultFloat32Slice(p.Processor.ImageProcessor.ImageStd, []float32{0.5, 0.5, 0.5}))
|
||||||
|
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultFloat32Slice(v, fallback []float32) []float32 {
|
||||||
|
if len(v) > 0 {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *lfm2VLProjectorModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
|
var out []*ggml.Tensor
|
||||||
|
|
||||||
|
numChannels := cmp.Or(p.VisionModel.NumChannels, uint32(3))
|
||||||
|
patchSize := cmp.Or(p.VisionModel.PatchSize, uint32(16))
|
||||||
|
|
||||||
|
for _, t := range ts {
|
||||||
|
name := t.Name()
|
||||||
|
if !(strings.HasPrefix(name, "v.") || strings.HasPrefix(name, "mm.")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
shape := t.Shape()
|
||||||
|
if name == "v.patch_embd.weight" && len(shape) == 2 {
|
||||||
|
inputDim := uint64(numChannels * patchSize * patchSize)
|
||||||
|
if shape[1] == inputDim {
|
||||||
|
shape = []uint64{shape[0], uint64(numChannels), uint64(patchSize), uint64(patchSize)}
|
||||||
|
channels := int(numChannels)
|
||||||
|
patch := int(patchSize)
|
||||||
|
t.SetRepacker(func(_ string, data []float32, srcShape []uint64) ([]float32, error) {
|
||||||
|
return repackPatchEmbeddingWeight(data, srcShape, channels, patch)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: name,
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: slices.Clone(shape),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *lfm2VLProjectorModel) Replacements() []string {
|
||||||
|
return []string{
|
||||||
|
"model.multi_modal_projector.linear_1", "mm.1",
|
||||||
|
"model.multi_modal_projector.linear_2", "mm.2",
|
||||||
|
"model.vision_tower.vision_model.embeddings.patch_embedding", "v.patch_embd",
|
||||||
|
"model.vision_tower.vision_model.embeddings.position_embedding", "v.position_embd",
|
||||||
|
"model.vision_tower.vision_model.encoder.layers", "v.blk",
|
||||||
|
"self_attn.q_proj", "attn_q",
|
||||||
|
"self_attn.k_proj", "attn_k",
|
||||||
|
"self_attn.v_proj", "attn_v",
|
||||||
|
"self_attn.out_proj", "attn_out",
|
||||||
|
"layer_norm1", "ln1",
|
||||||
|
"layer_norm2", "ln2",
|
||||||
|
"mlp.fc1", "ffn_up",
|
||||||
|
"mlp.fc2", "ffn_down",
|
||||||
|
"model.vision_tower.vision_model.post_layernorm", "v.post_ln",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func repackPatchEmbeddingWeight(data []float32, srcShape []uint64, channels, patch int) ([]float32, error) {
|
||||||
|
if len(srcShape) != 2 {
|
||||||
|
return nil, fmt.Errorf("invalid patch embedding shape rank: %d", len(srcShape))
|
||||||
|
}
|
||||||
|
|
||||||
|
outDim := int(srcShape[0])
|
||||||
|
flatInputDim := int(srcShape[1])
|
||||||
|
expectedInputDim := channels * patch * patch
|
||||||
|
if flatInputDim != expectedInputDim {
|
||||||
|
return nil, fmt.Errorf("invalid patch embedding input dim: got %d, want %d", flatInputDim, expectedInputDim)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedSize := outDim * flatInputDim
|
||||||
|
if len(data) != expectedSize {
|
||||||
|
return nil, fmt.Errorf("invalid patch embedding data size: got %d, want %d", len(data), expectedSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
repacked := make([]float32, len(data))
|
||||||
|
perChannel := patch * patch
|
||||||
|
|
||||||
|
for o := range outDim {
|
||||||
|
inBase := o * flatInputDim
|
||||||
|
outBase := o * flatInputDim
|
||||||
|
|
||||||
|
for y := range patch {
|
||||||
|
for x := range patch {
|
||||||
|
inPixelBase := inBase + (y*patch+x)*channels
|
||||||
|
for c := range channels {
|
||||||
|
src := inPixelBase + c
|
||||||
|
dst := outBase + c*perChannel + y*patch + x
|
||||||
|
repacked[dst] = data[src]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return repacked, nil
|
||||||
|
}
|
||||||
249
convert/convert_lfm2_vl_test.go
Normal file
249
convert/convert_lfm2_vl_test.go
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLFM2VLTextModelKVUsesTextConfig(t *testing.T) {
|
||||||
|
p := lfm2VLTextModel{
|
||||||
|
TextConfig: lfm2Model{
|
||||||
|
ModelParameters: ModelParameters{ModelType: "lfm2", VocabSize: 65536},
|
||||||
|
HiddenSize: 2048,
|
||||||
|
NumHiddenLayers: 16,
|
||||||
|
MaxPositionEmbeddings: 128000,
|
||||||
|
IntermediateSize: 12288,
|
||||||
|
BlockFFDim: 12288,
|
||||||
|
BlockAutoAdjustFFDim: true,
|
||||||
|
BlockMultipleOf: 256,
|
||||||
|
BlockFFNDimMultiplier: 1.0,
|
||||||
|
NumAttentionHeads: 32,
|
||||||
|
NumKeyValueHeads: 8,
|
||||||
|
LayerTypes: []string{"conv", "full_attention"},
|
||||||
|
NormEps: 1e-5,
|
||||||
|
ConvLCache: 3,
|
||||||
|
},
|
||||||
|
DownsampleFactor: 2,
|
||||||
|
VisionConfig: struct {
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
NumChannels uint32 `json:"num_channels"`
|
||||||
|
PatchSize uint32 `json:"patch_size"`
|
||||||
|
LayerNormEpsilon float32 `json:"layer_norm_eps"`
|
||||||
|
}{
|
||||||
|
HiddenSize: 1152,
|
||||||
|
IntermediateSize: 4304,
|
||||||
|
NumAttentionHeads: 16,
|
||||||
|
NumHiddenLayers: 27,
|
||||||
|
NumChannels: 3,
|
||||||
|
PatchSize: 16,
|
||||||
|
LayerNormEpsilon: 1e-6,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
p.Processor.ImageProcessor.TileSize = 512
|
||||||
|
p.Processor.ImageProcessor.ImageMean = []float32{0.5, 0.5, 0.5}
|
||||||
|
p.Processor.ImageProcessor.ImageStd = []float32{0.5, 0.5, 0.5}
|
||||||
|
|
||||||
|
kv := p.KV(&Tokenizer{
|
||||||
|
Vocabulary: &Vocabulary{
|
||||||
|
Model: "gpt2",
|
||||||
|
Tokens: []string{"<|pad|>", "<image>", "<|image_start|>", "<|image_end|>", "<|img_thumbnail|>"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if got, want := kv["general.architecture"], "lfm2"; got != want {
|
||||||
|
t.Fatalf("general.architecture = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := kv["feed_forward_length"], uint32(8192); got != want {
|
||||||
|
t.Fatalf("feed_forward_length = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := kv["vision.block_count"], uint32(27); got != want {
|
||||||
|
t.Fatalf("vision.block_count = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := kv["vision.image_size"], uint32(256); got != want {
|
||||||
|
t.Fatalf("vision.image_size = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := kv["vision.image_token_id"], uint32(396); got != want {
|
||||||
|
t.Fatalf("vision.image_token_id = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := kv["vision.image_start_token_id"], uint32(2); got != want {
|
||||||
|
t.Fatalf("vision.image_start_token_id = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := kv["vision.do_image_splitting"], true; got != want {
|
||||||
|
t.Fatalf("vision.do_image_splitting = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := kv["vision.min_tiles"], uint32(2); got != want {
|
||||||
|
t.Fatalf("vision.min_tiles = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := kv["vision.max_tiles"], uint32(10); got != want {
|
||||||
|
t.Fatalf("vision.max_tiles = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := kv["vision.tile_size"], uint32(512); got != want {
|
||||||
|
t.Fatalf("vision.tile_size = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := kv["vision.use_thumbnail"], true; got != want {
|
||||||
|
t.Fatalf("vision.use_thumbnail = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := kv["vision.use_image_special_tokens"], true; got != want {
|
||||||
|
t.Fatalf("vision.use_image_special_tokens = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLFM2VLTextModelTensorsIncludeVision(t *testing.T) {
|
||||||
|
p := lfm2VLTextModel{}
|
||||||
|
p.VisionConfig.PatchSize = 16
|
||||||
|
p.VisionConfig.NumChannels = 3
|
||||||
|
input := []Tensor{
|
||||||
|
newLFM2StubTensor("model.embed_tokens.weight", []uint64{65536, 2048}),
|
||||||
|
newLFM2StubTensor("model.layers.0.ffn_norm.weight", []uint64{2048}),
|
||||||
|
newLFM2StubTensor("v.patch_embd.weight", []uint64{1152, 768}),
|
||||||
|
newLFM2StubTensor("v.blk.0.attn_q.weight", []uint64{1152, 1152}),
|
||||||
|
newLFM2StubTensor("mm.1.weight", []uint64{2048, 4608}),
|
||||||
|
}
|
||||||
|
|
||||||
|
out := p.Tensors(input)
|
||||||
|
if len(out) == 0 {
|
||||||
|
t.Fatal("expected non-empty tensor list")
|
||||||
|
}
|
||||||
|
|
||||||
|
foundPatch := false
|
||||||
|
foundVision := false
|
||||||
|
for _, tns := range out {
|
||||||
|
if tns.Name == "v.patch_embd.weight" {
|
||||||
|
foundPatch = true
|
||||||
|
if !slices.Equal(tns.Shape, []uint64{1152, 3, 16, 16}) {
|
||||||
|
t.Fatalf("v.patch_embd.weight shape = %v, want [1152 3 16 16]", tns.Shape)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(tns.Name, "v.") || strings.HasPrefix(tns.Name, "mm.") {
|
||||||
|
foundVision = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundPatch {
|
||||||
|
t.Fatal("expected v.patch_embd.weight in output tensors")
|
||||||
|
}
|
||||||
|
if !foundVision {
|
||||||
|
t.Fatal("expected at least one vision/projector tensor in output")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLFM2VLTextModelReplacements(t *testing.T) {
|
||||||
|
p := lfm2VLTextModel{}
|
||||||
|
r := strings.NewReplacer(p.Replacements()...)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
in string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "language_model_embed_tokens",
|
||||||
|
in: "model.language_model.embed_tokens.weight",
|
||||||
|
want: "token_embd.weight",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "language_model_layers",
|
||||||
|
in: "model.language_model.layers.2.self_attn.q_proj.weight",
|
||||||
|
want: "blk.2.attn_q.weight",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested_language_model_prefix",
|
||||||
|
in: "model.language_model.model.embedding_norm.weight",
|
||||||
|
want: "token_embd_norm.weight",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := r.Replace(tt.in); got != tt.want {
|
||||||
|
t.Fatalf("replacement(%q) = %q, want %q", tt.in, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLFM2VLProjectorKV(t *testing.T) {
|
||||||
|
p := lfm2VLProjectorModel{
|
||||||
|
DownsampleFactor: 2,
|
||||||
|
ProjectorHiddenDim: 2048,
|
||||||
|
}
|
||||||
|
p.VisionModel.NumHiddenLayers = 27
|
||||||
|
p.VisionModel.HiddenSize = 1152
|
||||||
|
p.VisionModel.IntermediateSize = 4304
|
||||||
|
p.VisionModel.NumAttentionHeads = 16
|
||||||
|
p.VisionModel.PatchSize = 16
|
||||||
|
p.VisionModel.LayerNormEpsilon = 1e-6
|
||||||
|
p.Processor.ImageProcessor.TileSize = 512
|
||||||
|
p.Processor.ImageProcessor.ImageMean = []float32{0.5, 0.5, 0.5}
|
||||||
|
p.Processor.ImageProcessor.ImageStd = []float32{0.5, 0.5, 0.5}
|
||||||
|
|
||||||
|
kv := p.KV(nil)
|
||||||
|
|
||||||
|
if got, want := kv["general.architecture"], "clip"; got != want {
|
||||||
|
t.Fatalf("general.architecture = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := kv["clip.projector_type"], "lfm2"; got != want {
|
||||||
|
t.Fatalf("clip.projector_type = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := kv["clip.vision.image_size"], uint32(256); got != want {
|
||||||
|
t.Fatalf("clip.vision.image_size = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLFM2VLProjectorTensorsPatchReshape(t *testing.T) {
|
||||||
|
p := lfm2VLProjectorModel{}
|
||||||
|
p.VisionModel.NumChannels = 3
|
||||||
|
p.VisionModel.PatchSize = 16
|
||||||
|
|
||||||
|
input := []Tensor{
|
||||||
|
newLFM2StubTensor("v.patch_embd.weight", []uint64{1152, 768}),
|
||||||
|
newLFM2StubTensor("mm.1.weight", []uint64{2048, 4608}),
|
||||||
|
newLFM2StubTensor("model.embed_tokens.weight", []uint64{65536, 2048}),
|
||||||
|
}
|
||||||
|
|
||||||
|
out := p.Tensors(input)
|
||||||
|
if len(out) != 2 {
|
||||||
|
t.Fatalf("expected 2 tensors, got %d", len(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
var patchShape []uint64
|
||||||
|
for _, tns := range out {
|
||||||
|
if tns.Name == "v.patch_embd.weight" {
|
||||||
|
patchShape = tns.Shape
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !slices.Equal(patchShape, []uint64{1152, 3, 16, 16}) {
|
||||||
|
t.Fatalf("v.patch_embd.weight shape = %v, want [1152 3 16 16]", patchShape)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRepackPatchEmbeddingWeight(t *testing.T) {
|
||||||
|
data := []float32{
|
||||||
|
0, 1, // y=0,x=0
|
||||||
|
2, 3, // y=0,x=1
|
||||||
|
4, 5, // y=1,x=0
|
||||||
|
6, 7, // y=1,x=1
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := repackPatchEmbeddingWeight(data, []uint64{1, 8}, 2, 2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []float32{0, 2, 4, 6, 1, 3, 5, 7}
|
||||||
|
if !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("repacked data = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
385
convert/convert_nemotron_h.go
Normal file
385
convert/convert_nemotron_h.go
Normal file
@@ -0,0 +1,385 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"math"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type hybridPattern string
|
||||||
|
|
||||||
|
func (p *hybridPattern) UnmarshalJSON(data []byte) error {
|
||||||
|
if string(data) == "null" {
|
||||||
|
*p = ""
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var single string
|
||||||
|
if err := json.Unmarshal(data, &single); err == nil {
|
||||||
|
*p = hybridPattern(strings.TrimSpace(single))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var parts []string
|
||||||
|
if err := json.Unmarshal(data, &parts); err == nil {
|
||||||
|
*p = hybridPattern(strings.Join(parts, ""))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("hybrid_override_pattern must be a string or string array")
|
||||||
|
}
|
||||||
|
|
||||||
|
type nemotronHModel struct {
|
||||||
|
ModelParameters
|
||||||
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
|
HeadDim uint32 `json:"head_dim"`
|
||||||
|
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
||||||
|
NormEpsilon float32 `json:"norm_eps"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||||
|
ConvKernel uint32 `json:"conv_kernel"`
|
||||||
|
SSMStateSize uint32 `json:"ssm_state_size"`
|
||||||
|
MambaNumHeads uint32 `json:"mamba_num_heads"`
|
||||||
|
MambaHeadDim uint32 `json:"mamba_head_dim"`
|
||||||
|
NGroups uint32 `json:"n_groups"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
HybridOverridePattern hybridPattern `json:"hybrid_override_pattern"`
|
||||||
|
|
||||||
|
// MoE
|
||||||
|
NumExperts uint32 `json:"num_experts"`
|
||||||
|
NumSharedExperts uint32 `json:"num_shared_experts"`
|
||||||
|
NRoutedExperts uint32 `json:"n_routed_experts"`
|
||||||
|
NSharedExperts uint32 `json:"n_shared_experts"`
|
||||||
|
NumExpertsPerTok uint32 `json:"num_experts_per_tok"`
|
||||||
|
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
|
||||||
|
MoESharedExpertIntermediate uint32 `json:"moe_shared_expert_intermediate_size"`
|
||||||
|
NormTopKProb bool `json:"norm_topk_prob"`
|
||||||
|
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
|
||||||
|
ExpertGroupCount uint32 `json:"n_group"`
|
||||||
|
ExpertGroupUsedCount uint32 `json:"topk_group"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ ModelConverter = (*nemotronHModel)(nil)
|
||||||
|
|
||||||
|
func (n *nemotronHModel) parseMore(_ fs.FS) error {
|
||||||
|
if n.NumHiddenLayers == 0 {
|
||||||
|
return fmt.Errorf("nemotron_h: num_hidden_layers must be set")
|
||||||
|
}
|
||||||
|
if n.HiddenSize == 0 {
|
||||||
|
return fmt.Errorf("nemotron_h: hidden_size must be set")
|
||||||
|
}
|
||||||
|
if n.NumAttentionHeads == 0 {
|
||||||
|
return fmt.Errorf("nemotron_h: num_attention_heads must be set")
|
||||||
|
}
|
||||||
|
if n.HeadDim == 0 {
|
||||||
|
if n.HiddenSize%n.NumAttentionHeads != 0 {
|
||||||
|
return fmt.Errorf("nemotron_h: hidden_size (%d) must be divisible by num_attention_heads (%d)", n.HiddenSize, n.NumAttentionHeads)
|
||||||
|
}
|
||||||
|
n.HeadDim = n.HiddenSize / n.NumAttentionHeads
|
||||||
|
}
|
||||||
|
if n.NumKeyValueHeads == 0 {
|
||||||
|
n.NumKeyValueHeads = n.NumAttentionHeads
|
||||||
|
}
|
||||||
|
if n.ConvKernel == 0 {
|
||||||
|
return fmt.Errorf("nemotron_h: conv_kernel must be set")
|
||||||
|
}
|
||||||
|
if n.SSMStateSize == 0 {
|
||||||
|
return fmt.Errorf("nemotron_h: ssm_state_size must be set")
|
||||||
|
}
|
||||||
|
if n.ssmHeadCount() == 0 {
|
||||||
|
return fmt.Errorf("nemotron_h: mamba_num_heads must be set")
|
||||||
|
}
|
||||||
|
if n.MambaHeadDim == 0 {
|
||||||
|
return fmt.Errorf("nemotron_h: mamba_head_dim must be set")
|
||||||
|
}
|
||||||
|
if n.NGroups == 0 {
|
||||||
|
n.NGroups = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, _, err := n.layerArrays(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if n.isMoE() {
|
||||||
|
if n.routedExpertCount() == 0 {
|
||||||
|
return fmt.Errorf("nemotron_h: routed expert count must be set for MoE models")
|
||||||
|
}
|
||||||
|
if n.NumExpertsPerTok == 0 {
|
||||||
|
return fmt.Errorf("nemotron_h: num_experts_per_tok must be set for MoE models")
|
||||||
|
}
|
||||||
|
if n.NumExpertsPerTok > n.routedExpertCount() {
|
||||||
|
return fmt.Errorf("nemotron_h: num_experts_per_tok (%d) cannot exceed expert_count (%d)", n.NumExpertsPerTok, n.routedExpertCount())
|
||||||
|
}
|
||||||
|
if n.moeIntermediateSize() == 0 {
|
||||||
|
return fmt.Errorf("nemotron_h: moe_intermediate_size must be set for MoE models")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *nemotronHModel) isMoE() bool {
|
||||||
|
return cmp.Or(n.routedExpertCount(), n.NumExpertsPerTok, n.MoEIntermediateSize) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *nemotronHModel) routedExpertCount() uint32 {
|
||||||
|
return cmp.Or(n.NRoutedExperts, n.NumExperts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *nemotronHModel) sharedExpertCount() uint32 {
|
||||||
|
return cmp.Or(n.NSharedExperts, n.NumSharedExperts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *nemotronHModel) ssmHeadCount() uint32 {
|
||||||
|
return n.MambaNumHeads
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *nemotronHModel) ssmInnerSize() uint32 {
|
||||||
|
return n.MambaHeadDim * n.ssmHeadCount()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *nemotronHModel) epsilon() float32 {
|
||||||
|
return cmp.Or(n.NormEpsilon, n.LayerNormEpsilon, float32(1e-5))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *nemotronHModel) moeIntermediateSize() uint32 {
|
||||||
|
return cmp.Or(n.MoEIntermediateSize, n.IntermediateSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *nemotronHModel) denseIntermediateSize() uint32 {
|
||||||
|
return cmp.Or(n.IntermediateSize, n.MoEIntermediateSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *nemotronHModel) layerArrays() (headCountKV []uint32, ffnLengths []uint32, err error) {
|
||||||
|
pattern := strings.TrimSpace(string(n.HybridOverridePattern))
|
||||||
|
if pattern == "" {
|
||||||
|
return nil, nil, fmt.Errorf("nemotron_h: hybrid_override_pattern must be set")
|
||||||
|
}
|
||||||
|
|
||||||
|
runes := []rune(pattern)
|
||||||
|
if len(runes) != int(n.NumHiddenLayers) {
|
||||||
|
return nil, nil, fmt.Errorf("nemotron_h: hybrid_override_pattern length (%d) must match num_hidden_layers (%d)", len(runes), n.NumHiddenLayers)
|
||||||
|
}
|
||||||
|
|
||||||
|
headCountKV = make([]uint32, n.NumHiddenLayers)
|
||||||
|
ffnLengths = make([]uint32, n.NumHiddenLayers)
|
||||||
|
|
||||||
|
attnKVHeads := cmp.Or(n.NumKeyValueHeads, n.NumAttentionHeads)
|
||||||
|
moeFFN := n.moeIntermediateSize()
|
||||||
|
denseFFN := n.denseIntermediateSize()
|
||||||
|
|
||||||
|
for i, layerType := range runes {
|
||||||
|
switch layerType {
|
||||||
|
case 'M':
|
||||||
|
// Recurrent layer: no KV heads and no FFN.
|
||||||
|
case '*', 'A':
|
||||||
|
// Attention-only layer.
|
||||||
|
headCountKV[i] = attnKVHeads
|
||||||
|
case 'E':
|
||||||
|
// MoE layer.
|
||||||
|
if moeFFN == 0 {
|
||||||
|
return nil, nil, fmt.Errorf("nemotron_h: moe layer at index %d but moe_intermediate_size is zero", i)
|
||||||
|
}
|
||||||
|
ffnLengths[i] = moeFFN
|
||||||
|
case '-':
|
||||||
|
// Dense FFN layer.
|
||||||
|
if denseFFN == 0 {
|
||||||
|
return nil, nil, fmt.Errorf("nemotron_h: dense FFN layer at index %d but intermediate_size is zero", i)
|
||||||
|
}
|
||||||
|
ffnLengths[i] = denseFFN
|
||||||
|
default:
|
||||||
|
return nil, nil, fmt.Errorf("nemotron_h: unsupported layer type %q in hybrid_override_pattern at index %d", layerType, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return headCountKV, ffnLengths, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *nemotronHModel) KV(t *Tokenizer) KV {
|
||||||
|
kv := n.ModelParameters.KV(t)
|
||||||
|
|
||||||
|
arch := "nemotron_h"
|
||||||
|
if n.isMoE() {
|
||||||
|
arch = "nemotron_h_moe"
|
||||||
|
}
|
||||||
|
kv["general.architecture"] = arch
|
||||||
|
kv["block_count"] = n.NumHiddenLayers
|
||||||
|
kv["context_length"] = n.MaxPositionEmbeddings
|
||||||
|
kv["embedding_length"] = n.HiddenSize
|
||||||
|
kv["attention.head_count"] = n.NumAttentionHeads
|
||||||
|
kv["attention.key_length"] = n.HeadDim
|
||||||
|
kv["attention.value_length"] = n.HeadDim
|
||||||
|
kv["attention.layer_norm_epsilon"] = n.epsilon()
|
||||||
|
kv["attention.layer_norm_rms_epsilon"] = n.epsilon()
|
||||||
|
kv["rope.freq_base"] = cmp.Or(n.RopeTheta, float32(10000))
|
||||||
|
if n.PartialRotaryFactor > 0 && n.PartialRotaryFactor <= 1 {
|
||||||
|
kv["rope.dimension_count"] = uint32(float32(n.HeadDim) * n.PartialRotaryFactor)
|
||||||
|
}
|
||||||
|
|
||||||
|
if headCountKV, ffnLengths, err := n.layerArrays(); err == nil {
|
||||||
|
kv["attention.head_count_kv"] = headCountKV
|
||||||
|
kv["feed_forward_length"] = ffnLengths
|
||||||
|
}
|
||||||
|
|
||||||
|
kv["ssm.conv_kernel"] = n.ConvKernel
|
||||||
|
kv["ssm.inner_size"] = n.ssmInnerSize()
|
||||||
|
kv["ssm.state_size"] = n.SSMStateSize
|
||||||
|
kv["ssm.group_count"] = n.NGroups
|
||||||
|
kv["ssm.time_step_rank"] = n.ssmHeadCount()
|
||||||
|
|
||||||
|
if n.isMoE() {
|
||||||
|
kv["expert_count"] = n.routedExpertCount()
|
||||||
|
kv["expert_used_count"] = n.NumExpertsPerTok
|
||||||
|
kv["expert_feed_forward_length"] = n.moeIntermediateSize()
|
||||||
|
if n.sharedExpertCount() > 0 {
|
||||||
|
kv["expert_shared_count"] = n.sharedExpertCount()
|
||||||
|
}
|
||||||
|
if n.MoESharedExpertIntermediate > 0 {
|
||||||
|
kv["expert_shared_feed_forward_length"] = n.MoESharedExpertIntermediate
|
||||||
|
}
|
||||||
|
kv["expert_weights_norm"] = n.NormTopKProb
|
||||||
|
kv["expert_weights_scale"] = n.RoutedScalingFactor
|
||||||
|
if n.ExpertGroupCount > 0 {
|
||||||
|
kv["expert_group_count"] = n.ExpertGroupCount
|
||||||
|
}
|
||||||
|
if n.ExpertGroupUsedCount > 0 {
|
||||||
|
kv["expert_group_used_count"] = n.ExpertGroupUsedCount
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeVectorShapeToColumn(shape []uint64) []uint64 {
|
||||||
|
switch len(shape) {
|
||||||
|
case 1:
|
||||||
|
return []uint64{shape[0], 1}
|
||||||
|
case 2:
|
||||||
|
if shape[0] == 1 && shape[1] > 1 {
|
||||||
|
return []uint64{shape[1], 1}
|
||||||
|
}
|
||||||
|
if shape[1] == 1 && shape[0] > 1 {
|
||||||
|
return []uint64{shape[0], 1}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return slices.Clone(shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *nemotronHModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
|
var out []*ggml.Tensor
|
||||||
|
|
||||||
|
remaining := ts
|
||||||
|
if n.isMoE() {
|
||||||
|
merges := make([]merge, 0, n.NumHiddenLayers*2)
|
||||||
|
for i := range n.NumHiddenLayers {
|
||||||
|
merges = append(merges, merge{
|
||||||
|
fmt.Sprintf("blk.%d.mixer.experts.*.up_proj.weight", i),
|
||||||
|
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||||
|
}, merge{
|
||||||
|
fmt.Sprintf("blk.%d.mixer.experts.*.down_proj.weight", i),
|
||||||
|
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
merged, rest := mergeTensors(ts, merges...)
|
||||||
|
out = append(out, merged...)
|
||||||
|
remaining = rest
|
||||||
|
}
|
||||||
|
|
||||||
|
nGroups := uint64(cmp.Or(n.NGroups, uint32(1)))
|
||||||
|
for _, t := range remaining {
|
||||||
|
name := t.Name()
|
||||||
|
shape := slices.Clone(t.Shape())
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case strings.HasSuffix(name, ".ssm_a"):
|
||||||
|
shape = normalizeVectorShapeToColumn(shape)
|
||||||
|
t.SetRepacker(func(_ string, data []float32, _ []uint64) ([]float32, error) {
|
||||||
|
out := make([]float32, len(data))
|
||||||
|
for i, v := range data {
|
||||||
|
out[i] = -float32(math.Exp(float64(v)))
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
})
|
||||||
|
case strings.HasSuffix(name, ".ssm_d"):
|
||||||
|
shape = normalizeVectorShapeToColumn(shape)
|
||||||
|
case strings.HasSuffix(name, ".ssm_norm.weight"):
|
||||||
|
switch len(shape) {
|
||||||
|
case 1:
|
||||||
|
if nGroups > 0 && shape[0]%nGroups == 0 {
|
||||||
|
shape = []uint64{nGroups, shape[0] / nGroups}
|
||||||
|
}
|
||||||
|
case 2:
|
||||||
|
if shape[0] == 1 && nGroups > 0 && shape[1]%nGroups == 0 {
|
||||||
|
shape = []uint64{nGroups, shape[1] / nGroups}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case strings.HasSuffix(name, ".ssm_conv1d.weight"):
|
||||||
|
if len(shape) == 3 {
|
||||||
|
if shape[0] == 1 {
|
||||||
|
shape = []uint64{shape[1], shape[2]}
|
||||||
|
} else if shape[1] == 1 {
|
||||||
|
shape = []uint64{shape[0], shape[2]}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: name,
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: shape,
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *nemotronHModel) Replacements() []string {
|
||||||
|
return []string{
|
||||||
|
// Embedding and output
|
||||||
|
"lm_head", "output",
|
||||||
|
"backbone.embeddings", "token_embd",
|
||||||
|
"backbone.norm_f", "output_norm",
|
||||||
|
"backbone.layers", "blk",
|
||||||
|
|
||||||
|
// Recurrent (Mamba2) tensors
|
||||||
|
"mixer.in_proj", "ssm_in",
|
||||||
|
"mixer.out_proj", "ssm_out",
|
||||||
|
"mixer.dt_bias", "ssm_dt.bias",
|
||||||
|
"mixer.A_log", "ssm_a",
|
||||||
|
"mixer.D", "ssm_d",
|
||||||
|
"mixer.conv1d", "ssm_conv1d",
|
||||||
|
"mixer.norm.weight", "ssm_norm.weight",
|
||||||
|
|
||||||
|
// Attention tensors
|
||||||
|
"mixer.q_proj", "attn_q",
|
||||||
|
"mixer.k_proj", "attn_k",
|
||||||
|
"mixer.v_proj", "attn_v",
|
||||||
|
"mixer.o_proj", "attn_output",
|
||||||
|
|
||||||
|
// FFN / MoE tensors
|
||||||
|
"mixer.gate.e_score_correction_bias", "exp_probs_b.bias",
|
||||||
|
"mixer.gate", "ffn_gate_inp",
|
||||||
|
"mixer.fc1_latent_proj", "ffn_latent_in",
|
||||||
|
"mixer.fc2_latent_proj", "ffn_latent_out",
|
||||||
|
"mixer.shared_experts.up_proj", "ffn_up_shexp",
|
||||||
|
"mixer.shared_experts.down_proj", "ffn_down_shexp",
|
||||||
|
"mixer.up_proj", "ffn_up",
|
||||||
|
"mixer.down_proj", "ffn_down",
|
||||||
|
|
||||||
|
// Per-layer pre-norm
|
||||||
|
".norm.weight", ".attn_norm.weight",
|
||||||
|
}
|
||||||
|
}
|
||||||
230
convert/convert_nemotron_h_test.go
Normal file
230
convert/convert_nemotron_h_test.go
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHybridPatternUnmarshal(t *testing.T) {
|
||||||
|
t.Run("string", func(t *testing.T) {
|
||||||
|
var p hybridPattern
|
||||||
|
if err := json.Unmarshal([]byte(`"MEM*"`), &p); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if got, want := string(p), "MEM*"; got != want {
|
||||||
|
t.Fatalf("unexpected pattern: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("array", func(t *testing.T) {
|
||||||
|
var p hybridPattern
|
||||||
|
if err := json.Unmarshal([]byte(`["M","E","M","*"]`), &p); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if got, want := string(p), "MEM*"; got != want {
|
||||||
|
t.Fatalf("unexpected pattern: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNemotronHLayerArrays(t *testing.T) {
|
||||||
|
m := &nemotronHModel{
|
||||||
|
NumHiddenLayers: 5,
|
||||||
|
NumAttentionHeads: 32,
|
||||||
|
NumKeyValueHeads: 8,
|
||||||
|
HybridOverridePattern: "MEM*E",
|
||||||
|
NRoutedExperts: 128,
|
||||||
|
NumExpertsPerTok: 6,
|
||||||
|
MoEIntermediateSize: 1856,
|
||||||
|
}
|
||||||
|
|
||||||
|
headsKV, ffn, err := m.layerArrays()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := headsKV, []uint32{0, 0, 0, 8, 0}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected head_count_kv: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := ffn, []uint32{0, 1856, 0, 0, 1856}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected feed_forward_length: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNemotronHKV(t *testing.T) {
|
||||||
|
m := &nemotronHModel{
|
||||||
|
MaxPositionEmbeddings: 1048576,
|
||||||
|
HiddenSize: 2688,
|
||||||
|
NumHiddenLayers: 5,
|
||||||
|
NumAttentionHeads: 32,
|
||||||
|
NumKeyValueHeads: 2,
|
||||||
|
HeadDim: 128,
|
||||||
|
LayerNormEpsilon: 1e-5,
|
||||||
|
RopeTheta: 10000,
|
||||||
|
PartialRotaryFactor: 0.5,
|
||||||
|
ConvKernel: 4,
|
||||||
|
SSMStateSize: 128,
|
||||||
|
MambaNumHeads: 64,
|
||||||
|
MambaHeadDim: 64,
|
||||||
|
NGroups: 8,
|
||||||
|
HybridOverridePattern: "MEM*E",
|
||||||
|
NRoutedExperts: 128,
|
||||||
|
NSharedExperts: 1,
|
||||||
|
NumExpertsPerTok: 6,
|
||||||
|
MoEIntermediateSize: 1856,
|
||||||
|
MoESharedExpertIntermediate: 3712,
|
||||||
|
NormTopKProb: true,
|
||||||
|
RoutedScalingFactor: 2.5,
|
||||||
|
}
|
||||||
|
if err := m.parseMore(nil); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
|
||||||
|
if got, want := kv["general.architecture"], "nemotron_h_moe"; got != want {
|
||||||
|
t.Fatalf("unexpected architecture: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
headCountKV, ok := kv["attention.head_count_kv"].([]uint32)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("attention.head_count_kv has unexpected type: %T", kv["attention.head_count_kv"])
|
||||||
|
}
|
||||||
|
if got, want := headCountKV, []uint32{0, 0, 0, 2, 0}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected attention.head_count_kv: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
ffnLength, ok := kv["feed_forward_length"].([]uint32)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("feed_forward_length has unexpected type: %T", kv["feed_forward_length"])
|
||||||
|
}
|
||||||
|
if got, want := ffnLength, []uint32{0, 1856, 0, 0, 1856}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected feed_forward_length: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNemotronHTensorsTransforms(t *testing.T) {
|
||||||
|
m := &nemotronHModel{NGroups: 8}
|
||||||
|
in := []Tensor{
|
||||||
|
&fakeTensor{
|
||||||
|
name: "blk.0.ssm_a",
|
||||||
|
shape: []uint64{4},
|
||||||
|
data: []float32{0, 1, 2, 3},
|
||||||
|
},
|
||||||
|
&fakeTensor{
|
||||||
|
name: "blk.0.ssm_d",
|
||||||
|
shape: []uint64{4},
|
||||||
|
data: []float32{0, 1, 2, 3},
|
||||||
|
},
|
||||||
|
&fakeTensor{
|
||||||
|
name: "blk.0.ssm_norm.weight",
|
||||||
|
shape: []uint64{16},
|
||||||
|
data: make([]float32, 16),
|
||||||
|
},
|
||||||
|
&fakeTensor{
|
||||||
|
name: "blk.0.ssm_conv1d.weight",
|
||||||
|
shape: []uint64{10, 1, 4},
|
||||||
|
data: make([]float32, 40),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
out := m.Tensors(in)
|
||||||
|
if len(out) != len(in) {
|
||||||
|
t.Fatalf("unexpected output tensor count: got %d want %d", len(out), len(in))
|
||||||
|
}
|
||||||
|
|
||||||
|
got := map[string]struct {
|
||||||
|
shape []uint64
|
||||||
|
writer io.WriterTo
|
||||||
|
}{}
|
||||||
|
for _, t := range out {
|
||||||
|
got[t.Name] = struct {
|
||||||
|
shape []uint64
|
||||||
|
writer io.WriterTo
|
||||||
|
}{shape: t.Shape, writer: t.WriterTo}
|
||||||
|
}
|
||||||
|
|
||||||
|
if shape := got["blk.0.ssm_a"].shape; !slices.Equal(shape, []uint64{4, 1}) {
|
||||||
|
t.Fatalf("unexpected ssm_a shape: %v", shape)
|
||||||
|
}
|
||||||
|
if shape := got["blk.0.ssm_d"].shape; !slices.Equal(shape, []uint64{4, 1}) {
|
||||||
|
t.Fatalf("unexpected ssm_d shape: %v", shape)
|
||||||
|
}
|
||||||
|
if shape := got["blk.0.ssm_norm.weight"].shape; !slices.Equal(shape, []uint64{8, 2}) {
|
||||||
|
t.Fatalf("unexpected ssm_norm shape: %v", shape)
|
||||||
|
}
|
||||||
|
if shape := got["blk.0.ssm_conv1d.weight"].shape; !slices.Equal(shape, []uint64{10, 4}) {
|
||||||
|
t.Fatalf("unexpected ssm_conv1d shape: %v", shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := got["blk.0.ssm_a"].writer.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
values := make([]float32, 4)
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &values); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// 0 -> -exp(0) == -1
|
||||||
|
if values[0] != -1 {
|
||||||
|
t.Fatalf("unexpected transformed ssm_a[0]: got %v want -1", values[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNemotronHLoadModelMetadata(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
config := `{
|
||||||
|
"architectures": ["NemotronHForCausalLM"],
|
||||||
|
"model_type": "nemotron_h",
|
||||||
|
"num_hidden_layers": 4,
|
||||||
|
"hidden_size": 512,
|
||||||
|
"max_position_embeddings": 32768,
|
||||||
|
"num_attention_heads": 8,
|
||||||
|
"num_key_value_heads": 2,
|
||||||
|
"head_dim": 64,
|
||||||
|
"layer_norm_epsilon": 1e-5,
|
||||||
|
"conv_kernel": 4,
|
||||||
|
"ssm_state_size": 128,
|
||||||
|
"mamba_num_heads": 16,
|
||||||
|
"mamba_head_dim": 32,
|
||||||
|
"n_groups": 8,
|
||||||
|
"hybrid_override_pattern": "ME*M",
|
||||||
|
"n_routed_experts": 16,
|
||||||
|
"num_experts_per_tok": 4,
|
||||||
|
"moe_intermediate_size": 256
|
||||||
|
}`
|
||||||
|
|
||||||
|
if err := os.WriteFile(filepath.Join(tempDir, "config.json"), []byte(config), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(filepath.Join(tempDir, "tokenizer.json"), []byte(`{}`), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
kv, _, err := LoadModelMetadata(os.DirFS(tempDir))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, ok := kv.(*nemotronHModel); !ok {
|
||||||
|
t.Fatalf("unexpected converter type: %T", kv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNemotronHReplacementsLatentProjections(t *testing.T) {
|
||||||
|
m := &nemotronHModel{}
|
||||||
|
r := strings.NewReplacer(m.Replacements()...)
|
||||||
|
|
||||||
|
if got, want := r.Replace("backbone.layers.1.mixer.fc1_latent_proj.weight"), "blk.1.ffn_latent_in.weight"; got != want {
|
||||||
|
t.Fatalf("unexpected fc1 replacement: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := r.Replace("backbone.layers.1.mixer.fc2_latent_proj.weight"), "blk.1.ffn_latent_out.weight"; got != want {
|
||||||
|
t.Fatalf("unexpected fc2 replacement: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package convert
|
package convert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"math"
|
"math"
|
||||||
@@ -13,8 +14,21 @@ import (
|
|||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type qwen3NextModel struct {
|
type qwen3NextRopeScaling struct {
|
||||||
ModelParameters
|
Type string `json:"type"`
|
||||||
|
Factor ropeFactor `json:"factor"`
|
||||||
|
MropeSection []int32 `json:"mrope_section"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type qwen3NextRopeParams struct {
|
||||||
|
MRopeInterleaved bool `json:"mrope_interleaved"`
|
||||||
|
MropeSection []int32 `json:"mrope_section"`
|
||||||
|
RopeType string `json:"rope_type"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type qwen3NextTextConfig struct {
|
||||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
HiddenSize uint32 `json:"hidden_size"`
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
@@ -28,12 +42,13 @@ type qwen3NextModel struct {
|
|||||||
// MoE config
|
// MoE config
|
||||||
NumExperts uint32 `json:"num_experts"`
|
NumExperts uint32 `json:"num_experts"`
|
||||||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||||
NormTopkProb bool `json:"norm_topk_prob"`
|
NormTopkProb *bool `json:"norm_topk_prob"`
|
||||||
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
|
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
|
||||||
SharedExpertIntermSize uint32 `json:"shared_expert_intermediate_size"`
|
SharedExpertIntermSize uint32 `json:"shared_expert_intermediate_size"`
|
||||||
|
|
||||||
// Hybrid attention config
|
// Hybrid attention config
|
||||||
FullAttentionInterval uint32 `json:"full_attention_interval"`
|
FullAttentionInterval uint32 `json:"full_attention_interval"`
|
||||||
|
LayerTypes []string `json:"layer_types"`
|
||||||
|
|
||||||
// Linear attention (Gated Delta Net) config
|
// Linear attention (Gated Delta Net) config
|
||||||
LinearConvKernelDim uint32 `json:"linear_conv_kernel_dim"`
|
LinearConvKernelDim uint32 `json:"linear_conv_kernel_dim"`
|
||||||
@@ -44,15 +59,101 @@ type qwen3NextModel struct {
|
|||||||
|
|
||||||
// RoPE config
|
// RoPE config
|
||||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||||
RopeScaling struct {
|
RopeScaling qwen3NextRopeScaling `json:"rope_scaling"`
|
||||||
Type string `json:"type"`
|
RopeParameters qwen3NextRopeParams `json:"rope_parameters"`
|
||||||
Factor ropeFactor `json:"factor"`
|
}
|
||||||
} `json:"rope_scaling"`
|
|
||||||
|
type qwen3NextVisionConfig struct {
|
||||||
|
Depth uint32 `json:"depth"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
NumHeads uint32 `json:"num_heads"`
|
||||||
|
InChannels uint32 `json:"in_channels"`
|
||||||
|
PatchSize uint32 `json:"patch_size"`
|
||||||
|
SpatialMergeSize uint32 `json:"spatial_merge_size"`
|
||||||
|
RMSNormEps float32 `json:"layer_norm_epsilon"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
TemporalPatchSize uint32 `json:"temporal_patch_size"`
|
||||||
|
DeepstackVisualIndexes []int32 `json:"deepstack_visual_indexes"`
|
||||||
|
|
||||||
|
Size struct {
|
||||||
|
ShortestEdge uint32 `json:"shortest_edge"`
|
||||||
|
LongestEdge uint32 `json:"longest_edge"`
|
||||||
|
} `json:"size"`
|
||||||
|
|
||||||
|
ImageMean []float32 `json:"image_mean"`
|
||||||
|
ImageStd []float32 `json:"image_std"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type qwen3NextModel struct {
|
||||||
|
ModelParameters
|
||||||
|
qwen3NextTextConfig
|
||||||
|
|
||||||
|
TextConfig *qwen3NextTextConfig `json:"text_config"`
|
||||||
|
VisionModel qwen3NextVisionConfig `json:"vision_config"`
|
||||||
|
|
||||||
|
ImageTokenID uint32 `json:"image_token_id"`
|
||||||
|
VisionStartTokenID uint32 `json:"vision_start_token_id"`
|
||||||
|
VisionEndTokenID uint32 `json:"vision_end_token_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ ModelConverter = (*qwen3NextModel)(nil)
|
var _ ModelConverter = (*qwen3NextModel)(nil)
|
||||||
|
|
||||||
func (q *qwen3NextModel) parseMore(_ fs.FS) error {
|
func (q *qwen3NextModel) parseMore(fsys fs.FS) error {
|
||||||
|
if q.TextConfig != nil {
|
||||||
|
q.qwen3NextTextConfig = *q.TextConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
if q.RopeTheta == 0 {
|
||||||
|
q.RopeTheta = q.RopeParameters.RopeTheta
|
||||||
|
}
|
||||||
|
if q.PartialRotaryFactor == 0 {
|
||||||
|
q.PartialRotaryFactor = q.RopeParameters.PartialRotaryFactor
|
||||||
|
}
|
||||||
|
|
||||||
|
if q.RopeScaling.Type == "" && q.RopeParameters.RopeType != "" {
|
||||||
|
q.RopeScaling.Type = q.RopeParameters.RopeType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pull vision preprocessing fields when present.
|
||||||
|
if q.VisionModel.Depth > 0 {
|
||||||
|
if bts, err := fs.ReadFile(fsys, "preprocessor_config.json"); err == nil {
|
||||||
|
var pre struct {
|
||||||
|
Size struct {
|
||||||
|
ShortestEdge uint32 `json:"shortest_edge"`
|
||||||
|
LongestEdge uint32 `json:"longest_edge"`
|
||||||
|
} `json:"size"`
|
||||||
|
PatchSize uint32 `json:"patch_size"`
|
||||||
|
TemporalPatchSize uint32 `json:"temporal_patch_size"`
|
||||||
|
MergeSize uint32 `json:"merge_size"`
|
||||||
|
ImageMean []float32 `json:"image_mean"`
|
||||||
|
ImageStd []float32 `json:"image_std"`
|
||||||
|
}
|
||||||
|
if json.Unmarshal(bts, &pre) == nil {
|
||||||
|
if q.VisionModel.PatchSize == 0 {
|
||||||
|
q.VisionModel.PatchSize = pre.PatchSize
|
||||||
|
}
|
||||||
|
if q.VisionModel.TemporalPatchSize == 0 {
|
||||||
|
q.VisionModel.TemporalPatchSize = pre.TemporalPatchSize
|
||||||
|
}
|
||||||
|
if q.VisionModel.SpatialMergeSize == 0 {
|
||||||
|
q.VisionModel.SpatialMergeSize = pre.MergeSize
|
||||||
|
}
|
||||||
|
if q.VisionModel.Size.ShortestEdge == 0 {
|
||||||
|
q.VisionModel.Size.ShortestEdge = pre.Size.ShortestEdge
|
||||||
|
}
|
||||||
|
if q.VisionModel.Size.LongestEdge == 0 {
|
||||||
|
q.VisionModel.Size.LongestEdge = pre.Size.LongestEdge
|
||||||
|
}
|
||||||
|
if len(q.VisionModel.ImageMean) == 0 {
|
||||||
|
q.VisionModel.ImageMean = pre.ImageMean
|
||||||
|
}
|
||||||
|
if len(q.VisionModel.ImageStd) == 0 {
|
||||||
|
q.VisionModel.ImageStd = pre.ImageStd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if q.NumHiddenLayers == 0 {
|
if q.NumHiddenLayers == 0 {
|
||||||
return fmt.Errorf("qwen3next: num_hidden_layers must be set")
|
return fmt.Errorf("qwen3next: num_hidden_layers must be set")
|
||||||
}
|
}
|
||||||
@@ -74,36 +175,96 @@ func (q *qwen3NextModel) parseMore(_ fs.FS) error {
|
|||||||
if q.LinearNumKeyHeads == 0 || q.LinearNumValueHeads == 0 || q.LinearKeyHeadDim == 0 || q.LinearValueHeadDim == 0 {
|
if q.LinearNumKeyHeads == 0 || q.LinearNumValueHeads == 0 || q.LinearKeyHeadDim == 0 || q.LinearValueHeadDim == 0 {
|
||||||
return fmt.Errorf("qwen3next: linear attention config must be set (linear_num_key_heads, linear_num_value_heads, linear_key_head_dim, linear_value_head_dim)")
|
return fmt.Errorf("qwen3next: linear attention config must be set (linear_num_key_heads, linear_num_value_heads, linear_key_head_dim, linear_value_head_dim)")
|
||||||
}
|
}
|
||||||
if q.FullAttentionInterval == 0 {
|
if _, err := q.kvHeadCounts(); err != nil {
|
||||||
return fmt.Errorf("qwen3next: full_attention_interval must be set")
|
return err
|
||||||
}
|
|
||||||
if q.FullAttentionInterval > q.NumHiddenLayers {
|
|
||||||
return fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds num_hidden_layers (%d)", q.FullAttentionInterval, q.NumHiddenLayers)
|
|
||||||
}
|
|
||||||
|
|
||||||
hasFull := false
|
|
||||||
for i := range q.NumHiddenLayers {
|
|
||||||
if (i+1)%q.FullAttentionInterval == 0 {
|
|
||||||
hasFull = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !hasFull {
|
|
||||||
return fmt.Errorf("qwen3next: head_count_kv would be all zeros (full_attention_interval=%d, num_hidden_layers=%d)", q.FullAttentionInterval, q.NumHiddenLayers)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (q *qwen3NextModel) kvHeadCounts() ([]uint32, error) {
|
||||||
|
if len(q.LayerTypes) > 0 {
|
||||||
|
kv := make([]uint32, q.NumHiddenLayers)
|
||||||
|
hasFull := false
|
||||||
|
hasRecurrent := false
|
||||||
|
for i := range q.NumHiddenLayers {
|
||||||
|
layerType := ""
|
||||||
|
if i < uint32(len(q.LayerTypes)) {
|
||||||
|
layerType = q.LayerTypes[i]
|
||||||
|
}
|
||||||
|
if layerType == "full_attention" {
|
||||||
|
kv[i] = q.NumKeyValueHeads
|
||||||
|
hasFull = true
|
||||||
|
} else {
|
||||||
|
hasRecurrent = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasFull || !hasRecurrent {
|
||||||
|
return nil, fmt.Errorf("qwen3next: layer_types must include both full_attention and linear_attention")
|
||||||
|
}
|
||||||
|
return kv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if q.FullAttentionInterval == 0 {
|
||||||
|
return nil, fmt.Errorf("qwen3next: full_attention_interval must be set")
|
||||||
|
}
|
||||||
|
if q.FullAttentionInterval > q.NumHiddenLayers {
|
||||||
|
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds num_hidden_layers (%d)", q.FullAttentionInterval, q.NumHiddenLayers)
|
||||||
|
}
|
||||||
|
|
||||||
|
kv := make([]uint32, q.NumHiddenLayers)
|
||||||
|
hasFull := false
|
||||||
|
for i := range q.NumHiddenLayers {
|
||||||
|
if (i+1)%q.FullAttentionInterval == 0 {
|
||||||
|
kv[i] = q.NumKeyValueHeads
|
||||||
|
hasFull = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasFull {
|
||||||
|
return nil, fmt.Errorf("qwen3next: head_count_kv would be all zeros (full_attention_interval=%d, num_hidden_layers=%d)", q.FullAttentionInterval, q.NumHiddenLayers)
|
||||||
|
}
|
||||||
|
return kv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *qwen3NextModel) ropeSections() []int32 {
|
||||||
|
if len(q.RopeParameters.MropeSection) > 0 {
|
||||||
|
return q.RopeParameters.MropeSection
|
||||||
|
}
|
||||||
|
return q.RopeScaling.MropeSection
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *qwen3NextModel) shouldReorderVHeads() bool {
|
||||||
|
modelType := strings.ToLower(q.ModelType)
|
||||||
|
if strings.Contains(modelType, "qwen3_next") || strings.Contains(modelType, "qwen3next") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, arch := range q.Architectures {
|
||||||
|
arch = strings.ToLower(arch)
|
||||||
|
if strings.Contains(arch, "qwen3next") || strings.Contains(arch, "qwen3_next") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to qwen3.5 layout for all other qwen3next-family imports.
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (q *qwen3NextModel) KV(t *Tokenizer) KV {
|
func (q *qwen3NextModel) KV(t *Tokenizer) KV {
|
||||||
kv := q.ModelParameters.KV(t)
|
kv := q.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "qwen3next"
|
|
||||||
kv["tokenizer.ggml.pre"] = "qwen2"
|
arch := "qwen35"
|
||||||
|
if q.NumExperts > 0 {
|
||||||
|
arch = "qwen35moe"
|
||||||
|
}
|
||||||
|
kv["general.architecture"] = arch
|
||||||
|
kv["tokenizer.ggml.pre"] = "qwen35"
|
||||||
kv["block_count"] = q.NumHiddenLayers
|
kv["block_count"] = q.NumHiddenLayers
|
||||||
kv["context_length"] = q.MaxPositionEmbeddings
|
kv["context_length"] = q.MaxPositionEmbeddings
|
||||||
kv["embedding_length"] = q.HiddenSize
|
kv["embedding_length"] = q.HiddenSize
|
||||||
kv["feed_forward_length"] = q.IntermediateSize
|
kv["feed_forward_length"] = q.IntermediateSize
|
||||||
kv["attention.head_count"] = q.NumAttentionHeads
|
kv["attention.head_count"] = q.NumAttentionHeads
|
||||||
|
|
||||||
headDim := q.HeadDim
|
headDim := q.HeadDim
|
||||||
if headDim == 0 && q.NumAttentionHeads > 0 {
|
if headDim == 0 && q.NumAttentionHeads > 0 {
|
||||||
headDim = q.HiddenSize / q.NumAttentionHeads
|
headDim = q.HiddenSize / q.NumAttentionHeads
|
||||||
@@ -113,18 +274,31 @@ func (q *qwen3NextModel) KV(t *Tokenizer) KV {
|
|||||||
kv["attention.layer_norm_rms_epsilon"] = q.RMSNormEPS
|
kv["attention.layer_norm_rms_epsilon"] = q.RMSNormEPS
|
||||||
kv["rope.freq_base"] = q.RopeTheta
|
kv["rope.freq_base"] = q.RopeTheta
|
||||||
|
|
||||||
// RoPE dimension count (partial rotary)
|
|
||||||
// partial_rotary_factor = 0.25 means only 25% of head_dim uses RoPE
|
|
||||||
partialRotary := q.PartialRotaryFactor
|
partialRotary := q.PartialRotaryFactor
|
||||||
if partialRotary > 0 && partialRotary <= 1 {
|
if partialRotary > 0 && partialRotary <= 1 {
|
||||||
kv["rope.dimension_count"] = uint32(float32(headDim) * partialRotary)
|
kv["rope.dimension_count"] = uint32(float32(headDim) * partialRotary)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MoE config
|
if sections := q.ropeSections(); len(sections) > 0 {
|
||||||
|
kv["mrope_sections"] = sections
|
||||||
|
kv["rope.mrope_section"] = sections
|
||||||
|
kv["rope.dimension_sections"] = sections
|
||||||
|
}
|
||||||
|
if q.RopeParameters.MRopeInterleaved {
|
||||||
|
kv["rope.mrope_interleaved"] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if q.RopeScaling.Type != "" && q.RopeScaling.Type != "default" {
|
||||||
|
kv["rope.scaling.type"] = q.RopeScaling.Type
|
||||||
|
kv["rope.scaling.factor"] = q.RopeScaling.Factor
|
||||||
|
}
|
||||||
|
|
||||||
if q.NumExperts > 0 {
|
if q.NumExperts > 0 {
|
||||||
kv["expert_count"] = q.NumExperts
|
kv["expert_count"] = q.NumExperts
|
||||||
kv["expert_used_count"] = q.NumExpertsPerToken
|
kv["expert_used_count"] = q.NumExpertsPerToken
|
||||||
kv["norm_top_k_prob"] = q.NormTopkProb
|
if q.NormTopkProb != nil {
|
||||||
|
kv["norm_top_k_prob"] = *q.NormTopkProb
|
||||||
|
}
|
||||||
if q.MoEIntermediateSize > 0 {
|
if q.MoEIntermediateSize > 0 {
|
||||||
kv["expert_feed_forward_length"] = q.MoEIntermediateSize
|
kv["expert_feed_forward_length"] = q.MoEIntermediateSize
|
||||||
}
|
}
|
||||||
@@ -133,33 +307,66 @@ func (q *qwen3NextModel) KV(t *Tokenizer) KV {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SSM/Linear attention config
|
|
||||||
// d_inner = linear_value_head_dim * linear_num_value_heads
|
|
||||||
dInner := q.LinearValueHeadDim * q.LinearNumValueHeads
|
dInner := q.LinearValueHeadDim * q.LinearNumValueHeads
|
||||||
kv["ssm.inner_size"] = dInner
|
kv["ssm.inner_size"] = dInner
|
||||||
kv["ssm.state_size"] = q.LinearKeyHeadDim // head_k_dim
|
kv["ssm.state_size"] = q.LinearKeyHeadDim
|
||||||
kv["ssm.group_count"] = q.LinearNumKeyHeads // num_k_heads
|
kv["ssm.group_count"] = q.LinearNumKeyHeads
|
||||||
kv["ssm.time_step_rank"] = q.LinearNumValueHeads // num_v_heads
|
kv["ssm.time_step_rank"] = q.LinearNumValueHeads
|
||||||
kv["ssm.conv_kernel"] = q.LinearConvKernelDim
|
kv["ssm.conv_kernel"] = q.LinearConvKernelDim
|
||||||
interval := q.FullAttentionInterval
|
if q.shouldReorderVHeads() {
|
||||||
kv["full_attention_interval"] = interval
|
kv["ssm.v_head_reordered"] = true
|
||||||
|
|
||||||
// Build per-layer KV head count array to identify layer types
|
|
||||||
// 0 = recurrent (linear attention), non-zero = full attention
|
|
||||||
kvHeadCounts := make([]uint32, q.NumHiddenLayers)
|
|
||||||
for i := range q.NumHiddenLayers {
|
|
||||||
// Full attention every full_attention_interval layers (starting at interval-1)
|
|
||||||
if interval > 0 && (i+1)%interval == 0 {
|
|
||||||
kvHeadCounts[i] = q.NumKeyValueHeads
|
|
||||||
}
|
}
|
||||||
// else stays 0 (recurrent layer)
|
if q.FullAttentionInterval > 0 {
|
||||||
|
kv["full_attention_interval"] = q.FullAttentionInterval
|
||||||
}
|
}
|
||||||
kv["attention.head_count_kv"] = kvHeadCounts
|
|
||||||
|
|
||||||
// RoPE scaling
|
if headCounts, err := q.kvHeadCounts(); err == nil {
|
||||||
if q.RopeScaling.Type != "" {
|
kv["attention.head_count_kv"] = headCounts
|
||||||
kv["rope.scaling.type"] = q.RopeScaling.Type
|
}
|
||||||
kv["rope.scaling.factor"] = q.RopeScaling.Factor
|
|
||||||
|
if q.VisionModel.Depth > 0 {
|
||||||
|
kv["vision.block_count"] = q.VisionModel.Depth
|
||||||
|
kv["vision.embedding_length"] = q.VisionModel.HiddenSize
|
||||||
|
kv["vision.attention.head_count"] = q.VisionModel.NumHeads
|
||||||
|
kv["vision.num_channels"] = q.VisionModel.InChannels
|
||||||
|
if q.VisionModel.PatchSize > 0 {
|
||||||
|
kv["vision.patch_size"] = q.VisionModel.PatchSize
|
||||||
|
}
|
||||||
|
if q.VisionModel.SpatialMergeSize > 0 {
|
||||||
|
kv["vision.spatial_merge_size"] = q.VisionModel.SpatialMergeSize
|
||||||
|
}
|
||||||
|
if q.VisionModel.RMSNormEps > 0 {
|
||||||
|
kv["vision.attention.layer_norm_epsilon"] = q.VisionModel.RMSNormEps
|
||||||
|
}
|
||||||
|
if q.VisionModel.RopeTheta > 0 {
|
||||||
|
kv["vision.rope.freq_base"] = q.VisionModel.RopeTheta
|
||||||
|
}
|
||||||
|
if q.VisionModel.TemporalPatchSize > 0 {
|
||||||
|
kv["vision.temporal_patch_size"] = q.VisionModel.TemporalPatchSize
|
||||||
|
}
|
||||||
|
kv["vision.deepstack_visual_indexes"] = q.VisionModel.DeepstackVisualIndexes
|
||||||
|
if q.VisionModel.Size.ShortestEdge > 0 {
|
||||||
|
kv["vision.shortest_edge"] = q.VisionModel.Size.ShortestEdge
|
||||||
|
}
|
||||||
|
if q.VisionModel.Size.LongestEdge > 0 {
|
||||||
|
kv["vision.longest_edge"] = q.VisionModel.Size.LongestEdge
|
||||||
|
}
|
||||||
|
if len(q.VisionModel.ImageMean) > 0 {
|
||||||
|
kv["vision.image_mean"] = q.VisionModel.ImageMean
|
||||||
|
}
|
||||||
|
if len(q.VisionModel.ImageStd) > 0 {
|
||||||
|
kv["vision.image_std"] = q.VisionModel.ImageStd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if q.ImageTokenID > 0 {
|
||||||
|
kv["image_token_id"] = q.ImageTokenID
|
||||||
|
}
|
||||||
|
if q.VisionStartTokenID > 0 {
|
||||||
|
kv["vision_start_token_id"] = q.VisionStartTokenID
|
||||||
|
}
|
||||||
|
if q.VisionEndTokenID > 0 {
|
||||||
|
kv["vision_end_token_id"] = q.VisionEndTokenID
|
||||||
}
|
}
|
||||||
|
|
||||||
return kv
|
return kv
|
||||||
@@ -168,7 +375,6 @@ func (q *qwen3NextModel) KV(t *Tokenizer) KV {
|
|||||||
func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
var out []*ggml.Tensor
|
var out []*ggml.Tensor
|
||||||
|
|
||||||
// Create merges for expert tensors - stack individual experts into batched tensors
|
|
||||||
merges := make([]merge, q.NumHiddenLayers*3)
|
merges := make([]merge, q.NumHiddenLayers*3)
|
||||||
for i := range q.NumHiddenLayers {
|
for i := range q.NumHiddenLayers {
|
||||||
merges[i*3+0] = merge{
|
merges[i*3+0] = merge{
|
||||||
@@ -185,16 +391,13 @@ func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Merge expert tensors
|
|
||||||
merged, remaining := mergeTensors(ts, merges...)
|
merged, remaining := mergeTensors(ts, merges...)
|
||||||
out = append(out, merged...)
|
out = append(out, merged...)
|
||||||
|
|
||||||
// Process remaining tensors
|
|
||||||
for _, t := range remaining {
|
for _, t := range remaining {
|
||||||
name := t.Name()
|
name := t.Name()
|
||||||
shape := t.Shape()
|
shape := t.Shape()
|
||||||
|
|
||||||
// Split linear_attn.in_proj_qkvz (ssm_in) into attn_qkv + attn_gate when possible
|
|
||||||
if strings.HasSuffix(name, ".ssm_in.weight") {
|
if strings.HasSuffix(name, ".ssm_in.weight") {
|
||||||
if qkv, gate, ok := q.splitQKVZTensor(t); ok {
|
if qkv, gate, ok := q.splitQKVZTensor(t); ok {
|
||||||
out = append(out, qkv, gate)
|
out = append(out, qkv, gate)
|
||||||
@@ -204,84 +407,299 @@ func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
// Add 1 to norm weights (except ssm_norm which is linear_attn.norm)
|
case strings.Contains(name, ".mlp.experts.gate_up_proj"):
|
||||||
// This matches the Python converter behavior for qwen3next
|
out = append(out, slices.Collect(splitDim(t, 1,
|
||||||
|
split{Replacer: strings.NewReplacer(".mlp.experts.gate_up_proj", ".ffn_gate_exps.weight")},
|
||||||
|
split{Replacer: strings.NewReplacer(".mlp.experts.gate_up_proj", ".ffn_up_exps.weight")},
|
||||||
|
))...)
|
||||||
|
|
||||||
|
case strings.Contains(name, ".mlp.experts.down_proj"):
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: strings.NewReplacer(".mlp.experts.down_proj", ".ffn_down_exps.weight").Replace(name),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: slices.Clone(shape),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
|
||||||
|
case strings.HasPrefix(name, "v.blk.") && strings.Contains(name, ".attn_qkv"):
|
||||||
|
out = append(out, slices.Collect(splitDim(t, 0,
|
||||||
|
split{Replacer: strings.NewReplacer("attn_qkv", "attn_q")},
|
||||||
|
split{Replacer: strings.NewReplacer("attn_qkv", "attn_k")},
|
||||||
|
split{Replacer: strings.NewReplacer("attn_qkv", "attn_v")},
|
||||||
|
))...)
|
||||||
|
|
||||||
|
case strings.Contains(name, "patch_embed") && strings.HasSuffix(name, "weight"):
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: name,
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: append([]uint64{shape[0] * shape[1]}, shape[2:]...),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
|
||||||
case strings.HasSuffix(name, "_norm.weight") && !strings.HasSuffix(name, ".ssm_norm.weight"):
|
case strings.HasSuffix(name, "_norm.weight") && !strings.HasSuffix(name, ".ssm_norm.weight"):
|
||||||
t.SetRepacker(q.addOne)
|
t.SetRepacker(q.addOne)
|
||||||
out = append(out, &ggml.Tensor{
|
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||||
Name: name,
|
|
||||||
Kind: t.Kind(),
|
|
||||||
Shape: slices.Clone(shape),
|
|
||||||
WriterTo: t,
|
|
||||||
})
|
|
||||||
|
|
||||||
// Handle linear attention A_log -> ssm_a (negate and exp)
|
|
||||||
// Note: name has already been transformed by Replacements at this point
|
|
||||||
case strings.HasSuffix(name, ".ssm_a"):
|
case strings.HasSuffix(name, ".ssm_a"):
|
||||||
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
t.SetRepacker(q.repackSSMA())
|
||||||
// Compute -exp(A_log)
|
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||||
result := make([]float32, len(data))
|
|
||||||
for i, v := range data {
|
case strings.HasSuffix(name, ".attn_qkv.weight"):
|
||||||
// -exp(v)
|
if q.shouldReorderVHeads() {
|
||||||
result[i] = -float32(math.Exp(float64(v)))
|
t.SetRepacker(q.repackAttnQKV())
|
||||||
}
|
}
|
||||||
return result, nil
|
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||||
})
|
|
||||||
out = append(out, &ggml.Tensor{
|
case strings.HasSuffix(name, ".attn_gate.weight"):
|
||||||
Name: name,
|
if q.shouldReorderVHeads() {
|
||||||
Kind: t.Kind(),
|
// HF tensor layout is [out_features, in_features]; reorder rows.
|
||||||
Shape: slices.Clone(shape),
|
t.SetRepacker(q.repackReorderDim(0, int(q.LinearValueHeadDim)))
|
||||||
WriterTo: t,
|
}
|
||||||
})
|
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||||
|
|
||||||
|
case strings.HasSuffix(name, ".ssm_beta.weight"), strings.HasSuffix(name, ".ssm_alpha.weight"):
|
||||||
|
if q.shouldReorderVHeads() {
|
||||||
|
// HF tensor layout is [out_features, in_features]; reorder rows.
|
||||||
|
t.SetRepacker(q.repackReorderDim(0, 1))
|
||||||
|
}
|
||||||
|
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||||
|
|
||||||
|
case strings.HasSuffix(name, ".ssm_dt"):
|
||||||
|
if q.shouldReorderVHeads() {
|
||||||
|
t.SetRepacker(q.repackReorderDim(0, 1))
|
||||||
|
}
|
||||||
|
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||||
|
|
||||||
|
case strings.HasSuffix(name, ".ssm_out.weight"):
|
||||||
|
if q.shouldReorderVHeads() {
|
||||||
|
// HF out_proj layout is [out_features, in_features]; reorder columns.
|
||||||
|
t.SetRepacker(q.repackReorderDim(1, int(q.LinearValueHeadDim)))
|
||||||
|
}
|
||||||
|
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||||
|
|
||||||
// Squeeze conv1d weights: [1, D, K] or [D, 1, K] -> [D, K]
|
|
||||||
case strings.HasSuffix(name, ".ssm_conv1d.weight"):
|
case strings.HasSuffix(name, ".ssm_conv1d.weight"):
|
||||||
newShape := slices.Clone(shape)
|
newShape := slices.Clone(shape)
|
||||||
if len(shape) == 3 {
|
if len(shape) == 3 {
|
||||||
if shape[0] == 1 {
|
if shape[0] == 1 {
|
||||||
// [1, D, K] -> [D, K]
|
|
||||||
newShape = []uint64{shape[1], shape[2]}
|
newShape = []uint64{shape[1], shape[2]}
|
||||||
} else if shape[1] == 1 {
|
} else if shape[1] == 1 {
|
||||||
// [D, 1, K] -> [D, K]
|
|
||||||
newShape = []uint64{shape[0], shape[2]}
|
newShape = []uint64{shape[0], shape[2]}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
out = append(out, &ggml.Tensor{
|
if q.shouldReorderVHeads() {
|
||||||
Name: name,
|
t.SetRepacker(q.repackConv1D())
|
||||||
Kind: t.Kind(),
|
|
||||||
Shape: newShape,
|
|
||||||
WriterTo: t,
|
|
||||||
})
|
|
||||||
// Squeeze shared expert gate: [D, 1] or [1, D] -> [D]
|
|
||||||
case strings.HasSuffix(name, ".ffn_gate_inp_shexp.weight"):
|
|
||||||
newShape := slices.Clone(shape)
|
|
||||||
if len(shape) == 2 {
|
|
||||||
if shape[0] == 1 && shape[1] > 1 {
|
|
||||||
newShape = []uint64{shape[1]}
|
|
||||||
} else if shape[1] == 1 && shape[0] > 1 {
|
|
||||||
newShape = []uint64{shape[0]}
|
|
||||||
}
|
}
|
||||||
}
|
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: newShape, WriterTo: t})
|
||||||
out = append(out, &ggml.Tensor{
|
|
||||||
Name: name,
|
|
||||||
Kind: t.Kind(),
|
|
||||||
Shape: newShape,
|
|
||||||
WriterTo: t,
|
|
||||||
})
|
|
||||||
|
|
||||||
default:
|
default:
|
||||||
out = append(out, &ggml.Tensor{
|
out = append(out, &ggml.Tensor{Name: name, Kind: t.Kind(), Shape: slices.Clone(shape), WriterTo: t})
|
||||||
Name: name,
|
|
||||||
Kind: t.Kind(),
|
|
||||||
Shape: slices.Clone(shape),
|
|
||||||
WriterTo: t,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (q *qwen3NextModel) repackReorderDim(dim, headDim int) Repacker {
|
||||||
|
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||||
|
if !q.shouldReorderVHeads() {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
numK := int(q.LinearNumKeyHeads)
|
||||||
|
numVPerK := int(q.LinearNumValueHeads / q.LinearNumKeyHeads)
|
||||||
|
return reorderHeadLayout(data, shape, dim, numK, numVPerK, headDim)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *qwen3NextModel) repackAttnQKV() Repacker {
|
||||||
|
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||||
|
if !q.shouldReorderVHeads() || len(shape) != 2 {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rows := int(shape[0])
|
||||||
|
cols := int(shape[1])
|
||||||
|
numK := int(q.LinearNumKeyHeads)
|
||||||
|
numV := int(q.LinearNumValueHeads)
|
||||||
|
headK := int(q.LinearKeyHeadDim)
|
||||||
|
headV := int(q.LinearValueHeadDim)
|
||||||
|
qDim := headK * numK
|
||||||
|
kDim := headK * numK
|
||||||
|
vDim := headV * numV
|
||||||
|
qkvDim := qDim + kDim + vDim
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case rows == qkvDim:
|
||||||
|
// HF layout: [out_features, in_features]. Keep Q/K rows unchanged and
|
||||||
|
// reorder only V rows from grouped -> tiled head layout.
|
||||||
|
out := make([]float32, len(data))
|
||||||
|
qkRows := qDim + kDim
|
||||||
|
qkSize := qkRows * cols
|
||||||
|
copy(out[:qkSize], data[:qkSize])
|
||||||
|
|
||||||
|
vStart := qkSize
|
||||||
|
vEnd := vStart + vDim*cols
|
||||||
|
reorderedV, err := reorderHeadLayout(data[vStart:vEnd], []uint64{uint64(vDim), uint64(cols)}, 0, numK, numV/numK, headV)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
copy(out[vStart:vEnd], reorderedV)
|
||||||
|
copy(out[vEnd:], data[vEnd:])
|
||||||
|
return out, nil
|
||||||
|
|
||||||
|
case cols == qkvDim:
|
||||||
|
// Fallback for already-transposed [in_features, out_features] tensors.
|
||||||
|
out := make([]float32, len(data))
|
||||||
|
copy(out, data)
|
||||||
|
for r := range rows {
|
||||||
|
base := r * cols
|
||||||
|
vStart := base + qDim + kDim
|
||||||
|
vEnd := vStart + vDim
|
||||||
|
reorderedV, err := reorderHeadLayout(out[vStart:vEnd], []uint64{uint64(vDim)}, 0, numK, numV/numK, headV)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
copy(out[vStart:vEnd], reorderedV)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *qwen3NextModel) repackConv1D() Repacker {
|
||||||
|
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||||
|
if !q.shouldReorderVHeads() {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
normShape := slices.Clone(shape)
|
||||||
|
if len(shape) == 3 {
|
||||||
|
if shape[0] == 1 {
|
||||||
|
normShape = []uint64{shape[1], shape[2]}
|
||||||
|
} else if shape[1] == 1 {
|
||||||
|
normShape = []uint64{shape[0], shape[2]}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(normShape) != 2 {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rows := int(normShape[0])
|
||||||
|
cols := int(normShape[1])
|
||||||
|
numK := int(q.LinearNumKeyHeads)
|
||||||
|
numV := int(q.LinearNumValueHeads)
|
||||||
|
headK := int(q.LinearKeyHeadDim)
|
||||||
|
headV := int(q.LinearValueHeadDim)
|
||||||
|
qkChannels := 2 * headK * numK
|
||||||
|
totalChannels := qkChannels + headV*numV
|
||||||
|
if qkChannels <= 0 {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case rows == totalChannels:
|
||||||
|
// HF layout after squeeze: [channels, kernel]
|
||||||
|
out := make([]float32, len(data))
|
||||||
|
prefix := qkChannels * cols
|
||||||
|
copy(out[:prefix], data[:prefix])
|
||||||
|
reorderedV, err := reorderHeadLayout(data[prefix:], []uint64{uint64(totalChannels - qkChannels), uint64(cols)}, 0, numK, numV/numK, headV)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
copy(out[prefix:], reorderedV)
|
||||||
|
return out, nil
|
||||||
|
case cols == totalChannels:
|
||||||
|
// Fallback for transposed [kernel, channels]
|
||||||
|
out := make([]float32, len(data))
|
||||||
|
copy(out, data)
|
||||||
|
vChannels := totalChannels - qkChannels
|
||||||
|
for r := range rows {
|
||||||
|
base := r * cols
|
||||||
|
vStart := base + qkChannels
|
||||||
|
vEnd := vStart + vChannels
|
||||||
|
reorderedV, err := reorderHeadLayout(out[vStart:vEnd], []uint64{uint64(vChannels)}, 0, numK, numV/numK, headV)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
copy(out[vStart:vEnd], reorderedV)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
default:
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *qwen3NextModel) repackSSMA() Repacker {
|
||||||
|
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||||
|
result := make([]float32, len(data))
|
||||||
|
for i, v := range data {
|
||||||
|
result[i] = -float32(math.Exp(float64(v)))
|
||||||
|
}
|
||||||
|
if !q.shouldReorderVHeads() {
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
numK := int(q.LinearNumKeyHeads)
|
||||||
|
numVPerK := int(q.LinearNumValueHeads / q.LinearNumKeyHeads)
|
||||||
|
return reorderHeadLayout(result, shape, 0, numK, numVPerK, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func reorderHeadLayout(data []float32, shape []uint64, dim int, numKHeads, numVPerK, headDim int) ([]float32, error) {
|
||||||
|
if len(shape) == 0 || numKHeads <= 0 || numVPerK <= 0 || headDim <= 0 {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dims := make([]int, len(shape))
|
||||||
|
for i := range shape {
|
||||||
|
dims[i] = int(shape[i])
|
||||||
|
}
|
||||||
|
if dim < 0 {
|
||||||
|
dim += len(dims)
|
||||||
|
}
|
||||||
|
if dim < 0 || dim >= len(dims) {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := numKHeads * numVPerK * headDim
|
||||||
|
if dims[dim] != expected {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
newShape := make([]int, 0, len(dims)+2)
|
||||||
|
newShape = append(newShape, dims[:dim]...)
|
||||||
|
newShape = append(newShape, numKHeads, numVPerK, headDim)
|
||||||
|
newShape = append(newShape, dims[dim+1:]...)
|
||||||
|
|
||||||
|
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||||
|
if err := tt.Reshape(newShape...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
perm := make([]int, len(newShape))
|
||||||
|
for i := range perm {
|
||||||
|
perm[i] = i
|
||||||
|
}
|
||||||
|
perm[dim], perm[dim+1] = perm[dim+1], perm[dim]
|
||||||
|
|
||||||
|
tt, err := tensor.Transpose(tt, perm...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tt = tensor.Materialize(tt)
|
||||||
|
|
||||||
|
total := 1
|
||||||
|
for _, d := range dims {
|
||||||
|
total *= d
|
||||||
|
}
|
||||||
|
if err := tt.Reshape(total); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return native.VectorF32(tt.(*tensor.Dense))
|
||||||
|
}
|
||||||
|
|
||||||
type qkvzSplitSpec struct {
|
type qkvzSplitSpec struct {
|
||||||
hidden int
|
hidden int
|
||||||
headKDim int
|
headKDim int
|
||||||
@@ -369,7 +787,6 @@ func (q *qwen3NextModel) repackQKVZ(spec qkvzSplitSpec, extractGate bool) Repack
|
|||||||
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Convert to [hidden, out_features] layout for slicing
|
|
||||||
tt, err = tensor.Transpose(tt, 1, 0)
|
tt, err = tensor.Transpose(tt, 1, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -444,7 +861,6 @@ func (q *qwen3NextModel) repackQKVZ(spec qkvzSplitSpec, extractGate bool) Repack
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// addOne adds 1.0 to all elements in the tensor (for norm weights)
|
|
||||||
func (*qwen3NextModel) addOne(_ string, data []float32, shape []uint64) ([]float32, error) {
|
func (*qwen3NextModel) addOne(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||||
n := tensor.New(tensor.WithShape(int(shape[0])), tensor.WithBacking(data))
|
n := tensor.New(tensor.WithShape(int(shape[0])), tensor.WithBacking(data))
|
||||||
ones := tensor.Ones(tensor.Float32, int(shape[0]))
|
ones := tensor.Ones(tensor.Float32, int(shape[0]))
|
||||||
@@ -471,10 +887,21 @@ func (q *qwen3NextModel) Replacements() []string {
|
|||||||
return []string{
|
return []string{
|
||||||
// Embeddings and output
|
// Embeddings and output
|
||||||
"lm_head", "output",
|
"lm_head", "output",
|
||||||
|
"model.language_model.embed_tokens", "token_embd",
|
||||||
|
"model.language_model.norm", "output_norm",
|
||||||
|
"model.language_model.layers", "blk",
|
||||||
"model.embed_tokens", "token_embd",
|
"model.embed_tokens", "token_embd",
|
||||||
"model.norm", "output_norm",
|
"model.norm", "output_norm",
|
||||||
"model.layers", "blk",
|
"model.layers", "blk",
|
||||||
|
|
||||||
|
// Vision
|
||||||
|
"model.visual", "v",
|
||||||
|
"patch_embed.proj", "patch_embed",
|
||||||
|
"blocks", "blk",
|
||||||
|
"attn.qkv", "attn_qkv",
|
||||||
|
"attn.proj", "attn_out",
|
||||||
|
"deepstack_merger_list", "deepstack_merger",
|
||||||
|
|
||||||
// Layer norms
|
// Layer norms
|
||||||
"input_layernorm", "attn_norm",
|
"input_layernorm", "attn_norm",
|
||||||
"post_attention_layernorm", "post_attention_norm",
|
"post_attention_layernorm", "post_attention_norm",
|
||||||
@@ -487,9 +914,16 @@ func (q *qwen3NextModel) Replacements() []string {
|
|||||||
"self_attn.v_proj", "attn_v",
|
"self_attn.v_proj", "attn_v",
|
||||||
"self_attn.o_proj", "attn_output",
|
"self_attn.o_proj", "attn_output",
|
||||||
|
|
||||||
// Linear attention (Gated Delta Net)
|
// Linear attention (legacy qwen3next)
|
||||||
"linear_attn.in_proj_qkvz", "ssm_in",
|
"linear_attn.in_proj_qkvz", "ssm_in",
|
||||||
"linear_attn.in_proj_ba", "ssm_ba",
|
"linear_attn.in_proj_ba", "ssm_ba",
|
||||||
|
|
||||||
|
// Linear attention (qwen35)
|
||||||
|
"linear_attn.in_proj_qkv", "attn_qkv",
|
||||||
|
"linear_attn.in_proj_z", "attn_gate",
|
||||||
|
"linear_attn.in_proj_a", "ssm_alpha",
|
||||||
|
"linear_attn.in_proj_b", "ssm_beta",
|
||||||
|
|
||||||
"linear_attn.conv1d", "ssm_conv1d",
|
"linear_attn.conv1d", "ssm_conv1d",
|
||||||
"linear_attn.dt_bias", "ssm_dt",
|
"linear_attn.dt_bias", "ssm_dt",
|
||||||
"linear_attn.dt_proj", "ssm_dt",
|
"linear_attn.dt_proj", "ssm_dt",
|
||||||
@@ -497,14 +931,14 @@ func (q *qwen3NextModel) Replacements() []string {
|
|||||||
"linear_attn.norm", "ssm_norm",
|
"linear_attn.norm", "ssm_norm",
|
||||||
"linear_attn.out_proj", "ssm_out",
|
"linear_attn.out_proj", "ssm_out",
|
||||||
|
|
||||||
// MoE (experts are stacked via mergeTensors, not replaced here)
|
// MoE
|
||||||
"mlp.gate.weight", "ffn_gate_inp.weight",
|
"mlp.gate.weight", "ffn_gate_inp.weight",
|
||||||
"mlp.shared_expert.down_proj", "ffn_down_shexp",
|
"mlp.shared_expert.down_proj", "ffn_down_shexp",
|
||||||
"mlp.shared_expert.gate_proj", "ffn_gate_shexp",
|
"mlp.shared_expert.gate_proj", "ffn_gate_shexp",
|
||||||
"mlp.shared_expert.up_proj", "ffn_up_shexp",
|
"mlp.shared_expert.up_proj", "ffn_up_shexp",
|
||||||
"mlp.shared_expert_gate", "ffn_gate_inp_shexp",
|
"mlp.shared_expert_gate", "ffn_gate_inp_shexp",
|
||||||
|
|
||||||
// Dense FFN (if any layers use it)
|
// Dense FFN
|
||||||
"mlp.down_proj", "ffn_down",
|
"mlp.down_proj", "ffn_down",
|
||||||
"mlp.gate_proj", "ffn_gate",
|
"mlp.gate_proj", "ffn_gate",
|
||||||
"mlp.up_proj", "ffn_up",
|
"mlp.up_proj", "ffn_up",
|
||||||
|
|||||||
563
convert/convert_qwen3next_test.go
Normal file
563
convert/convert_qwen3next_test.go
Normal file
@@ -0,0 +1,563 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"os"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
func boolPtr(v bool) *bool {
|
||||||
|
return &v
|
||||||
|
}
|
||||||
|
|
||||||
|
func readTensorData(t *testing.T, tensor *ggml.Tensor) []float32 {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := tensor.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
numel := 1
|
||||||
|
for _, d := range tensor.Shape {
|
||||||
|
numel *= int(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
values := make([]float32, numel)
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &values); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return values
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3NextLegacyModelTypeDisablesReorder(t *testing.T) {
|
||||||
|
m := &qwen3NextModel{
|
||||||
|
ModelParameters: ModelParameters{
|
||||||
|
ModelType: "qwen3_next",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.shouldReorderVHeads() {
|
||||||
|
t.Fatalf("legacy qwen3_next model_type should not reorder v-head layout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3NextLegacyArchitectureDisablesReorder(t *testing.T) {
|
||||||
|
m := &qwen3NextModel{
|
||||||
|
ModelParameters: ModelParameters{
|
||||||
|
Architectures: []string{"Qwen3NextForCausalLM"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.shouldReorderVHeads() {
|
||||||
|
t.Fatalf("legacy Qwen3Next architecture should not reorder v-head layout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3NextKVLegacyConfig(t *testing.T) {
|
||||||
|
m := &qwen3NextModel{
|
||||||
|
ModelParameters: ModelParameters{
|
||||||
|
ModelType: "qwen3_next",
|
||||||
|
},
|
||||||
|
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||||
|
MaxPositionEmbeddings: 8192,
|
||||||
|
HiddenSize: 512,
|
||||||
|
NumHiddenLayers: 4,
|
||||||
|
IntermediateSize: 2048,
|
||||||
|
NumAttentionHeads: 8,
|
||||||
|
NumKeyValueHeads: 2,
|
||||||
|
HeadDim: 64,
|
||||||
|
RopeTheta: 1_000_000,
|
||||||
|
RMSNormEPS: 1e-6,
|
||||||
|
|
||||||
|
NumExperts: 8,
|
||||||
|
NumExpertsPerToken: 2,
|
||||||
|
NormTopkProb: boolPtr(true),
|
||||||
|
MoEIntermediateSize: 256,
|
||||||
|
SharedExpertIntermSize: 512,
|
||||||
|
|
||||||
|
FullAttentionInterval: 2,
|
||||||
|
|
||||||
|
LinearConvKernelDim: 4,
|
||||||
|
LinearKeyHeadDim: 64,
|
||||||
|
LinearNumKeyHeads: 2,
|
||||||
|
LinearNumValueHeads: 4,
|
||||||
|
LinearValueHeadDim: 64,
|
||||||
|
|
||||||
|
PartialRotaryFactor: 0.25,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.parseMore(os.DirFS(t.TempDir())); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
|
||||||
|
if got, want := kv["general.architecture"], "qwen35moe"; got != want {
|
||||||
|
t.Fatalf("unexpected architecture: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := kv["tokenizer.ggml.pre"], "qwen35"; got != want {
|
||||||
|
t.Fatalf("unexpected tokenizer pre: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
headCountKV, ok := kv["attention.head_count_kv"].([]uint32)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("attention.head_count_kv has unexpected type: %T", kv["attention.head_count_kv"])
|
||||||
|
}
|
||||||
|
if got, want := headCountKV, []uint32{0, 2, 0, 2}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected attention.head_count_kv: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := kv["ssm.v_head_reordered"]; ok {
|
||||||
|
t.Fatalf("legacy qwen3next should not enable ssm.v_head_reordered")
|
||||||
|
}
|
||||||
|
if got, want := kv["norm_top_k_prob"], true; got != want {
|
||||||
|
t.Fatalf("unexpected norm_top_k_prob: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35MoeOmitsNormTopKProbWhenUnset(t *testing.T) {
|
||||||
|
m := &qwen3NextModel{
|
||||||
|
ModelParameters: ModelParameters{
|
||||||
|
ModelType: "qwen3_5",
|
||||||
|
},
|
||||||
|
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||||
|
MaxPositionEmbeddings: 4096,
|
||||||
|
HiddenSize: 512,
|
||||||
|
NumHiddenLayers: 4,
|
||||||
|
IntermediateSize: 2048,
|
||||||
|
NumAttentionHeads: 8,
|
||||||
|
NumKeyValueHeads: 2,
|
||||||
|
HeadDim: 64,
|
||||||
|
RopeTheta: 1_000_000,
|
||||||
|
RMSNormEPS: 1e-6,
|
||||||
|
NumExperts: 8,
|
||||||
|
NumExpertsPerToken: 2,
|
||||||
|
FullAttentionInterval: 2,
|
||||||
|
LinearConvKernelDim: 4,
|
||||||
|
LinearKeyHeadDim: 64,
|
||||||
|
LinearNumKeyHeads: 2,
|
||||||
|
LinearNumValueHeads: 4,
|
||||||
|
LinearValueHeadDim: 64,
|
||||||
|
PartialRotaryFactor: 0.25,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.parseMore(os.DirFS(t.TempDir())); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
|
||||||
|
if _, ok := kv["norm_top_k_prob"]; ok {
|
||||||
|
t.Fatalf("expected norm_top_k_prob to be omitted when not set in config")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35KVFromTextConfig(t *testing.T) {
|
||||||
|
m := &qwen3NextModel{
|
||||||
|
ModelParameters: ModelParameters{
|
||||||
|
ModelType: "qwen3_5",
|
||||||
|
},
|
||||||
|
TextConfig: &qwen3NextTextConfig{
|
||||||
|
MaxPositionEmbeddings: 16384,
|
||||||
|
HiddenSize: 1024,
|
||||||
|
NumHiddenLayers: 4,
|
||||||
|
IntermediateSize: 4096,
|
||||||
|
NumAttentionHeads: 8,
|
||||||
|
NumKeyValueHeads: 4,
|
||||||
|
HeadDim: 128,
|
||||||
|
RMSNormEPS: 1e-6,
|
||||||
|
|
||||||
|
LayerTypes: []string{
|
||||||
|
"linear_attention",
|
||||||
|
"full_attention",
|
||||||
|
"linear_attention",
|
||||||
|
"full_attention",
|
||||||
|
},
|
||||||
|
|
||||||
|
LinearConvKernelDim: 4,
|
||||||
|
LinearKeyHeadDim: 128,
|
||||||
|
LinearNumKeyHeads: 2,
|
||||||
|
LinearNumValueHeads: 4,
|
||||||
|
LinearValueHeadDim: 128,
|
||||||
|
|
||||||
|
RopeParameters: qwen3NextRopeParams{
|
||||||
|
MRopeInterleaved: true,
|
||||||
|
MropeSection: []int32{11, 11, 10},
|
||||||
|
RopeType: "default",
|
||||||
|
RopeTheta: 10_000_000,
|
||||||
|
PartialRotaryFactor: 0.25,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
VisionModel: qwen3NextVisionConfig{
|
||||||
|
Depth: 2,
|
||||||
|
HiddenSize: 128,
|
||||||
|
NumHeads: 4,
|
||||||
|
InChannels: 3,
|
||||||
|
PatchSize: 16,
|
||||||
|
SpatialMergeSize: 2,
|
||||||
|
RMSNormEps: 1e-6,
|
||||||
|
RopeTheta: 10_000,
|
||||||
|
TemporalPatchSize: 2,
|
||||||
|
DeepstackVisualIndexes: []int32{1},
|
||||||
|
},
|
||||||
|
ImageTokenID: 1001,
|
||||||
|
VisionStartTokenID: 1002,
|
||||||
|
VisionEndTokenID: 1003,
|
||||||
|
}
|
||||||
|
m.VisionModel.Size.ShortestEdge = 224
|
||||||
|
m.VisionModel.Size.LongestEdge = 4096
|
||||||
|
m.VisionModel.ImageMean = []float32{0.5, 0.5, 0.5}
|
||||||
|
m.VisionModel.ImageStd = []float32{0.2, 0.2, 0.2}
|
||||||
|
|
||||||
|
if err := m.parseMore(os.DirFS(t.TempDir())); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
|
||||||
|
if got, want := kv["general.architecture"], "qwen35"; got != want {
|
||||||
|
t.Fatalf("unexpected architecture: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
headCountKV, ok := kv["attention.head_count_kv"].([]uint32)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("attention.head_count_kv has unexpected type: %T", kv["attention.head_count_kv"])
|
||||||
|
}
|
||||||
|
if got, want := headCountKV, []uint32{0, 4, 0, 4}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected attention.head_count_kv: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, ok := kv["ssm.v_head_reordered"].(bool); !ok || !got {
|
||||||
|
t.Fatalf("expected ssm.v_head_reordered=true, got %v (%T)", kv["ssm.v_head_reordered"], kv["ssm.v_head_reordered"])
|
||||||
|
}
|
||||||
|
|
||||||
|
mrope, ok := kv["mrope_sections"].([]int32)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("mrope_sections has unexpected type: %T", kv["mrope_sections"])
|
||||||
|
}
|
||||||
|
if got, want := mrope, []int32{11, 11, 10}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected mrope_sections: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
ropeSections, ok := kv["rope.dimension_sections"].([]int32)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("rope.dimension_sections has unexpected type: %T", kv["rope.dimension_sections"])
|
||||||
|
}
|
||||||
|
if got, want := ropeSections, []int32{11, 11, 10}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected rope.dimension_sections: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, ok := kv["rope.mrope_interleaved"].(bool); !ok || !got {
|
||||||
|
t.Fatalf("expected rope.mrope_interleaved=true, got %v (%T)", kv["rope.mrope_interleaved"], kv["rope.mrope_interleaved"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := kv["vision.block_count"], uint32(2); got != want {
|
||||||
|
t.Fatalf("unexpected vision.block_count: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3NextReplacements(t *testing.T) {
|
||||||
|
r := strings.NewReplacer((&qwen3NextModel{}).Replacements()...)
|
||||||
|
|
||||||
|
if got, want := r.Replace("model.language_model.layers.1.linear_attn.in_proj_qkv.weight"), "blk.1.attn_qkv.weight"; got != want {
|
||||||
|
t.Fatalf("unexpected language-model replacement: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := r.Replace("model.visual.blocks.0.attn.qkv.weight"), "v.blk.0.attn_qkv.weight"; got != want {
|
||||||
|
t.Fatalf("unexpected vision replacement: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := r.Replace("model.layers.1.linear_attn.in_proj_qkvz.weight"), "blk.1.ssm_in.weight"; got != want {
|
||||||
|
t.Fatalf("unexpected legacy replacement: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ReordersVHeads(t *testing.T) {
|
||||||
|
m := &qwen3NextModel{
|
||||||
|
ModelParameters: ModelParameters{
|
||||||
|
ModelType: "qwen3_5",
|
||||||
|
},
|
||||||
|
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||||
|
LinearNumKeyHeads: 2,
|
||||||
|
LinearNumValueHeads: 4,
|
||||||
|
LinearValueHeadDim: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
out := m.Tensors([]Tensor{
|
||||||
|
&fakeTensor{
|
||||||
|
name: "blk.0.attn_gate.weight",
|
||||||
|
shape: []uint64{4, 2},
|
||||||
|
data: []float32{0, 1, 2, 3, 4, 5, 6, 7},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if len(out) != 1 {
|
||||||
|
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := readTensorData(t, out[0]), []float32{0, 1, 4, 5, 2, 3, 6, 7}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected data: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ReordersAttnQKVOutputDim(t *testing.T) {
|
||||||
|
m := &qwen3NextModel{
|
||||||
|
ModelParameters: ModelParameters{
|
||||||
|
ModelType: "qwen3_5",
|
||||||
|
},
|
||||||
|
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||||
|
LinearNumKeyHeads: 2,
|
||||||
|
LinearNumValueHeads: 4,
|
||||||
|
LinearKeyHeadDim: 1,
|
||||||
|
LinearValueHeadDim: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
out := m.Tensors([]Tensor{
|
||||||
|
&fakeTensor{
|
||||||
|
name: "blk.0.attn_qkv.weight",
|
||||||
|
shape: []uint64{8, 2}, // [out_features, in_features] (HF layout)
|
||||||
|
data: []float32{
|
||||||
|
0, 1, // q0
|
||||||
|
2, 3, // q1
|
||||||
|
4, 5, // k0
|
||||||
|
6, 7, // k1
|
||||||
|
10, 11, // v(k0,v0)
|
||||||
|
12, 13, // v(k0,v1)
|
||||||
|
20, 21, // v(k1,v0)
|
||||||
|
22, 23, // v(k1,v1)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if len(out) != 1 {
|
||||||
|
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := readTensorData(t, out[0]), []float32{
|
||||||
|
0, 1, 2, 3, 4, 5, 6, 7,
|
||||||
|
10, 11, 20, 21, 12, 13, 22, 23,
|
||||||
|
}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected qkv data: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ReordersSsmOutInputDim(t *testing.T) {
|
||||||
|
m := &qwen3NextModel{
|
||||||
|
ModelParameters: ModelParameters{
|
||||||
|
ModelType: "qwen3_5",
|
||||||
|
},
|
||||||
|
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||||
|
LinearNumKeyHeads: 2,
|
||||||
|
LinearNumValueHeads: 4,
|
||||||
|
LinearValueHeadDim: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
out := m.Tensors([]Tensor{
|
||||||
|
&fakeTensor{
|
||||||
|
name: "blk.0.ssm_out.weight",
|
||||||
|
shape: []uint64{2, 4},
|
||||||
|
data: []float32{0, 1, 2, 3, 4, 5, 6, 7},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if len(out) != 1 {
|
||||||
|
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := readTensorData(t, out[0]), []float32{0, 2, 1, 3, 4, 6, 5, 7}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected ssm_out data: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ReordersSsmBetaRows(t *testing.T) {
|
||||||
|
m := &qwen3NextModel{
|
||||||
|
ModelParameters: ModelParameters{
|
||||||
|
ModelType: "qwen3_5",
|
||||||
|
},
|
||||||
|
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||||
|
LinearNumKeyHeads: 2,
|
||||||
|
LinearNumValueHeads: 4,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
out := m.Tensors([]Tensor{
|
||||||
|
&fakeTensor{
|
||||||
|
name: "blk.0.ssm_beta.weight",
|
||||||
|
shape: []uint64{4, 2},
|
||||||
|
data: []float32{0, 1, 2, 3, 4, 5, 6, 7},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if len(out) != 1 {
|
||||||
|
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := readTensorData(t, out[0]), []float32{0, 1, 4, 5, 2, 3, 6, 7}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected ssm_beta data: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35ReordersConv1DChannelDim(t *testing.T) {
|
||||||
|
m := &qwen3NextModel{
|
||||||
|
ModelParameters: ModelParameters{
|
||||||
|
ModelType: "qwen3_5",
|
||||||
|
},
|
||||||
|
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||||
|
LinearNumKeyHeads: 2,
|
||||||
|
LinearNumValueHeads: 4,
|
||||||
|
LinearKeyHeadDim: 1,
|
||||||
|
LinearValueHeadDim: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
out := m.Tensors([]Tensor{
|
||||||
|
&fakeTensor{
|
||||||
|
name: "blk.0.ssm_conv1d.weight",
|
||||||
|
shape: []uint64{8, 2}, // [channels, kernel] after squeeze
|
||||||
|
data: []float32{
|
||||||
|
0, 1, // q0
|
||||||
|
2, 3, // q1
|
||||||
|
4, 5, // k0
|
||||||
|
6, 7, // k1
|
||||||
|
10, 11, // v(k0,v0)
|
||||||
|
12, 13, // v(k0,v1)
|
||||||
|
20, 21, // v(k1,v0)
|
||||||
|
22, 23, // v(k1,v1)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if len(out) != 1 {
|
||||||
|
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := readTensorData(t, out[0]), []float32{
|
||||||
|
0, 1, 2, 3, 4, 5, 6, 7,
|
||||||
|
10, 11, 20, 21, 12, 13, 22, 23,
|
||||||
|
}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected conv1d data: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLegacyQwen3NextDoesNotReorderVHeads(t *testing.T) {
|
||||||
|
m := &qwen3NextModel{
|
||||||
|
ModelParameters: ModelParameters{
|
||||||
|
ModelType: "qwen3_next",
|
||||||
|
},
|
||||||
|
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||||
|
LinearNumKeyHeads: 2,
|
||||||
|
LinearNumValueHeads: 4,
|
||||||
|
LinearValueHeadDim: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
out := m.Tensors([]Tensor{
|
||||||
|
&fakeTensor{
|
||||||
|
name: "blk.0.attn_gate.weight",
|
||||||
|
shape: []uint64{4, 1},
|
||||||
|
data: []float32{0, 1, 2, 3},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if len(out) != 1 {
|
||||||
|
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := readTensorData(t, out[0]), []float32{0, 1, 2, 3}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected data for legacy qwen3next: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35MoePackedExperts(t *testing.T) {
|
||||||
|
m := &qwen3NextModel{
|
||||||
|
qwen3NextTextConfig: qwen3NextTextConfig{
|
||||||
|
NumHiddenLayers: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
out := m.Tensors([]Tensor{
|
||||||
|
&fakeTensor{
|
||||||
|
name: "blk.0.mlp.experts.gate_up_proj",
|
||||||
|
shape: []uint64{2, 4, 3},
|
||||||
|
data: []float32{
|
||||||
|
0, 1, 2,
|
||||||
|
3, 4, 5,
|
||||||
|
6, 7, 8,
|
||||||
|
9, 10, 11,
|
||||||
|
12, 13, 14,
|
||||||
|
15, 16, 17,
|
||||||
|
18, 19, 20,
|
||||||
|
21, 22, 23,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
&fakeTensor{
|
||||||
|
name: "blk.0.mlp.experts.down_proj",
|
||||||
|
shape: []uint64{2, 5, 3},
|
||||||
|
data: make([]float32, 2*5*3),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
get := func(name string) *ggml.Tensor {
|
||||||
|
for _, tensor := range out {
|
||||||
|
if tensor.Name == name {
|
||||||
|
return tensor
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
gate := get("blk.0.ffn_gate_exps.weight")
|
||||||
|
if gate == nil {
|
||||||
|
t.Fatalf("missing tensor %q", "blk.0.ffn_gate_exps.weight")
|
||||||
|
}
|
||||||
|
if got, want := gate.Shape, []uint64{2, 2, 3}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected gate shape: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := readTensorData(t, gate), []float32{
|
||||||
|
0, 1, 2, 3, 4, 5,
|
||||||
|
12, 13, 14, 15, 16, 17,
|
||||||
|
}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected gate values: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
up := get("blk.0.ffn_up_exps.weight")
|
||||||
|
if up == nil {
|
||||||
|
t.Fatalf("missing tensor %q", "blk.0.ffn_up_exps.weight")
|
||||||
|
}
|
||||||
|
if got, want := up.Shape, []uint64{2, 2, 3}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected up shape: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := readTensorData(t, up), []float32{
|
||||||
|
6, 7, 8, 9, 10, 11,
|
||||||
|
18, 19, 20, 21, 22, 23,
|
||||||
|
}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected up values: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
down := get("blk.0.ffn_down_exps.weight")
|
||||||
|
if down == nil {
|
||||||
|
t.Fatalf("missing tensor %q", "blk.0.ffn_down_exps.weight")
|
||||||
|
}
|
||||||
|
if got, want := down.Shape, []uint64{2, 5, 3}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected down shape: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen35SharedExpertGateKeepsMatrixShape(t *testing.T) {
|
||||||
|
m := &qwen3NextModel{}
|
||||||
|
|
||||||
|
out := m.Tensors([]Tensor{
|
||||||
|
&fakeTensor{
|
||||||
|
name: "blk.0.ffn_gate_inp_shexp.weight",
|
||||||
|
shape: []uint64{1, 4},
|
||||||
|
data: []float32{0, 1, 2, 3},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if len(out) != 1 {
|
||||||
|
t.Fatalf("unexpected output tensor count: got %d want 1", len(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := out[0].Shape, []uint64{1, 4}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected shared gate shape: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
97
convert/json_compat.go
Normal file
97
convert/json_compat.go
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
// sanitizeNonFiniteJSON rewrites non-standard JSON numeric tokens that some
|
||||||
|
// HF configs emit (Infinity, -Infinity, NaN) into standard JSON numbers.
|
||||||
|
//
|
||||||
|
// This is intentionally conservative:
|
||||||
|
// - only runs outside quoted strings
|
||||||
|
// - only rewrites full tokens
|
||||||
|
//
|
||||||
|
// We map these values to 0 because encoding/json rejects non-finite values,
|
||||||
|
// and these fields are typically model-side metadata not consumed by the
|
||||||
|
// converter.
|
||||||
|
func sanitizeNonFiniteJSON(in []byte) []byte {
|
||||||
|
if len(in) == 0 {
|
||||||
|
return in
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make([]byte, 0, len(in))
|
||||||
|
inString := false
|
||||||
|
escape := false
|
||||||
|
|
||||||
|
for i := 0; i < len(in); {
|
||||||
|
c := in[i]
|
||||||
|
|
||||||
|
if inString {
|
||||||
|
out = append(out, c)
|
||||||
|
if escape {
|
||||||
|
escape = false
|
||||||
|
} else if c == '\\' {
|
||||||
|
escape = true
|
||||||
|
} else if c == '"' {
|
||||||
|
inString = false
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if c == '"' {
|
||||||
|
inString = true
|
||||||
|
out = append(out, c)
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasToken(in, i, "-Infinity") {
|
||||||
|
out = append(out, '0')
|
||||||
|
i += len("-Infinity")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasToken(in, i, "Infinity") {
|
||||||
|
out = append(out, '0')
|
||||||
|
i += len("Infinity")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasToken(in, i, "NaN") {
|
||||||
|
out = append(out, '0')
|
||||||
|
i += len("NaN")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
out = append(out, c)
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasToken(in []byte, at int, tok string) bool {
|
||||||
|
end := at + len(tok)
|
||||||
|
if at < 0 || end > len(in) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if string(in[at:end]) != tok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if at > 0 && !isJSONValuePrefixBoundary(in[at-1]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if end < len(in) && !isJSONValueSuffixBoundary(in[end]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func isJSONWhitespace(b byte) bool {
|
||||||
|
return b == ' ' || b == '\t' || b == '\n' || b == '\r'
|
||||||
|
}
|
||||||
|
|
||||||
|
func isJSONValuePrefixBoundary(b byte) bool {
|
||||||
|
return isJSONWhitespace(b) || b == ':' || b == ',' || b == '['
|
||||||
|
}
|
||||||
|
|
||||||
|
func isJSONValueSuffixBoundary(b byte) bool {
|
||||||
|
return isJSONWhitespace(b) || b == ',' || b == ']' || b == '}'
|
||||||
|
}
|
||||||
46
convert/json_compat_test.go
Normal file
46
convert/json_compat_test.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestSanitizeNonFiniteJSON(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
in string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "infinity token",
|
||||||
|
in: `{"a":[0,Infinity,1]}`,
|
||||||
|
want: `{"a":[0,0,1]}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative infinity token",
|
||||||
|
in: `{"a":-Infinity}`,
|
||||||
|
want: `{"a":0}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nan token",
|
||||||
|
in: `{"a":NaN}`,
|
||||||
|
want: `{"a":0}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tokens inside strings untouched",
|
||||||
|
in: `{"a":"Infinity -Infinity NaN","b":Infinity}`,
|
||||||
|
want: `{"a":"Infinity -Infinity NaN","b":0}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "identifier-like token untouched",
|
||||||
|
in: `{"a":InfinityValue}`,
|
||||||
|
want: `{"a":InfinityValue}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := string(sanitizeNonFiniteJSON([]byte(tt.in)))
|
||||||
|
if got != tt.want {
|
||||||
|
t.Fatalf("sanitizeNonFiniteJSON() = %q, want %q", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -101,6 +101,8 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
|
|||||||
t.Pre = "deepseek-coder"
|
t.Pre = "deepseek-coder"
|
||||||
case "1ff7f41064896984db5d1bb6ff64fa4bc29007d08c1b439e505b7392777a319e":
|
case "1ff7f41064896984db5d1bb6ff64fa4bc29007d08c1b439e505b7392777a319e":
|
||||||
t.Pre = "qwen2"
|
t.Pre = "qwen2"
|
||||||
|
case "00431aed57e696b747435f734d1e3b9b1bfd931a121fb5cac7129e97c181e9ba":
|
||||||
|
t.Pre = "qwen35"
|
||||||
case "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855":
|
case "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855":
|
||||||
// noop, empty pretokenizer
|
// noop, empty pretokenizer
|
||||||
default:
|
default:
|
||||||
@@ -213,6 +215,11 @@ type tokenizer struct {
|
|||||||
PreTokenizer struct {
|
PreTokenizer struct {
|
||||||
PreTokenizers []struct {
|
PreTokenizers []struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
Behavior string `json:"behavior"`
|
||||||
|
Invert bool `json:"invert"`
|
||||||
|
AddPrefixSpace bool `json:"add_prefix_space"`
|
||||||
|
TrimOffsets bool `json:"trim_offsets"`
|
||||||
|
UseRegex bool `json:"use_regex"`
|
||||||
Pattern struct {
|
Pattern struct {
|
||||||
Regex string `json:"Regex"`
|
Regex string `json:"Regex"`
|
||||||
} `json:"pattern"`
|
} `json:"pattern"`
|
||||||
|
|||||||
@@ -191,6 +191,84 @@ func TestParseTokenizer(t *testing.T) {
|
|||||||
Pre: "default",
|
Pre: "default",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "llama-bpe pretokenizer and control tokens",
|
||||||
|
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||||
|
"tokenizer.json": strings.NewReader(`{
|
||||||
|
"added_tokens": [
|
||||||
|
{"id": 1, "content": "<|startoftext|>", "special": true},
|
||||||
|
{"id": 6, "content": "<|im_start|>", "special": true},
|
||||||
|
{"id": 7, "content": "<|im_end|>", "special": true},
|
||||||
|
{"id": 8, "content": "<|tool_list_start|>", "special": true},
|
||||||
|
{"id": 9, "content": "<|tool_list_end|>", "special": true},
|
||||||
|
{"id": 10, "content": "<|tool_call_start|>", "special": true},
|
||||||
|
{"id": 11, "content": "<|tool_call_end|>", "special": true},
|
||||||
|
{"id": 12, "content": "<|tool_response_start|>", "special": true},
|
||||||
|
{"id": 13, "content": "<|tool_response_end|>", "special": true},
|
||||||
|
{"id": 396, "content": "<image>", "special": true},
|
||||||
|
{"id": 64400, "content": "<think>", "special": true},
|
||||||
|
{"id": 64401, "content": "</think>", "special": true}
|
||||||
|
],
|
||||||
|
"model": {
|
||||||
|
"vocab": {
|
||||||
|
"<|startoftext|>": 1,
|
||||||
|
"<|im_start|>": 6,
|
||||||
|
"<|im_end|>": 7,
|
||||||
|
"<|tool_list_start|>": 8,
|
||||||
|
"<|tool_list_end|>": 9,
|
||||||
|
"<|tool_call_start|>": 10,
|
||||||
|
"<|tool_call_end|>": 11,
|
||||||
|
"<|tool_response_start|>": 12,
|
||||||
|
"<|tool_response_end|>": 13,
|
||||||
|
"<image>": 396,
|
||||||
|
"<think>": 64400,
|
||||||
|
"</think>": 64401
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"pre_tokenizer": {
|
||||||
|
"type": "Sequence",
|
||||||
|
"pretokenizers": [
|
||||||
|
{
|
||||||
|
"type": "Split",
|
||||||
|
"pattern": {
|
||||||
|
"Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||||
|
},
|
||||||
|
"behavior": "Isolated",
|
||||||
|
"invert": false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "ByteLevel",
|
||||||
|
"add_prefix_space": false,
|
||||||
|
"trim_offsets": true,
|
||||||
|
"use_regex": false
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`),
|
||||||
|
}),
|
||||||
|
want: &Tokenizer{
|
||||||
|
Vocabulary: &Vocabulary{
|
||||||
|
Model: "gpt2",
|
||||||
|
Tokens: []string{
|
||||||
|
"<|startoftext|>",
|
||||||
|
"<|im_start|>",
|
||||||
|
"<|im_end|>",
|
||||||
|
"<|tool_list_start|>",
|
||||||
|
"<|tool_list_end|>",
|
||||||
|
"<|tool_call_start|>",
|
||||||
|
"<|tool_call_end|>",
|
||||||
|
"<|tool_response_start|>",
|
||||||
|
"<|tool_response_end|>",
|
||||||
|
"<image>",
|
||||||
|
"<think>",
|
||||||
|
"</think>",
|
||||||
|
},
|
||||||
|
Scores: []float32{1, 6, 7, 8, 9, 10, 11, 12, 13, 396, 64400, 64401},
|
||||||
|
Types: []int32{3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3},
|
||||||
|
},
|
||||||
|
Pre: "llama-bpe",
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "list string merges",
|
name: "list string merges",
|
||||||
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||||
@@ -308,6 +386,28 @@ func TestParseTokenizer(t *testing.T) {
|
|||||||
Pre: "default",
|
Pre: "default",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "qwen35 pretokenizer",
|
||||||
|
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||||
|
"tokenizer.json": strings.NewReader(`{
|
||||||
|
"pre_tokenizer": {
|
||||||
|
"type": "Sequence",
|
||||||
|
"pretokenizers": [
|
||||||
|
{
|
||||||
|
"type": "Split",
|
||||||
|
"pattern": {
|
||||||
|
"Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`),
|
||||||
|
}),
|
||||||
|
want: &Tokenizer{
|
||||||
|
Vocabulary: &Vocabulary{Model: "gpt2"},
|
||||||
|
Pre: "qwen35",
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range cases {
|
for _, tt := range cases {
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ To use Ollama with tools that expect the Anthropic API (like Claude Code), set t
|
|||||||
|
|
||||||
```shell
|
```shell
|
||||||
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
|
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
|
||||||
export ANTHROPIC_API_KEY="" # required but ignored
|
|
||||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -269,7 +268,7 @@ ollama launch claude --config
|
|||||||
Set the environment variables and run Claude Code:
|
Set the environment variables and run Claude Code:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder
|
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 claude --model qwen3-coder
|
||||||
```
|
```
|
||||||
|
|
||||||
Or set the environment variables in your shell profile:
|
Or set the environment variables in your shell profile:
|
||||||
@@ -277,7 +276,6 @@ Or set the environment variables in your shell profile:
|
|||||||
```shell
|
```shell
|
||||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||||
export ANTHROPIC_API_KEY=""
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Then run Claude Code with any Ollama model:
|
Then run Claude Code with any Ollama model:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ Ollama provides compatibility with parts of the [OpenAI API](https://platform.op
|
|||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
### Simple `v1/chat/completions` example
|
### Simple `/v1/chat/completions` example
|
||||||
|
|
||||||
<CodeGroup dropdown>
|
<CodeGroup dropdown>
|
||||||
|
|
||||||
@@ -57,7 +57,7 @@ curl -X POST http://localhost:11434/v1/chat/completions \
|
|||||||
|
|
||||||
</CodeGroup>
|
</CodeGroup>
|
||||||
|
|
||||||
### Simple `v1/responses` example
|
### Simple `/v1/responses` example
|
||||||
|
|
||||||
<CodeGroup dropdown>
|
<CodeGroup dropdown>
|
||||||
|
|
||||||
@@ -103,7 +103,7 @@ curl -X POST http://localhost:11434/v1/responses \
|
|||||||
|
|
||||||
</CodeGroup>
|
</CodeGroup>
|
||||||
|
|
||||||
### v1/chat/completions with vision example
|
### `/v1/chat/completions` with vision example
|
||||||
|
|
||||||
<CodeGroup dropdown>
|
<CodeGroup dropdown>
|
||||||
|
|
||||||
@@ -184,6 +184,7 @@ curl -X POST http://localhost:11434/v1/chat/completions \
|
|||||||
- [x] Reproducible outputs
|
- [x] Reproducible outputs
|
||||||
- [x] Vision
|
- [x] Vision
|
||||||
- [x] Tools
|
- [x] Tools
|
||||||
|
- [x] Reasoning/thinking control (for thinking models)
|
||||||
- [ ] Logprobs
|
- [ ] Logprobs
|
||||||
|
|
||||||
#### Supported request fields
|
#### Supported request fields
|
||||||
@@ -207,6 +208,9 @@ curl -X POST http://localhost:11434/v1/chat/completions \
|
|||||||
- [x] `top_p`
|
- [x] `top_p`
|
||||||
- [x] `max_tokens`
|
- [x] `max_tokens`
|
||||||
- [x] `tools`
|
- [x] `tools`
|
||||||
|
- [x] `reasoning_effort` (`"high"`, `"medium"`, `"low"`, `"none"`)
|
||||||
|
- [x] `reasoning`
|
||||||
|
- [x] `effort` (`"high"`, `"medium"`, `"low"`, `"none"`)
|
||||||
- [ ] `tool_choice`
|
- [ ] `tool_choice`
|
||||||
- [ ] `logit_bias`
|
- [ ] `logit_bias`
|
||||||
- [ ] `user`
|
- [ ] `user`
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ ollama launch claude
|
|||||||
Launch with a specific model:
|
Launch with a specific model:
|
||||||
|
|
||||||
```
|
```
|
||||||
ollama launch claude --model qwen3-coder
|
ollama launch claude --model qwen3.5
|
||||||
```
|
```
|
||||||
|
|
||||||
Configure without launching:
|
Configure without launching:
|
||||||
|
|||||||
@@ -51,6 +51,9 @@ Install prerequisites:
|
|||||||
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
|
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
|
||||||
- (Optional) VULKAN GPU support
|
- (Optional) VULKAN GPU support
|
||||||
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
|
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
|
||||||
|
- (Optional) MLX engine support
|
||||||
|
- [CUDA 13+ SDK](https://developer.nvidia.com/cuda-downloads)
|
||||||
|
- [cuDNN 9+](https://developer.nvidia.com/cudnn)
|
||||||
|
|
||||||
Then, configure and build the project:
|
Then, configure and build the project:
|
||||||
|
|
||||||
@@ -101,6 +104,10 @@ Install prerequisites:
|
|||||||
- (Optional) VULKAN GPU support
|
- (Optional) VULKAN GPU support
|
||||||
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
|
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
|
||||||
- Or install via package manager: `sudo apt install vulkan-sdk` (Ubuntu/Debian) or `sudo dnf install vulkan-sdk` (Fedora/CentOS)
|
- Or install via package manager: `sudo apt install vulkan-sdk` (Ubuntu/Debian) or `sudo dnf install vulkan-sdk` (Fedora/CentOS)
|
||||||
|
- (Optional) MLX engine support
|
||||||
|
- [CUDA 13+ SDK](https://developer.nvidia.com/cuda-downloads)
|
||||||
|
- [cuDNN 9+](https://developer.nvidia.com/cudnn)
|
||||||
|
- OpenBLAS/LAPACK: `sudo apt install libopenblas-dev liblapack-dev liblapacke-dev` (Ubuntu/Debian)
|
||||||
> [!IMPORTANT]
|
> [!IMPORTANT]
|
||||||
> Ensure prerequisites are in `PATH` before running CMake.
|
> Ensure prerequisites are in `PATH` before running CMake.
|
||||||
|
|
||||||
@@ -118,6 +125,67 @@ Lastly, run Ollama:
|
|||||||
go run . serve
|
go run . serve
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## MLX Engine (Optional)
|
||||||
|
|
||||||
|
The MLX engine enables running safetensor based models. It requires building the [MLX](https://github.com/ml-explore/mlx) and [MLX-C](https://github.com/ml-explore/mlx-c) shared libraries separately via CMake. On MacOS, MLX leverages the Metal library to run on the GPU, and on Windows and Linux, runs on NVIDIA GPUs via CUDA v13.
|
||||||
|
|
||||||
|
### macOS (Apple Silicon)
|
||||||
|
|
||||||
|
Requires the Metal toolchain. Install [Xcode](https://developer.apple.com/xcode/) first, then:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
xcodebuild -downloadComponent MetalToolchain
|
||||||
|
```
|
||||||
|
|
||||||
|
Verify it's installed correctly (should print "no input files"):
|
||||||
|
|
||||||
|
```shell
|
||||||
|
xcrun metal
|
||||||
|
```
|
||||||
|
|
||||||
|
Then build:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
cmake -B build --preset MLX
|
||||||
|
cmake --build build --preset MLX --parallel
|
||||||
|
cmake --install build --component MLX
|
||||||
|
```
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Without the Metal toolchain, cmake will silently complete with Metal disabled. Check the cmake output for `Setting MLX_BUILD_METAL=OFF` which indicates the toolchain is missing.
|
||||||
|
|
||||||
|
### Windows / Linux (CUDA)
|
||||||
|
|
||||||
|
Requires CUDA 13+ and [cuDNN](https://developer.nvidia.com/cudnn) 9+.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
cmake -B build --preset "MLX CUDA 13"
|
||||||
|
cmake --build build --target mlx --target mlxc --config Release --parallel
|
||||||
|
cmake --install build --component MLX --strip
|
||||||
|
```
|
||||||
|
|
||||||
|
### Local MLX source overrides
|
||||||
|
|
||||||
|
To build against a local checkout of MLX and/or MLX-C (useful for development), set environment variables before running CMake:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
export OLLAMA_MLX_SOURCE=/path/to/mlx
|
||||||
|
export OLLAMA_MLX_C_SOURCE=/path/to/mlx-c
|
||||||
|
```
|
||||||
|
|
||||||
|
For example, using the helper scripts with local mlx and mlx-c repos:
|
||||||
|
```shell
|
||||||
|
OLLAMA_MLX_SOURCE=../mlx OLLAMA_MLX_C_SOURCE=../mlx-c ./scripts/build_linux.sh
|
||||||
|
|
||||||
|
OLLAMA_MLX_SOURCE=../mlx OLLAMA_MLX_C_SOURCE=../mlx-c ./scripts/build_darwin.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
```powershell
|
||||||
|
$env:OLLAMA_MLX_SOURCE="../mlx"
|
||||||
|
$env:OLLAMA_MLX_C_SOURCE="../mlx-c"
|
||||||
|
./scripts/build_darwin.ps1
|
||||||
|
```
|
||||||
|
|
||||||
## Docker
|
## Docker
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
|||||||
31
docs/gpu.mdx
31
docs/gpu.mdx
@@ -61,11 +61,17 @@ Ollama supports the following AMD GPUs via the ROCm library:
|
|||||||
|
|
||||||
### Linux Support
|
### Linux Support
|
||||||
|
|
||||||
|
Ollama requires the AMD ROCm v7 driver on Linux. You can install or upgrade
|
||||||
|
using the `amdgpu-install` utility from
|
||||||
|
[AMD's ROCm documentation](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/).
|
||||||
|
|
||||||
| Family | Cards and accelerators |
|
| Family | Cards and accelerators |
|
||||||
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` |
|
| AMD Radeon RX | `9070 XT` `9070 GRE` `9070` `9060 XT` `9060 XT LP` `9060` `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7700` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `5700 XT` `5700` `5600 XT` `5500 XT` |
|
||||||
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `SSG` |
|
| AMD Radeon AI PRO | `R9700` `R9600D` |
|
||||||
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` |
|
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` |
|
||||||
|
| AMD Ryzen AI | `Ryzen AI Max+ 395` `Ryzen AI Max 390` `Ryzen AI Max 385` `Ryzen AI 9 HX 475` `Ryzen AI 9 HX 470` `Ryzen AI 9 465` `Ryzen AI 9 HX 375` `Ryzen AI 9 HX 370` `Ryzen AI 9 365` |
|
||||||
|
| AMD Instinct | `MI350X` `MI300X` `MI300A` `MI250X` `MI250` `MI210` `MI100` |
|
||||||
|
|
||||||
### Windows Support
|
### Windows Support
|
||||||
|
|
||||||
@@ -97,17 +103,20 @@ This table shows some example GPUs that map to these LLVM targets:
|
|||||||
| **LLVM Target** | **An Example GPU** |
|
| **LLVM Target** | **An Example GPU** |
|
||||||
|-----------------|---------------------|
|
|-----------------|---------------------|
|
||||||
| gfx908 | Radeon Instinct MI100 |
|
| gfx908 | Radeon Instinct MI100 |
|
||||||
| gfx90a | Radeon Instinct MI210 |
|
| gfx90a | Radeon Instinct MI210/MI250 |
|
||||||
| gfx940 | Radeon Instinct MI300 |
|
| gfx942 | Radeon Instinct MI300X/MI300A |
|
||||||
| gfx941 | |
|
| gfx950 | Radeon Instinct MI350X |
|
||||||
| gfx942 | |
|
| gfx1010 | Radeon RX 5700 XT |
|
||||||
|
| gfx1012 | Radeon RX 5500 XT |
|
||||||
| gfx1030 | Radeon PRO V620 |
|
| gfx1030 | Radeon PRO V620 |
|
||||||
| gfx1100 | Radeon PRO W7900 |
|
| gfx1100 | Radeon PRO W7900 |
|
||||||
| gfx1101 | Radeon PRO W7700 |
|
| gfx1101 | Radeon PRO W7700 |
|
||||||
| gfx1102 | Radeon RX 7600 |
|
| gfx1102 | Radeon RX 7600 |
|
||||||
|
| gfx1103 | Radeon 780M |
|
||||||
AMD is working on enhancing ROCm v6 to broaden support for families of GPUs in a
|
| gfx1150 | Ryzen AI 9 HX 375 |
|
||||||
future release which should increase support for more GPUs.
|
| gfx1151 | Ryzen AI Max+ 395 |
|
||||||
|
| gfx1200 | Radeon RX 9070 |
|
||||||
|
| gfx1201 | Radeon RX 9070 XT |
|
||||||
|
|
||||||
Reach out on [Discord](https://discord.gg/ollama) or file an
|
Reach out on [Discord](https://discord.gg/ollama) or file an
|
||||||
[issue](https://github.com/ollama/ollama/issues) for additional help.
|
[issue](https://github.com/ollama/ollama/issues) for additional help.
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ title: Claude Code
|
|||||||
|
|
||||||
Claude Code is Anthropic's agentic coding tool that can read, modify, and execute code in your working directory.
|
Claude Code is Anthropic's agentic coding tool that can read, modify, and execute code in your working directory.
|
||||||
|
|
||||||
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `glm-4.7`, `qwen3-coder`, `gpt-oss`.
|
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `qwen3.5`, `glm-5:cloud`, `kimi-k2.5:cloud`.
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
@@ -32,13 +32,57 @@ irm https://claude.ai/install.ps1 | iex
|
|||||||
ollama launch claude
|
ollama launch claude
|
||||||
```
|
```
|
||||||
|
|
||||||
To configure without launching:
|
### Run directly with a model
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
ollama launch claude --config
|
ollama launch claude --model kimi-k2.5:cloud
|
||||||
```
|
```
|
||||||
|
|
||||||
### Manual setup
|
## Recommended Models
|
||||||
|
|
||||||
|
- `kimi-k2.5:cloud`
|
||||||
|
- `glm-5:cloud`
|
||||||
|
- `minimax-m2.5:cloud`
|
||||||
|
- `qwen3.5:cloud`
|
||||||
|
- `glm-4.7-flash`
|
||||||
|
- `qwen3.5`
|
||||||
|
|
||||||
|
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
||||||
|
|
||||||
|
## Scheduled Tasks with `/loop`
|
||||||
|
|
||||||
|
The `/loop` command runs a prompt or slash command on a recurring schedule inside Claude Code. This is useful for automating repetitive tasks like checking PRs, running research, or setting reminders.
|
||||||
|
|
||||||
|
```
|
||||||
|
/loop <interval> <prompt or /command>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
|
**Check in on your PRs**
|
||||||
|
|
||||||
|
```
|
||||||
|
/loop 30m Check my open PRs and summarize their status
|
||||||
|
```
|
||||||
|
|
||||||
|
**Automate research tasks**
|
||||||
|
|
||||||
|
```
|
||||||
|
/loop 1h Research the latest AI news and summarize key developments
|
||||||
|
```
|
||||||
|
|
||||||
|
**Automate bug reporting and triaging**
|
||||||
|
|
||||||
|
```
|
||||||
|
/loop 15m Check for new GitHub issues and triage by priority
|
||||||
|
```
|
||||||
|
|
||||||
|
**Set reminders**
|
||||||
|
|
||||||
|
```
|
||||||
|
/loop 1h Remind me to review the deploy status
|
||||||
|
```
|
||||||
|
|
||||||
|
## Manual setup
|
||||||
|
|
||||||
Claude Code connects to Ollama using the Anthropic-compatible API.
|
Claude Code connects to Ollama using the Anthropic-compatible API.
|
||||||
|
|
||||||
@@ -53,23 +97,14 @@ export ANTHROPIC_BASE_URL=http://localhost:11434
|
|||||||
2. Run Claude Code with an Ollama model:
|
2. Run Claude Code with an Ollama model:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
claude --model gpt-oss:20b
|
claude --model qwen3.5
|
||||||
```
|
```
|
||||||
|
|
||||||
Or run with environment variables inline:
|
Or run with environment variables inline:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder
|
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model glm-5:cloud
|
||||||
```
|
```
|
||||||
|
|
||||||
**Note:** Claude Code requires a large context window. We recommend at least 64k tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
|
**Note:** Claude Code requires a large context window. We recommend at least 64k tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
|
||||||
|
|
||||||
## Recommended Models
|
|
||||||
|
|
||||||
- `qwen3-coder`
|
|
||||||
- `glm-4.7`
|
|
||||||
- `gpt-oss:20b`
|
|
||||||
- `gpt-oss:120b`
|
|
||||||
|
|
||||||
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,47 +4,65 @@ title: OpenClaw
|
|||||||
|
|
||||||
OpenClaw is a personal AI assistant that runs on your own devices. It bridges messaging services (WhatsApp, Telegram, Slack, Discord, iMessage, and more) to AI coding agents through a centralized gateway.
|
OpenClaw is a personal AI assistant that runs on your own devices. It bridges messaging services (WhatsApp, Telegram, Slack, Discord, iMessage, and more) to AI coding agents through a centralized gateway.
|
||||||
|
|
||||||
## Install
|
## Quick start
|
||||||
|
|
||||||
Install [OpenClaw](https://openclaw.ai/)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
npm install -g openclaw@latest
|
|
||||||
```
|
|
||||||
|
|
||||||
Then run the onboarding wizard:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
openclaw onboard --install-daemon
|
|
||||||
```
|
|
||||||
|
|
||||||
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
|
|
||||||
|
|
||||||
## Usage with Ollama
|
|
||||||
|
|
||||||
### Quick setup
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
ollama launch openclaw
|
ollama launch openclaw
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Ollama handles everything automatically:
|
||||||
|
|
||||||
|
1. **Install** — If OpenClaw isn't installed, Ollama prompts to install it via npm
|
||||||
|
2. **Security** — On the first launch, a security notice explains the risks of tool access
|
||||||
|
3. **Model** — Pick a model from the selector (local or cloud)
|
||||||
|
4. **Onboarding** — Ollama configures the provider, installs the gateway daemon, and sets your model as the primary
|
||||||
|
5. **Gateway** — Starts in the background and opens the OpenClaw TUI
|
||||||
|
|
||||||
|
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens if using local models. See [Context length](/context-length) for more information.</Note>
|
||||||
|
|
||||||
<Note>Previously known as Clawdbot. `ollama launch clawdbot` still works as an alias.</Note>
|
<Note>Previously known as Clawdbot. `ollama launch clawdbot` still works as an alias.</Note>
|
||||||
|
|
||||||
This configures OpenClaw to use Ollama and starts the gateway.
|
## Configure without launching
|
||||||
If the gateway is already running, no changes need to be made as the gateway will auto-reload the changes.
|
|
||||||
|
|
||||||
|
To change the model without starting the gateway and TUI:
|
||||||
|
|
||||||
To configure without launching:
|
```bash
|
||||||
|
|
||||||
```shell
|
|
||||||
ollama launch openclaw --config
|
ollama launch openclaw --config
|
||||||
```
|
```
|
||||||
|
|
||||||
## Recommended Models
|
To use a specific model directly:
|
||||||
|
|
||||||
- `qwen3-coder`
|
```bash
|
||||||
- `glm-4.7`
|
ollama launch openclaw --model kimi-k2.5:cloud
|
||||||
- `gpt-oss:20b`
|
```
|
||||||
- `gpt-oss:120b`
|
|
||||||
|
If the gateway is already running, it restarts automatically to pick up the new model.
|
||||||
|
|
||||||
|
## Recommended models
|
||||||
|
|
||||||
|
**Cloud models**:
|
||||||
|
|
||||||
|
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
|
||||||
|
- `minimax-m2.5:cloud` — Fast, efficient coding and real-world productivity
|
||||||
|
- `glm-5:cloud` — Reasoning and code generation
|
||||||
|
|
||||||
|
**Local models:**
|
||||||
|
|
||||||
|
- `glm-4.7-flash` — Reasoning and code generation locally (~25 GB VRAM)
|
||||||
|
|
||||||
|
More models at [ollama.com/search](https://ollama.com/search?c=cloud).
|
||||||
|
|
||||||
|
## Connect messaging apps
|
||||||
|
|
||||||
|
```bash
|
||||||
|
openclaw configure --section channels
|
||||||
|
```
|
||||||
|
|
||||||
|
Link WhatsApp, Telegram, Slack, Discord, or iMessage to chat with your local models from anywhere.
|
||||||
|
|
||||||
|
## Stopping the gateway
|
||||||
|
|
||||||
|
```bash
|
||||||
|
openclaw gateway stop
|
||||||
|
```
|
||||||
|
|
||||||
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user