From 10e51c51771ea8536715876ee6707928712be41e Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Mon, 9 Mar 2026 17:24:45 -0700 Subject: [PATCH] MLX: add header vendoring and remove go build tag (#14642) * prefer rocm v6 on windows Avoid building with v7 - more changes are needed * MLX: add header vendoring and remove go build tag This switches to using a vendoring approach for the mlx-c headers so that Go can build without requiring a cmake first. This enables building the new MLX based code by default. Every time cmake runs, the headers are refreshed, so we can easily keep them in sync when we bump mlx versions. Basic Windows and Linux support are verified. * ci: harden for flaky choco repo servers CI sometimes fails due to choco not actually installing cache. Since it just speeds up the build, we can proceed without. * review comments --- .github/workflows/release.yaml | 52 +- .github/workflows/test.yaml | 64 +- CMakeLists.txt | 106 +- CMakePresets.json | 1 + Dockerfile | 23 +- MLX_CORE_VERSION | 1 + docs/development.md | 68 + parser/parser.go | 3 + scripts/build_darwin.sh | 6 +- scripts/build_windows.ps1 | 205 ++- scripts/env.sh | 8 + x/create/client/quantize.go | 7 +- x/create/client/quantize_stub.go | 25 - x/imagegen/cache/cache.go | 2 - x/imagegen/cache/step.go | 2 - x/imagegen/cache/teacache.go | 2 - x/imagegen/cmd/engine/README.md | 2 +- x/imagegen/cmd/engine/generate.go | 2 - x/imagegen/cmd/engine/image.go | 2 - x/imagegen/cmd/engine/main.go | 2 - x/imagegen/cmd/engine/sample.go | 2 - x/imagegen/image.go | 2 - x/imagegen/image_processor.go | 2 - x/imagegen/imagegen.go | 2 - x/imagegen/manifest/weights.go | 2 - x/imagegen/mlx/CMakeLists.txt | 68 +- x/imagegen/mlx/compile.go | 2 - x/imagegen/mlx/doc.go | 4 +- x/imagegen/mlx/generate_wrappers.go | 15 +- x/imagegen/mlx/mlx.c | 1182 ++++++++-------- x/imagegen/mlx/mlx.go | 125 +- x/imagegen/mlx/mlx_dynamic.c | 139 +- x/imagegen/mlx/mlx_dynamic.h | 10 +- x/imagegen/mlx/mlx_test.go | 2 - x/imagegen/models/flux2/flux2.go | 2 - x/imagegen/models/flux2/rope.go | 2 - x/imagegen/models/flux2/scheduler.go | 2 - x/imagegen/models/flux2/transformer.go | 2 - x/imagegen/models/flux2/vae.go | 2 - x/imagegen/models/qwen3/text_encoder.go | 2 - x/imagegen/models/zimage/scheduler.go | 2 - x/imagegen/models/zimage/text_encoder.go | 2 - x/imagegen/models/zimage/transformer.go | 2 - x/imagegen/models/zimage/vae.go | 2 - x/imagegen/models/zimage/zimage.go | 2 - x/imagegen/nn/nn.go | 2 - x/imagegen/nn/nn_test.go | 2 - x/imagegen/runner.go | 2 - x/imagegen/runner_stub.go | 10 - x/imagegen/safetensors/loader.go | 2 - x/imagegen/safetensors/safetensors.go | 2 - x/imagegen/safetensors/safetensors_test.go | 2 - x/imagegen/tokenizer/tokenizer.go | 2 - x/imagegen/tokenizer/tokenizer_test.go | 2 - x/imagegen/vae/tiling.go | 2 - x/mlxrunner/cache.go | 2 - x/mlxrunner/cache/cache.go | 2 - x/mlxrunner/cache/recurrent.go | 2 - x/mlxrunner/client.go | 27 +- x/mlxrunner/imports.go | 2 - x/mlxrunner/mlx/CMakeLists.txt | 4 + x/mlxrunner/mlx/act.go | 2 - x/mlxrunner/mlx/array.go | 2 - x/mlxrunner/mlx/array_test.go | 11 +- x/mlxrunner/mlx/dtype.go | 2 - x/mlxrunner/mlx/dynamic.go | 40 +- x/mlxrunner/mlx/dynamic.h | 14 +- x/mlxrunner/mlx/fast.go | 2 - x/mlxrunner/mlx/gated_delta.go | 2 - x/mlxrunner/mlx/generated.c | 8 +- x/mlxrunner/mlx/generator/generated.c.gotmpl | 2 +- x/mlxrunner/mlx/generator/main.go | 14 +- x/mlxrunner/mlx/include/mlx/c/README.md | 12 + x/mlxrunner/mlx/include/mlx/c/array.h | 420 ++++++ x/mlxrunner/mlx/include/mlx/c/closure.h | 197 +++ x/mlxrunner/mlx/include/mlx/c/compile.h | 57 + x/mlxrunner/mlx/include/mlx/c/cuda.h | 39 + x/mlxrunner/mlx/include/mlx/c/device.h | 154 ++ x/mlxrunner/mlx/include/mlx/c/distributed.h | 83 ++ .../mlx/include/mlx/c/distributed_group.h | 58 + x/mlxrunner/mlx/include/mlx/c/error.h | 41 + x/mlxrunner/mlx/include/mlx/c/export.h | 75 + x/mlxrunner/mlx/include/mlx/c/fast.h | 206 +++ x/mlxrunner/mlx/include/mlx/c/fft.h | 138 ++ x/mlxrunner/mlx/include/mlx/c/half.h | 26 + x/mlxrunner/mlx/include/mlx/c/io.h | 63 + x/mlxrunner/mlx/include/mlx/c/io_types.h | 104 ++ x/mlxrunner/mlx/include/mlx/c/linalg.h | 128 ++ x/mlxrunner/mlx/include/mlx/c/map.h | 149 ++ x/mlxrunner/mlx/include/mlx/c/memory.h | 47 + x/mlxrunner/mlx/include/mlx/c/metal.h | 41 + x/mlxrunner/mlx/include/mlx/c/mlx.h | 34 + x/mlxrunner/mlx/include/mlx/c/ops.h | 1235 +++++++++++++++++ x/mlxrunner/mlx/include/mlx/c/optional.h | 51 + x/mlxrunner/mlx/include/mlx/c/random.h | 166 +++ x/mlxrunner/mlx/include/mlx/c/stream.h | 88 ++ x/mlxrunner/mlx/include/mlx/c/string.h | 55 + x/mlxrunner/mlx/include/mlx/c/transforms.h | 68 + .../mlx/include/mlx/c/transforms_impl.h | 54 + x/mlxrunner/mlx/include/mlx/c/vector.h | 133 ++ x/mlxrunner/mlx/include/mlx/c/version.h | 18 + x/mlxrunner/mlx/io.go | 16 +- x/mlxrunner/mlx/memory.go | 2 - x/mlxrunner/mlx/mlx.go | 19 +- x/mlxrunner/mlx/nn.go | 2 - x/mlxrunner/mlx/ops.go | 2 - x/mlxrunner/mlx/ops_extra.go | 2 - x/mlxrunner/mlx/random.go | 2 - x/mlxrunner/mlx/slice.go | 2 - x/mlxrunner/mlx/stream.go | 18 +- x/mlxrunner/model/base/base.go | 2 - x/mlxrunner/model/base/base_stub.go | 3 - x/mlxrunner/model/linear.go | 2 - x/mlxrunner/model/quant.go | 2 - x/mlxrunner/model/root.go | 2 - x/mlxrunner/model/root_stub.go | 3 - x/mlxrunner/pipeline.go | 2 - x/mlxrunner/runner.go | 2 - x/mlxrunner/sample/sample.go | 2 - x/mlxrunner/server.go | 9 +- x/mlxrunner/server_stub.go | 10 - x/models/gemma3/gemma3.go | 2 - x/models/glm4_moe_lite/glm4_moe_lite.go | 2 - x/models/glm4_moe_lite/parser.go | 2 - x/models/glm4_moe_lite/parser_test.go | 2 - x/models/glm4_moe_lite/render.go | 2 - x/models/glm4_moe_lite/render_test.go | 2 - x/models/llama/llama.go | 2 - x/models/nn/nn.go | 2 - x/models/qwen3/qwen3.go | 2 - x/models/qwen3_5/qwen3_5.go | 2 - x/models/qwen3_5/qwen3_5_test.go | 2 - x/models/qwen3_5_moe/qwen3_5_moe.go | 2 - x/tokenizer/tokenizer.go | 2 - x/tokenizer/tokenizer_benchmark_test.go | 2 - x/tokenizer/tokenizer_bpe.go | 2 - x/tokenizer/tokenizer_correctness_test.go | 2 - x/tokenizer/tokenizer_decode.go | 2 - x/tokenizer/tokenizer_encode.go | 2 - x/tokenizer/tokenizer_ggml_parity_test.go | 2 - x/tokenizer/tokenizer_load.go | 2 - x/tokenizer/tokenizer_load_test.go | 2 - 142 files changed, 5350 insertions(+), 1064 deletions(-) create mode 100644 MLX_CORE_VERSION delete mode 100644 x/create/client/quantize_stub.go delete mode 100644 x/imagegen/runner_stub.go create mode 100644 x/mlxrunner/mlx/include/mlx/c/README.md create mode 100644 x/mlxrunner/mlx/include/mlx/c/array.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/closure.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/compile.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/cuda.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/device.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/distributed.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/distributed_group.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/error.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/export.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/fast.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/fft.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/half.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/io.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/io_types.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/linalg.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/map.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/memory.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/metal.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/mlx.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/ops.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/optional.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/random.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/stream.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/string.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/transforms.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/transforms_impl.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/vector.h create mode 100644 x/mlxrunner/mlx/include/mlx/c/version.h delete mode 100644 x/mlxrunner/model/base/base_stub.go delete mode 100644 x/mlxrunner/model/root_stub.go delete mode 100644 x/mlxrunner/server_stub.go diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 6d9596807..a51ca1b9c 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -117,6 +117,25 @@ jobs: install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe flags: '' runner_dir: 'vulkan' + - os: windows + arch: amd64 + preset: 'MLX CUDA 13' + install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe + cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip + cuda-components: + - '"cudart"' + - '"nvcc"' + - '"cublas"' + - '"cublas_dev"' + - '"cufft"' + - '"cufft_dev"' + - '"nvrtc"' + - '"nvrtc_dev"' + - '"crt"' + - '"nvvm"' + - '"nvptxcompiler"' + cuda-version: '13.0' + flags: '' runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }} environment: release env: @@ -125,8 +144,10 @@ jobs: - name: Install system dependencies run: | choco install -y --no-progress ccache ninja - ccache -o cache_dir=${{ github.workspace }}\.ccache - - if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan') + if (Get-Command ccache -ErrorAction SilentlyContinue) { + ccache -o cache_dir=${{ github.workspace }}\.ccache + } + - if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan') || startsWith(matrix.preset, 'MLX ') id: cache-install uses: actions/cache/restore@v4 with: @@ -134,8 +155,9 @@ jobs: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA C:\Program Files\AMD\ROCm C:\VulkanSDK - key: ${{ matrix.install }} - - if: startsWith(matrix.preset, 'CUDA ') + C:\Program Files\NVIDIA\CUDNN + key: ${{ matrix.install }}-${{ matrix.cudnn-install }} + - if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'MLX ') name: Install CUDA ${{ matrix.cuda-version }} run: | $ErrorActionPreference = "Stop" @@ -179,6 +201,23 @@ jobs: run: | echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append + - if: startsWith(matrix.preset, 'MLX ') + name: Install cuDNN for MLX + run: | + $ErrorActionPreference = "Stop" + $cudnnRoot = "C:\Program Files\NVIDIA\CUDNN" + if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') { + Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip" + Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted + $cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName + New-Item -ItemType Directory -Force -Path $cudnnRoot + Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse + } + + echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append + echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append + echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append + echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }} uses: actions/cache/save@v4 with: @@ -186,7 +225,8 @@ jobs: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA C:\Program Files\AMD\ROCm C:\VulkanSDK - key: ${{ matrix.install }} + C:\Program Files\NVIDIA\CUDNN + key: ${{ matrix.install }}-${{ matrix.cudnn-install }} - uses: actions/checkout@v4 - uses: actions/cache@v4 with: @@ -198,7 +238,7 @@ jobs: Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo' cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}" cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}" - cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip + cmake --install build --component "${{ startsWith(matrix.preset, 'MLX ') && 'MLX' || startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue env: CMAKE_GENERATOR: Ninja diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index a47156516..cf0545b56 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -37,7 +37,7 @@ jobs: | xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))" } - echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT + echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*' '.github/**/*') | tee -a $GITHUB_OUTPUT echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT linux: @@ -60,6 +60,10 @@ jobs: mesa-vulkan-drivers vulkan-tools libvulkan1 libvulkan-dev vulkan-sdk cmake ccache g++ make + - preset: 'MLX CUDA 13' + container: nvidia/cuda:13.0.0-devel-ubuntu22.04 + extra-packages: libcudnn9-dev-cuda-13 libopenblas-dev liblapack-dev liblapacke-dev git curl + flags: '-DCMAKE_CUDA_ARCHITECTURES=87 -DBLAS_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu -DLAPACK_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu' runs-on: linux container: ${{ matrix.container }} steps: @@ -76,6 +80,10 @@ jobs: $sudo apt-get update fi $sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }} + # MLX requires CMake 3.25+, install from official releases + if [ "${{ matrix.preset }}" = "MLX CUDA 13" ]; then + curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.31.2/cmake-3.31.2-linux-$(uname -m).tar.gz | $sudo tar xz -C /usr/local --strip-components 1 + fi # Export VULKAN_SDK if provided by LunarG package (defensive) if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then echo "VULKAN_SDK=/usr" >> $GITHUB_ENV @@ -87,8 +95,8 @@ jobs: path: /github/home/.cache/ccache key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }} - run: | - cmake --preset ${{ matrix.preset }} ${{ matrix.flags }} - cmake --build --preset ${{ matrix.preset }} --parallel + cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} + cmake --build --preset "${{ matrix.preset }}" --parallel windows: needs: [changes] @@ -114,12 +122,31 @@ jobs: flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' - preset: Vulkan install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe + - preset: 'MLX CUDA 13' + install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe + cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip + flags: '-DCMAKE_CUDA_ARCHITECTURES=80' + cuda-components: + - '"cudart"' + - '"nvcc"' + - '"cublas"' + - '"cublas_dev"' + - '"cufft"' + - '"cufft_dev"' + - '"nvrtc"' + - '"nvrtc_dev"' + - '"crt"' + - '"nvvm"' + - '"nvptxcompiler"' + cuda-version: '13.0' runs-on: windows steps: - run: | choco install -y --no-progress ccache ninja - ccache -o cache_dir=${{ github.workspace }}\.ccache - - if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan' + if (Get-Command ccache -ErrorAction SilentlyContinue) { + ccache -o cache_dir=${{ github.workspace }}\.ccache + } + - if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan' || matrix.preset == 'MLX CUDA 13' id: cache-install uses: actions/cache/restore@v4 with: @@ -127,8 +154,9 @@ jobs: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA C:\Program Files\AMD\ROCm C:\VulkanSDK - key: ${{ matrix.install }} - - if: matrix.preset == 'CUDA' + C:\Program Files\NVIDIA\CUDNN + key: ${{ matrix.install }}-${{ matrix.cudnn-install }} + - if: matrix.preset == 'CUDA' || matrix.preset == 'MLX CUDA 13' name: Install CUDA ${{ matrix.cuda-version }} run: | $ErrorActionPreference = "Stop" @@ -164,10 +192,27 @@ jobs: Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe" Start-Process -FilePath .\install.exe -ArgumentList "-c","--am","--al","in" -NoNewWindow -Wait } - + $vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV + - if: matrix.preset == 'MLX CUDA 13' + name: Install cuDNN for MLX + run: | + $ErrorActionPreference = "Stop" + $cudnnRoot = "C:\Program Files\NVIDIA\CUDNN" + if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') { + Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip" + Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted + $cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName + New-Item -ItemType Directory -Force -Path $cudnnRoot + Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse + } + + echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append + echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append + echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append + echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }} uses: actions/cache/save@v4 with: @@ -175,7 +220,8 @@ jobs: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA C:\Program Files\AMD\ROCm C:\VulkanSDK - key: ${{ matrix.install }} + C:\Program Files\NVIDIA\CUDNN + key: ${{ matrix.install }}-${{ matrix.cudnn-install }} - uses: actions/checkout@v4 - uses: actions/cache@v4 with: diff --git a/CMakeLists.txt b/CMakeLists.txt index a9e4471d8..8c37d3374 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -64,10 +64,15 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR}) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${OLLAMA_BUILD_DIR}) -include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src) -include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include) -include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu) -include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx) +# Store ggml include paths for use with target_include_directories later. +# We avoid global include_directories() to prevent polluting the include path +# for other projects like MLX (whose openblas dependency has its own common.h). +set(GGML_INCLUDE_DIRS + ${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src + ${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include + ${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu + ${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx +) add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0) @@ -87,6 +92,14 @@ if(NOT CPU_VARIANTS) set(CPU_VARIANTS "ggml-cpu") endif() +# Apply ggml include directories to ggml targets only (not globally) +target_include_directories(ggml-base PRIVATE ${GGML_INCLUDE_DIRS}) +foreach(variant ${CPU_VARIANTS}) + if(TARGET ${variant}) + target_include_directories(${variant} PRIVATE ${GGML_INCLUDE_DIRS}) + endif() +endforeach() + install(TARGETS ggml-base ${CPU_VARIANTS} RUNTIME_DEPENDENCIES PRE_EXCLUDE_REGEXES ".*" @@ -103,6 +116,7 @@ if(CMAKE_CUDA_COMPILER) find_package(CUDAToolkit) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda) + target_include_directories(ggml-cuda PRIVATE ${GGML_INCLUDE_DIRS}) install(TARGETS ggml-cuda RUNTIME_DEPENDENCIES DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR} @@ -134,6 +148,7 @@ if(CMAKE_HIP_COMPILER) if(AMDGPU_TARGETS) find_package(hip REQUIRED) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip) + target_include_directories(ggml-hip PRIVATE ${GGML_INCLUDE_DIRS}) if (WIN32) target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY) @@ -168,6 +183,7 @@ if(NOT APPLE) find_package(Vulkan) if(Vulkan_FOUND) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan) + target_include_directories(ggml-vulkan PRIVATE ${GGML_INCLUDE_DIRS}) install(TARGETS ggml-vulkan RUNTIME_DEPENDENCIES PRE_INCLUDE_REGEXES vulkan @@ -179,7 +195,6 @@ if(NOT APPLE) endif() option(MLX_ENGINE "Enable MLX backend" OFF) - if(MLX_ENGINE) message(STATUS "Setting up MLX (this takes a while...)") add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx) @@ -187,10 +202,36 @@ if(MLX_ENGINE) # Find CUDA toolkit if MLX is built with CUDA support find_package(CUDAToolkit) + # Build list of directories for runtime dependency resolution + set(MLX_RUNTIME_DIRS ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}) + # Add cuDNN bin paths for DLLs (Windows MLX CUDA builds) + # CUDNN_ROOT_DIR is the standard CMake variable for cuDNN location + if(DEFINED ENV{CUDNN_ROOT_DIR}) + # cuDNN 9.x has versioned subdirectories under bin/ (e.g., bin/13.0/) + file(GLOB CUDNN_BIN_SUBDIRS "$ENV{CUDNN_ROOT_DIR}/bin/*") + list(APPEND MLX_RUNTIME_DIRS ${CUDNN_BIN_SUBDIRS}) + endif() + # Add build output directory and MLX dependency build directories + list(APPEND MLX_RUNTIME_DIRS ${OLLAMA_BUILD_DIR}) + # OpenBLAS DLL location (pre-built zip extracts into openblas-src/bin/) + list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/openblas-src/bin) + # NCCL: on Linux, if real NCCL is found, cmake bundles libnccl.so via the + # regex below. If NCCL is not found, MLX links a static stub (OBJECT lib) + # so there is no runtime dependency. This path covers the stub build dir + # for windows so we include the DLL in our dependencies. + list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/distributed/nccl/nccl_stub-prefix/src/nccl_stub-build/Release) + + # Base regexes for runtime dependencies (cross-platform) + set(MLX_INCLUDE_REGEXES cublas cublasLt cudart cufft nvrtc nvrtc-builtins cudnn nccl openblas gfortran) + # On Windows, also include dl.dll (dlfcn-win32 POSIX emulation layer) + if(WIN32) + list(APPEND MLX_INCLUDE_REGEXES "^dl\\.dll$") + endif() + install(TARGETS mlx mlxc RUNTIME_DEPENDENCIES - DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR} - PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran + DIRECTORIES ${MLX_RUNTIME_DIRS} + PRE_INCLUDE_REGEXES ${MLX_INCLUDE_REGEXES} PRE_EXCLUDE_REGEXES ".*" RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX @@ -205,13 +246,54 @@ if(MLX_ENGINE) COMPONENT MLX) endif() - # Manually install cudart and cublas since they might not be picked up as direct dependencies + # Install CCCL headers for NVRTC JIT compilation at runtime. + # MLX's own install rules use the default component so they get skipped by + # --component MLX. Headers are installed alongside libmlx in OLLAMA_INSTALL_DIR. + # On Linux, MLX's jit_module.cpp resolves CCCL via + # current_binary_dir().parent_path() / "include" / "cccl", so we create a + # symlink from lib/ollama/include -> ${OLLAMA_RUNNER_DIR}/include + # This will need refinement if we add multiple CUDA versions for MLX in the future. + if(EXISTS ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda) + install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda + DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl + COMPONENT MLX) + install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/nv + DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl + COMPONENT MLX) + if(NOT WIN32 AND NOT APPLE) + install(CODE " + set(_link \"${CMAKE_INSTALL_PREFIX}/lib/ollama/include\") + set(_target \"${OLLAMA_RUNNER_DIR}/include\") + if(NOT EXISTS \${_link}) + execute_process(COMMAND \${CMAKE_COMMAND} -E create_symlink \${_target} \${_link}) + endif() + " COMPONENT MLX) + endif() + endif() + + # On Windows, explicitly install dl.dll (dlfcn-win32 POSIX dlopen emulation) + # RUNTIME_DEPENDENCIES auto-excludes it via POST_EXCLUDE_FILES_STRICT because + # dlfcn-win32 is a known CMake target with its own install rules (which install + # to the wrong destination). We must install it explicitly here. + if(WIN32) + install(FILES ${OLLAMA_BUILD_DIR}/dl.dll + DESTINATION ${OLLAMA_INSTALL_DIR} + COMPONENT MLX) + endif() + + # Manually install CUDA runtime libraries that MLX loads via dlopen + # (not detected by RUNTIME_DEPENDENCIES since they aren't link-time deps) if(CUDAToolkit_FOUND) - file(GLOB CUDART_LIBS + file(GLOB MLX_CUDA_LIBS "${CUDAToolkit_LIBRARY_DIR}/libcudart.so*" - "${CUDAToolkit_LIBRARY_DIR}/libcublas.so*") - if(CUDART_LIBS) - install(FILES ${CUDART_LIBS} + "${CUDAToolkit_LIBRARY_DIR}/libcublas.so*" + "${CUDAToolkit_LIBRARY_DIR}/libcublasLt.so*" + "${CUDAToolkit_LIBRARY_DIR}/libnvrtc.so*" + "${CUDAToolkit_LIBRARY_DIR}/libnvrtc-builtins.so*" + "${CUDAToolkit_LIBRARY_DIR}/libcufft.so*" + "${CUDAToolkit_LIBRARY_DIR}/libcudnn.so*") + if(MLX_CUDA_LIBS) + install(FILES ${MLX_CUDA_LIBS} DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX) endif() diff --git a/CMakePresets.json b/CMakePresets.json index 0d643038a..d099d3f16 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -112,6 +112,7 @@ "name": "MLX CUDA 13", "inherits": [ "MLX", "CUDA 13" ], "cacheVariables": { + "MLX_CUDA_ARCHITECTURES": "86;89;90;90a;100;103;75-virtual;80-virtual;110-virtual;120-virtual;121-virtual", "OLLAMA_RUNNER_DIR": "mlx_cuda_v13" } } diff --git a/Dockerfile b/Dockerfile index cabb6cc82..0743cc45d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -9,6 +9,11 @@ ARG CMAKEVERSION=3.31.2 ARG NINJAVERSION=1.12.1 ARG VULKANVERSION=1.4.321.1 +# Default empty stages for local MLX source overrides. +# Override with: docker build --build-context local-mlx=../mlx --build-context local-mlx-c=../mlx-c +FROM scratch AS local-mlx +FROM scratch AS local-mlx-c + FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64 RUN dnf install -y yum-utils ccache gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ gcc-toolset-11-binutils \ && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo @@ -152,12 +157,20 @@ COPY CMakeLists.txt CMakePresets.json . COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY x/imagegen/mlx x/imagegen/mlx COPY go.mod go.sum . -COPY MLX_VERSION . +COPY MLX_VERSION MLX_CORE_VERSION . RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local ENV PATH=/usr/local/go/bin:$PATH RUN go mod download RUN --mount=type=cache,target=/root/.ccache \ - cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \ + --mount=type=bind,from=local-mlx,target=/tmp/local-mlx \ + --mount=type=bind,from=local-mlx-c,target=/tmp/local-mlx-c \ + if [ -f /tmp/local-mlx/CMakeLists.txt ]; then \ + export OLLAMA_MLX_SOURCE=/tmp/local-mlx; \ + fi \ + && if [ -f /tmp/local-mlx-c/CMakeLists.txt ]; then \ + export OLLAMA_MLX_C_SOURCE=/tmp/local-mlx-c; \ + fi \ + && cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \ && cmake --build --preset 'MLX CUDA 13' -- -l $(nproc) \ && cmake --install build --component MLX --strip @@ -168,16 +181,14 @@ RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux- ENV PATH=/usr/local/go/bin:$PATH RUN go mod download COPY . . -# Clone mlx-c headers for CGO (version from MLX_VERSION file) -RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src ARG GOFLAGS="'-ldflags=-w -s'" ENV CGO_ENABLED=1 ARG CGO_CFLAGS ARG CGO_CXXFLAGS -ENV CGO_CFLAGS="${CGO_CFLAGS} -I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src" +ENV CGO_CFLAGS="${CGO_CFLAGS}" ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}" RUN --mount=type=cache,target=/root/.cache/go-build \ - go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama . + go build -trimpath -buildmode=pie -o /bin/ollama . FROM --platform=linux/amd64 scratch AS amd64 # COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/ diff --git a/MLX_CORE_VERSION b/MLX_CORE_VERSION new file mode 100644 index 000000000..912750052 --- /dev/null +++ b/MLX_CORE_VERSION @@ -0,0 +1 @@ +v0.30.6 diff --git a/docs/development.md b/docs/development.md index d0120a191..12c69c204 100644 --- a/docs/development.md +++ b/docs/development.md @@ -51,6 +51,9 @@ Install prerequisites: - [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network) - (Optional) VULKAN GPU support - [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs +- (Optional) MLX engine support + - [CUDA 13+ SDK](https://developer.nvidia.com/cuda-downloads) + - [cuDNN 9+](https://developer.nvidia.com/cudnn) Then, configure and build the project: @@ -101,6 +104,10 @@ Install prerequisites: - (Optional) VULKAN GPU support - [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs - Or install via package manager: `sudo apt install vulkan-sdk` (Ubuntu/Debian) or `sudo dnf install vulkan-sdk` (Fedora/CentOS) +- (Optional) MLX engine support + - [CUDA 13+ SDK](https://developer.nvidia.com/cuda-downloads) + - [cuDNN 9+](https://developer.nvidia.com/cudnn) + - OpenBLAS/LAPACK: `sudo apt install libopenblas-dev liblapack-dev liblapacke-dev` (Ubuntu/Debian) > [!IMPORTANT] > Ensure prerequisites are in `PATH` before running CMake. @@ -118,6 +125,67 @@ Lastly, run Ollama: go run . serve ``` +## MLX Engine (Optional) + +The MLX engine enables running safetensor based models. It requires building the [MLX](https://github.com/ml-explore/mlx) and [MLX-C](https://github.com/ml-explore/mlx-c) shared libraries separately via CMake. On MacOS, MLX leverages the Metal library to run on the GPU, and on Windows and Linux, runs on NVIDIA GPUs via CUDA v13. + +### macOS (Apple Silicon) + +Requires the Metal toolchain. Install [Xcode](https://developer.apple.com/xcode/) first, then: + +```shell +xcodebuild -downloadComponent MetalToolchain +``` + +Verify it's installed correctly (should print "no input files"): + +```shell +xcrun metal +``` + +Then build: + +```shell +cmake -B build --preset MLX +cmake --build build --preset MLX --parallel +cmake --install build --component MLX +``` + +> [!NOTE] +> Without the Metal toolchain, cmake will silently complete with Metal disabled. Check the cmake output for `Setting MLX_BUILD_METAL=OFF` which indicates the toolchain is missing. + +### Windows / Linux (CUDA) + +Requires CUDA 13+ and [cuDNN](https://developer.nvidia.com/cudnn) 9+. + +```shell +cmake -B build --preset "MLX CUDA 13" +cmake --build build --target mlx --target mlxc --config Release --parallel +cmake --install build --component MLX --strip +``` + +### Local MLX source overrides + +To build against a local checkout of MLX and/or MLX-C (useful for development), set environment variables before running CMake: + +```shell +export OLLAMA_MLX_SOURCE=/path/to/mlx +export OLLAMA_MLX_C_SOURCE=/path/to/mlx-c +``` + +For example, using the helper scripts with local mlx and mlx-c repos: +```shell +OLLAMA_MLX_SOURCE=../mlx OLLAMA_MLX_C_SOURCE=../mlx-c ./scripts/build_linux.sh + +OLLAMA_MLX_SOURCE=../mlx OLLAMA_MLX_C_SOURCE=../mlx-c ./scripts/build_darwin.sh +``` + +```powershell +$env:OLLAMA_MLX_SOURCE="../mlx" +$env:OLLAMA_MLX_C_SOURCE="../mlx-c" +./scripts/build_darwin.ps1 +``` + ## Docker ```shell diff --git a/parser/parser.go b/parser/parser.go index 5ef918bf2..f3b6dcb55 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -181,6 +181,9 @@ func fileDigestMap(path string) (map[string]string, error) { } if !filepath.IsLocal(rel) { + if strings.Contains(rel, ".cache") { + return nil, fmt.Errorf("insecure path: %s\n\nUse --local-dir when downloading model to disable caching", rel) + } return nil, fmt.Errorf("insecure path: %s", rel) } diff --git a/scripts/build_darwin.sh b/scripts/build_darwin.sh index 4325a9787..4bec54cf8 100755 --- a/scripts/build_darwin.sh +++ b/scripts/build_darwin.sh @@ -59,7 +59,7 @@ _build_darwin() { cmake --install $BUILD_DIR --component CPU cmake --install $BUILD_DIR --component MLX # Override CGO flags to point to the amd64 build directory - MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0" + MLX_CGO_CFLAGS="-O3 -mmacosx-version-min=14.0" MLX_CGO_LDFLAGS="-ldl -lc++ -framework Accelerate -mmacosx-version-min=14.0" else BUILD_DIR=build @@ -70,10 +70,10 @@ _build_darwin() { cmake --build --preset MLX --parallel cmake --install $BUILD_DIR --component MLX # Use default CGO flags from mlx.go for arm64 - MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0" + MLX_CGO_CFLAGS="-O3 -mmacosx-version-min=14.0" MLX_CGO_LDFLAGS="-lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0" fi - GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX . + GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -o $INSTALL_PREFIX . # Copy MLX libraries to same directory as executable for dlopen cp $INSTALL_PREFIX/lib/ollama/libmlxc.dylib $INSTALL_PREFIX/ cp $INSTALL_PREFIX/lib/ollama/libmlx.dylib $INSTALL_PREFIX/ diff --git a/scripts/build_windows.ps1 b/scripts/build_windows.ps1 index 21e6f3be0..f162e4f7e 100644 --- a/scripts/build_windows.ps1 +++ b/scripts/build_windows.ps1 @@ -4,7 +4,10 @@ # # gcloud auth application-default login -$ErrorActionPreference = "Stop" +# Use "Continue" so that stderr output from native commands (e.g. CGo warnings) +# is not promoted to a terminating exception by the try/catch block. +# All native commands already check $LASTEXITCODE explicitly. +$ErrorActionPreference = "Continue" mkdir -Force -path .\dist | Out-Null @@ -16,13 +19,13 @@ function checkEnv { if ($null -ne $arch) { $script:ARCH = ($arch.ToString().ToLower()).Replace("x64", "amd64") } else { - write-host "WARNING: old powershell detected, assuming amd64 architecture - set `$env:ARCH to override" + Write-Output "WARNING: old powershell detected, assuming amd64 architecture - set `$env:ARCH to override" $script:ARCH="amd64" } } $script:TARGET_ARCH=$script:ARCH Write-host "Building for ${script:TARGET_ARCH}" - write-host "Locating required tools and paths" + Write-Output "Locating required tools and paths" $script:SRC_DIR=$PWD # Locate CUDA versions @@ -37,16 +40,17 @@ function checkEnv { $script:CUDA_DIRS=($cudaList | sort-object -Descending) } if ($script:CUDA_DIRS.length -gt 0) { - write-host "Available CUDA Versions: $script:CUDA_DIRS" + Write-Output "Available CUDA Versions: $script:CUDA_DIRS" } else { - write-host "No CUDA versions detected" + Write-Output "No CUDA versions detected" } - # Locate ROCm version - if ($null -ne $env:HIP_PATH) { + # Locate ROCm v6 + $rocmDir=(get-item "C:\Program Files\AMD\ROCm\6.*" -ea 'silentlycontinue' | sort-object -Descending | select-object -First 1) + if ($null -ne $rocmDir) { + $script:HIP_PATH=$rocmDir.FullName + } elseif ($null -ne $env:HIP_PATH -and $env:HIP_PATH -match '[/\\]6\.') { $script:HIP_PATH=$env:HIP_PATH - } else { - $script:HIP_PATH=(get-item "C:\Program Files\AMD\ROCm\*\bin\" -ea 'silentlycontinue' | sort-object -Descending) } $inoSetup=(get-item "C:\Program Files*\Inno Setup*\") @@ -78,7 +82,7 @@ function checkEnv { } else { $script:PKG_VERSION="0.0.0" } - write-host "Building Ollama $script:VERSION with package version $script:PKG_VERSION" + Write-Output "Building Ollama $script:VERSION with package version $script:PKG_VERSION" # Note: Windows Kits 10 signtool crashes with GCP's plugin if ($null -eq $env:SIGN_TOOL) { @@ -87,12 +91,32 @@ function checkEnv { ${script:SignTool}=${env:SIGN_TOOL} } if ("${env:KEY_CONTAINER}") { - ${script:OLLAMA_CERT}=$(resolve-path "${script:SRC_DIR}\ollama_inc.crt") - Write-host "Code signing enabled" + if (Test-Path "${script:SRC_DIR}\ollama_inc.crt") { + ${script:OLLAMA_CERT}=$(resolve-path "${script:SRC_DIR}\ollama_inc.crt") + Write-host "Code signing enabled" + } else { + Write-Output "WARNING: KEY_CONTAINER is set but ollama_inc.crt not found at ${script:SRC_DIR}\ollama_inc.crt - code signing disabled" + } } else { - write-host "Code signing disabled - please set KEY_CONTAINERS to sign and copy ollama_inc.crt to the top of the source tree" + Write-Output "Code signing disabled - please set KEY_CONTAINERS to sign and copy ollama_inc.crt to the top of the source tree" } - $script:JOBS=([Environment]::ProcessorCount) + if ($env:OLLAMA_BUILD_PARALLEL) { + $script:JOBS=[int]$env:OLLAMA_BUILD_PARALLEL + } else { + # Use physical core count rather than logical processors (hyperthreads) + # to avoid saturating the system during builds + try { + $cores = (Get-CimInstance Win32_Processor | Measure-Object -Property NumberOfCores -Sum).Sum + } catch { + $cores = 0 + } + if ($cores -gt 0) { + $script:JOBS = $cores + } else { + $script:JOBS = [Environment]::ProcessorCount + } + } + Write-Output "Build parallelism: $script:JOBS (set OLLAMA_BUILD_PARALLEL to override)" } @@ -127,7 +151,7 @@ function cuda11 { } } } - write-host "Building CUDA v$cudaMajorVer backend libraries $cuda" + Write-Output "Building CUDA v$cudaMajorVer backend libraries $cuda" $env:CUDAToolkit_ROOT=$cuda & cmake -B build\cuda_v$cudaMajorVer --preset "CUDA $cudaMajorVer" -T cuda="$cuda" -DCMAKE_CUDA_COMPILER="$cuda\bin\nvcc.exe" -G "Visual Studio 16 2019" --install-prefix "$script:DIST_DIR" if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} @@ -136,12 +160,12 @@ function cuda11 { & cmake --install build\cuda_v$cudaMajorVer --component "CUDA" --strip if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} } else { - write-host "CUDA v$cudaMajorVer not detected, skipping" + Write-Output "CUDA v$cudaMajorVer not detected, skipping" } } else { - write-host "not arch we wanted" + Write-Output "not arch we wanted" } - write-host "done" + Write-Output "done" } function cudaCommon { @@ -159,7 +183,7 @@ function cudaCommon { } } } - write-host "Building CUDA v$cudaMajorVer backend libraries $cuda" + Write-Output "Building CUDA v$cudaMajorVer backend libraries $cuda" $env:CUDAToolkit_ROOT=$cuda & cmake -B build\cuda_v$cudaMajorVer --preset "CUDA $cudaMajorVer" -T cuda="$cuda" --install-prefix "$script:DIST_DIR" if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} @@ -168,7 +192,7 @@ function cudaCommon { & cmake --install build\cuda_v$cudaMajorVer --component "CUDA" --strip if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} } else { - write-host "CUDA v$cudaMajorVer not detected, skipping" + Write-Output "CUDA v$cudaMajorVer not detected, skipping" } } } @@ -181,11 +205,11 @@ function cuda13 { cudaCommon("13") } -function rocm { +function rocm6 { mkdir -Force -path "${script:DIST_DIR}\" | Out-Null if ($script:ARCH -ne "arm64") { if ($script:HIP_PATH) { - write-host "Building ROCm backend libraries $script:HIP_PATH" + Write-Output "Building ROCm backend libraries $script:HIP_PATH" if (-Not (get-command -ErrorAction silent ninja)) { $NINJA_DIR=(gci -path (Get-CimInstance MSFT_VSInstance -Namespace root/cimv2/vs)[0].InstallLocation -r -fi ninja.exe).Directory.FullName $env:PATH="$NINJA_DIR;$env:PATH" @@ -193,9 +217,11 @@ function rocm { $env:HIPCXX="${script:HIP_PATH}\bin\clang++.exe" $env:HIP_PLATFORM="amd" $env:CMAKE_PREFIX_PATH="${script:HIP_PATH}" + # Set CC/CXX via environment instead of -D flags to avoid triggering + # spurious compiler-change reconfigures that reset CMAKE_INSTALL_PREFIX + $env:CC="${script:HIP_PATH}\bin\clang.exe" + $env:CXX="${script:HIP_PATH}\bin\clang++.exe" & cmake -B build\rocm --preset "ROCm 6" -G Ninja ` - -DCMAKE_C_COMPILER=clang ` - -DCMAKE_CXX_COMPILER=clang++ ` -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" ` -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" ` --install-prefix $script:DIST_DIR @@ -203,20 +229,22 @@ function rocm { $env:HIPCXX="" $env:HIP_PLATFORM="" $env:CMAKE_PREFIX_PATH="" + $env:CC="" + $env:CXX="" & cmake --build build\rocm --target ggml-hip --config Release --parallel $script:JOBS if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} & cmake --install build\rocm --component "HIP" --strip if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} Remove-Item -Path $script:DIST_DIR\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue } else { - write-host "ROCm not detected, skipping" + Write-Output "ROCm not detected, skipping" } } } function vulkan { if ($env:VULKAN_SDK) { - write-host "Building Vulkan backend libraries" + Write-Output "Building Vulkan backend libraries" & cmake -B build\vulkan --preset Vulkan --install-prefix $script:DIST_DIR if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} & cmake --build build\vulkan --target ggml-vulkan --config Release --parallel $script:JOBS @@ -224,33 +252,91 @@ function vulkan { & cmake --install build\vulkan --component Vulkan --strip if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} } else { - write-host "Vulkan not detected, skipping" + Write-Output "Vulkan not detected, skipping" + } +} + +function mlxCuda13 { + mkdir -Force -path "${script:DIST_DIR}\" | Out-Null + $cudaMajorVer="13" + if ($script:ARCH -ne "arm64") { + if ("$script:CUDA_DIRS".Contains("v$cudaMajorVer")) { + foreach ($d in $Script:CUDA_DIRS){ + if ($d.FullName.Contains("v$cudaMajorVer")) { + if (test-path -literalpath (join-path -path $d -childpath "nvcc.exe" ) ) { + $cuda=($d.FullName|split-path -parent) + break + } + } + } + + # Check for cuDNN - required for MLX CUDA backend + # Supports two layouts: + # 1. CI/zip extract: CUDNN\include\cudnn.h, lib\x64\, bin\x64\ + # 2. Official installer: CUDNN\v*\include\{cuda-ver}\cudnn.h, lib\{cuda-ver}\x64\, bin\{cuda-ver}\ + if ($env:CUDNN_INCLUDE_PATH -and $env:CUDNN_LIBRARY_PATH) { + Write-Output "Using cuDNN from environment: $env:CUDNN_INCLUDE_PATH" + } elseif (Test-Path "C:\Program Files\NVIDIA\CUDNN\include\cudnn.h") { + # CI/zip layout (flat) + $cudnnRoot = "C:\Program Files\NVIDIA\CUDNN" + $env:CUDNN_ROOT_DIR = $cudnnRoot + $env:CUDNN_INCLUDE_PATH = "$cudnnRoot\include" + $env:CUDNN_LIBRARY_PATH = "$cudnnRoot\lib\x64" + Write-Output "Found cuDNN at $cudnnRoot (flat layout)" + } else { + # Official installer layout (versioned) + $cudnnRoot = $null + $resolved = Resolve-Path -Path "C:\Program Files\NVIDIA\CUDNN\v*" -ErrorAction SilentlyContinue | Sort-Object -Descending | Select-Object -First 1 + if ($resolved -and (Test-Path "$($resolved.Path)\include\$cudaMajorVer.0\cudnn.h")) { + $cudnnRoot = $resolved.Path + $env:CUDNN_ROOT_DIR = $cudnnRoot + $env:CUDNN_INCLUDE_PATH = "$cudnnRoot\include\$cudaMajorVer.0" + $env:CUDNN_LIBRARY_PATH = "$cudnnRoot\lib\$cudaMajorVer.0\x64" + Write-Output "Found cuDNN at $cudnnRoot (official installer, CUDA $cudaMajorVer.0)" + } else { + Write-Output "cuDNN not found - set CUDNN_INCLUDE_PATH and CUDNN_LIBRARY_PATH environment variables" + Write-Output "Skipping MLX build" + return + } + } + + Write-Output "Building MLX CUDA v$cudaMajorVer backend libraries $cuda" + $env:CUDAToolkit_ROOT=$cuda + & cmake -B build\mlx_cuda_v$cudaMajorVer --preset "MLX CUDA $cudaMajorVer" -T cuda="$cuda" --install-prefix "$script:DIST_DIR" + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --build build\mlx_cuda_v$cudaMajorVer --target mlx --target mlxc --config Release --parallel $script:JOBS -- /nodeReuse:false + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --install build\mlx_cuda_v$cudaMajorVer --component "MLX" --strip + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + } else { + Write-Output "CUDA v$cudaMajorVer not detected, skipping MLX build" + } } } function ollama { mkdir -Force -path "${script:DIST_DIR}\" | Out-Null - write-host "Building ollama CLI" + Write-Output "Building ollama CLI" & go build -trimpath -ldflags "-s -w -X=github.com/ollama/ollama/version.Version=$script:VERSION -X=github.com/ollama/ollama/server.mode=release" . if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} cp .\ollama.exe "${script:DIST_DIR}\" } function app { - write-host "Building Ollama App $script:VERSION with package version $script:PKG_VERSION" + Write-Output "Building Ollama App $script:VERSION with package version $script:PKG_VERSION" if (!(Get-Command npm -ErrorAction SilentlyContinue)) { - write-host "npm is not installed. Please install Node.js and npm first:" - write-host " Visit: https://nodejs.org/" + Write-Output "npm is not installed. Please install Node.js and npm first:" + Write-Output " Visit: https://nodejs.org/" exit 1 } if (!(Get-Command tsc -ErrorAction SilentlyContinue)) { - write-host "Installing TypeScript compiler..." + Write-Output "Installing TypeScript compiler..." npm install -g typescript } if (!(Get-Command tscriptify -ErrorAction SilentlyContinue)) { - write-host "Installing tscriptify..." + Write-Output "Installing tscriptify..." go install github.com/tkrajina/typescriptify-golang-structs/tscriptify@latest } if (!(Get-Command tscriptify -ErrorAction SilentlyContinue)) { @@ -260,32 +346,32 @@ function app { Push-Location app/ui/app npm install if ($LASTEXITCODE -ne 0) { - write-host "ERROR: npm install failed with exit code $LASTEXITCODE" + Write-Output "ERROR: npm install failed with exit code $LASTEXITCODE" exit $LASTEXITCODE } - write-host "Building React application..." + Write-Output "Building React application..." npm run build if ($LASTEXITCODE -ne 0) { - write-host "ERROR: npm run build failed with exit code $LASTEXITCODE" + Write-Output "ERROR: npm run build failed with exit code $LASTEXITCODE" exit $LASTEXITCODE } # Check if dist directory exists and has content if (!(Test-Path "dist")) { - write-host "ERROR: dist directory was not created by npm run build" + Write-Output "ERROR: dist directory was not created by npm run build" exit 1 } $distFiles = Get-ChildItem "dist" -Recurse if ($distFiles.Count -eq 0) { - write-host "ERROR: dist directory is empty after npm run build" + Write-Output "ERROR: dist directory is empty after npm run build" exit 1 } Pop-Location - write-host "Running go generate" + Write-Output "Running go generate" & go generate ./... if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} & go build -trimpath -ldflags "-s -w -H windowsgui -X=github.com/ollama/ollama/app/version.Version=$script:VERSION" -o .\dist\windows-ollama-app-${script:ARCH}.exe ./app/cmd/app/ @@ -293,42 +379,42 @@ function app { } function deps { - write-host "Download MSVC Redistributables" + Write-Output "Download MSVC Redistributables" mkdir -Force -path "${script:SRC_DIR}\dist\\windows-arm64" | Out-Null mkdir -Force -path "${script:SRC_DIR}\dist\\windows-amd64" | Out-Null - invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.arm64.exe" -OutFile "${script:SRC_DIR}\dist\windows-arm64\vc_redist.arm64.exe" - invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.x64.exe" -OutFile "${script:SRC_DIR}\dist\windows-amd64\vc_redist.x64.exe" - write-host "Done." + invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.arm64.exe" -OutFile "${script:SRC_DIR}\dist\windows-arm64\vc_redist.arm64.exe" -ErrorAction Stop + invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.x64.exe" -OutFile "${script:SRC_DIR}\dist\windows-amd64\vc_redist.x64.exe" -ErrorAction Stop + Write-Output "Done." } function sign { # Copy install.ps1 to dist for release packaging - write-host "Copying install.ps1 to dist" - Copy-Item -Path "${script:SRC_DIR}\scripts\install.ps1" -Destination "${script:SRC_DIR}\dist\install.ps1" + Write-Output "Copying install.ps1 to dist" + Copy-Item -Path "${script:SRC_DIR}\scripts\install.ps1" -Destination "${script:SRC_DIR}\dist\install.ps1" -ErrorAction Stop if ("${env:KEY_CONTAINER}") { - write-host "Signing Ollama executables, scripts and libraries" + Write-Output "Signing Ollama executables, scripts and libraries" & "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" ` /csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} ` $(get-childitem -path "${script:SRC_DIR}\dist\windows-*" -r -include @('*.exe', '*.dll')) if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} - write-host "Signing install.ps1" + Write-Output "Signing install.ps1" & "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" ` /csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} ` "${script:SRC_DIR}\dist\install.ps1" if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} } else { - write-host "Signing not enabled" + Write-Output "Signing not enabled" } } function installer { if ($null -eq ${script:INNO_SETUP_DIR}) { - write-host "ERROR: missing Inno Setup installation directory - install from https://jrsoftware.org/isdl.php" + Write-Output "ERROR: missing Inno Setup installation directory - install from https://jrsoftware.org/isdl.php" exit 1 } - write-host "Building Ollama Installer" + Write-Output "Building Ollama Installer" cd "${script:SRC_DIR}\app" $env:PKG_VERSION=$script:PKG_VERSION if ("${env:KEY_CONTAINER}") { @@ -342,24 +428,24 @@ function installer { function zip { if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64") { if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm") { - write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64-rocm.zip" + Write-Output "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64-rocm.zip" # Temporarily adjust paths so we can retain the same directory structure Remove-Item -ea 0 -r "${script:SRC_DIR}\dist\windows-amd64-rocm" mkdir -Force -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama" Write-Output "Extract this ROCm zip file to the same location where you extracted ollama-windows-amd64.zip" > "${script:SRC_DIR}\dist\windows-amd64-rocm\README.txt" - Move-Item -path "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -destination "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama" + Move-Item -path "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -destination "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama" -ErrorAction Stop Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-amd64-rocm\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64-rocm.zip" -Force } - write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64.zip" + Write-Output "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64.zip" Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-amd64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64.zip" -Force if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64-rocm") { - Move-Item -destination "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama\rocm" + Move-Item -destination "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama\rocm" -ErrorAction Stop } } if (Test-Path -Path "${script:SRC_DIR}\dist\windows-arm64") { - write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-arm64.zip" + Write-Output "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-arm64.zip" Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-arm64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-arm64.zip" -Force } } @@ -375,8 +461,9 @@ try { cpu cuda12 cuda13 - rocm + rocm6 vulkan + mlxCuda13 ollama app deps @@ -385,13 +472,13 @@ try { zip } else { for ( $i = 0; $i -lt $args.count; $i++ ) { - write-host "running build step $($args[$i])" + Write-Output "running build step $($args[$i])" & $($args[$i]) } } } catch { - write-host "Build Failed" - write-host $_ + Write-Error "Build Failed: $($_.Exception.Message)" + Write-Error "$($_.ScriptStackTrace)" } finally { set-location $script:SRC_DIR $env:PKG_VERSION="" diff --git a/scripts/env.sh b/scripts/env.sh index 65a970bdc..3e6d69cc9 100644 --- a/scripts/env.sh +++ b/scripts/env.sh @@ -18,6 +18,14 @@ OLLAMA_COMMON_BUILD_ARGS="--build-arg=VERSION \ --build-arg=GPU_RUNNER_CPU_FLAGS \ --build-arg=AMDGPU_TARGETS" +# Forward local MLX source overrides as Docker build contexts +if [ -n "${OLLAMA_MLX_SOURCE:-}" ]; then + OLLAMA_COMMON_BUILD_ARGS="$OLLAMA_COMMON_BUILD_ARGS --build-context local-mlx=$(cd "$OLLAMA_MLX_SOURCE" && pwd)" +fi +if [ -n "${OLLAMA_MLX_C_SOURCE:-}" ]; then + OLLAMA_COMMON_BUILD_ARGS="$OLLAMA_COMMON_BUILD_ARGS --build-context local-mlx-c=$(cd "$OLLAMA_MLX_C_SOURCE" && pwd)" +fi + echo "Building Ollama" echo "VERSION=$VERSION" echo "PLATFORM=$PLATFORM" \ No newline at end of file diff --git a/x/create/client/quantize.go b/x/create/client/quantize.go index 6ccb1ad6d..bb379bbac 100644 --- a/x/create/client/quantize.go +++ b/x/create/client/quantize.go @@ -1,5 +1,3 @@ -//go:build mlx - package client import ( @@ -194,9 +192,10 @@ func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) { return blobData, nil } -// QuantizeSupported returns true if quantization is supported (MLX build) +// QuantizeSupported returns true if quantization is supported (MLX library available) func QuantizeSupported() bool { - return true + mlx.InitMLX() + return mlx.IsMLXAvailable() } // ensureTempDir creates the temp directory for quantization if it doesn't exist diff --git a/x/create/client/quantize_stub.go b/x/create/client/quantize_stub.go deleted file mode 100644 index 7a75671a0..000000000 --- a/x/create/client/quantize_stub.go +++ /dev/null @@ -1,25 +0,0 @@ -//go:build !mlx - -package client - -import ( - "fmt" - "io" - - "github.com/ollama/ollama/x/create" -) - -// quantizeTensor is not available without MLX -func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) { - return nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)") -} - -// quantizePackedGroup is not available without MLX -func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) { - return nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)") -} - -// QuantizeSupported returns false when MLX is not available -func QuantizeSupported() bool { - return false -} diff --git a/x/imagegen/cache/cache.go b/x/imagegen/cache/cache.go index 8a25193cd..d4e19ba55 100644 --- a/x/imagegen/cache/cache.go +++ b/x/imagegen/cache/cache.go @@ -1,5 +1,3 @@ -//go:build mlx - package cache import "github.com/ollama/ollama/x/imagegen/mlx" diff --git a/x/imagegen/cache/step.go b/x/imagegen/cache/step.go index f91f22fa0..066f2f645 100644 --- a/x/imagegen/cache/step.go +++ b/x/imagegen/cache/step.go @@ -1,5 +1,3 @@ -//go:build mlx - package cache import "github.com/ollama/ollama/x/imagegen/mlx" diff --git a/x/imagegen/cache/teacache.go b/x/imagegen/cache/teacache.go index 60031d8cb..fb06047ea 100644 --- a/x/imagegen/cache/teacache.go +++ b/x/imagegen/cache/teacache.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package cache provides caching mechanisms for diffusion model inference. package cache diff --git a/x/imagegen/cmd/engine/README.md b/x/imagegen/cmd/engine/README.md index 3991c02a8..e6ab2f5d3 100644 --- a/x/imagegen/cmd/engine/README.md +++ b/x/imagegen/cmd/engine/README.md @@ -5,7 +5,7 @@ Experimental MLX backend for running models on Apple Silicon and CUDA. ## Build ```bash -go build -tags mlx -o engine ./x/imagegen/cmd/engine +go build -o engine ./x/imagegen/cmd/engine ``` ## Text Generation diff --git a/x/imagegen/cmd/engine/generate.go b/x/imagegen/cmd/engine/generate.go index 51118afc1..95173cd80 100644 --- a/x/imagegen/cmd/engine/generate.go +++ b/x/imagegen/cmd/engine/generate.go @@ -1,5 +1,3 @@ -//go:build mlx - package main import ( diff --git a/x/imagegen/cmd/engine/image.go b/x/imagegen/cmd/engine/image.go index e8af2222a..3c393cf66 100644 --- a/x/imagegen/cmd/engine/image.go +++ b/x/imagegen/cmd/engine/image.go @@ -1,5 +1,3 @@ -//go:build mlx - package main import ( diff --git a/x/imagegen/cmd/engine/main.go b/x/imagegen/cmd/engine/main.go index 6ec7de9e1..31411f466 100644 --- a/x/imagegen/cmd/engine/main.go +++ b/x/imagegen/cmd/engine/main.go @@ -1,5 +1,3 @@ -//go:build mlx - package main import ( diff --git a/x/imagegen/cmd/engine/sample.go b/x/imagegen/cmd/engine/sample.go index 5d723e6dc..165c40774 100644 --- a/x/imagegen/cmd/engine/sample.go +++ b/x/imagegen/cmd/engine/sample.go @@ -1,5 +1,3 @@ -//go:build mlx - package main import "github.com/ollama/ollama/x/imagegen/mlx" diff --git a/x/imagegen/image.go b/x/imagegen/image.go index 2dca0ee1d..9bdd95304 100644 --- a/x/imagegen/image.go +++ b/x/imagegen/image.go @@ -1,5 +1,3 @@ -//go:build mlx - package imagegen import ( diff --git a/x/imagegen/image_processor.go b/x/imagegen/image_processor.go index 7a562feb5..c3e68ebb3 100644 --- a/x/imagegen/image_processor.go +++ b/x/imagegen/image_processor.go @@ -1,5 +1,3 @@ -//go:build mlx - package imagegen import ( diff --git a/x/imagegen/imagegen.go b/x/imagegen/imagegen.go index d870bed9b..c4b586505 100644 --- a/x/imagegen/imagegen.go +++ b/x/imagegen/imagegen.go @@ -1,5 +1,3 @@ -//go:build mlx - package imagegen import ( diff --git a/x/imagegen/manifest/weights.go b/x/imagegen/manifest/weights.go index e1209c9db..fcb30449e 100644 --- a/x/imagegen/manifest/weights.go +++ b/x/imagegen/manifest/weights.go @@ -1,5 +1,3 @@ -//go:build mlx - package manifest import ( diff --git a/x/imagegen/mlx/CMakeLists.txt b/x/imagegen/mlx/CMakeLists.txt index b62cbf2eb..70246ef4b 100644 --- a/x/imagegen/mlx/CMakeLists.txt +++ b/x/imagegen/mlx/CMakeLists.txt @@ -4,6 +4,10 @@ include(FetchContent) file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_C_GIT_TAG) string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG) +# Read MLX core version from top-level file +file(READ "${CMAKE_SOURCE_DIR}/MLX_CORE_VERSION" MLX_GIT_TAG) +string(STRIP "${MLX_GIT_TAG}" MLX_GIT_TAG) + set(MLX_C_BUILD_EXAMPLES OFF) set(MLX_BUILD_GGUF OFF) @@ -43,6 +47,17 @@ if(NOT MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_ARCHITECTURES) message(STATUS "Using CMAKE_CUDA_ARCHITECTURES for MLX: ${MLX_CUDA_ARCHITECTURES}") endif() +# Forward cuDNN environment variables to cmake variables so MLX's FindCUDNN.cmake +# can find them via HINTS ${CUDNN_INCLUDE_PATH} / ${CUDNN_LIBRARY_PATH}. +if(DEFINED ENV{CUDNN_INCLUDE_PATH} AND NOT CUDNN_INCLUDE_PATH) + set(CUDNN_INCLUDE_PATH "$ENV{CUDNN_INCLUDE_PATH}" CACHE PATH "cuDNN include path") + message(STATUS "Using CUDNN_INCLUDE_PATH from environment: ${CUDNN_INCLUDE_PATH}") +endif() +if(DEFINED ENV{CUDNN_LIBRARY_PATH} AND NOT CUDNN_LIBRARY_PATH) + set(CUDNN_LIBRARY_PATH "$ENV{CUDNN_LIBRARY_PATH}" CACHE PATH "cuDNN library path") + message(STATUS "Using CUDNN_LIBRARY_PATH from environment: ${CUDNN_LIBRARY_PATH}") +endif() + # Enable CUDA backend if CUDA architectures are specified and CUDA compiler is available if(MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_COMPILER) set(MLX_BUILD_CUDA ON CACHE BOOL "Build CUDA backend for MLX" FORCE) @@ -51,11 +66,58 @@ elseif(MLX_CUDA_ARCHITECTURES) message(WARNING "MLX_CUDA_ARCHITECTURES specified but CUDA compiler not found, CUDA backend will be disabled") endif() +# Allow local source overrides via environment variables +# Resolve to absolute paths so FetchContent doesn't break on relative dirs. +if(DEFINED ENV{OLLAMA_MLX_SOURCE}) + get_filename_component(_mlx_src "$ENV{OLLAMA_MLX_SOURCE}" ABSOLUTE BASE_DIR ${CMAKE_SOURCE_DIR}) + set(FETCHCONTENT_SOURCE_DIR_MLX "${_mlx_src}" CACHE PATH "" FORCE) + message(STATUS "Using local MLX source: ${_mlx_src}") +endif() +if(DEFINED ENV{OLLAMA_MLX_C_SOURCE}) + get_filename_component(_mlx_c_src "$ENV{OLLAMA_MLX_C_SOURCE}" ABSOLUTE BASE_DIR ${CMAKE_SOURCE_DIR}) + set(FETCHCONTENT_SOURCE_DIR_MLX-C "${_mlx_c_src}" CACHE PATH "" FORCE) + message(STATUS "Using local MLX-C source: ${_mlx_c_src}") +endif() + +# Pre-declare mlx so our pinned version takes precedence over the one +# hardcoded in mlx-c's CMakeLists.txt (first FetchContent_Declare wins). FetchContent_Declare( - mlx-c - GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git" - GIT_TAG ${MLX_C_GIT_TAG}) + mlx + GIT_REPOSITORY "https://github.com/ml-explore/mlx.git" + GIT_TAG ${MLX_GIT_TAG} +) + +FetchContent_Declare( + mlx-c + GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git" + GIT_TAG ${MLX_C_GIT_TAG} +) FetchContent_MakeAvailable(mlx-c) +# Sync vendored headers with fetched version +file(GLOB _mlx_c_hdrs "${mlx-c_SOURCE_DIR}/mlx/c/*.h") +file(COPY ${_mlx_c_hdrs} DESTINATION "${CMAKE_SOURCE_DIR}/x/mlxrunner/mlx/include/mlx/c/") + +# For local dev builds, override MLX_VERSION with git describe output +if(TARGET mlx_version AND DEFINED FETCHCONTENT_SOURCE_DIR_MLX) + execute_process( + COMMAND git describe --tags --first-parent --abbrev=7 --long --dirty --always + WORKING_DIRECTORY ${mlx_SOURCE_DIR} + OUTPUT_VARIABLE _mlx_git_version + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET + RESULT_VARIABLE _mlx_git_result + ) + if(_mlx_git_result EQUAL 0 AND _mlx_git_version) + # Strip leading "v" prefix for consistency + string(REGEX REPLACE "^v" "" _mlx_git_version "${_mlx_git_version}") + get_target_property(_mlx_defs mlx_version COMPILE_DEFINITIONS) + list(FILTER _mlx_defs EXCLUDE REGEX "^MLX_VERSION=") + set_target_properties(mlx_version PROPERTIES COMPILE_DEFINITIONS "${_mlx_defs}") + target_compile_definitions(mlx_version PRIVATE "MLX_VERSION=\"${_mlx_git_version}\"") + message(STATUS "MLX version (local dev): ${_mlx_git_version}") + endif() +endif() + set_target_output_directory(mlx) set_target_output_directory(mlxc) diff --git a/x/imagegen/mlx/compile.go b/x/imagegen/mlx/compile.go index 0dd2dd02a..746e6eaf6 100644 --- a/x/imagegen/mlx/compile.go +++ b/x/imagegen/mlx/compile.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx /* diff --git a/x/imagegen/mlx/doc.go b/x/imagegen/mlx/doc.go index ced1802b0..5410f3b2d 100644 --- a/x/imagegen/mlx/doc.go +++ b/x/imagegen/mlx/doc.go @@ -1,6 +1,4 @@ -//go:build mlx - // Package mlx provides Go bindings for the MLX-C library with dynamic loading support. // -//go:generate go run generate_wrappers.go ../../../build/_deps/mlx-c-src/mlx/c mlx.h mlx.c +//go:generate go run generate_wrappers.go ../../mlxrunner/mlx/include/mlx/c mlx.h mlx.c package mlx diff --git a/x/imagegen/mlx/generate_wrappers.go b/x/imagegen/mlx/generate_wrappers.go index 8aa5bd0c8..114ac2a15 100644 --- a/x/imagegen/mlx/generate_wrappers.go +++ b/x/imagegen/mlx/generate_wrappers.go @@ -291,8 +291,15 @@ func generateWrapperFiles(functions []Function, headerPath, implPath string) err implBuf.WriteString("#include \"mlx/c/mlx.h\"\n") implBuf.WriteString("#include \"mlx_dynamic.h\"\n") - implBuf.WriteString("#include \n") - implBuf.WriteString("#include \n\n") + implBuf.WriteString("#include \n\n") + implBuf.WriteString("// Platform-specific dynamic loading\n") + implBuf.WriteString("#ifdef _WIN32\n") + implBuf.WriteString("#include \n") + implBuf.WriteString("#define GET_SYM(handle, name) (void*)GetProcAddress((HMODULE)(handle), name)\n") + implBuf.WriteString("#else\n") + implBuf.WriteString("#include \n") + implBuf.WriteString("#define GET_SYM(handle, name) dlsym(handle, name)\n") + implBuf.WriteString("#endif\n\n") // Function pointer definitions implBuf.WriteString("// Function pointer definitions\n") @@ -308,7 +315,7 @@ func generateWrapperFiles(functions []Function, headerPath, implPath string) err implBuf.WriteString("\n") // Initialization function - implBuf.WriteString("// Initialize all function pointers via dlsym\n") + implBuf.WriteString("// Initialize all function pointers\n") implBuf.WriteString("int mlx_load_functions(void* handle) {\n") implBuf.WriteString(" if (handle == NULL) {\n") implBuf.WriteString(" fprintf(stderr, \"MLX: Invalid library handle\\n\");\n") @@ -319,7 +326,7 @@ func generateWrapperFiles(functions []Function, headerPath, implPath string) err if fn.NeedsARM64Guard { implBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n") } - implBuf.WriteString(fmt.Sprintf(" %s_ptr = dlsym(handle, \"%s\");\n", fn.Name, fn.Name)) + implBuf.WriteString(fmt.Sprintf(" %s_ptr = GET_SYM(handle, \"%s\");\n", fn.Name, fn.Name)) implBuf.WriteString(fmt.Sprintf(" if (%s_ptr == NULL) {\n", fn.Name)) implBuf.WriteString(fmt.Sprintf(" fprintf(stderr, \"MLX: Failed to load symbol: %s\\n\");\n", fn.Name)) implBuf.WriteString(" return -1;\n") diff --git a/x/imagegen/mlx/mlx.c b/x/imagegen/mlx/mlx.c index 770b60922..b0ccbacdf 100644 --- a/x/imagegen/mlx/mlx.c +++ b/x/imagegen/mlx/mlx.c @@ -5,7 +5,15 @@ #include "mlx/c/mlx.h" #include "mlx_dynamic.h" #include + +// Platform-specific dynamic loading +#ifdef _WIN32 +#include +#define GET_SYM(handle, name) (void*)GetProcAddress((HMODULE)(handle), name) +#else #include +#define GET_SYM(handle, name) dlsym(handle, name) +#endif // Function pointer definitions size_t (*mlx_dtype_size_ptr)(mlx_dtype dtype) = NULL; @@ -603,2947 +611,2947 @@ size_t (*mlx_vector_string_size_ptr)(mlx_vector_string vec) = NULL; int (*mlx_vector_string_get_ptr)(char** res, const mlx_vector_string vec, size_t idx) = NULL; int (*mlx_version_ptr)(mlx_string* str_) = NULL; -// Initialize all function pointers via dlsym +// Initialize all function pointers int mlx_load_functions(void* handle) { if (handle == NULL) { fprintf(stderr, "MLX: Invalid library handle\n"); return -1; } - mlx_dtype_size_ptr = dlsym(handle, "mlx_dtype_size"); + mlx_dtype_size_ptr = GET_SYM(handle, "mlx_dtype_size"); if (mlx_dtype_size_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_dtype_size\n"); return -1; } - mlx_array_tostring_ptr = dlsym(handle, "mlx_array_tostring"); + mlx_array_tostring_ptr = GET_SYM(handle, "mlx_array_tostring"); if (mlx_array_tostring_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_tostring\n"); return -1; } - mlx_array_new_ptr = dlsym(handle, "mlx_array_new"); + mlx_array_new_ptr = GET_SYM(handle, "mlx_array_new"); if (mlx_array_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new\n"); return -1; } - mlx_array_free_ptr = dlsym(handle, "mlx_array_free"); + mlx_array_free_ptr = GET_SYM(handle, "mlx_array_free"); if (mlx_array_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_free\n"); return -1; } - mlx_array_new_bool_ptr = dlsym(handle, "mlx_array_new_bool"); + mlx_array_new_bool_ptr = GET_SYM(handle, "mlx_array_new_bool"); if (mlx_array_new_bool_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_bool\n"); return -1; } - mlx_array_new_int_ptr = dlsym(handle, "mlx_array_new_int"); + mlx_array_new_int_ptr = GET_SYM(handle, "mlx_array_new_int"); if (mlx_array_new_int_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_int\n"); return -1; } - mlx_array_new_float32_ptr = dlsym(handle, "mlx_array_new_float32"); + mlx_array_new_float32_ptr = GET_SYM(handle, "mlx_array_new_float32"); if (mlx_array_new_float32_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_float32\n"); return -1; } - mlx_array_new_float_ptr = dlsym(handle, "mlx_array_new_float"); + mlx_array_new_float_ptr = GET_SYM(handle, "mlx_array_new_float"); if (mlx_array_new_float_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_float\n"); return -1; } - mlx_array_new_float64_ptr = dlsym(handle, "mlx_array_new_float64"); + mlx_array_new_float64_ptr = GET_SYM(handle, "mlx_array_new_float64"); if (mlx_array_new_float64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_float64\n"); return -1; } - mlx_array_new_double_ptr = dlsym(handle, "mlx_array_new_double"); + mlx_array_new_double_ptr = GET_SYM(handle, "mlx_array_new_double"); if (mlx_array_new_double_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_double\n"); return -1; } - mlx_array_new_complex_ptr = dlsym(handle, "mlx_array_new_complex"); + mlx_array_new_complex_ptr = GET_SYM(handle, "mlx_array_new_complex"); if (mlx_array_new_complex_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_complex\n"); return -1; } - mlx_array_new_data_ptr = dlsym(handle, "mlx_array_new_data"); + mlx_array_new_data_ptr = GET_SYM(handle, "mlx_array_new_data"); if (mlx_array_new_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data\n"); return -1; } - mlx_array_new_data_managed_ptr = dlsym(handle, "mlx_array_new_data_managed"); + mlx_array_new_data_managed_ptr = GET_SYM(handle, "mlx_array_new_data_managed"); if (mlx_array_new_data_managed_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed\n"); return -1; } - mlx_array_new_data_managed_payload_ptr = dlsym(handle, "mlx_array_new_data_managed_payload"); + mlx_array_new_data_managed_payload_ptr = GET_SYM(handle, "mlx_array_new_data_managed_payload"); if (mlx_array_new_data_managed_payload_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed_payload\n"); return -1; } - mlx_array_set_ptr = dlsym(handle, "mlx_array_set"); + mlx_array_set_ptr = GET_SYM(handle, "mlx_array_set"); if (mlx_array_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set\n"); return -1; } - mlx_array_set_bool_ptr = dlsym(handle, "mlx_array_set_bool"); + mlx_array_set_bool_ptr = GET_SYM(handle, "mlx_array_set_bool"); if (mlx_array_set_bool_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_bool\n"); return -1; } - mlx_array_set_int_ptr = dlsym(handle, "mlx_array_set_int"); + mlx_array_set_int_ptr = GET_SYM(handle, "mlx_array_set_int"); if (mlx_array_set_int_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_int\n"); return -1; } - mlx_array_set_float32_ptr = dlsym(handle, "mlx_array_set_float32"); + mlx_array_set_float32_ptr = GET_SYM(handle, "mlx_array_set_float32"); if (mlx_array_set_float32_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_float32\n"); return -1; } - mlx_array_set_float_ptr = dlsym(handle, "mlx_array_set_float"); + mlx_array_set_float_ptr = GET_SYM(handle, "mlx_array_set_float"); if (mlx_array_set_float_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_float\n"); return -1; } - mlx_array_set_float64_ptr = dlsym(handle, "mlx_array_set_float64"); + mlx_array_set_float64_ptr = GET_SYM(handle, "mlx_array_set_float64"); if (mlx_array_set_float64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_float64\n"); return -1; } - mlx_array_set_double_ptr = dlsym(handle, "mlx_array_set_double"); + mlx_array_set_double_ptr = GET_SYM(handle, "mlx_array_set_double"); if (mlx_array_set_double_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_double\n"); return -1; } - mlx_array_set_complex_ptr = dlsym(handle, "mlx_array_set_complex"); + mlx_array_set_complex_ptr = GET_SYM(handle, "mlx_array_set_complex"); if (mlx_array_set_complex_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_complex\n"); return -1; } - mlx_array_set_data_ptr = dlsym(handle, "mlx_array_set_data"); + mlx_array_set_data_ptr = GET_SYM(handle, "mlx_array_set_data"); if (mlx_array_set_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_data\n"); return -1; } - mlx_array_itemsize_ptr = dlsym(handle, "mlx_array_itemsize"); + mlx_array_itemsize_ptr = GET_SYM(handle, "mlx_array_itemsize"); if (mlx_array_itemsize_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_itemsize\n"); return -1; } - mlx_array_size_ptr = dlsym(handle, "mlx_array_size"); + mlx_array_size_ptr = GET_SYM(handle, "mlx_array_size"); if (mlx_array_size_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_size\n"); return -1; } - mlx_array_nbytes_ptr = dlsym(handle, "mlx_array_nbytes"); + mlx_array_nbytes_ptr = GET_SYM(handle, "mlx_array_nbytes"); if (mlx_array_nbytes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_nbytes\n"); return -1; } - mlx_array_ndim_ptr = dlsym(handle, "mlx_array_ndim"); + mlx_array_ndim_ptr = GET_SYM(handle, "mlx_array_ndim"); if (mlx_array_ndim_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_ndim\n"); return -1; } - mlx_array_shape_ptr = dlsym(handle, "mlx_array_shape"); + mlx_array_shape_ptr = GET_SYM(handle, "mlx_array_shape"); if (mlx_array_shape_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_shape\n"); return -1; } - mlx_array_strides_ptr = dlsym(handle, "mlx_array_strides"); + mlx_array_strides_ptr = GET_SYM(handle, "mlx_array_strides"); if (mlx_array_strides_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_strides\n"); return -1; } - mlx_array_dim_ptr = dlsym(handle, "mlx_array_dim"); + mlx_array_dim_ptr = GET_SYM(handle, "mlx_array_dim"); if (mlx_array_dim_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_dim\n"); return -1; } - mlx_array_dtype_ptr = dlsym(handle, "mlx_array_dtype"); + mlx_array_dtype_ptr = GET_SYM(handle, "mlx_array_dtype"); if (mlx_array_dtype_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_dtype\n"); return -1; } - mlx_array_eval_ptr = dlsym(handle, "mlx_array_eval"); + mlx_array_eval_ptr = GET_SYM(handle, "mlx_array_eval"); if (mlx_array_eval_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_eval\n"); return -1; } - mlx_array_item_bool_ptr = dlsym(handle, "mlx_array_item_bool"); + mlx_array_item_bool_ptr = GET_SYM(handle, "mlx_array_item_bool"); if (mlx_array_item_bool_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_bool\n"); return -1; } - mlx_array_item_uint8_ptr = dlsym(handle, "mlx_array_item_uint8"); + mlx_array_item_uint8_ptr = GET_SYM(handle, "mlx_array_item_uint8"); if (mlx_array_item_uint8_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint8\n"); return -1; } - mlx_array_item_uint16_ptr = dlsym(handle, "mlx_array_item_uint16"); + mlx_array_item_uint16_ptr = GET_SYM(handle, "mlx_array_item_uint16"); if (mlx_array_item_uint16_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint16\n"); return -1; } - mlx_array_item_uint32_ptr = dlsym(handle, "mlx_array_item_uint32"); + mlx_array_item_uint32_ptr = GET_SYM(handle, "mlx_array_item_uint32"); if (mlx_array_item_uint32_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint32\n"); return -1; } - mlx_array_item_uint64_ptr = dlsym(handle, "mlx_array_item_uint64"); + mlx_array_item_uint64_ptr = GET_SYM(handle, "mlx_array_item_uint64"); if (mlx_array_item_uint64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint64\n"); return -1; } - mlx_array_item_int8_ptr = dlsym(handle, "mlx_array_item_int8"); + mlx_array_item_int8_ptr = GET_SYM(handle, "mlx_array_item_int8"); if (mlx_array_item_int8_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int8\n"); return -1; } - mlx_array_item_int16_ptr = dlsym(handle, "mlx_array_item_int16"); + mlx_array_item_int16_ptr = GET_SYM(handle, "mlx_array_item_int16"); if (mlx_array_item_int16_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int16\n"); return -1; } - mlx_array_item_int32_ptr = dlsym(handle, "mlx_array_item_int32"); + mlx_array_item_int32_ptr = GET_SYM(handle, "mlx_array_item_int32"); if (mlx_array_item_int32_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int32\n"); return -1; } - mlx_array_item_int64_ptr = dlsym(handle, "mlx_array_item_int64"); + mlx_array_item_int64_ptr = GET_SYM(handle, "mlx_array_item_int64"); if (mlx_array_item_int64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int64\n"); return -1; } - mlx_array_item_float32_ptr = dlsym(handle, "mlx_array_item_float32"); + mlx_array_item_float32_ptr = GET_SYM(handle, "mlx_array_item_float32"); if (mlx_array_item_float32_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_float32\n"); return -1; } - mlx_array_item_float64_ptr = dlsym(handle, "mlx_array_item_float64"); + mlx_array_item_float64_ptr = GET_SYM(handle, "mlx_array_item_float64"); if (mlx_array_item_float64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_float64\n"); return -1; } - mlx_array_item_complex64_ptr = dlsym(handle, "mlx_array_item_complex64"); + mlx_array_item_complex64_ptr = GET_SYM(handle, "mlx_array_item_complex64"); if (mlx_array_item_complex64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_complex64\n"); return -1; } #if defined(__aarch64__) || defined(_M_ARM64) - mlx_array_item_float16_ptr = dlsym(handle, "mlx_array_item_float16"); + mlx_array_item_float16_ptr = GET_SYM(handle, "mlx_array_item_float16"); if (mlx_array_item_float16_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_float16\n"); return -1; } #endif #if defined(__aarch64__) || defined(_M_ARM64) - mlx_array_item_bfloat16_ptr = dlsym(handle, "mlx_array_item_bfloat16"); + mlx_array_item_bfloat16_ptr = GET_SYM(handle, "mlx_array_item_bfloat16"); if (mlx_array_item_bfloat16_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_bfloat16\n"); return -1; } #endif - mlx_array_data_bool_ptr = dlsym(handle, "mlx_array_data_bool"); + mlx_array_data_bool_ptr = GET_SYM(handle, "mlx_array_data_bool"); if (mlx_array_data_bool_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_bool\n"); return -1; } - mlx_array_data_uint8_ptr = dlsym(handle, "mlx_array_data_uint8"); + mlx_array_data_uint8_ptr = GET_SYM(handle, "mlx_array_data_uint8"); if (mlx_array_data_uint8_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint8\n"); return -1; } - mlx_array_data_uint16_ptr = dlsym(handle, "mlx_array_data_uint16"); + mlx_array_data_uint16_ptr = GET_SYM(handle, "mlx_array_data_uint16"); if (mlx_array_data_uint16_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint16\n"); return -1; } - mlx_array_data_uint32_ptr = dlsym(handle, "mlx_array_data_uint32"); + mlx_array_data_uint32_ptr = GET_SYM(handle, "mlx_array_data_uint32"); if (mlx_array_data_uint32_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint32\n"); return -1; } - mlx_array_data_uint64_ptr = dlsym(handle, "mlx_array_data_uint64"); + mlx_array_data_uint64_ptr = GET_SYM(handle, "mlx_array_data_uint64"); if (mlx_array_data_uint64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint64\n"); return -1; } - mlx_array_data_int8_ptr = dlsym(handle, "mlx_array_data_int8"); + mlx_array_data_int8_ptr = GET_SYM(handle, "mlx_array_data_int8"); if (mlx_array_data_int8_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int8\n"); return -1; } - mlx_array_data_int16_ptr = dlsym(handle, "mlx_array_data_int16"); + mlx_array_data_int16_ptr = GET_SYM(handle, "mlx_array_data_int16"); if (mlx_array_data_int16_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int16\n"); return -1; } - mlx_array_data_int32_ptr = dlsym(handle, "mlx_array_data_int32"); + mlx_array_data_int32_ptr = GET_SYM(handle, "mlx_array_data_int32"); if (mlx_array_data_int32_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int32\n"); return -1; } - mlx_array_data_int64_ptr = dlsym(handle, "mlx_array_data_int64"); + mlx_array_data_int64_ptr = GET_SYM(handle, "mlx_array_data_int64"); if (mlx_array_data_int64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int64\n"); return -1; } - mlx_array_data_float32_ptr = dlsym(handle, "mlx_array_data_float32"); + mlx_array_data_float32_ptr = GET_SYM(handle, "mlx_array_data_float32"); if (mlx_array_data_float32_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_float32\n"); return -1; } - mlx_array_data_float64_ptr = dlsym(handle, "mlx_array_data_float64"); + mlx_array_data_float64_ptr = GET_SYM(handle, "mlx_array_data_float64"); if (mlx_array_data_float64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_float64\n"); return -1; } - mlx_array_data_complex64_ptr = dlsym(handle, "mlx_array_data_complex64"); + mlx_array_data_complex64_ptr = GET_SYM(handle, "mlx_array_data_complex64"); if (mlx_array_data_complex64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_complex64\n"); return -1; } #if defined(__aarch64__) || defined(_M_ARM64) - mlx_array_data_float16_ptr = dlsym(handle, "mlx_array_data_float16"); + mlx_array_data_float16_ptr = GET_SYM(handle, "mlx_array_data_float16"); if (mlx_array_data_float16_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_float16\n"); return -1; } #endif #if defined(__aarch64__) || defined(_M_ARM64) - mlx_array_data_bfloat16_ptr = dlsym(handle, "mlx_array_data_bfloat16"); + mlx_array_data_bfloat16_ptr = GET_SYM(handle, "mlx_array_data_bfloat16"); if (mlx_array_data_bfloat16_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_bfloat16\n"); return -1; } #endif - _mlx_array_is_available_ptr = dlsym(handle, "_mlx_array_is_available"); + _mlx_array_is_available_ptr = GET_SYM(handle, "_mlx_array_is_available"); if (_mlx_array_is_available_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_available\n"); return -1; } - _mlx_array_wait_ptr = dlsym(handle, "_mlx_array_wait"); + _mlx_array_wait_ptr = GET_SYM(handle, "_mlx_array_wait"); if (_mlx_array_wait_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_wait\n"); return -1; } - _mlx_array_is_contiguous_ptr = dlsym(handle, "_mlx_array_is_contiguous"); + _mlx_array_is_contiguous_ptr = GET_SYM(handle, "_mlx_array_is_contiguous"); if (_mlx_array_is_contiguous_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_contiguous\n"); return -1; } - _mlx_array_is_row_contiguous_ptr = dlsym(handle, "_mlx_array_is_row_contiguous"); + _mlx_array_is_row_contiguous_ptr = GET_SYM(handle, "_mlx_array_is_row_contiguous"); if (_mlx_array_is_row_contiguous_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_row_contiguous\n"); return -1; } - _mlx_array_is_col_contiguous_ptr = dlsym(handle, "_mlx_array_is_col_contiguous"); + _mlx_array_is_col_contiguous_ptr = GET_SYM(handle, "_mlx_array_is_col_contiguous"); if (_mlx_array_is_col_contiguous_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_col_contiguous\n"); return -1; } - mlx_closure_new_ptr = dlsym(handle, "mlx_closure_new"); + mlx_closure_new_ptr = GET_SYM(handle, "mlx_closure_new"); if (mlx_closure_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new\n"); return -1; } - mlx_closure_free_ptr = dlsym(handle, "mlx_closure_free"); + mlx_closure_free_ptr = GET_SYM(handle, "mlx_closure_free"); if (mlx_closure_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_free\n"); return -1; } - mlx_closure_new_func_ptr = dlsym(handle, "mlx_closure_new_func"); + mlx_closure_new_func_ptr = GET_SYM(handle, "mlx_closure_new_func"); if (mlx_closure_new_func_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new_func\n"); return -1; } - mlx_closure_new_func_payload_ptr = dlsym(handle, "mlx_closure_new_func_payload"); + mlx_closure_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_new_func_payload"); if (mlx_closure_new_func_payload_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new_func_payload\n"); return -1; } - mlx_closure_set_ptr = dlsym(handle, "mlx_closure_set"); + mlx_closure_set_ptr = GET_SYM(handle, "mlx_closure_set"); if (mlx_closure_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_set\n"); return -1; } - mlx_closure_apply_ptr = dlsym(handle, "mlx_closure_apply"); + mlx_closure_apply_ptr = GET_SYM(handle, "mlx_closure_apply"); if (mlx_closure_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_apply\n"); return -1; } - mlx_closure_new_unary_ptr = dlsym(handle, "mlx_closure_new_unary"); + mlx_closure_new_unary_ptr = GET_SYM(handle, "mlx_closure_new_unary"); if (mlx_closure_new_unary_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new_unary\n"); return -1; } - mlx_closure_kwargs_new_ptr = dlsym(handle, "mlx_closure_kwargs_new"); + mlx_closure_kwargs_new_ptr = GET_SYM(handle, "mlx_closure_kwargs_new"); if (mlx_closure_kwargs_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_new\n"); return -1; } - mlx_closure_kwargs_free_ptr = dlsym(handle, "mlx_closure_kwargs_free"); + mlx_closure_kwargs_free_ptr = GET_SYM(handle, "mlx_closure_kwargs_free"); if (mlx_closure_kwargs_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_free\n"); return -1; } - mlx_closure_kwargs_new_func_ptr = dlsym(handle, "mlx_closure_kwargs_new_func"); + mlx_closure_kwargs_new_func_ptr = GET_SYM(handle, "mlx_closure_kwargs_new_func"); if (mlx_closure_kwargs_new_func_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_new_func\n"); return -1; } - mlx_closure_kwargs_new_func_payload_ptr = dlsym(handle, "mlx_closure_kwargs_new_func_payload"); + mlx_closure_kwargs_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_kwargs_new_func_payload"); if (mlx_closure_kwargs_new_func_payload_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_new_func_payload\n"); return -1; } - mlx_closure_kwargs_set_ptr = dlsym(handle, "mlx_closure_kwargs_set"); + mlx_closure_kwargs_set_ptr = GET_SYM(handle, "mlx_closure_kwargs_set"); if (mlx_closure_kwargs_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_set\n"); return -1; } - mlx_closure_kwargs_apply_ptr = dlsym(handle, "mlx_closure_kwargs_apply"); + mlx_closure_kwargs_apply_ptr = GET_SYM(handle, "mlx_closure_kwargs_apply"); if (mlx_closure_kwargs_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_apply\n"); return -1; } - mlx_closure_value_and_grad_new_ptr = dlsym(handle, "mlx_closure_value_and_grad_new"); + mlx_closure_value_and_grad_new_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_new"); if (mlx_closure_value_and_grad_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_new\n"); return -1; } - mlx_closure_value_and_grad_free_ptr = dlsym(handle, "mlx_closure_value_and_grad_free"); + mlx_closure_value_and_grad_free_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_free"); if (mlx_closure_value_and_grad_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_free\n"); return -1; } - mlx_closure_value_and_grad_new_func_ptr = dlsym(handle, "mlx_closure_value_and_grad_new_func"); + mlx_closure_value_and_grad_new_func_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_new_func"); if (mlx_closure_value_and_grad_new_func_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_new_func\n"); return -1; } - mlx_closure_value_and_grad_new_func_payload_ptr = dlsym(handle, "mlx_closure_value_and_grad_new_func_payload"); + mlx_closure_value_and_grad_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_new_func_payload"); if (mlx_closure_value_and_grad_new_func_payload_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_new_func_payload\n"); return -1; } - mlx_closure_value_and_grad_set_ptr = dlsym(handle, "mlx_closure_value_and_grad_set"); + mlx_closure_value_and_grad_set_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_set"); if (mlx_closure_value_and_grad_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_set\n"); return -1; } - mlx_closure_value_and_grad_apply_ptr = dlsym(handle, "mlx_closure_value_and_grad_apply"); + mlx_closure_value_and_grad_apply_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_apply"); if (mlx_closure_value_and_grad_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_apply\n"); return -1; } - mlx_closure_custom_new_ptr = dlsym(handle, "mlx_closure_custom_new"); + mlx_closure_custom_new_ptr = GET_SYM(handle, "mlx_closure_custom_new"); if (mlx_closure_custom_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_new\n"); return -1; } - mlx_closure_custom_free_ptr = dlsym(handle, "mlx_closure_custom_free"); + mlx_closure_custom_free_ptr = GET_SYM(handle, "mlx_closure_custom_free"); if (mlx_closure_custom_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_free\n"); return -1; } - mlx_closure_custom_new_func_ptr = dlsym(handle, "mlx_closure_custom_new_func"); + mlx_closure_custom_new_func_ptr = GET_SYM(handle, "mlx_closure_custom_new_func"); if (mlx_closure_custom_new_func_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_new_func\n"); return -1; } - mlx_closure_custom_new_func_payload_ptr = dlsym(handle, "mlx_closure_custom_new_func_payload"); + mlx_closure_custom_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_custom_new_func_payload"); if (mlx_closure_custom_new_func_payload_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_new_func_payload\n"); return -1; } - mlx_closure_custom_set_ptr = dlsym(handle, "mlx_closure_custom_set"); + mlx_closure_custom_set_ptr = GET_SYM(handle, "mlx_closure_custom_set"); if (mlx_closure_custom_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_set\n"); return -1; } - mlx_closure_custom_apply_ptr = dlsym(handle, "mlx_closure_custom_apply"); + mlx_closure_custom_apply_ptr = GET_SYM(handle, "mlx_closure_custom_apply"); if (mlx_closure_custom_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_apply\n"); return -1; } - mlx_closure_custom_jvp_new_ptr = dlsym(handle, "mlx_closure_custom_jvp_new"); + mlx_closure_custom_jvp_new_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_new"); if (mlx_closure_custom_jvp_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_new\n"); return -1; } - mlx_closure_custom_jvp_free_ptr = dlsym(handle, "mlx_closure_custom_jvp_free"); + mlx_closure_custom_jvp_free_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_free"); if (mlx_closure_custom_jvp_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_free\n"); return -1; } - mlx_closure_custom_jvp_new_func_ptr = dlsym(handle, "mlx_closure_custom_jvp_new_func"); + mlx_closure_custom_jvp_new_func_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_new_func"); if (mlx_closure_custom_jvp_new_func_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_new_func\n"); return -1; } - mlx_closure_custom_jvp_new_func_payload_ptr = dlsym(handle, "mlx_closure_custom_jvp_new_func_payload"); + mlx_closure_custom_jvp_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_new_func_payload"); if (mlx_closure_custom_jvp_new_func_payload_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_new_func_payload\n"); return -1; } - mlx_closure_custom_jvp_set_ptr = dlsym(handle, "mlx_closure_custom_jvp_set"); + mlx_closure_custom_jvp_set_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_set"); if (mlx_closure_custom_jvp_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_set\n"); return -1; } - mlx_closure_custom_jvp_apply_ptr = dlsym(handle, "mlx_closure_custom_jvp_apply"); + mlx_closure_custom_jvp_apply_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_apply"); if (mlx_closure_custom_jvp_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_apply\n"); return -1; } - mlx_closure_custom_vmap_new_ptr = dlsym(handle, "mlx_closure_custom_vmap_new"); + mlx_closure_custom_vmap_new_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_new"); if (mlx_closure_custom_vmap_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_new\n"); return -1; } - mlx_closure_custom_vmap_free_ptr = dlsym(handle, "mlx_closure_custom_vmap_free"); + mlx_closure_custom_vmap_free_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_free"); if (mlx_closure_custom_vmap_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_free\n"); return -1; } - mlx_closure_custom_vmap_new_func_ptr = dlsym(handle, "mlx_closure_custom_vmap_new_func"); + mlx_closure_custom_vmap_new_func_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_new_func"); if (mlx_closure_custom_vmap_new_func_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_new_func\n"); return -1; } - mlx_closure_custom_vmap_new_func_payload_ptr = dlsym(handle, "mlx_closure_custom_vmap_new_func_payload"); + mlx_closure_custom_vmap_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_new_func_payload"); if (mlx_closure_custom_vmap_new_func_payload_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_new_func_payload\n"); return -1; } - mlx_closure_custom_vmap_set_ptr = dlsym(handle, "mlx_closure_custom_vmap_set"); + mlx_closure_custom_vmap_set_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_set"); if (mlx_closure_custom_vmap_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_set\n"); return -1; } - mlx_closure_custom_vmap_apply_ptr = dlsym(handle, "mlx_closure_custom_vmap_apply"); + mlx_closure_custom_vmap_apply_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_apply"); if (mlx_closure_custom_vmap_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_apply\n"); return -1; } - mlx_compile_ptr = dlsym(handle, "mlx_compile"); + mlx_compile_ptr = GET_SYM(handle, "mlx_compile"); if (mlx_compile_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_compile\n"); return -1; } - mlx_detail_compile_ptr = dlsym(handle, "mlx_detail_compile"); + mlx_detail_compile_ptr = GET_SYM(handle, "mlx_detail_compile"); if (mlx_detail_compile_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_compile\n"); return -1; } - mlx_detail_compile_clear_cache_ptr = dlsym(handle, "mlx_detail_compile_clear_cache"); + mlx_detail_compile_clear_cache_ptr = GET_SYM(handle, "mlx_detail_compile_clear_cache"); if (mlx_detail_compile_clear_cache_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_compile_clear_cache\n"); return -1; } - mlx_detail_compile_erase_ptr = dlsym(handle, "mlx_detail_compile_erase"); + mlx_detail_compile_erase_ptr = GET_SYM(handle, "mlx_detail_compile_erase"); if (mlx_detail_compile_erase_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_compile_erase\n"); return -1; } - mlx_disable_compile_ptr = dlsym(handle, "mlx_disable_compile"); + mlx_disable_compile_ptr = GET_SYM(handle, "mlx_disable_compile"); if (mlx_disable_compile_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_disable_compile\n"); return -1; } - mlx_enable_compile_ptr = dlsym(handle, "mlx_enable_compile"); + mlx_enable_compile_ptr = GET_SYM(handle, "mlx_enable_compile"); if (mlx_enable_compile_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_enable_compile\n"); return -1; } - mlx_set_compile_mode_ptr = dlsym(handle, "mlx_set_compile_mode"); + mlx_set_compile_mode_ptr = GET_SYM(handle, "mlx_set_compile_mode"); if (mlx_set_compile_mode_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_compile_mode\n"); return -1; } - mlx_cuda_is_available_ptr = dlsym(handle, "mlx_cuda_is_available"); + mlx_cuda_is_available_ptr = GET_SYM(handle, "mlx_cuda_is_available"); if (mlx_cuda_is_available_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_cuda_is_available\n"); return -1; } - mlx_device_new_ptr = dlsym(handle, "mlx_device_new"); + mlx_device_new_ptr = GET_SYM(handle, "mlx_device_new"); if (mlx_device_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_new\n"); return -1; } - mlx_device_new_type_ptr = dlsym(handle, "mlx_device_new_type"); + mlx_device_new_type_ptr = GET_SYM(handle, "mlx_device_new_type"); if (mlx_device_new_type_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_new_type\n"); return -1; } - mlx_device_free_ptr = dlsym(handle, "mlx_device_free"); + mlx_device_free_ptr = GET_SYM(handle, "mlx_device_free"); if (mlx_device_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_free\n"); return -1; } - mlx_device_set_ptr = dlsym(handle, "mlx_device_set"); + mlx_device_set_ptr = GET_SYM(handle, "mlx_device_set"); if (mlx_device_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_set\n"); return -1; } - mlx_device_tostring_ptr = dlsym(handle, "mlx_device_tostring"); + mlx_device_tostring_ptr = GET_SYM(handle, "mlx_device_tostring"); if (mlx_device_tostring_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_tostring\n"); return -1; } - mlx_device_equal_ptr = dlsym(handle, "mlx_device_equal"); + mlx_device_equal_ptr = GET_SYM(handle, "mlx_device_equal"); if (mlx_device_equal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_equal\n"); return -1; } - mlx_device_get_index_ptr = dlsym(handle, "mlx_device_get_index"); + mlx_device_get_index_ptr = GET_SYM(handle, "mlx_device_get_index"); if (mlx_device_get_index_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_get_index\n"); return -1; } - mlx_device_get_type_ptr = dlsym(handle, "mlx_device_get_type"); + mlx_device_get_type_ptr = GET_SYM(handle, "mlx_device_get_type"); if (mlx_device_get_type_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_get_type\n"); return -1; } - mlx_get_default_device_ptr = dlsym(handle, "mlx_get_default_device"); + mlx_get_default_device_ptr = GET_SYM(handle, "mlx_get_default_device"); if (mlx_get_default_device_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_get_default_device\n"); return -1; } - mlx_set_default_device_ptr = dlsym(handle, "mlx_set_default_device"); + mlx_set_default_device_ptr = GET_SYM(handle, "mlx_set_default_device"); if (mlx_set_default_device_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_default_device\n"); return -1; } - mlx_device_is_available_ptr = dlsym(handle, "mlx_device_is_available"); + mlx_device_is_available_ptr = GET_SYM(handle, "mlx_device_is_available"); if (mlx_device_is_available_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_is_available\n"); return -1; } - mlx_device_count_ptr = dlsym(handle, "mlx_device_count"); + mlx_device_count_ptr = GET_SYM(handle, "mlx_device_count"); if (mlx_device_count_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_count\n"); return -1; } - mlx_device_info_new_ptr = dlsym(handle, "mlx_device_info_new"); + mlx_device_info_new_ptr = GET_SYM(handle, "mlx_device_info_new"); if (mlx_device_info_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_new\n"); return -1; } - mlx_device_info_get_ptr = dlsym(handle, "mlx_device_info_get"); + mlx_device_info_get_ptr = GET_SYM(handle, "mlx_device_info_get"); if (mlx_device_info_get_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get\n"); return -1; } - mlx_device_info_free_ptr = dlsym(handle, "mlx_device_info_free"); + mlx_device_info_free_ptr = GET_SYM(handle, "mlx_device_info_free"); if (mlx_device_info_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_free\n"); return -1; } - mlx_device_info_has_key_ptr = dlsym(handle, "mlx_device_info_has_key"); + mlx_device_info_has_key_ptr = GET_SYM(handle, "mlx_device_info_has_key"); if (mlx_device_info_has_key_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_has_key\n"); return -1; } - mlx_device_info_is_string_ptr = dlsym(handle, "mlx_device_info_is_string"); + mlx_device_info_is_string_ptr = GET_SYM(handle, "mlx_device_info_is_string"); if (mlx_device_info_is_string_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_is_string\n"); return -1; } - mlx_device_info_get_string_ptr = dlsym(handle, "mlx_device_info_get_string"); + mlx_device_info_get_string_ptr = GET_SYM(handle, "mlx_device_info_get_string"); if (mlx_device_info_get_string_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_string\n"); return -1; } - mlx_device_info_get_size_ptr = dlsym(handle, "mlx_device_info_get_size"); + mlx_device_info_get_size_ptr = GET_SYM(handle, "mlx_device_info_get_size"); if (mlx_device_info_get_size_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_size\n"); return -1; } - mlx_device_info_get_keys_ptr = dlsym(handle, "mlx_device_info_get_keys"); + mlx_device_info_get_keys_ptr = GET_SYM(handle, "mlx_device_info_get_keys"); if (mlx_device_info_get_keys_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_keys\n"); return -1; } - mlx_distributed_all_gather_ptr = dlsym(handle, "mlx_distributed_all_gather"); + mlx_distributed_all_gather_ptr = GET_SYM(handle, "mlx_distributed_all_gather"); if (mlx_distributed_all_gather_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_gather\n"); return -1; } - mlx_distributed_all_max_ptr = dlsym(handle, "mlx_distributed_all_max"); + mlx_distributed_all_max_ptr = GET_SYM(handle, "mlx_distributed_all_max"); if (mlx_distributed_all_max_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_max\n"); return -1; } - mlx_distributed_all_min_ptr = dlsym(handle, "mlx_distributed_all_min"); + mlx_distributed_all_min_ptr = GET_SYM(handle, "mlx_distributed_all_min"); if (mlx_distributed_all_min_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_min\n"); return -1; } - mlx_distributed_all_sum_ptr = dlsym(handle, "mlx_distributed_all_sum"); + mlx_distributed_all_sum_ptr = GET_SYM(handle, "mlx_distributed_all_sum"); if (mlx_distributed_all_sum_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_sum\n"); return -1; } - mlx_distributed_recv_ptr = dlsym(handle, "mlx_distributed_recv"); + mlx_distributed_recv_ptr = GET_SYM(handle, "mlx_distributed_recv"); if (mlx_distributed_recv_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_recv\n"); return -1; } - mlx_distributed_recv_like_ptr = dlsym(handle, "mlx_distributed_recv_like"); + mlx_distributed_recv_like_ptr = GET_SYM(handle, "mlx_distributed_recv_like"); if (mlx_distributed_recv_like_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_recv_like\n"); return -1; } - mlx_distributed_send_ptr = dlsym(handle, "mlx_distributed_send"); + mlx_distributed_send_ptr = GET_SYM(handle, "mlx_distributed_send"); if (mlx_distributed_send_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_send\n"); return -1; } - mlx_distributed_sum_scatter_ptr = dlsym(handle, "mlx_distributed_sum_scatter"); + mlx_distributed_sum_scatter_ptr = GET_SYM(handle, "mlx_distributed_sum_scatter"); if (mlx_distributed_sum_scatter_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_sum_scatter\n"); return -1; } - mlx_distributed_group_rank_ptr = dlsym(handle, "mlx_distributed_group_rank"); + mlx_distributed_group_rank_ptr = GET_SYM(handle, "mlx_distributed_group_rank"); if (mlx_distributed_group_rank_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_group_rank\n"); return -1; } - mlx_distributed_group_size_ptr = dlsym(handle, "mlx_distributed_group_size"); + mlx_distributed_group_size_ptr = GET_SYM(handle, "mlx_distributed_group_size"); if (mlx_distributed_group_size_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_group_size\n"); return -1; } - mlx_distributed_group_split_ptr = dlsym(handle, "mlx_distributed_group_split"); + mlx_distributed_group_split_ptr = GET_SYM(handle, "mlx_distributed_group_split"); if (mlx_distributed_group_split_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_group_split\n"); return -1; } - mlx_distributed_is_available_ptr = dlsym(handle, "mlx_distributed_is_available"); + mlx_distributed_is_available_ptr = GET_SYM(handle, "mlx_distributed_is_available"); if (mlx_distributed_is_available_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_is_available\n"); return -1; } - mlx_distributed_init_ptr = dlsym(handle, "mlx_distributed_init"); + mlx_distributed_init_ptr = GET_SYM(handle, "mlx_distributed_init"); if (mlx_distributed_init_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_init\n"); return -1; } - mlx_set_error_handler_ptr = dlsym(handle, "mlx_set_error_handler"); + mlx_set_error_handler_ptr = GET_SYM(handle, "mlx_set_error_handler"); if (mlx_set_error_handler_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_error_handler\n"); return -1; } - _mlx_error_ptr = dlsym(handle, "_mlx_error"); + _mlx_error_ptr = GET_SYM(handle, "_mlx_error"); if (_mlx_error_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: _mlx_error\n"); return -1; } - mlx_export_function_ptr = dlsym(handle, "mlx_export_function"); + mlx_export_function_ptr = GET_SYM(handle, "mlx_export_function"); if (mlx_export_function_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_export_function\n"); return -1; } - mlx_export_function_kwargs_ptr = dlsym(handle, "mlx_export_function_kwargs"); + mlx_export_function_kwargs_ptr = GET_SYM(handle, "mlx_export_function_kwargs"); if (mlx_export_function_kwargs_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_export_function_kwargs\n"); return -1; } - mlx_function_exporter_new_ptr = dlsym(handle, "mlx_function_exporter_new"); + mlx_function_exporter_new_ptr = GET_SYM(handle, "mlx_function_exporter_new"); if (mlx_function_exporter_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_new\n"); return -1; } - mlx_function_exporter_free_ptr = dlsym(handle, "mlx_function_exporter_free"); + mlx_function_exporter_free_ptr = GET_SYM(handle, "mlx_function_exporter_free"); if (mlx_function_exporter_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_free\n"); return -1; } - mlx_function_exporter_apply_ptr = dlsym(handle, "mlx_function_exporter_apply"); + mlx_function_exporter_apply_ptr = GET_SYM(handle, "mlx_function_exporter_apply"); if (mlx_function_exporter_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_apply\n"); return -1; } - mlx_function_exporter_apply_kwargs_ptr = dlsym(handle, "mlx_function_exporter_apply_kwargs"); + mlx_function_exporter_apply_kwargs_ptr = GET_SYM(handle, "mlx_function_exporter_apply_kwargs"); if (mlx_function_exporter_apply_kwargs_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_apply_kwargs\n"); return -1; } - mlx_imported_function_new_ptr = dlsym(handle, "mlx_imported_function_new"); + mlx_imported_function_new_ptr = GET_SYM(handle, "mlx_imported_function_new"); if (mlx_imported_function_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_new\n"); return -1; } - mlx_imported_function_free_ptr = dlsym(handle, "mlx_imported_function_free"); + mlx_imported_function_free_ptr = GET_SYM(handle, "mlx_imported_function_free"); if (mlx_imported_function_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_free\n"); return -1; } - mlx_imported_function_apply_ptr = dlsym(handle, "mlx_imported_function_apply"); + mlx_imported_function_apply_ptr = GET_SYM(handle, "mlx_imported_function_apply"); if (mlx_imported_function_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_apply\n"); return -1; } - mlx_imported_function_apply_kwargs_ptr = dlsym(handle, "mlx_imported_function_apply_kwargs"); + mlx_imported_function_apply_kwargs_ptr = GET_SYM(handle, "mlx_imported_function_apply_kwargs"); if (mlx_imported_function_apply_kwargs_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_apply_kwargs\n"); return -1; } - mlx_fast_cuda_kernel_config_new_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_new"); + mlx_fast_cuda_kernel_config_new_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_new"); if (mlx_fast_cuda_kernel_config_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_new\n"); return -1; } - mlx_fast_cuda_kernel_config_free_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_free"); + mlx_fast_cuda_kernel_config_free_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_free"); if (mlx_fast_cuda_kernel_config_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_free\n"); return -1; } - mlx_fast_cuda_kernel_config_add_output_arg_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_add_output_arg"); + mlx_fast_cuda_kernel_config_add_output_arg_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_add_output_arg"); if (mlx_fast_cuda_kernel_config_add_output_arg_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_output_arg\n"); return -1; } - mlx_fast_cuda_kernel_config_set_grid_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_set_grid"); + mlx_fast_cuda_kernel_config_set_grid_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_set_grid"); if (mlx_fast_cuda_kernel_config_set_grid_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_grid\n"); return -1; } - mlx_fast_cuda_kernel_config_set_thread_group_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_set_thread_group"); + mlx_fast_cuda_kernel_config_set_thread_group_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_set_thread_group"); if (mlx_fast_cuda_kernel_config_set_thread_group_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_thread_group\n"); return -1; } - mlx_fast_cuda_kernel_config_set_init_value_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_set_init_value"); + mlx_fast_cuda_kernel_config_set_init_value_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_set_init_value"); if (mlx_fast_cuda_kernel_config_set_init_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_init_value\n"); return -1; } - mlx_fast_cuda_kernel_config_set_verbose_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_set_verbose"); + mlx_fast_cuda_kernel_config_set_verbose_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_set_verbose"); if (mlx_fast_cuda_kernel_config_set_verbose_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_verbose\n"); return -1; } - mlx_fast_cuda_kernel_config_add_template_arg_dtype_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_add_template_arg_dtype"); + mlx_fast_cuda_kernel_config_add_template_arg_dtype_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_add_template_arg_dtype"); if (mlx_fast_cuda_kernel_config_add_template_arg_dtype_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_template_arg_dtype\n"); return -1; } - mlx_fast_cuda_kernel_config_add_template_arg_int_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_add_template_arg_int"); + mlx_fast_cuda_kernel_config_add_template_arg_int_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_add_template_arg_int"); if (mlx_fast_cuda_kernel_config_add_template_arg_int_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_template_arg_int\n"); return -1; } - mlx_fast_cuda_kernel_config_add_template_arg_bool_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_add_template_arg_bool"); + mlx_fast_cuda_kernel_config_add_template_arg_bool_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_add_template_arg_bool"); if (mlx_fast_cuda_kernel_config_add_template_arg_bool_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_template_arg_bool\n"); return -1; } - mlx_fast_cuda_kernel_new_ptr = dlsym(handle, "mlx_fast_cuda_kernel_new"); + mlx_fast_cuda_kernel_new_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_new"); if (mlx_fast_cuda_kernel_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_new\n"); return -1; } - mlx_fast_cuda_kernel_free_ptr = dlsym(handle, "mlx_fast_cuda_kernel_free"); + mlx_fast_cuda_kernel_free_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_free"); if (mlx_fast_cuda_kernel_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_free\n"); return -1; } - mlx_fast_cuda_kernel_apply_ptr = dlsym(handle, "mlx_fast_cuda_kernel_apply"); + mlx_fast_cuda_kernel_apply_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_apply"); if (mlx_fast_cuda_kernel_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_apply\n"); return -1; } - mlx_fast_layer_norm_ptr = dlsym(handle, "mlx_fast_layer_norm"); + mlx_fast_layer_norm_ptr = GET_SYM(handle, "mlx_fast_layer_norm"); if (mlx_fast_layer_norm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_layer_norm\n"); return -1; } - mlx_fast_metal_kernel_config_new_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_new"); + mlx_fast_metal_kernel_config_new_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_new"); if (mlx_fast_metal_kernel_config_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_new\n"); return -1; } - mlx_fast_metal_kernel_config_free_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_free"); + mlx_fast_metal_kernel_config_free_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_free"); if (mlx_fast_metal_kernel_config_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_free\n"); return -1; } - mlx_fast_metal_kernel_config_add_output_arg_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_add_output_arg"); + mlx_fast_metal_kernel_config_add_output_arg_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_add_output_arg"); if (mlx_fast_metal_kernel_config_add_output_arg_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_output_arg\n"); return -1; } - mlx_fast_metal_kernel_config_set_grid_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_set_grid"); + mlx_fast_metal_kernel_config_set_grid_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_set_grid"); if (mlx_fast_metal_kernel_config_set_grid_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_grid\n"); return -1; } - mlx_fast_metal_kernel_config_set_thread_group_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_set_thread_group"); + mlx_fast_metal_kernel_config_set_thread_group_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_set_thread_group"); if (mlx_fast_metal_kernel_config_set_thread_group_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_thread_group\n"); return -1; } - mlx_fast_metal_kernel_config_set_init_value_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_set_init_value"); + mlx_fast_metal_kernel_config_set_init_value_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_set_init_value"); if (mlx_fast_metal_kernel_config_set_init_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_init_value\n"); return -1; } - mlx_fast_metal_kernel_config_set_verbose_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_set_verbose"); + mlx_fast_metal_kernel_config_set_verbose_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_set_verbose"); if (mlx_fast_metal_kernel_config_set_verbose_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_verbose\n"); return -1; } - mlx_fast_metal_kernel_config_add_template_arg_dtype_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_add_template_arg_dtype"); + mlx_fast_metal_kernel_config_add_template_arg_dtype_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_add_template_arg_dtype"); if (mlx_fast_metal_kernel_config_add_template_arg_dtype_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_template_arg_dtype\n"); return -1; } - mlx_fast_metal_kernel_config_add_template_arg_int_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_add_template_arg_int"); + mlx_fast_metal_kernel_config_add_template_arg_int_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_add_template_arg_int"); if (mlx_fast_metal_kernel_config_add_template_arg_int_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_template_arg_int\n"); return -1; } - mlx_fast_metal_kernel_config_add_template_arg_bool_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_add_template_arg_bool"); + mlx_fast_metal_kernel_config_add_template_arg_bool_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_add_template_arg_bool"); if (mlx_fast_metal_kernel_config_add_template_arg_bool_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_template_arg_bool\n"); return -1; } - mlx_fast_metal_kernel_new_ptr = dlsym(handle, "mlx_fast_metal_kernel_new"); + mlx_fast_metal_kernel_new_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_new"); if (mlx_fast_metal_kernel_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_new\n"); return -1; } - mlx_fast_metal_kernel_free_ptr = dlsym(handle, "mlx_fast_metal_kernel_free"); + mlx_fast_metal_kernel_free_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_free"); if (mlx_fast_metal_kernel_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_free\n"); return -1; } - mlx_fast_metal_kernel_apply_ptr = dlsym(handle, "mlx_fast_metal_kernel_apply"); + mlx_fast_metal_kernel_apply_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_apply"); if (mlx_fast_metal_kernel_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_apply\n"); return -1; } - mlx_fast_rms_norm_ptr = dlsym(handle, "mlx_fast_rms_norm"); + mlx_fast_rms_norm_ptr = GET_SYM(handle, "mlx_fast_rms_norm"); if (mlx_fast_rms_norm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_rms_norm\n"); return -1; } - mlx_fast_rope_ptr = dlsym(handle, "mlx_fast_rope"); + mlx_fast_rope_ptr = GET_SYM(handle, "mlx_fast_rope"); if (mlx_fast_rope_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_rope\n"); return -1; } - mlx_fast_rope_dynamic_ptr = dlsym(handle, "mlx_fast_rope_dynamic"); + mlx_fast_rope_dynamic_ptr = GET_SYM(handle, "mlx_fast_rope_dynamic"); if (mlx_fast_rope_dynamic_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_rope_dynamic\n"); return -1; } - mlx_fast_scaled_dot_product_attention_ptr = dlsym(handle, "mlx_fast_scaled_dot_product_attention"); + mlx_fast_scaled_dot_product_attention_ptr = GET_SYM(handle, "mlx_fast_scaled_dot_product_attention"); if (mlx_fast_scaled_dot_product_attention_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_scaled_dot_product_attention\n"); return -1; } - mlx_fft_fft_ptr = dlsym(handle, "mlx_fft_fft"); + mlx_fft_fft_ptr = GET_SYM(handle, "mlx_fft_fft"); if (mlx_fft_fft_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fft\n"); return -1; } - mlx_fft_fft2_ptr = dlsym(handle, "mlx_fft_fft2"); + mlx_fft_fft2_ptr = GET_SYM(handle, "mlx_fft_fft2"); if (mlx_fft_fft2_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fft2\n"); return -1; } - mlx_fft_fftn_ptr = dlsym(handle, "mlx_fft_fftn"); + mlx_fft_fftn_ptr = GET_SYM(handle, "mlx_fft_fftn"); if (mlx_fft_fftn_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fftn\n"); return -1; } - mlx_fft_fftshift_ptr = dlsym(handle, "mlx_fft_fftshift"); + mlx_fft_fftshift_ptr = GET_SYM(handle, "mlx_fft_fftshift"); if (mlx_fft_fftshift_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fftshift\n"); return -1; } - mlx_fft_ifft_ptr = dlsym(handle, "mlx_fft_ifft"); + mlx_fft_ifft_ptr = GET_SYM(handle, "mlx_fft_ifft"); if (mlx_fft_ifft_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifft\n"); return -1; } - mlx_fft_ifft2_ptr = dlsym(handle, "mlx_fft_ifft2"); + mlx_fft_ifft2_ptr = GET_SYM(handle, "mlx_fft_ifft2"); if (mlx_fft_ifft2_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifft2\n"); return -1; } - mlx_fft_ifftn_ptr = dlsym(handle, "mlx_fft_ifftn"); + mlx_fft_ifftn_ptr = GET_SYM(handle, "mlx_fft_ifftn"); if (mlx_fft_ifftn_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifftn\n"); return -1; } - mlx_fft_ifftshift_ptr = dlsym(handle, "mlx_fft_ifftshift"); + mlx_fft_ifftshift_ptr = GET_SYM(handle, "mlx_fft_ifftshift"); if (mlx_fft_ifftshift_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifftshift\n"); return -1; } - mlx_fft_irfft_ptr = dlsym(handle, "mlx_fft_irfft"); + mlx_fft_irfft_ptr = GET_SYM(handle, "mlx_fft_irfft"); if (mlx_fft_irfft_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_irfft\n"); return -1; } - mlx_fft_irfft2_ptr = dlsym(handle, "mlx_fft_irfft2"); + mlx_fft_irfft2_ptr = GET_SYM(handle, "mlx_fft_irfft2"); if (mlx_fft_irfft2_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_irfft2\n"); return -1; } - mlx_fft_irfftn_ptr = dlsym(handle, "mlx_fft_irfftn"); + mlx_fft_irfftn_ptr = GET_SYM(handle, "mlx_fft_irfftn"); if (mlx_fft_irfftn_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_irfftn\n"); return -1; } - mlx_fft_rfft_ptr = dlsym(handle, "mlx_fft_rfft"); + mlx_fft_rfft_ptr = GET_SYM(handle, "mlx_fft_rfft"); if (mlx_fft_rfft_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_rfft\n"); return -1; } - mlx_fft_rfft2_ptr = dlsym(handle, "mlx_fft_rfft2"); + mlx_fft_rfft2_ptr = GET_SYM(handle, "mlx_fft_rfft2"); if (mlx_fft_rfft2_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_rfft2\n"); return -1; } - mlx_fft_rfftn_ptr = dlsym(handle, "mlx_fft_rfftn"); + mlx_fft_rfftn_ptr = GET_SYM(handle, "mlx_fft_rfftn"); if (mlx_fft_rfftn_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_rfftn\n"); return -1; } - mlx_load_reader_ptr = dlsym(handle, "mlx_load_reader"); + mlx_load_reader_ptr = GET_SYM(handle, "mlx_load_reader"); if (mlx_load_reader_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_load_reader\n"); return -1; } - mlx_load_ptr = dlsym(handle, "mlx_load"); + mlx_load_ptr = GET_SYM(handle, "mlx_load"); if (mlx_load_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_load\n"); return -1; } - mlx_load_safetensors_reader_ptr = dlsym(handle, "mlx_load_safetensors_reader"); + mlx_load_safetensors_reader_ptr = GET_SYM(handle, "mlx_load_safetensors_reader"); if (mlx_load_safetensors_reader_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_load_safetensors_reader\n"); return -1; } - mlx_load_safetensors_ptr = dlsym(handle, "mlx_load_safetensors"); + mlx_load_safetensors_ptr = GET_SYM(handle, "mlx_load_safetensors"); if (mlx_load_safetensors_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_load_safetensors\n"); return -1; } - mlx_save_writer_ptr = dlsym(handle, "mlx_save_writer"); + mlx_save_writer_ptr = GET_SYM(handle, "mlx_save_writer"); if (mlx_save_writer_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_save_writer\n"); return -1; } - mlx_save_ptr = dlsym(handle, "mlx_save"); + mlx_save_ptr = GET_SYM(handle, "mlx_save"); if (mlx_save_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_save\n"); return -1; } - mlx_save_safetensors_writer_ptr = dlsym(handle, "mlx_save_safetensors_writer"); + mlx_save_safetensors_writer_ptr = GET_SYM(handle, "mlx_save_safetensors_writer"); if (mlx_save_safetensors_writer_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_save_safetensors_writer\n"); return -1; } - mlx_save_safetensors_ptr = dlsym(handle, "mlx_save_safetensors"); + mlx_save_safetensors_ptr = GET_SYM(handle, "mlx_save_safetensors"); if (mlx_save_safetensors_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_save_safetensors\n"); return -1; } - mlx_io_reader_new_ptr = dlsym(handle, "mlx_io_reader_new"); + mlx_io_reader_new_ptr = GET_SYM(handle, "mlx_io_reader_new"); if (mlx_io_reader_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_new\n"); return -1; } - mlx_io_reader_descriptor_ptr = dlsym(handle, "mlx_io_reader_descriptor"); + mlx_io_reader_descriptor_ptr = GET_SYM(handle, "mlx_io_reader_descriptor"); if (mlx_io_reader_descriptor_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_descriptor\n"); return -1; } - mlx_io_reader_tostring_ptr = dlsym(handle, "mlx_io_reader_tostring"); + mlx_io_reader_tostring_ptr = GET_SYM(handle, "mlx_io_reader_tostring"); if (mlx_io_reader_tostring_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_tostring\n"); return -1; } - mlx_io_reader_free_ptr = dlsym(handle, "mlx_io_reader_free"); + mlx_io_reader_free_ptr = GET_SYM(handle, "mlx_io_reader_free"); if (mlx_io_reader_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_free\n"); return -1; } - mlx_io_writer_new_ptr = dlsym(handle, "mlx_io_writer_new"); + mlx_io_writer_new_ptr = GET_SYM(handle, "mlx_io_writer_new"); if (mlx_io_writer_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_new\n"); return -1; } - mlx_io_writer_descriptor_ptr = dlsym(handle, "mlx_io_writer_descriptor"); + mlx_io_writer_descriptor_ptr = GET_SYM(handle, "mlx_io_writer_descriptor"); if (mlx_io_writer_descriptor_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_descriptor\n"); return -1; } - mlx_io_writer_tostring_ptr = dlsym(handle, "mlx_io_writer_tostring"); + mlx_io_writer_tostring_ptr = GET_SYM(handle, "mlx_io_writer_tostring"); if (mlx_io_writer_tostring_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_tostring\n"); return -1; } - mlx_io_writer_free_ptr = dlsym(handle, "mlx_io_writer_free"); + mlx_io_writer_free_ptr = GET_SYM(handle, "mlx_io_writer_free"); if (mlx_io_writer_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_free\n"); return -1; } - mlx_linalg_cholesky_ptr = dlsym(handle, "mlx_linalg_cholesky"); + mlx_linalg_cholesky_ptr = GET_SYM(handle, "mlx_linalg_cholesky"); if (mlx_linalg_cholesky_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_cholesky\n"); return -1; } - mlx_linalg_cholesky_inv_ptr = dlsym(handle, "mlx_linalg_cholesky_inv"); + mlx_linalg_cholesky_inv_ptr = GET_SYM(handle, "mlx_linalg_cholesky_inv"); if (mlx_linalg_cholesky_inv_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_cholesky_inv\n"); return -1; } - mlx_linalg_cross_ptr = dlsym(handle, "mlx_linalg_cross"); + mlx_linalg_cross_ptr = GET_SYM(handle, "mlx_linalg_cross"); if (mlx_linalg_cross_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_cross\n"); return -1; } - mlx_linalg_eig_ptr = dlsym(handle, "mlx_linalg_eig"); + mlx_linalg_eig_ptr = GET_SYM(handle, "mlx_linalg_eig"); if (mlx_linalg_eig_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eig\n"); return -1; } - mlx_linalg_eigh_ptr = dlsym(handle, "mlx_linalg_eigh"); + mlx_linalg_eigh_ptr = GET_SYM(handle, "mlx_linalg_eigh"); if (mlx_linalg_eigh_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eigh\n"); return -1; } - mlx_linalg_eigvals_ptr = dlsym(handle, "mlx_linalg_eigvals"); + mlx_linalg_eigvals_ptr = GET_SYM(handle, "mlx_linalg_eigvals"); if (mlx_linalg_eigvals_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eigvals\n"); return -1; } - mlx_linalg_eigvalsh_ptr = dlsym(handle, "mlx_linalg_eigvalsh"); + mlx_linalg_eigvalsh_ptr = GET_SYM(handle, "mlx_linalg_eigvalsh"); if (mlx_linalg_eigvalsh_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eigvalsh\n"); return -1; } - mlx_linalg_inv_ptr = dlsym(handle, "mlx_linalg_inv"); + mlx_linalg_inv_ptr = GET_SYM(handle, "mlx_linalg_inv"); if (mlx_linalg_inv_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_inv\n"); return -1; } - mlx_linalg_lu_ptr = dlsym(handle, "mlx_linalg_lu"); + mlx_linalg_lu_ptr = GET_SYM(handle, "mlx_linalg_lu"); if (mlx_linalg_lu_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_lu\n"); return -1; } - mlx_linalg_lu_factor_ptr = dlsym(handle, "mlx_linalg_lu_factor"); + mlx_linalg_lu_factor_ptr = GET_SYM(handle, "mlx_linalg_lu_factor"); if (mlx_linalg_lu_factor_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_lu_factor\n"); return -1; } - mlx_linalg_norm_ptr = dlsym(handle, "mlx_linalg_norm"); + mlx_linalg_norm_ptr = GET_SYM(handle, "mlx_linalg_norm"); if (mlx_linalg_norm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_norm\n"); return -1; } - mlx_linalg_norm_matrix_ptr = dlsym(handle, "mlx_linalg_norm_matrix"); + mlx_linalg_norm_matrix_ptr = GET_SYM(handle, "mlx_linalg_norm_matrix"); if (mlx_linalg_norm_matrix_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_norm_matrix\n"); return -1; } - mlx_linalg_norm_l2_ptr = dlsym(handle, "mlx_linalg_norm_l2"); + mlx_linalg_norm_l2_ptr = GET_SYM(handle, "mlx_linalg_norm_l2"); if (mlx_linalg_norm_l2_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_norm_l2\n"); return -1; } - mlx_linalg_pinv_ptr = dlsym(handle, "mlx_linalg_pinv"); + mlx_linalg_pinv_ptr = GET_SYM(handle, "mlx_linalg_pinv"); if (mlx_linalg_pinv_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_pinv\n"); return -1; } - mlx_linalg_qr_ptr = dlsym(handle, "mlx_linalg_qr"); + mlx_linalg_qr_ptr = GET_SYM(handle, "mlx_linalg_qr"); if (mlx_linalg_qr_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_qr\n"); return -1; } - mlx_linalg_solve_ptr = dlsym(handle, "mlx_linalg_solve"); + mlx_linalg_solve_ptr = GET_SYM(handle, "mlx_linalg_solve"); if (mlx_linalg_solve_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_solve\n"); return -1; } - mlx_linalg_solve_triangular_ptr = dlsym(handle, "mlx_linalg_solve_triangular"); + mlx_linalg_solve_triangular_ptr = GET_SYM(handle, "mlx_linalg_solve_triangular"); if (mlx_linalg_solve_triangular_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_solve_triangular\n"); return -1; } - mlx_linalg_svd_ptr = dlsym(handle, "mlx_linalg_svd"); + mlx_linalg_svd_ptr = GET_SYM(handle, "mlx_linalg_svd"); if (mlx_linalg_svd_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_svd\n"); return -1; } - mlx_linalg_tri_inv_ptr = dlsym(handle, "mlx_linalg_tri_inv"); + mlx_linalg_tri_inv_ptr = GET_SYM(handle, "mlx_linalg_tri_inv"); if (mlx_linalg_tri_inv_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_tri_inv\n"); return -1; } - mlx_map_string_to_array_new_ptr = dlsym(handle, "mlx_map_string_to_array_new"); + mlx_map_string_to_array_new_ptr = GET_SYM(handle, "mlx_map_string_to_array_new"); if (mlx_map_string_to_array_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_new\n"); return -1; } - mlx_map_string_to_array_set_ptr = dlsym(handle, "mlx_map_string_to_array_set"); + mlx_map_string_to_array_set_ptr = GET_SYM(handle, "mlx_map_string_to_array_set"); if (mlx_map_string_to_array_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_set\n"); return -1; } - mlx_map_string_to_array_free_ptr = dlsym(handle, "mlx_map_string_to_array_free"); + mlx_map_string_to_array_free_ptr = GET_SYM(handle, "mlx_map_string_to_array_free"); if (mlx_map_string_to_array_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_free\n"); return -1; } - mlx_map_string_to_array_insert_ptr = dlsym(handle, "mlx_map_string_to_array_insert"); + mlx_map_string_to_array_insert_ptr = GET_SYM(handle, "mlx_map_string_to_array_insert"); if (mlx_map_string_to_array_insert_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_insert\n"); return -1; } - mlx_map_string_to_array_get_ptr = dlsym(handle, "mlx_map_string_to_array_get"); + mlx_map_string_to_array_get_ptr = GET_SYM(handle, "mlx_map_string_to_array_get"); if (mlx_map_string_to_array_get_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_get\n"); return -1; } - mlx_map_string_to_array_iterator_new_ptr = dlsym(handle, "mlx_map_string_to_array_iterator_new"); + mlx_map_string_to_array_iterator_new_ptr = GET_SYM(handle, "mlx_map_string_to_array_iterator_new"); if (mlx_map_string_to_array_iterator_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_iterator_new\n"); return -1; } - mlx_map_string_to_array_iterator_free_ptr = dlsym(handle, "mlx_map_string_to_array_iterator_free"); + mlx_map_string_to_array_iterator_free_ptr = GET_SYM(handle, "mlx_map_string_to_array_iterator_free"); if (mlx_map_string_to_array_iterator_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_iterator_free\n"); return -1; } - mlx_map_string_to_array_iterator_next_ptr = dlsym(handle, "mlx_map_string_to_array_iterator_next"); + mlx_map_string_to_array_iterator_next_ptr = GET_SYM(handle, "mlx_map_string_to_array_iterator_next"); if (mlx_map_string_to_array_iterator_next_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_iterator_next\n"); return -1; } - mlx_map_string_to_string_new_ptr = dlsym(handle, "mlx_map_string_to_string_new"); + mlx_map_string_to_string_new_ptr = GET_SYM(handle, "mlx_map_string_to_string_new"); if (mlx_map_string_to_string_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_new\n"); return -1; } - mlx_map_string_to_string_set_ptr = dlsym(handle, "mlx_map_string_to_string_set"); + mlx_map_string_to_string_set_ptr = GET_SYM(handle, "mlx_map_string_to_string_set"); if (mlx_map_string_to_string_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_set\n"); return -1; } - mlx_map_string_to_string_free_ptr = dlsym(handle, "mlx_map_string_to_string_free"); + mlx_map_string_to_string_free_ptr = GET_SYM(handle, "mlx_map_string_to_string_free"); if (mlx_map_string_to_string_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_free\n"); return -1; } - mlx_map_string_to_string_insert_ptr = dlsym(handle, "mlx_map_string_to_string_insert"); + mlx_map_string_to_string_insert_ptr = GET_SYM(handle, "mlx_map_string_to_string_insert"); if (mlx_map_string_to_string_insert_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_insert\n"); return -1; } - mlx_map_string_to_string_get_ptr = dlsym(handle, "mlx_map_string_to_string_get"); + mlx_map_string_to_string_get_ptr = GET_SYM(handle, "mlx_map_string_to_string_get"); if (mlx_map_string_to_string_get_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_get\n"); return -1; } - mlx_map_string_to_string_iterator_new_ptr = dlsym(handle, "mlx_map_string_to_string_iterator_new"); + mlx_map_string_to_string_iterator_new_ptr = GET_SYM(handle, "mlx_map_string_to_string_iterator_new"); if (mlx_map_string_to_string_iterator_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_iterator_new\n"); return -1; } - mlx_map_string_to_string_iterator_free_ptr = dlsym(handle, "mlx_map_string_to_string_iterator_free"); + mlx_map_string_to_string_iterator_free_ptr = GET_SYM(handle, "mlx_map_string_to_string_iterator_free"); if (mlx_map_string_to_string_iterator_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_iterator_free\n"); return -1; } - mlx_map_string_to_string_iterator_next_ptr = dlsym(handle, "mlx_map_string_to_string_iterator_next"); + mlx_map_string_to_string_iterator_next_ptr = GET_SYM(handle, "mlx_map_string_to_string_iterator_next"); if (mlx_map_string_to_string_iterator_next_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_iterator_next\n"); return -1; } - mlx_clear_cache_ptr = dlsym(handle, "mlx_clear_cache"); + mlx_clear_cache_ptr = GET_SYM(handle, "mlx_clear_cache"); if (mlx_clear_cache_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_clear_cache\n"); return -1; } - mlx_get_active_memory_ptr = dlsym(handle, "mlx_get_active_memory"); + mlx_get_active_memory_ptr = GET_SYM(handle, "mlx_get_active_memory"); if (mlx_get_active_memory_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_get_active_memory\n"); return -1; } - mlx_get_cache_memory_ptr = dlsym(handle, "mlx_get_cache_memory"); + mlx_get_cache_memory_ptr = GET_SYM(handle, "mlx_get_cache_memory"); if (mlx_get_cache_memory_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_get_cache_memory\n"); return -1; } - mlx_get_memory_limit_ptr = dlsym(handle, "mlx_get_memory_limit"); + mlx_get_memory_limit_ptr = GET_SYM(handle, "mlx_get_memory_limit"); if (mlx_get_memory_limit_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_get_memory_limit\n"); return -1; } - mlx_get_peak_memory_ptr = dlsym(handle, "mlx_get_peak_memory"); + mlx_get_peak_memory_ptr = GET_SYM(handle, "mlx_get_peak_memory"); if (mlx_get_peak_memory_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_get_peak_memory\n"); return -1; } - mlx_reset_peak_memory_ptr = dlsym(handle, "mlx_reset_peak_memory"); + mlx_reset_peak_memory_ptr = GET_SYM(handle, "mlx_reset_peak_memory"); if (mlx_reset_peak_memory_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_reset_peak_memory\n"); return -1; } - mlx_set_cache_limit_ptr = dlsym(handle, "mlx_set_cache_limit"); + mlx_set_cache_limit_ptr = GET_SYM(handle, "mlx_set_cache_limit"); if (mlx_set_cache_limit_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_cache_limit\n"); return -1; } - mlx_set_memory_limit_ptr = dlsym(handle, "mlx_set_memory_limit"); + mlx_set_memory_limit_ptr = GET_SYM(handle, "mlx_set_memory_limit"); if (mlx_set_memory_limit_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_memory_limit\n"); return -1; } - mlx_set_wired_limit_ptr = dlsym(handle, "mlx_set_wired_limit"); + mlx_set_wired_limit_ptr = GET_SYM(handle, "mlx_set_wired_limit"); if (mlx_set_wired_limit_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_wired_limit\n"); return -1; } - mlx_metal_is_available_ptr = dlsym(handle, "mlx_metal_is_available"); + mlx_metal_is_available_ptr = GET_SYM(handle, "mlx_metal_is_available"); if (mlx_metal_is_available_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_is_available\n"); return -1; } - mlx_metal_start_capture_ptr = dlsym(handle, "mlx_metal_start_capture"); + mlx_metal_start_capture_ptr = GET_SYM(handle, "mlx_metal_start_capture"); if (mlx_metal_start_capture_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_start_capture\n"); return -1; } - mlx_metal_stop_capture_ptr = dlsym(handle, "mlx_metal_stop_capture"); + mlx_metal_stop_capture_ptr = GET_SYM(handle, "mlx_metal_stop_capture"); if (mlx_metal_stop_capture_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_stop_capture\n"); return -1; } - mlx_abs_ptr = dlsym(handle, "mlx_abs"); + mlx_abs_ptr = GET_SYM(handle, "mlx_abs"); if (mlx_abs_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_abs\n"); return -1; } - mlx_add_ptr = dlsym(handle, "mlx_add"); + mlx_add_ptr = GET_SYM(handle, "mlx_add"); if (mlx_add_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_add\n"); return -1; } - mlx_addmm_ptr = dlsym(handle, "mlx_addmm"); + mlx_addmm_ptr = GET_SYM(handle, "mlx_addmm"); if (mlx_addmm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_addmm\n"); return -1; } - mlx_all_axes_ptr = dlsym(handle, "mlx_all_axes"); + mlx_all_axes_ptr = GET_SYM(handle, "mlx_all_axes"); if (mlx_all_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_all_axes\n"); return -1; } - mlx_all_axis_ptr = dlsym(handle, "mlx_all_axis"); + mlx_all_axis_ptr = GET_SYM(handle, "mlx_all_axis"); if (mlx_all_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_all_axis\n"); return -1; } - mlx_all_ptr = dlsym(handle, "mlx_all"); + mlx_all_ptr = GET_SYM(handle, "mlx_all"); if (mlx_all_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_all\n"); return -1; } - mlx_allclose_ptr = dlsym(handle, "mlx_allclose"); + mlx_allclose_ptr = GET_SYM(handle, "mlx_allclose"); if (mlx_allclose_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_allclose\n"); return -1; } - mlx_any_axes_ptr = dlsym(handle, "mlx_any_axes"); + mlx_any_axes_ptr = GET_SYM(handle, "mlx_any_axes"); if (mlx_any_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_any_axes\n"); return -1; } - mlx_any_axis_ptr = dlsym(handle, "mlx_any_axis"); + mlx_any_axis_ptr = GET_SYM(handle, "mlx_any_axis"); if (mlx_any_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_any_axis\n"); return -1; } - mlx_any_ptr = dlsym(handle, "mlx_any"); + mlx_any_ptr = GET_SYM(handle, "mlx_any"); if (mlx_any_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_any\n"); return -1; } - mlx_arange_ptr = dlsym(handle, "mlx_arange"); + mlx_arange_ptr = GET_SYM(handle, "mlx_arange"); if (mlx_arange_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_arange\n"); return -1; } - mlx_arccos_ptr = dlsym(handle, "mlx_arccos"); + mlx_arccos_ptr = GET_SYM(handle, "mlx_arccos"); if (mlx_arccos_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_arccos\n"); return -1; } - mlx_arccosh_ptr = dlsym(handle, "mlx_arccosh"); + mlx_arccosh_ptr = GET_SYM(handle, "mlx_arccosh"); if (mlx_arccosh_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_arccosh\n"); return -1; } - mlx_arcsin_ptr = dlsym(handle, "mlx_arcsin"); + mlx_arcsin_ptr = GET_SYM(handle, "mlx_arcsin"); if (mlx_arcsin_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_arcsin\n"); return -1; } - mlx_arcsinh_ptr = dlsym(handle, "mlx_arcsinh"); + mlx_arcsinh_ptr = GET_SYM(handle, "mlx_arcsinh"); if (mlx_arcsinh_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_arcsinh\n"); return -1; } - mlx_arctan_ptr = dlsym(handle, "mlx_arctan"); + mlx_arctan_ptr = GET_SYM(handle, "mlx_arctan"); if (mlx_arctan_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_arctan\n"); return -1; } - mlx_arctan2_ptr = dlsym(handle, "mlx_arctan2"); + mlx_arctan2_ptr = GET_SYM(handle, "mlx_arctan2"); if (mlx_arctan2_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_arctan2\n"); return -1; } - mlx_arctanh_ptr = dlsym(handle, "mlx_arctanh"); + mlx_arctanh_ptr = GET_SYM(handle, "mlx_arctanh"); if (mlx_arctanh_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_arctanh\n"); return -1; } - mlx_argmax_axis_ptr = dlsym(handle, "mlx_argmax_axis"); + mlx_argmax_axis_ptr = GET_SYM(handle, "mlx_argmax_axis"); if (mlx_argmax_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_argmax_axis\n"); return -1; } - mlx_argmax_ptr = dlsym(handle, "mlx_argmax"); + mlx_argmax_ptr = GET_SYM(handle, "mlx_argmax"); if (mlx_argmax_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_argmax\n"); return -1; } - mlx_argmin_axis_ptr = dlsym(handle, "mlx_argmin_axis"); + mlx_argmin_axis_ptr = GET_SYM(handle, "mlx_argmin_axis"); if (mlx_argmin_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_argmin_axis\n"); return -1; } - mlx_argmin_ptr = dlsym(handle, "mlx_argmin"); + mlx_argmin_ptr = GET_SYM(handle, "mlx_argmin"); if (mlx_argmin_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_argmin\n"); return -1; } - mlx_argpartition_axis_ptr = dlsym(handle, "mlx_argpartition_axis"); + mlx_argpartition_axis_ptr = GET_SYM(handle, "mlx_argpartition_axis"); if (mlx_argpartition_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_argpartition_axis\n"); return -1; } - mlx_argpartition_ptr = dlsym(handle, "mlx_argpartition"); + mlx_argpartition_ptr = GET_SYM(handle, "mlx_argpartition"); if (mlx_argpartition_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_argpartition\n"); return -1; } - mlx_argsort_axis_ptr = dlsym(handle, "mlx_argsort_axis"); + mlx_argsort_axis_ptr = GET_SYM(handle, "mlx_argsort_axis"); if (mlx_argsort_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_argsort_axis\n"); return -1; } - mlx_argsort_ptr = dlsym(handle, "mlx_argsort"); + mlx_argsort_ptr = GET_SYM(handle, "mlx_argsort"); if (mlx_argsort_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_argsort\n"); return -1; } - mlx_array_equal_ptr = dlsym(handle, "mlx_array_equal"); + mlx_array_equal_ptr = GET_SYM(handle, "mlx_array_equal"); if (mlx_array_equal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_equal\n"); return -1; } - mlx_as_strided_ptr = dlsym(handle, "mlx_as_strided"); + mlx_as_strided_ptr = GET_SYM(handle, "mlx_as_strided"); if (mlx_as_strided_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_as_strided\n"); return -1; } - mlx_astype_ptr = dlsym(handle, "mlx_astype"); + mlx_astype_ptr = GET_SYM(handle, "mlx_astype"); if (mlx_astype_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_astype\n"); return -1; } - mlx_atleast_1d_ptr = dlsym(handle, "mlx_atleast_1d"); + mlx_atleast_1d_ptr = GET_SYM(handle, "mlx_atleast_1d"); if (mlx_atleast_1d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_1d\n"); return -1; } - mlx_atleast_2d_ptr = dlsym(handle, "mlx_atleast_2d"); + mlx_atleast_2d_ptr = GET_SYM(handle, "mlx_atleast_2d"); if (mlx_atleast_2d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_2d\n"); return -1; } - mlx_atleast_3d_ptr = dlsym(handle, "mlx_atleast_3d"); + mlx_atleast_3d_ptr = GET_SYM(handle, "mlx_atleast_3d"); if (mlx_atleast_3d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_3d\n"); return -1; } - mlx_bitwise_and_ptr = dlsym(handle, "mlx_bitwise_and"); + mlx_bitwise_and_ptr = GET_SYM(handle, "mlx_bitwise_and"); if (mlx_bitwise_and_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_and\n"); return -1; } - mlx_bitwise_invert_ptr = dlsym(handle, "mlx_bitwise_invert"); + mlx_bitwise_invert_ptr = GET_SYM(handle, "mlx_bitwise_invert"); if (mlx_bitwise_invert_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_invert\n"); return -1; } - mlx_bitwise_or_ptr = dlsym(handle, "mlx_bitwise_or"); + mlx_bitwise_or_ptr = GET_SYM(handle, "mlx_bitwise_or"); if (mlx_bitwise_or_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_or\n"); return -1; } - mlx_bitwise_xor_ptr = dlsym(handle, "mlx_bitwise_xor"); + mlx_bitwise_xor_ptr = GET_SYM(handle, "mlx_bitwise_xor"); if (mlx_bitwise_xor_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_xor\n"); return -1; } - mlx_block_masked_mm_ptr = dlsym(handle, "mlx_block_masked_mm"); + mlx_block_masked_mm_ptr = GET_SYM(handle, "mlx_block_masked_mm"); if (mlx_block_masked_mm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_block_masked_mm\n"); return -1; } - mlx_broadcast_arrays_ptr = dlsym(handle, "mlx_broadcast_arrays"); + mlx_broadcast_arrays_ptr = GET_SYM(handle, "mlx_broadcast_arrays"); if (mlx_broadcast_arrays_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_broadcast_arrays\n"); return -1; } - mlx_broadcast_to_ptr = dlsym(handle, "mlx_broadcast_to"); + mlx_broadcast_to_ptr = GET_SYM(handle, "mlx_broadcast_to"); if (mlx_broadcast_to_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_broadcast_to\n"); return -1; } - mlx_ceil_ptr = dlsym(handle, "mlx_ceil"); + mlx_ceil_ptr = GET_SYM(handle, "mlx_ceil"); if (mlx_ceil_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_ceil\n"); return -1; } - mlx_clip_ptr = dlsym(handle, "mlx_clip"); + mlx_clip_ptr = GET_SYM(handle, "mlx_clip"); if (mlx_clip_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_clip\n"); return -1; } - mlx_concatenate_axis_ptr = dlsym(handle, "mlx_concatenate_axis"); + mlx_concatenate_axis_ptr = GET_SYM(handle, "mlx_concatenate_axis"); if (mlx_concatenate_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_concatenate_axis\n"); return -1; } - mlx_concatenate_ptr = dlsym(handle, "mlx_concatenate"); + mlx_concatenate_ptr = GET_SYM(handle, "mlx_concatenate"); if (mlx_concatenate_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_concatenate\n"); return -1; } - mlx_conjugate_ptr = dlsym(handle, "mlx_conjugate"); + mlx_conjugate_ptr = GET_SYM(handle, "mlx_conjugate"); if (mlx_conjugate_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_conjugate\n"); return -1; } - mlx_contiguous_ptr = dlsym(handle, "mlx_contiguous"); + mlx_contiguous_ptr = GET_SYM(handle, "mlx_contiguous"); if (mlx_contiguous_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_contiguous\n"); return -1; } - mlx_conv1d_ptr = dlsym(handle, "mlx_conv1d"); + mlx_conv1d_ptr = GET_SYM(handle, "mlx_conv1d"); if (mlx_conv1d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_conv1d\n"); return -1; } - mlx_conv2d_ptr = dlsym(handle, "mlx_conv2d"); + mlx_conv2d_ptr = GET_SYM(handle, "mlx_conv2d"); if (mlx_conv2d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_conv2d\n"); return -1; } - mlx_conv3d_ptr = dlsym(handle, "mlx_conv3d"); + mlx_conv3d_ptr = GET_SYM(handle, "mlx_conv3d"); if (mlx_conv3d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_conv3d\n"); return -1; } - mlx_conv_general_ptr = dlsym(handle, "mlx_conv_general"); + mlx_conv_general_ptr = GET_SYM(handle, "mlx_conv_general"); if (mlx_conv_general_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_general\n"); return -1; } - mlx_conv_transpose1d_ptr = dlsym(handle, "mlx_conv_transpose1d"); + mlx_conv_transpose1d_ptr = GET_SYM(handle, "mlx_conv_transpose1d"); if (mlx_conv_transpose1d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_transpose1d\n"); return -1; } - mlx_conv_transpose2d_ptr = dlsym(handle, "mlx_conv_transpose2d"); + mlx_conv_transpose2d_ptr = GET_SYM(handle, "mlx_conv_transpose2d"); if (mlx_conv_transpose2d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_transpose2d\n"); return -1; } - mlx_conv_transpose3d_ptr = dlsym(handle, "mlx_conv_transpose3d"); + mlx_conv_transpose3d_ptr = GET_SYM(handle, "mlx_conv_transpose3d"); if (mlx_conv_transpose3d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_transpose3d\n"); return -1; } - mlx_copy_ptr = dlsym(handle, "mlx_copy"); + mlx_copy_ptr = GET_SYM(handle, "mlx_copy"); if (mlx_copy_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_copy\n"); return -1; } - mlx_cos_ptr = dlsym(handle, "mlx_cos"); + mlx_cos_ptr = GET_SYM(handle, "mlx_cos"); if (mlx_cos_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_cos\n"); return -1; } - mlx_cosh_ptr = dlsym(handle, "mlx_cosh"); + mlx_cosh_ptr = GET_SYM(handle, "mlx_cosh"); if (mlx_cosh_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_cosh\n"); return -1; } - mlx_cummax_ptr = dlsym(handle, "mlx_cummax"); + mlx_cummax_ptr = GET_SYM(handle, "mlx_cummax"); if (mlx_cummax_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_cummax\n"); return -1; } - mlx_cummin_ptr = dlsym(handle, "mlx_cummin"); + mlx_cummin_ptr = GET_SYM(handle, "mlx_cummin"); if (mlx_cummin_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_cummin\n"); return -1; } - mlx_cumprod_ptr = dlsym(handle, "mlx_cumprod"); + mlx_cumprod_ptr = GET_SYM(handle, "mlx_cumprod"); if (mlx_cumprod_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_cumprod\n"); return -1; } - mlx_cumsum_ptr = dlsym(handle, "mlx_cumsum"); + mlx_cumsum_ptr = GET_SYM(handle, "mlx_cumsum"); if (mlx_cumsum_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_cumsum\n"); return -1; } - mlx_degrees_ptr = dlsym(handle, "mlx_degrees"); + mlx_degrees_ptr = GET_SYM(handle, "mlx_degrees"); if (mlx_degrees_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_degrees\n"); return -1; } - mlx_depends_ptr = dlsym(handle, "mlx_depends"); + mlx_depends_ptr = GET_SYM(handle, "mlx_depends"); if (mlx_depends_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_depends\n"); return -1; } - mlx_dequantize_ptr = dlsym(handle, "mlx_dequantize"); + mlx_dequantize_ptr = GET_SYM(handle, "mlx_dequantize"); if (mlx_dequantize_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_dequantize\n"); return -1; } - mlx_diag_ptr = dlsym(handle, "mlx_diag"); + mlx_diag_ptr = GET_SYM(handle, "mlx_diag"); if (mlx_diag_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_diag\n"); return -1; } - mlx_diagonal_ptr = dlsym(handle, "mlx_diagonal"); + mlx_diagonal_ptr = GET_SYM(handle, "mlx_diagonal"); if (mlx_diagonal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_diagonal\n"); return -1; } - mlx_divide_ptr = dlsym(handle, "mlx_divide"); + mlx_divide_ptr = GET_SYM(handle, "mlx_divide"); if (mlx_divide_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_divide\n"); return -1; } - mlx_divmod_ptr = dlsym(handle, "mlx_divmod"); + mlx_divmod_ptr = GET_SYM(handle, "mlx_divmod"); if (mlx_divmod_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_divmod\n"); return -1; } - mlx_einsum_ptr = dlsym(handle, "mlx_einsum"); + mlx_einsum_ptr = GET_SYM(handle, "mlx_einsum"); if (mlx_einsum_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_einsum\n"); return -1; } - mlx_equal_ptr = dlsym(handle, "mlx_equal"); + mlx_equal_ptr = GET_SYM(handle, "mlx_equal"); if (mlx_equal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_equal\n"); return -1; } - mlx_erf_ptr = dlsym(handle, "mlx_erf"); + mlx_erf_ptr = GET_SYM(handle, "mlx_erf"); if (mlx_erf_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_erf\n"); return -1; } - mlx_erfinv_ptr = dlsym(handle, "mlx_erfinv"); + mlx_erfinv_ptr = GET_SYM(handle, "mlx_erfinv"); if (mlx_erfinv_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_erfinv\n"); return -1; } - mlx_exp_ptr = dlsym(handle, "mlx_exp"); + mlx_exp_ptr = GET_SYM(handle, "mlx_exp"); if (mlx_exp_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_exp\n"); return -1; } - mlx_expand_dims_axes_ptr = dlsym(handle, "mlx_expand_dims_axes"); + mlx_expand_dims_axes_ptr = GET_SYM(handle, "mlx_expand_dims_axes"); if (mlx_expand_dims_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_expand_dims_axes\n"); return -1; } - mlx_expand_dims_ptr = dlsym(handle, "mlx_expand_dims"); + mlx_expand_dims_ptr = GET_SYM(handle, "mlx_expand_dims"); if (mlx_expand_dims_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_expand_dims\n"); return -1; } - mlx_expm1_ptr = dlsym(handle, "mlx_expm1"); + mlx_expm1_ptr = GET_SYM(handle, "mlx_expm1"); if (mlx_expm1_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_expm1\n"); return -1; } - mlx_eye_ptr = dlsym(handle, "mlx_eye"); + mlx_eye_ptr = GET_SYM(handle, "mlx_eye"); if (mlx_eye_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_eye\n"); return -1; } - mlx_flatten_ptr = dlsym(handle, "mlx_flatten"); + mlx_flatten_ptr = GET_SYM(handle, "mlx_flatten"); if (mlx_flatten_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_flatten\n"); return -1; } - mlx_floor_ptr = dlsym(handle, "mlx_floor"); + mlx_floor_ptr = GET_SYM(handle, "mlx_floor"); if (mlx_floor_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_floor\n"); return -1; } - mlx_floor_divide_ptr = dlsym(handle, "mlx_floor_divide"); + mlx_floor_divide_ptr = GET_SYM(handle, "mlx_floor_divide"); if (mlx_floor_divide_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_floor_divide\n"); return -1; } - mlx_from_fp8_ptr = dlsym(handle, "mlx_from_fp8"); + mlx_from_fp8_ptr = GET_SYM(handle, "mlx_from_fp8"); if (mlx_from_fp8_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_from_fp8\n"); return -1; } - mlx_full_ptr = dlsym(handle, "mlx_full"); + mlx_full_ptr = GET_SYM(handle, "mlx_full"); if (mlx_full_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_full\n"); return -1; } - mlx_full_like_ptr = dlsym(handle, "mlx_full_like"); + mlx_full_like_ptr = GET_SYM(handle, "mlx_full_like"); if (mlx_full_like_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_full_like\n"); return -1; } - mlx_gather_ptr = dlsym(handle, "mlx_gather"); + mlx_gather_ptr = GET_SYM(handle, "mlx_gather"); if (mlx_gather_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_gather\n"); return -1; } - mlx_gather_single_ptr = dlsym(handle, "mlx_gather_single"); + mlx_gather_single_ptr = GET_SYM(handle, "mlx_gather_single"); if (mlx_gather_single_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_gather_single\n"); return -1; } - mlx_gather_mm_ptr = dlsym(handle, "mlx_gather_mm"); + mlx_gather_mm_ptr = GET_SYM(handle, "mlx_gather_mm"); if (mlx_gather_mm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_gather_mm\n"); return -1; } - mlx_gather_qmm_ptr = dlsym(handle, "mlx_gather_qmm"); + mlx_gather_qmm_ptr = GET_SYM(handle, "mlx_gather_qmm"); if (mlx_gather_qmm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_gather_qmm\n"); return -1; } - mlx_greater_ptr = dlsym(handle, "mlx_greater"); + mlx_greater_ptr = GET_SYM(handle, "mlx_greater"); if (mlx_greater_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_greater\n"); return -1; } - mlx_greater_equal_ptr = dlsym(handle, "mlx_greater_equal"); + mlx_greater_equal_ptr = GET_SYM(handle, "mlx_greater_equal"); if (mlx_greater_equal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_greater_equal\n"); return -1; } - mlx_hadamard_transform_ptr = dlsym(handle, "mlx_hadamard_transform"); + mlx_hadamard_transform_ptr = GET_SYM(handle, "mlx_hadamard_transform"); if (mlx_hadamard_transform_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_hadamard_transform\n"); return -1; } - mlx_identity_ptr = dlsym(handle, "mlx_identity"); + mlx_identity_ptr = GET_SYM(handle, "mlx_identity"); if (mlx_identity_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_identity\n"); return -1; } - mlx_imag_ptr = dlsym(handle, "mlx_imag"); + mlx_imag_ptr = GET_SYM(handle, "mlx_imag"); if (mlx_imag_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_imag\n"); return -1; } - mlx_inner_ptr = dlsym(handle, "mlx_inner"); + mlx_inner_ptr = GET_SYM(handle, "mlx_inner"); if (mlx_inner_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_inner\n"); return -1; } - mlx_isclose_ptr = dlsym(handle, "mlx_isclose"); + mlx_isclose_ptr = GET_SYM(handle, "mlx_isclose"); if (mlx_isclose_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_isclose\n"); return -1; } - mlx_isfinite_ptr = dlsym(handle, "mlx_isfinite"); + mlx_isfinite_ptr = GET_SYM(handle, "mlx_isfinite"); if (mlx_isfinite_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_isfinite\n"); return -1; } - mlx_isinf_ptr = dlsym(handle, "mlx_isinf"); + mlx_isinf_ptr = GET_SYM(handle, "mlx_isinf"); if (mlx_isinf_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_isinf\n"); return -1; } - mlx_isnan_ptr = dlsym(handle, "mlx_isnan"); + mlx_isnan_ptr = GET_SYM(handle, "mlx_isnan"); if (mlx_isnan_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_isnan\n"); return -1; } - mlx_isneginf_ptr = dlsym(handle, "mlx_isneginf"); + mlx_isneginf_ptr = GET_SYM(handle, "mlx_isneginf"); if (mlx_isneginf_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_isneginf\n"); return -1; } - mlx_isposinf_ptr = dlsym(handle, "mlx_isposinf"); + mlx_isposinf_ptr = GET_SYM(handle, "mlx_isposinf"); if (mlx_isposinf_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_isposinf\n"); return -1; } - mlx_kron_ptr = dlsym(handle, "mlx_kron"); + mlx_kron_ptr = GET_SYM(handle, "mlx_kron"); if (mlx_kron_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_kron\n"); return -1; } - mlx_left_shift_ptr = dlsym(handle, "mlx_left_shift"); + mlx_left_shift_ptr = GET_SYM(handle, "mlx_left_shift"); if (mlx_left_shift_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_left_shift\n"); return -1; } - mlx_less_ptr = dlsym(handle, "mlx_less"); + mlx_less_ptr = GET_SYM(handle, "mlx_less"); if (mlx_less_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_less\n"); return -1; } - mlx_less_equal_ptr = dlsym(handle, "mlx_less_equal"); + mlx_less_equal_ptr = GET_SYM(handle, "mlx_less_equal"); if (mlx_less_equal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_less_equal\n"); return -1; } - mlx_linspace_ptr = dlsym(handle, "mlx_linspace"); + mlx_linspace_ptr = GET_SYM(handle, "mlx_linspace"); if (mlx_linspace_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linspace\n"); return -1; } - mlx_log_ptr = dlsym(handle, "mlx_log"); + mlx_log_ptr = GET_SYM(handle, "mlx_log"); if (mlx_log_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_log\n"); return -1; } - mlx_log10_ptr = dlsym(handle, "mlx_log10"); + mlx_log10_ptr = GET_SYM(handle, "mlx_log10"); if (mlx_log10_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_log10\n"); return -1; } - mlx_log1p_ptr = dlsym(handle, "mlx_log1p"); + mlx_log1p_ptr = GET_SYM(handle, "mlx_log1p"); if (mlx_log1p_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_log1p\n"); return -1; } - mlx_log2_ptr = dlsym(handle, "mlx_log2"); + mlx_log2_ptr = GET_SYM(handle, "mlx_log2"); if (mlx_log2_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_log2\n"); return -1; } - mlx_logaddexp_ptr = dlsym(handle, "mlx_logaddexp"); + mlx_logaddexp_ptr = GET_SYM(handle, "mlx_logaddexp"); if (mlx_logaddexp_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_logaddexp\n"); return -1; } - mlx_logcumsumexp_ptr = dlsym(handle, "mlx_logcumsumexp"); + mlx_logcumsumexp_ptr = GET_SYM(handle, "mlx_logcumsumexp"); if (mlx_logcumsumexp_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_logcumsumexp\n"); return -1; } - mlx_logical_and_ptr = dlsym(handle, "mlx_logical_and"); + mlx_logical_and_ptr = GET_SYM(handle, "mlx_logical_and"); if (mlx_logical_and_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_logical_and\n"); return -1; } - mlx_logical_not_ptr = dlsym(handle, "mlx_logical_not"); + mlx_logical_not_ptr = GET_SYM(handle, "mlx_logical_not"); if (mlx_logical_not_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_logical_not\n"); return -1; } - mlx_logical_or_ptr = dlsym(handle, "mlx_logical_or"); + mlx_logical_or_ptr = GET_SYM(handle, "mlx_logical_or"); if (mlx_logical_or_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_logical_or\n"); return -1; } - mlx_logsumexp_axes_ptr = dlsym(handle, "mlx_logsumexp_axes"); + mlx_logsumexp_axes_ptr = GET_SYM(handle, "mlx_logsumexp_axes"); if (mlx_logsumexp_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_logsumexp_axes\n"); return -1; } - mlx_logsumexp_axis_ptr = dlsym(handle, "mlx_logsumexp_axis"); + mlx_logsumexp_axis_ptr = GET_SYM(handle, "mlx_logsumexp_axis"); if (mlx_logsumexp_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_logsumexp_axis\n"); return -1; } - mlx_logsumexp_ptr = dlsym(handle, "mlx_logsumexp"); + mlx_logsumexp_ptr = GET_SYM(handle, "mlx_logsumexp"); if (mlx_logsumexp_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_logsumexp\n"); return -1; } - mlx_masked_scatter_ptr = dlsym(handle, "mlx_masked_scatter"); + mlx_masked_scatter_ptr = GET_SYM(handle, "mlx_masked_scatter"); if (mlx_masked_scatter_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_masked_scatter\n"); return -1; } - mlx_matmul_ptr = dlsym(handle, "mlx_matmul"); + mlx_matmul_ptr = GET_SYM(handle, "mlx_matmul"); if (mlx_matmul_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_matmul\n"); return -1; } - mlx_max_axes_ptr = dlsym(handle, "mlx_max_axes"); + mlx_max_axes_ptr = GET_SYM(handle, "mlx_max_axes"); if (mlx_max_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_max_axes\n"); return -1; } - mlx_max_axis_ptr = dlsym(handle, "mlx_max_axis"); + mlx_max_axis_ptr = GET_SYM(handle, "mlx_max_axis"); if (mlx_max_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_max_axis\n"); return -1; } - mlx_max_ptr = dlsym(handle, "mlx_max"); + mlx_max_ptr = GET_SYM(handle, "mlx_max"); if (mlx_max_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_max\n"); return -1; } - mlx_maximum_ptr = dlsym(handle, "mlx_maximum"); + mlx_maximum_ptr = GET_SYM(handle, "mlx_maximum"); if (mlx_maximum_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_maximum\n"); return -1; } - mlx_mean_axes_ptr = dlsym(handle, "mlx_mean_axes"); + mlx_mean_axes_ptr = GET_SYM(handle, "mlx_mean_axes"); if (mlx_mean_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_mean_axes\n"); return -1; } - mlx_mean_axis_ptr = dlsym(handle, "mlx_mean_axis"); + mlx_mean_axis_ptr = GET_SYM(handle, "mlx_mean_axis"); if (mlx_mean_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_mean_axis\n"); return -1; } - mlx_mean_ptr = dlsym(handle, "mlx_mean"); + mlx_mean_ptr = GET_SYM(handle, "mlx_mean"); if (mlx_mean_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_mean\n"); return -1; } - mlx_median_ptr = dlsym(handle, "mlx_median"); + mlx_median_ptr = GET_SYM(handle, "mlx_median"); if (mlx_median_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_median\n"); return -1; } - mlx_meshgrid_ptr = dlsym(handle, "mlx_meshgrid"); + mlx_meshgrid_ptr = GET_SYM(handle, "mlx_meshgrid"); if (mlx_meshgrid_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_meshgrid\n"); return -1; } - mlx_min_axes_ptr = dlsym(handle, "mlx_min_axes"); + mlx_min_axes_ptr = GET_SYM(handle, "mlx_min_axes"); if (mlx_min_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_min_axes\n"); return -1; } - mlx_min_axis_ptr = dlsym(handle, "mlx_min_axis"); + mlx_min_axis_ptr = GET_SYM(handle, "mlx_min_axis"); if (mlx_min_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_min_axis\n"); return -1; } - mlx_min_ptr = dlsym(handle, "mlx_min"); + mlx_min_ptr = GET_SYM(handle, "mlx_min"); if (mlx_min_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_min\n"); return -1; } - mlx_minimum_ptr = dlsym(handle, "mlx_minimum"); + mlx_minimum_ptr = GET_SYM(handle, "mlx_minimum"); if (mlx_minimum_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_minimum\n"); return -1; } - mlx_moveaxis_ptr = dlsym(handle, "mlx_moveaxis"); + mlx_moveaxis_ptr = GET_SYM(handle, "mlx_moveaxis"); if (mlx_moveaxis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_moveaxis\n"); return -1; } - mlx_multiply_ptr = dlsym(handle, "mlx_multiply"); + mlx_multiply_ptr = GET_SYM(handle, "mlx_multiply"); if (mlx_multiply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_multiply\n"); return -1; } - mlx_nan_to_num_ptr = dlsym(handle, "mlx_nan_to_num"); + mlx_nan_to_num_ptr = GET_SYM(handle, "mlx_nan_to_num"); if (mlx_nan_to_num_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_nan_to_num\n"); return -1; } - mlx_negative_ptr = dlsym(handle, "mlx_negative"); + mlx_negative_ptr = GET_SYM(handle, "mlx_negative"); if (mlx_negative_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_negative\n"); return -1; } - mlx_not_equal_ptr = dlsym(handle, "mlx_not_equal"); + mlx_not_equal_ptr = GET_SYM(handle, "mlx_not_equal"); if (mlx_not_equal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_not_equal\n"); return -1; } - mlx_number_of_elements_ptr = dlsym(handle, "mlx_number_of_elements"); + mlx_number_of_elements_ptr = GET_SYM(handle, "mlx_number_of_elements"); if (mlx_number_of_elements_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_number_of_elements\n"); return -1; } - mlx_ones_ptr = dlsym(handle, "mlx_ones"); + mlx_ones_ptr = GET_SYM(handle, "mlx_ones"); if (mlx_ones_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_ones\n"); return -1; } - mlx_ones_like_ptr = dlsym(handle, "mlx_ones_like"); + mlx_ones_like_ptr = GET_SYM(handle, "mlx_ones_like"); if (mlx_ones_like_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_ones_like\n"); return -1; } - mlx_outer_ptr = dlsym(handle, "mlx_outer"); + mlx_outer_ptr = GET_SYM(handle, "mlx_outer"); if (mlx_outer_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_outer\n"); return -1; } - mlx_pad_ptr = dlsym(handle, "mlx_pad"); + mlx_pad_ptr = GET_SYM(handle, "mlx_pad"); if (mlx_pad_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_pad\n"); return -1; } - mlx_pad_symmetric_ptr = dlsym(handle, "mlx_pad_symmetric"); + mlx_pad_symmetric_ptr = GET_SYM(handle, "mlx_pad_symmetric"); if (mlx_pad_symmetric_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_pad_symmetric\n"); return -1; } - mlx_partition_axis_ptr = dlsym(handle, "mlx_partition_axis"); + mlx_partition_axis_ptr = GET_SYM(handle, "mlx_partition_axis"); if (mlx_partition_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_partition_axis\n"); return -1; } - mlx_partition_ptr = dlsym(handle, "mlx_partition"); + mlx_partition_ptr = GET_SYM(handle, "mlx_partition"); if (mlx_partition_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_partition\n"); return -1; } - mlx_power_ptr = dlsym(handle, "mlx_power"); + mlx_power_ptr = GET_SYM(handle, "mlx_power"); if (mlx_power_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_power\n"); return -1; } - mlx_prod_axes_ptr = dlsym(handle, "mlx_prod_axes"); + mlx_prod_axes_ptr = GET_SYM(handle, "mlx_prod_axes"); if (mlx_prod_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_prod_axes\n"); return -1; } - mlx_prod_axis_ptr = dlsym(handle, "mlx_prod_axis"); + mlx_prod_axis_ptr = GET_SYM(handle, "mlx_prod_axis"); if (mlx_prod_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_prod_axis\n"); return -1; } - mlx_prod_ptr = dlsym(handle, "mlx_prod"); + mlx_prod_ptr = GET_SYM(handle, "mlx_prod"); if (mlx_prod_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_prod\n"); return -1; } - mlx_put_along_axis_ptr = dlsym(handle, "mlx_put_along_axis"); + mlx_put_along_axis_ptr = GET_SYM(handle, "mlx_put_along_axis"); if (mlx_put_along_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_put_along_axis\n"); return -1; } - mlx_qqmm_ptr = dlsym(handle, "mlx_qqmm"); + mlx_qqmm_ptr = GET_SYM(handle, "mlx_qqmm"); if (mlx_qqmm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_qqmm\n"); return -1; } - mlx_quantize_ptr = dlsym(handle, "mlx_quantize"); + mlx_quantize_ptr = GET_SYM(handle, "mlx_quantize"); if (mlx_quantize_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_quantize\n"); return -1; } - mlx_quantized_matmul_ptr = dlsym(handle, "mlx_quantized_matmul"); + mlx_quantized_matmul_ptr = GET_SYM(handle, "mlx_quantized_matmul"); if (mlx_quantized_matmul_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_quantized_matmul\n"); return -1; } - mlx_radians_ptr = dlsym(handle, "mlx_radians"); + mlx_radians_ptr = GET_SYM(handle, "mlx_radians"); if (mlx_radians_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_radians\n"); return -1; } - mlx_real_ptr = dlsym(handle, "mlx_real"); + mlx_real_ptr = GET_SYM(handle, "mlx_real"); if (mlx_real_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_real\n"); return -1; } - mlx_reciprocal_ptr = dlsym(handle, "mlx_reciprocal"); + mlx_reciprocal_ptr = GET_SYM(handle, "mlx_reciprocal"); if (mlx_reciprocal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_reciprocal\n"); return -1; } - mlx_remainder_ptr = dlsym(handle, "mlx_remainder"); + mlx_remainder_ptr = GET_SYM(handle, "mlx_remainder"); if (mlx_remainder_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_remainder\n"); return -1; } - mlx_repeat_axis_ptr = dlsym(handle, "mlx_repeat_axis"); + mlx_repeat_axis_ptr = GET_SYM(handle, "mlx_repeat_axis"); if (mlx_repeat_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_repeat_axis\n"); return -1; } - mlx_repeat_ptr = dlsym(handle, "mlx_repeat"); + mlx_repeat_ptr = GET_SYM(handle, "mlx_repeat"); if (mlx_repeat_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_repeat\n"); return -1; } - mlx_reshape_ptr = dlsym(handle, "mlx_reshape"); + mlx_reshape_ptr = GET_SYM(handle, "mlx_reshape"); if (mlx_reshape_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_reshape\n"); return -1; } - mlx_right_shift_ptr = dlsym(handle, "mlx_right_shift"); + mlx_right_shift_ptr = GET_SYM(handle, "mlx_right_shift"); if (mlx_right_shift_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_right_shift\n"); return -1; } - mlx_roll_axis_ptr = dlsym(handle, "mlx_roll_axis"); + mlx_roll_axis_ptr = GET_SYM(handle, "mlx_roll_axis"); if (mlx_roll_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_roll_axis\n"); return -1; } - mlx_roll_axes_ptr = dlsym(handle, "mlx_roll_axes"); + mlx_roll_axes_ptr = GET_SYM(handle, "mlx_roll_axes"); if (mlx_roll_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_roll_axes\n"); return -1; } - mlx_roll_ptr = dlsym(handle, "mlx_roll"); + mlx_roll_ptr = GET_SYM(handle, "mlx_roll"); if (mlx_roll_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_roll\n"); return -1; } - mlx_round_ptr = dlsym(handle, "mlx_round"); + mlx_round_ptr = GET_SYM(handle, "mlx_round"); if (mlx_round_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_round\n"); return -1; } - mlx_rsqrt_ptr = dlsym(handle, "mlx_rsqrt"); + mlx_rsqrt_ptr = GET_SYM(handle, "mlx_rsqrt"); if (mlx_rsqrt_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_rsqrt\n"); return -1; } - mlx_scatter_ptr = dlsym(handle, "mlx_scatter"); + mlx_scatter_ptr = GET_SYM(handle, "mlx_scatter"); if (mlx_scatter_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter\n"); return -1; } - mlx_scatter_single_ptr = dlsym(handle, "mlx_scatter_single"); + mlx_scatter_single_ptr = GET_SYM(handle, "mlx_scatter_single"); if (mlx_scatter_single_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_single\n"); return -1; } - mlx_scatter_add_ptr = dlsym(handle, "mlx_scatter_add"); + mlx_scatter_add_ptr = GET_SYM(handle, "mlx_scatter_add"); if (mlx_scatter_add_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_add\n"); return -1; } - mlx_scatter_add_single_ptr = dlsym(handle, "mlx_scatter_add_single"); + mlx_scatter_add_single_ptr = GET_SYM(handle, "mlx_scatter_add_single"); if (mlx_scatter_add_single_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_add_single\n"); return -1; } - mlx_scatter_add_axis_ptr = dlsym(handle, "mlx_scatter_add_axis"); + mlx_scatter_add_axis_ptr = GET_SYM(handle, "mlx_scatter_add_axis"); if (mlx_scatter_add_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_add_axis\n"); return -1; } - mlx_scatter_max_ptr = dlsym(handle, "mlx_scatter_max"); + mlx_scatter_max_ptr = GET_SYM(handle, "mlx_scatter_max"); if (mlx_scatter_max_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_max\n"); return -1; } - mlx_scatter_max_single_ptr = dlsym(handle, "mlx_scatter_max_single"); + mlx_scatter_max_single_ptr = GET_SYM(handle, "mlx_scatter_max_single"); if (mlx_scatter_max_single_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_max_single\n"); return -1; } - mlx_scatter_min_ptr = dlsym(handle, "mlx_scatter_min"); + mlx_scatter_min_ptr = GET_SYM(handle, "mlx_scatter_min"); if (mlx_scatter_min_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_min\n"); return -1; } - mlx_scatter_min_single_ptr = dlsym(handle, "mlx_scatter_min_single"); + mlx_scatter_min_single_ptr = GET_SYM(handle, "mlx_scatter_min_single"); if (mlx_scatter_min_single_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_min_single\n"); return -1; } - mlx_scatter_prod_ptr = dlsym(handle, "mlx_scatter_prod"); + mlx_scatter_prod_ptr = GET_SYM(handle, "mlx_scatter_prod"); if (mlx_scatter_prod_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_prod\n"); return -1; } - mlx_scatter_prod_single_ptr = dlsym(handle, "mlx_scatter_prod_single"); + mlx_scatter_prod_single_ptr = GET_SYM(handle, "mlx_scatter_prod_single"); if (mlx_scatter_prod_single_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_prod_single\n"); return -1; } - mlx_segmented_mm_ptr = dlsym(handle, "mlx_segmented_mm"); + mlx_segmented_mm_ptr = GET_SYM(handle, "mlx_segmented_mm"); if (mlx_segmented_mm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_segmented_mm\n"); return -1; } - mlx_sigmoid_ptr = dlsym(handle, "mlx_sigmoid"); + mlx_sigmoid_ptr = GET_SYM(handle, "mlx_sigmoid"); if (mlx_sigmoid_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sigmoid\n"); return -1; } - mlx_sign_ptr = dlsym(handle, "mlx_sign"); + mlx_sign_ptr = GET_SYM(handle, "mlx_sign"); if (mlx_sign_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sign\n"); return -1; } - mlx_sin_ptr = dlsym(handle, "mlx_sin"); + mlx_sin_ptr = GET_SYM(handle, "mlx_sin"); if (mlx_sin_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sin\n"); return -1; } - mlx_sinh_ptr = dlsym(handle, "mlx_sinh"); + mlx_sinh_ptr = GET_SYM(handle, "mlx_sinh"); if (mlx_sinh_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sinh\n"); return -1; } - mlx_slice_ptr = dlsym(handle, "mlx_slice"); + mlx_slice_ptr = GET_SYM(handle, "mlx_slice"); if (mlx_slice_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_slice\n"); return -1; } - mlx_slice_dynamic_ptr = dlsym(handle, "mlx_slice_dynamic"); + mlx_slice_dynamic_ptr = GET_SYM(handle, "mlx_slice_dynamic"); if (mlx_slice_dynamic_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_slice_dynamic\n"); return -1; } - mlx_slice_update_ptr = dlsym(handle, "mlx_slice_update"); + mlx_slice_update_ptr = GET_SYM(handle, "mlx_slice_update"); if (mlx_slice_update_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_slice_update\n"); return -1; } - mlx_slice_update_dynamic_ptr = dlsym(handle, "mlx_slice_update_dynamic"); + mlx_slice_update_dynamic_ptr = GET_SYM(handle, "mlx_slice_update_dynamic"); if (mlx_slice_update_dynamic_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_slice_update_dynamic\n"); return -1; } - mlx_softmax_axes_ptr = dlsym(handle, "mlx_softmax_axes"); + mlx_softmax_axes_ptr = GET_SYM(handle, "mlx_softmax_axes"); if (mlx_softmax_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_softmax_axes\n"); return -1; } - mlx_softmax_axis_ptr = dlsym(handle, "mlx_softmax_axis"); + mlx_softmax_axis_ptr = GET_SYM(handle, "mlx_softmax_axis"); if (mlx_softmax_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_softmax_axis\n"); return -1; } - mlx_softmax_ptr = dlsym(handle, "mlx_softmax"); + mlx_softmax_ptr = GET_SYM(handle, "mlx_softmax"); if (mlx_softmax_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_softmax\n"); return -1; } - mlx_sort_axis_ptr = dlsym(handle, "mlx_sort_axis"); + mlx_sort_axis_ptr = GET_SYM(handle, "mlx_sort_axis"); if (mlx_sort_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sort_axis\n"); return -1; } - mlx_sort_ptr = dlsym(handle, "mlx_sort"); + mlx_sort_ptr = GET_SYM(handle, "mlx_sort"); if (mlx_sort_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sort\n"); return -1; } - mlx_split_ptr = dlsym(handle, "mlx_split"); + mlx_split_ptr = GET_SYM(handle, "mlx_split"); if (mlx_split_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_split\n"); return -1; } - mlx_split_sections_ptr = dlsym(handle, "mlx_split_sections"); + mlx_split_sections_ptr = GET_SYM(handle, "mlx_split_sections"); if (mlx_split_sections_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_split_sections\n"); return -1; } - mlx_sqrt_ptr = dlsym(handle, "mlx_sqrt"); + mlx_sqrt_ptr = GET_SYM(handle, "mlx_sqrt"); if (mlx_sqrt_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sqrt\n"); return -1; } - mlx_square_ptr = dlsym(handle, "mlx_square"); + mlx_square_ptr = GET_SYM(handle, "mlx_square"); if (mlx_square_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_square\n"); return -1; } - mlx_squeeze_axes_ptr = dlsym(handle, "mlx_squeeze_axes"); + mlx_squeeze_axes_ptr = GET_SYM(handle, "mlx_squeeze_axes"); if (mlx_squeeze_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_squeeze_axes\n"); return -1; } - mlx_squeeze_axis_ptr = dlsym(handle, "mlx_squeeze_axis"); + mlx_squeeze_axis_ptr = GET_SYM(handle, "mlx_squeeze_axis"); if (mlx_squeeze_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_squeeze_axis\n"); return -1; } - mlx_squeeze_ptr = dlsym(handle, "mlx_squeeze"); + mlx_squeeze_ptr = GET_SYM(handle, "mlx_squeeze"); if (mlx_squeeze_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_squeeze\n"); return -1; } - mlx_stack_axis_ptr = dlsym(handle, "mlx_stack_axis"); + mlx_stack_axis_ptr = GET_SYM(handle, "mlx_stack_axis"); if (mlx_stack_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stack_axis\n"); return -1; } - mlx_stack_ptr = dlsym(handle, "mlx_stack"); + mlx_stack_ptr = GET_SYM(handle, "mlx_stack"); if (mlx_stack_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stack\n"); return -1; } - mlx_std_axes_ptr = dlsym(handle, "mlx_std_axes"); + mlx_std_axes_ptr = GET_SYM(handle, "mlx_std_axes"); if (mlx_std_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_std_axes\n"); return -1; } - mlx_std_axis_ptr = dlsym(handle, "mlx_std_axis"); + mlx_std_axis_ptr = GET_SYM(handle, "mlx_std_axis"); if (mlx_std_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_std_axis\n"); return -1; } - mlx_std_ptr = dlsym(handle, "mlx_std"); + mlx_std_ptr = GET_SYM(handle, "mlx_std"); if (mlx_std_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_std\n"); return -1; } - mlx_stop_gradient_ptr = dlsym(handle, "mlx_stop_gradient"); + mlx_stop_gradient_ptr = GET_SYM(handle, "mlx_stop_gradient"); if (mlx_stop_gradient_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stop_gradient\n"); return -1; } - mlx_subtract_ptr = dlsym(handle, "mlx_subtract"); + mlx_subtract_ptr = GET_SYM(handle, "mlx_subtract"); if (mlx_subtract_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_subtract\n"); return -1; } - mlx_sum_axes_ptr = dlsym(handle, "mlx_sum_axes"); + mlx_sum_axes_ptr = GET_SYM(handle, "mlx_sum_axes"); if (mlx_sum_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sum_axes\n"); return -1; } - mlx_sum_axis_ptr = dlsym(handle, "mlx_sum_axis"); + mlx_sum_axis_ptr = GET_SYM(handle, "mlx_sum_axis"); if (mlx_sum_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sum_axis\n"); return -1; } - mlx_sum_ptr = dlsym(handle, "mlx_sum"); + mlx_sum_ptr = GET_SYM(handle, "mlx_sum"); if (mlx_sum_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sum\n"); return -1; } - mlx_swapaxes_ptr = dlsym(handle, "mlx_swapaxes"); + mlx_swapaxes_ptr = GET_SYM(handle, "mlx_swapaxes"); if (mlx_swapaxes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_swapaxes\n"); return -1; } - mlx_take_axis_ptr = dlsym(handle, "mlx_take_axis"); + mlx_take_axis_ptr = GET_SYM(handle, "mlx_take_axis"); if (mlx_take_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_take_axis\n"); return -1; } - mlx_take_ptr = dlsym(handle, "mlx_take"); + mlx_take_ptr = GET_SYM(handle, "mlx_take"); if (mlx_take_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_take\n"); return -1; } - mlx_take_along_axis_ptr = dlsym(handle, "mlx_take_along_axis"); + mlx_take_along_axis_ptr = GET_SYM(handle, "mlx_take_along_axis"); if (mlx_take_along_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_take_along_axis\n"); return -1; } - mlx_tan_ptr = dlsym(handle, "mlx_tan"); + mlx_tan_ptr = GET_SYM(handle, "mlx_tan"); if (mlx_tan_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_tan\n"); return -1; } - mlx_tanh_ptr = dlsym(handle, "mlx_tanh"); + mlx_tanh_ptr = GET_SYM(handle, "mlx_tanh"); if (mlx_tanh_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_tanh\n"); return -1; } - mlx_tensordot_ptr = dlsym(handle, "mlx_tensordot"); + mlx_tensordot_ptr = GET_SYM(handle, "mlx_tensordot"); if (mlx_tensordot_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_tensordot\n"); return -1; } - mlx_tensordot_axis_ptr = dlsym(handle, "mlx_tensordot_axis"); + mlx_tensordot_axis_ptr = GET_SYM(handle, "mlx_tensordot_axis"); if (mlx_tensordot_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_tensordot_axis\n"); return -1; } - mlx_tile_ptr = dlsym(handle, "mlx_tile"); + mlx_tile_ptr = GET_SYM(handle, "mlx_tile"); if (mlx_tile_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_tile\n"); return -1; } - mlx_to_fp8_ptr = dlsym(handle, "mlx_to_fp8"); + mlx_to_fp8_ptr = GET_SYM(handle, "mlx_to_fp8"); if (mlx_to_fp8_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_to_fp8\n"); return -1; } - mlx_topk_axis_ptr = dlsym(handle, "mlx_topk_axis"); + mlx_topk_axis_ptr = GET_SYM(handle, "mlx_topk_axis"); if (mlx_topk_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_topk_axis\n"); return -1; } - mlx_topk_ptr = dlsym(handle, "mlx_topk"); + mlx_topk_ptr = GET_SYM(handle, "mlx_topk"); if (mlx_topk_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_topk\n"); return -1; } - mlx_trace_ptr = dlsym(handle, "mlx_trace"); + mlx_trace_ptr = GET_SYM(handle, "mlx_trace"); if (mlx_trace_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_trace\n"); return -1; } - mlx_transpose_axes_ptr = dlsym(handle, "mlx_transpose_axes"); + mlx_transpose_axes_ptr = GET_SYM(handle, "mlx_transpose_axes"); if (mlx_transpose_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_transpose_axes\n"); return -1; } - mlx_transpose_ptr = dlsym(handle, "mlx_transpose"); + mlx_transpose_ptr = GET_SYM(handle, "mlx_transpose"); if (mlx_transpose_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_transpose\n"); return -1; } - mlx_tri_ptr = dlsym(handle, "mlx_tri"); + mlx_tri_ptr = GET_SYM(handle, "mlx_tri"); if (mlx_tri_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_tri\n"); return -1; } - mlx_tril_ptr = dlsym(handle, "mlx_tril"); + mlx_tril_ptr = GET_SYM(handle, "mlx_tril"); if (mlx_tril_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_tril\n"); return -1; } - mlx_triu_ptr = dlsym(handle, "mlx_triu"); + mlx_triu_ptr = GET_SYM(handle, "mlx_triu"); if (mlx_triu_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_triu\n"); return -1; } - mlx_unflatten_ptr = dlsym(handle, "mlx_unflatten"); + mlx_unflatten_ptr = GET_SYM(handle, "mlx_unflatten"); if (mlx_unflatten_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_unflatten\n"); return -1; } - mlx_var_axes_ptr = dlsym(handle, "mlx_var_axes"); + mlx_var_axes_ptr = GET_SYM(handle, "mlx_var_axes"); if (mlx_var_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_var_axes\n"); return -1; } - mlx_var_axis_ptr = dlsym(handle, "mlx_var_axis"); + mlx_var_axis_ptr = GET_SYM(handle, "mlx_var_axis"); if (mlx_var_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_var_axis\n"); return -1; } - mlx_var_ptr = dlsym(handle, "mlx_var"); + mlx_var_ptr = GET_SYM(handle, "mlx_var"); if (mlx_var_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_var\n"); return -1; } - mlx_view_ptr = dlsym(handle, "mlx_view"); + mlx_view_ptr = GET_SYM(handle, "mlx_view"); if (mlx_view_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_view\n"); return -1; } - mlx_where_ptr = dlsym(handle, "mlx_where"); + mlx_where_ptr = GET_SYM(handle, "mlx_where"); if (mlx_where_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_where\n"); return -1; } - mlx_zeros_ptr = dlsym(handle, "mlx_zeros"); + mlx_zeros_ptr = GET_SYM(handle, "mlx_zeros"); if (mlx_zeros_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_zeros\n"); return -1; } - mlx_zeros_like_ptr = dlsym(handle, "mlx_zeros_like"); + mlx_zeros_like_ptr = GET_SYM(handle, "mlx_zeros_like"); if (mlx_zeros_like_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_zeros_like\n"); return -1; } - mlx_random_bernoulli_ptr = dlsym(handle, "mlx_random_bernoulli"); + mlx_random_bernoulli_ptr = GET_SYM(handle, "mlx_random_bernoulli"); if (mlx_random_bernoulli_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_bernoulli\n"); return -1; } - mlx_random_bits_ptr = dlsym(handle, "mlx_random_bits"); + mlx_random_bits_ptr = GET_SYM(handle, "mlx_random_bits"); if (mlx_random_bits_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_bits\n"); return -1; } - mlx_random_categorical_shape_ptr = dlsym(handle, "mlx_random_categorical_shape"); + mlx_random_categorical_shape_ptr = GET_SYM(handle, "mlx_random_categorical_shape"); if (mlx_random_categorical_shape_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_categorical_shape\n"); return -1; } - mlx_random_categorical_num_samples_ptr = dlsym(handle, "mlx_random_categorical_num_samples"); + mlx_random_categorical_num_samples_ptr = GET_SYM(handle, "mlx_random_categorical_num_samples"); if (mlx_random_categorical_num_samples_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_categorical_num_samples\n"); return -1; } - mlx_random_categorical_ptr = dlsym(handle, "mlx_random_categorical"); + mlx_random_categorical_ptr = GET_SYM(handle, "mlx_random_categorical"); if (mlx_random_categorical_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_categorical\n"); return -1; } - mlx_random_gumbel_ptr = dlsym(handle, "mlx_random_gumbel"); + mlx_random_gumbel_ptr = GET_SYM(handle, "mlx_random_gumbel"); if (mlx_random_gumbel_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_gumbel\n"); return -1; } - mlx_random_key_ptr = dlsym(handle, "mlx_random_key"); + mlx_random_key_ptr = GET_SYM(handle, "mlx_random_key"); if (mlx_random_key_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_key\n"); return -1; } - mlx_random_laplace_ptr = dlsym(handle, "mlx_random_laplace"); + mlx_random_laplace_ptr = GET_SYM(handle, "mlx_random_laplace"); if (mlx_random_laplace_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_laplace\n"); return -1; } - mlx_random_multivariate_normal_ptr = dlsym(handle, "mlx_random_multivariate_normal"); + mlx_random_multivariate_normal_ptr = GET_SYM(handle, "mlx_random_multivariate_normal"); if (mlx_random_multivariate_normal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_multivariate_normal\n"); return -1; } - mlx_random_normal_broadcast_ptr = dlsym(handle, "mlx_random_normal_broadcast"); + mlx_random_normal_broadcast_ptr = GET_SYM(handle, "mlx_random_normal_broadcast"); if (mlx_random_normal_broadcast_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_normal_broadcast\n"); return -1; } - mlx_random_normal_ptr = dlsym(handle, "mlx_random_normal"); + mlx_random_normal_ptr = GET_SYM(handle, "mlx_random_normal"); if (mlx_random_normal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_normal\n"); return -1; } - mlx_random_permutation_ptr = dlsym(handle, "mlx_random_permutation"); + mlx_random_permutation_ptr = GET_SYM(handle, "mlx_random_permutation"); if (mlx_random_permutation_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_permutation\n"); return -1; } - mlx_random_permutation_arange_ptr = dlsym(handle, "mlx_random_permutation_arange"); + mlx_random_permutation_arange_ptr = GET_SYM(handle, "mlx_random_permutation_arange"); if (mlx_random_permutation_arange_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_permutation_arange\n"); return -1; } - mlx_random_randint_ptr = dlsym(handle, "mlx_random_randint"); + mlx_random_randint_ptr = GET_SYM(handle, "mlx_random_randint"); if (mlx_random_randint_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_randint\n"); return -1; } - mlx_random_seed_ptr = dlsym(handle, "mlx_random_seed"); + mlx_random_seed_ptr = GET_SYM(handle, "mlx_random_seed"); if (mlx_random_seed_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_seed\n"); return -1; } - mlx_random_split_num_ptr = dlsym(handle, "mlx_random_split_num"); + mlx_random_split_num_ptr = GET_SYM(handle, "mlx_random_split_num"); if (mlx_random_split_num_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_split_num\n"); return -1; } - mlx_random_split_ptr = dlsym(handle, "mlx_random_split"); + mlx_random_split_ptr = GET_SYM(handle, "mlx_random_split"); if (mlx_random_split_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_split\n"); return -1; } - mlx_random_truncated_normal_ptr = dlsym(handle, "mlx_random_truncated_normal"); + mlx_random_truncated_normal_ptr = GET_SYM(handle, "mlx_random_truncated_normal"); if (mlx_random_truncated_normal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_truncated_normal\n"); return -1; } - mlx_random_uniform_ptr = dlsym(handle, "mlx_random_uniform"); + mlx_random_uniform_ptr = GET_SYM(handle, "mlx_random_uniform"); if (mlx_random_uniform_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_uniform\n"); return -1; } - mlx_stream_new_ptr = dlsym(handle, "mlx_stream_new"); + mlx_stream_new_ptr = GET_SYM(handle, "mlx_stream_new"); if (mlx_stream_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_new\n"); return -1; } - mlx_stream_new_device_ptr = dlsym(handle, "mlx_stream_new_device"); + mlx_stream_new_device_ptr = GET_SYM(handle, "mlx_stream_new_device"); if (mlx_stream_new_device_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_new_device\n"); return -1; } - mlx_stream_set_ptr = dlsym(handle, "mlx_stream_set"); + mlx_stream_set_ptr = GET_SYM(handle, "mlx_stream_set"); if (mlx_stream_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_set\n"); return -1; } - mlx_stream_free_ptr = dlsym(handle, "mlx_stream_free"); + mlx_stream_free_ptr = GET_SYM(handle, "mlx_stream_free"); if (mlx_stream_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_free\n"); return -1; } - mlx_stream_tostring_ptr = dlsym(handle, "mlx_stream_tostring"); + mlx_stream_tostring_ptr = GET_SYM(handle, "mlx_stream_tostring"); if (mlx_stream_tostring_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_tostring\n"); return -1; } - mlx_stream_equal_ptr = dlsym(handle, "mlx_stream_equal"); + mlx_stream_equal_ptr = GET_SYM(handle, "mlx_stream_equal"); if (mlx_stream_equal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_equal\n"); return -1; } - mlx_stream_get_device_ptr = dlsym(handle, "mlx_stream_get_device"); + mlx_stream_get_device_ptr = GET_SYM(handle, "mlx_stream_get_device"); if (mlx_stream_get_device_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_get_device\n"); return -1; } - mlx_stream_get_index_ptr = dlsym(handle, "mlx_stream_get_index"); + mlx_stream_get_index_ptr = GET_SYM(handle, "mlx_stream_get_index"); if (mlx_stream_get_index_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_get_index\n"); return -1; } - mlx_synchronize_ptr = dlsym(handle, "mlx_synchronize"); + mlx_synchronize_ptr = GET_SYM(handle, "mlx_synchronize"); if (mlx_synchronize_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_synchronize\n"); return -1; } - mlx_get_default_stream_ptr = dlsym(handle, "mlx_get_default_stream"); + mlx_get_default_stream_ptr = GET_SYM(handle, "mlx_get_default_stream"); if (mlx_get_default_stream_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_get_default_stream\n"); return -1; } - mlx_set_default_stream_ptr = dlsym(handle, "mlx_set_default_stream"); + mlx_set_default_stream_ptr = GET_SYM(handle, "mlx_set_default_stream"); if (mlx_set_default_stream_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_default_stream\n"); return -1; } - mlx_default_cpu_stream_new_ptr = dlsym(handle, "mlx_default_cpu_stream_new"); + mlx_default_cpu_stream_new_ptr = GET_SYM(handle, "mlx_default_cpu_stream_new"); if (mlx_default_cpu_stream_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_default_cpu_stream_new\n"); return -1; } - mlx_default_gpu_stream_new_ptr = dlsym(handle, "mlx_default_gpu_stream_new"); + mlx_default_gpu_stream_new_ptr = GET_SYM(handle, "mlx_default_gpu_stream_new"); if (mlx_default_gpu_stream_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_default_gpu_stream_new\n"); return -1; } - mlx_string_new_ptr = dlsym(handle, "mlx_string_new"); + mlx_string_new_ptr = GET_SYM(handle, "mlx_string_new"); if (mlx_string_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_string_new\n"); return -1; } - mlx_string_new_data_ptr = dlsym(handle, "mlx_string_new_data"); + mlx_string_new_data_ptr = GET_SYM(handle, "mlx_string_new_data"); if (mlx_string_new_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_string_new_data\n"); return -1; } - mlx_string_set_ptr = dlsym(handle, "mlx_string_set"); + mlx_string_set_ptr = GET_SYM(handle, "mlx_string_set"); if (mlx_string_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_string_set\n"); return -1; } - mlx_string_data_ptr = dlsym(handle, "mlx_string_data"); + mlx_string_data_ptr = GET_SYM(handle, "mlx_string_data"); if (mlx_string_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_string_data\n"); return -1; } - mlx_string_free_ptr = dlsym(handle, "mlx_string_free"); + mlx_string_free_ptr = GET_SYM(handle, "mlx_string_free"); if (mlx_string_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_string_free\n"); return -1; } - mlx_async_eval_ptr = dlsym(handle, "mlx_async_eval"); + mlx_async_eval_ptr = GET_SYM(handle, "mlx_async_eval"); if (mlx_async_eval_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_async_eval\n"); return -1; } - mlx_checkpoint_ptr = dlsym(handle, "mlx_checkpoint"); + mlx_checkpoint_ptr = GET_SYM(handle, "mlx_checkpoint"); if (mlx_checkpoint_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_checkpoint\n"); return -1; } - mlx_custom_function_ptr = dlsym(handle, "mlx_custom_function"); + mlx_custom_function_ptr = GET_SYM(handle, "mlx_custom_function"); if (mlx_custom_function_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_custom_function\n"); return -1; } - mlx_custom_vjp_ptr = dlsym(handle, "mlx_custom_vjp"); + mlx_custom_vjp_ptr = GET_SYM(handle, "mlx_custom_vjp"); if (mlx_custom_vjp_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_custom_vjp\n"); return -1; } - mlx_eval_ptr = dlsym(handle, "mlx_eval"); + mlx_eval_ptr = GET_SYM(handle, "mlx_eval"); if (mlx_eval_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_eval\n"); return -1; } - mlx_jvp_ptr = dlsym(handle, "mlx_jvp"); + mlx_jvp_ptr = GET_SYM(handle, "mlx_jvp"); if (mlx_jvp_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_jvp\n"); return -1; } - mlx_value_and_grad_ptr = dlsym(handle, "mlx_value_and_grad"); + mlx_value_and_grad_ptr = GET_SYM(handle, "mlx_value_and_grad"); if (mlx_value_and_grad_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_value_and_grad\n"); return -1; } - mlx_vjp_ptr = dlsym(handle, "mlx_vjp"); + mlx_vjp_ptr = GET_SYM(handle, "mlx_vjp"); if (mlx_vjp_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vjp\n"); return -1; } - mlx_detail_vmap_replace_ptr = dlsym(handle, "mlx_detail_vmap_replace"); + mlx_detail_vmap_replace_ptr = GET_SYM(handle, "mlx_detail_vmap_replace"); if (mlx_detail_vmap_replace_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_vmap_replace\n"); return -1; } - mlx_detail_vmap_trace_ptr = dlsym(handle, "mlx_detail_vmap_trace"); + mlx_detail_vmap_trace_ptr = GET_SYM(handle, "mlx_detail_vmap_trace"); if (mlx_detail_vmap_trace_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_vmap_trace\n"); return -1; } - mlx_vector_array_new_ptr = dlsym(handle, "mlx_vector_array_new"); + mlx_vector_array_new_ptr = GET_SYM(handle, "mlx_vector_array_new"); if (mlx_vector_array_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_new\n"); return -1; } - mlx_vector_array_set_ptr = dlsym(handle, "mlx_vector_array_set"); + mlx_vector_array_set_ptr = GET_SYM(handle, "mlx_vector_array_set"); if (mlx_vector_array_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_set\n"); return -1; } - mlx_vector_array_free_ptr = dlsym(handle, "mlx_vector_array_free"); + mlx_vector_array_free_ptr = GET_SYM(handle, "mlx_vector_array_free"); if (mlx_vector_array_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_free\n"); return -1; } - mlx_vector_array_new_data_ptr = dlsym(handle, "mlx_vector_array_new_data"); + mlx_vector_array_new_data_ptr = GET_SYM(handle, "mlx_vector_array_new_data"); if (mlx_vector_array_new_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_new_data\n"); return -1; } - mlx_vector_array_new_value_ptr = dlsym(handle, "mlx_vector_array_new_value"); + mlx_vector_array_new_value_ptr = GET_SYM(handle, "mlx_vector_array_new_value"); if (mlx_vector_array_new_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_new_value\n"); return -1; } - mlx_vector_array_set_data_ptr = dlsym(handle, "mlx_vector_array_set_data"); + mlx_vector_array_set_data_ptr = GET_SYM(handle, "mlx_vector_array_set_data"); if (mlx_vector_array_set_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_set_data\n"); return -1; } - mlx_vector_array_set_value_ptr = dlsym(handle, "mlx_vector_array_set_value"); + mlx_vector_array_set_value_ptr = GET_SYM(handle, "mlx_vector_array_set_value"); if (mlx_vector_array_set_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_set_value\n"); return -1; } - mlx_vector_array_append_data_ptr = dlsym(handle, "mlx_vector_array_append_data"); + mlx_vector_array_append_data_ptr = GET_SYM(handle, "mlx_vector_array_append_data"); if (mlx_vector_array_append_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_append_data\n"); return -1; } - mlx_vector_array_append_value_ptr = dlsym(handle, "mlx_vector_array_append_value"); + mlx_vector_array_append_value_ptr = GET_SYM(handle, "mlx_vector_array_append_value"); if (mlx_vector_array_append_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_append_value\n"); return -1; } - mlx_vector_array_size_ptr = dlsym(handle, "mlx_vector_array_size"); + mlx_vector_array_size_ptr = GET_SYM(handle, "mlx_vector_array_size"); if (mlx_vector_array_size_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_size\n"); return -1; } - mlx_vector_array_get_ptr = dlsym(handle, "mlx_vector_array_get"); + mlx_vector_array_get_ptr = GET_SYM(handle, "mlx_vector_array_get"); if (mlx_vector_array_get_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_get\n"); return -1; } - mlx_vector_vector_array_new_ptr = dlsym(handle, "mlx_vector_vector_array_new"); + mlx_vector_vector_array_new_ptr = GET_SYM(handle, "mlx_vector_vector_array_new"); if (mlx_vector_vector_array_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_new\n"); return -1; } - mlx_vector_vector_array_set_ptr = dlsym(handle, "mlx_vector_vector_array_set"); + mlx_vector_vector_array_set_ptr = GET_SYM(handle, "mlx_vector_vector_array_set"); if (mlx_vector_vector_array_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_set\n"); return -1; } - mlx_vector_vector_array_free_ptr = dlsym(handle, "mlx_vector_vector_array_free"); + mlx_vector_vector_array_free_ptr = GET_SYM(handle, "mlx_vector_vector_array_free"); if (mlx_vector_vector_array_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_free\n"); return -1; } - mlx_vector_vector_array_new_data_ptr = dlsym(handle, "mlx_vector_vector_array_new_data"); + mlx_vector_vector_array_new_data_ptr = GET_SYM(handle, "mlx_vector_vector_array_new_data"); if (mlx_vector_vector_array_new_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_new_data\n"); return -1; } - mlx_vector_vector_array_new_value_ptr = dlsym(handle, "mlx_vector_vector_array_new_value"); + mlx_vector_vector_array_new_value_ptr = GET_SYM(handle, "mlx_vector_vector_array_new_value"); if (mlx_vector_vector_array_new_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_new_value\n"); return -1; } - mlx_vector_vector_array_set_data_ptr = dlsym(handle, "mlx_vector_vector_array_set_data"); + mlx_vector_vector_array_set_data_ptr = GET_SYM(handle, "mlx_vector_vector_array_set_data"); if (mlx_vector_vector_array_set_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_set_data\n"); return -1; } - mlx_vector_vector_array_set_value_ptr = dlsym(handle, "mlx_vector_vector_array_set_value"); + mlx_vector_vector_array_set_value_ptr = GET_SYM(handle, "mlx_vector_vector_array_set_value"); if (mlx_vector_vector_array_set_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_set_value\n"); return -1; } - mlx_vector_vector_array_append_data_ptr = dlsym(handle, "mlx_vector_vector_array_append_data"); + mlx_vector_vector_array_append_data_ptr = GET_SYM(handle, "mlx_vector_vector_array_append_data"); if (mlx_vector_vector_array_append_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_append_data\n"); return -1; } - mlx_vector_vector_array_append_value_ptr = dlsym(handle, "mlx_vector_vector_array_append_value"); + mlx_vector_vector_array_append_value_ptr = GET_SYM(handle, "mlx_vector_vector_array_append_value"); if (mlx_vector_vector_array_append_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_append_value\n"); return -1; } - mlx_vector_vector_array_size_ptr = dlsym(handle, "mlx_vector_vector_array_size"); + mlx_vector_vector_array_size_ptr = GET_SYM(handle, "mlx_vector_vector_array_size"); if (mlx_vector_vector_array_size_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_size\n"); return -1; } - mlx_vector_vector_array_get_ptr = dlsym(handle, "mlx_vector_vector_array_get"); + mlx_vector_vector_array_get_ptr = GET_SYM(handle, "mlx_vector_vector_array_get"); if (mlx_vector_vector_array_get_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_get\n"); return -1; } - mlx_vector_int_new_ptr = dlsym(handle, "mlx_vector_int_new"); + mlx_vector_int_new_ptr = GET_SYM(handle, "mlx_vector_int_new"); if (mlx_vector_int_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_new\n"); return -1; } - mlx_vector_int_set_ptr = dlsym(handle, "mlx_vector_int_set"); + mlx_vector_int_set_ptr = GET_SYM(handle, "mlx_vector_int_set"); if (mlx_vector_int_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_set\n"); return -1; } - mlx_vector_int_free_ptr = dlsym(handle, "mlx_vector_int_free"); + mlx_vector_int_free_ptr = GET_SYM(handle, "mlx_vector_int_free"); if (mlx_vector_int_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_free\n"); return -1; } - mlx_vector_int_new_data_ptr = dlsym(handle, "mlx_vector_int_new_data"); + mlx_vector_int_new_data_ptr = GET_SYM(handle, "mlx_vector_int_new_data"); if (mlx_vector_int_new_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_new_data\n"); return -1; } - mlx_vector_int_new_value_ptr = dlsym(handle, "mlx_vector_int_new_value"); + mlx_vector_int_new_value_ptr = GET_SYM(handle, "mlx_vector_int_new_value"); if (mlx_vector_int_new_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_new_value\n"); return -1; } - mlx_vector_int_set_data_ptr = dlsym(handle, "mlx_vector_int_set_data"); + mlx_vector_int_set_data_ptr = GET_SYM(handle, "mlx_vector_int_set_data"); if (mlx_vector_int_set_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_set_data\n"); return -1; } - mlx_vector_int_set_value_ptr = dlsym(handle, "mlx_vector_int_set_value"); + mlx_vector_int_set_value_ptr = GET_SYM(handle, "mlx_vector_int_set_value"); if (mlx_vector_int_set_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_set_value\n"); return -1; } - mlx_vector_int_append_data_ptr = dlsym(handle, "mlx_vector_int_append_data"); + mlx_vector_int_append_data_ptr = GET_SYM(handle, "mlx_vector_int_append_data"); if (mlx_vector_int_append_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_append_data\n"); return -1; } - mlx_vector_int_append_value_ptr = dlsym(handle, "mlx_vector_int_append_value"); + mlx_vector_int_append_value_ptr = GET_SYM(handle, "mlx_vector_int_append_value"); if (mlx_vector_int_append_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_append_value\n"); return -1; } - mlx_vector_int_size_ptr = dlsym(handle, "mlx_vector_int_size"); + mlx_vector_int_size_ptr = GET_SYM(handle, "mlx_vector_int_size"); if (mlx_vector_int_size_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_size\n"); return -1; } - mlx_vector_int_get_ptr = dlsym(handle, "mlx_vector_int_get"); + mlx_vector_int_get_ptr = GET_SYM(handle, "mlx_vector_int_get"); if (mlx_vector_int_get_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_get\n"); return -1; } - mlx_vector_string_new_ptr = dlsym(handle, "mlx_vector_string_new"); + mlx_vector_string_new_ptr = GET_SYM(handle, "mlx_vector_string_new"); if (mlx_vector_string_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_new\n"); return -1; } - mlx_vector_string_set_ptr = dlsym(handle, "mlx_vector_string_set"); + mlx_vector_string_set_ptr = GET_SYM(handle, "mlx_vector_string_set"); if (mlx_vector_string_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_set\n"); return -1; } - mlx_vector_string_free_ptr = dlsym(handle, "mlx_vector_string_free"); + mlx_vector_string_free_ptr = GET_SYM(handle, "mlx_vector_string_free"); if (mlx_vector_string_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_free\n"); return -1; } - mlx_vector_string_new_data_ptr = dlsym(handle, "mlx_vector_string_new_data"); + mlx_vector_string_new_data_ptr = GET_SYM(handle, "mlx_vector_string_new_data"); if (mlx_vector_string_new_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_new_data\n"); return -1; } - mlx_vector_string_new_value_ptr = dlsym(handle, "mlx_vector_string_new_value"); + mlx_vector_string_new_value_ptr = GET_SYM(handle, "mlx_vector_string_new_value"); if (mlx_vector_string_new_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_new_value\n"); return -1; } - mlx_vector_string_set_data_ptr = dlsym(handle, "mlx_vector_string_set_data"); + mlx_vector_string_set_data_ptr = GET_SYM(handle, "mlx_vector_string_set_data"); if (mlx_vector_string_set_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_set_data\n"); return -1; } - mlx_vector_string_set_value_ptr = dlsym(handle, "mlx_vector_string_set_value"); + mlx_vector_string_set_value_ptr = GET_SYM(handle, "mlx_vector_string_set_value"); if (mlx_vector_string_set_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_set_value\n"); return -1; } - mlx_vector_string_append_data_ptr = dlsym(handle, "mlx_vector_string_append_data"); + mlx_vector_string_append_data_ptr = GET_SYM(handle, "mlx_vector_string_append_data"); if (mlx_vector_string_append_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_append_data\n"); return -1; } - mlx_vector_string_append_value_ptr = dlsym(handle, "mlx_vector_string_append_value"); + mlx_vector_string_append_value_ptr = GET_SYM(handle, "mlx_vector_string_append_value"); if (mlx_vector_string_append_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_append_value\n"); return -1; } - mlx_vector_string_size_ptr = dlsym(handle, "mlx_vector_string_size"); + mlx_vector_string_size_ptr = GET_SYM(handle, "mlx_vector_string_size"); if (mlx_vector_string_size_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_size\n"); return -1; } - mlx_vector_string_get_ptr = dlsym(handle, "mlx_vector_string_get"); + mlx_vector_string_get_ptr = GET_SYM(handle, "mlx_vector_string_get"); if (mlx_vector_string_get_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_get\n"); return -1; } - mlx_version_ptr = dlsym(handle, "mlx_version"); + mlx_version_ptr = GET_SYM(handle, "mlx_version"); if (mlx_version_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_version\n"); return -1; diff --git a/x/imagegen/mlx/mlx.go b/x/imagegen/mlx/mlx.go index cf3e51572..b529b9088 100644 --- a/x/imagegen/mlx/mlx.go +++ b/x/imagegen/mlx/mlx.go @@ -1,9 +1,7 @@ -//go:build mlx - package mlx /* -#cgo CFLAGS: -O3 -I${SRCDIR}/../../../build/_deps/mlx-c-src -I${SRCDIR} +#cgo CFLAGS: -O3 -I${SRCDIR}/../../mlxrunner/mlx/include -I${SRCDIR} #cgo darwin LDFLAGS: -lc++ -framework Metal -framework Foundation -framework Accelerate #cgo linux LDFLAGS: -lstdc++ -ldl #cgo windows LDFLAGS: -lstdc++ @@ -32,7 +30,7 @@ static inline void set_default_stream(mlx_stream s) { _default_stream = s; } -// CPU stream for file loading (Load primitive only runs on CPU) +// CPU stream for operations that only support CPU evaluation static inline mlx_stream cpu_stream() { if (_cpu_stream.ctx == NULL) { _cpu_stream = mlx_default_cpu_stream_new(); @@ -45,8 +43,11 @@ static inline mlx_stream cpu_stream() { // nocallback: function won't call back into Go */ import "C" + import ( "fmt" + "os" + "path/filepath" "reflect" "runtime" "sync" @@ -1502,15 +1503,21 @@ type SafetensorsFile struct { metadata C.mlx_map_string_to_string } -// LoadSafetensorsNative loads a safetensors file using MLX's optimized loader -// Note: Uses CPU stream because Load primitive only runs on CPU +// LoadSafetensorsNative loads a safetensors file using MLX's optimized loader. +// On CUDA, Load::eval_gpu is implemented so we use the default (GPU) stream. +// On Metal, Load::eval_gpu is not implemented so we must use the CPU stream. func LoadSafetensorsNative(path string) (*SafetensorsFile, error) { cPath := C.CString(path) defer C.free(unsafe.Pointer(cPath)) + stream := C.default_stream() + if runtime.GOOS == "darwin" { + stream = C.cpu_stream() + } + var arrays C.mlx_map_string_to_array var metadata C.mlx_map_string_to_string - if C.mlx_load_safetensors(&arrays, &metadata, cPath, C.cpu_stream()) != 0 { + if C.mlx_load_safetensors(&arrays, &metadata, cPath, stream) != 0 { return nil, fmt.Errorf("failed to load safetensors: %s", path) } return &SafetensorsFile{arrays: arrays, metadata: metadata}, nil @@ -1689,11 +1696,91 @@ func ArgmaxKeepArray(logits *Array) *Array { // // Thread safety: Protected by randomStateMu, mimicking Python's GIL behavior. // All random functions that use global state acquire this lock. -var RandomState = []*Array{nil} -var randomStateMu sync.Mutex +var ( + RandomState = []*Array{nil} + randomStateMu sync.Mutex +) -var mlxInitialized bool -var mlxInitError error +var ( + mlxInitialized bool + mlxInitError error +) + +// mlxLibName returns the platform-specific shared library filename. +func mlxLibName() string { + switch runtime.GOOS { + case "windows": + return "mlxc.dll" + case "darwin": + return "libmlxc.dylib" + default: + return "libmlxc.so" + } +} + +// findMLXLibrary searches for the MLX shared library in standard locations. +// Returns the path to the library, or empty string if not found. +func findMLXLibrary() string { + libName := mlxLibName() + + // 1. OLLAMA_LIBRARY_PATH — check each dir and mlx_* subdirs + if paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH"); ok { + for _, dir := range filepath.SplitList(paths) { + candidate := filepath.Join(dir, libName) + if _, err := os.Stat(candidate); err == nil { + return candidate + } + if mlxDirs, err := filepath.Glob(filepath.Join(dir, "mlx*")); err == nil { + for _, mlxDir := range mlxDirs { + candidate = filepath.Join(mlxDir, libName) + if _, err := os.Stat(candidate); err == nil { + return candidate + } + } + } + } + } + + // 2. Executable directory and lib/ollama/mlx* subdirs + if exe, err := os.Executable(); err == nil { + if eval, err := filepath.EvalSymlinks(exe); err == nil { + exe = eval + } + exeDir := filepath.Dir(exe) + + // Check exe dir directly (macOS copies dylib here) + candidate := filepath.Join(exeDir, libName) + if _, err := os.Stat(candidate); err == nil { + return candidate + } + + // Check exe_dir/lib/ollama/mlx* subdirectories + // and exe_dir/../lib/ollama/mlx* (standard bin/lib sibling layout) + for _, libOllamaDir := range []string{ + filepath.Join(exeDir, "lib", "ollama"), + filepath.Join(exeDir, "..", "lib", "ollama"), + } { + if mlxDirs, err := filepath.Glob(filepath.Join(libOllamaDir, "mlx*")); err == nil { + for _, mlxDir := range mlxDirs { + candidate = filepath.Join(mlxDir, libName) + if _, err := os.Stat(candidate); err == nil { + return candidate + } + } + } + } + } + + // 3. Build directory (for tests run from repo root) + if cwd, err := os.Getwd(); err == nil { + candidate := filepath.Join(cwd, "build", "lib", "ollama", libName) + if _, err := os.Stat(candidate); err == nil { + return candidate + } + } + + return "" +} // InitMLX initializes the MLX library by dynamically loading libmlxc. // This must be called before using any MLX functions. @@ -1703,9 +1790,16 @@ func InitMLX() error { return mlxInitError } - // Try to load the MLX dynamic library - ret := C.mlx_dynamic_init() - if ret != 0 { + // Search for the library using Go path discovery + libPath := findMLXLibrary() + if libPath == "" { + mlxInitError = fmt.Errorf("failed to initialize MLX: %s not found", mlxLibName()) + return mlxInitError + } + + cPath := C.CString(libPath) + defer C.free(unsafe.Pointer(cPath)) + if C.mlx_dynamic_init_path(cPath) != 0 { errMsg := C.GoString(C.mlx_dynamic_error()) mlxInitError = fmt.Errorf("failed to initialize MLX: %s", errMsg) return mlxInitError @@ -1713,8 +1807,7 @@ func InitMLX() error { // Initialize all function pointers via dlsym handle := C.mlx_get_handle() - ret = C.mlx_load_functions(handle) - if ret != 0 { + if C.mlx_load_functions(handle) != 0 { mlxInitError = fmt.Errorf("failed to load MLX function symbols") return mlxInitError } diff --git a/x/imagegen/mlx/mlx_dynamic.c b/x/imagegen/mlx/mlx_dynamic.c index aedef7a01..1281ec0e3 100644 --- a/x/imagegen/mlx/mlx_dynamic.c +++ b/x/imagegen/mlx/mlx_dynamic.c @@ -9,114 +9,76 @@ #ifdef _WIN32 #include typedef HMODULE lib_handle_t; -#define LOAD_LIB(path) LoadLibraryA(path) -#define GET_SYMBOL(handle, name) GetProcAddress(handle, name) -#define CLOSE_LIB(handle) FreeLibrary(handle) -#define LIB_ERROR() "LoadLibrary failed" +static char win_error_buffer[256] = {0}; +static const char* get_win_error(void) { + DWORD err = GetLastError(); + snprintf(win_error_buffer, sizeof(win_error_buffer), "error code %lu", err); + return win_error_buffer; +} +#define LIB_ERROR() get_win_error() #else #include typedef void* lib_handle_t; -#define LOAD_LIB(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL) -#define GET_SYMBOL(handle, name) dlsym(handle, name) -#define CLOSE_LIB(handle) dlclose(handle) #define LIB_ERROR() dlerror() -#ifdef __APPLE__ -#include -#include -#endif #endif static lib_handle_t mlx_handle = NULL; static int mlx_initialized = 0; static char mlx_error_buffer[512] = {0}; -#ifdef __APPLE__ -// Get path to library in same directory as executable -static char* get_exe_relative_path(const char* libname) { - static char path[1024]; - uint32_t size = sizeof(path); - if (_NSGetExecutablePath(path, &size) != 0) { - return NULL; +#ifdef _WIN32 +// Windows: Load library from a path with dependency resolution. +// Temporarily adds the library's directory to the DLL search path +// so that dependencies (like mlx.dll) in the same directory are found. +static int try_load_win(const char* path) { + if (!path) return 0; + + // Extract directory and add to DLL search path for dependency resolution + char dir_path[MAX_PATH]; + strncpy(dir_path, path, MAX_PATH - 1); + dir_path[MAX_PATH - 1] = '\0'; + char* last_slash = strrchr(dir_path, '\\'); + if (!last_slash) last_slash = strrchr(dir_path, '/'); + if (last_slash) { + *last_slash = '\0'; + SetDllDirectoryA(dir_path); } - // Get directory of executable - char* dir = dirname(path); - static char fullpath[1024]; - snprintf(fullpath, sizeof(fullpath), "%s/%s", dir, libname); - return fullpath; + + mlx_handle = LoadLibraryA(path); + SetDllDirectoryA(NULL); + return mlx_handle != NULL; } #endif // Try to load library from a specific path static int try_load_lib(const char* path) { if (!path) return 0; - mlx_handle = LOAD_LIB(path); +#ifdef _WIN32 + return try_load_win(path); +#else + mlx_handle = dlopen(path, RTLD_LAZY | RTLD_GLOBAL); return mlx_handle != NULL; +#endif } -// Initialize MLX dynamic library -// Returns 0 on success, -1 on failure -// On failure, call mlx_dynamic_error() to get error message -int mlx_dynamic_init(void) { +// Initialize the MLX dynamic library from a specific path. +// Returns 0 on success, -1 on failure. +int mlx_dynamic_init_path(const char* path) { if (mlx_initialized) { - return 0; // Already initialized + return 0; } - const char* lib_path = NULL; - const char* tried_paths[8] = {0}; - int num_tried = 0; - -#ifdef _WIN32 - // Windows: try same directory as executable - lib_path = "libmlxc.dll"; - tried_paths[num_tried++] = lib_path; - if (try_load_lib(lib_path)) goto success; -#elif defined(__APPLE__) - // macOS: try executable directory first - lib_path = get_exe_relative_path("libmlxc.dylib"); - if (lib_path) { - tried_paths[num_tried++] = lib_path; - if (try_load_lib(lib_path)) goto success; + if (try_load_lib(path)) { + mlx_initialized = 1; + snprintf(mlx_error_buffer, sizeof(mlx_error_buffer), + "MLX: Successfully loaded %s", path ? path : "library"); + return 0; } - // Try build directory (for tests run from repo root) - lib_path = "./build/lib/ollama/libmlxc.dylib"; - tried_paths[num_tried++] = lib_path; - if (try_load_lib(lib_path)) goto success; - // Fallback to system paths - lib_path = "libmlxc.dylib"; - tried_paths[num_tried++] = lib_path; - if (try_load_lib(lib_path)) goto success; -#else - // Linux: try build directory first (for tests) - lib_path = "./build/lib/ollama/libmlxc.so"; - tried_paths[num_tried++] = lib_path; - if (try_load_lib(lib_path)) goto success; - // Fallback to system paths - lib_path = "libmlxc.so"; - tried_paths[num_tried++] = lib_path; - if (try_load_lib(lib_path)) goto success; -#endif - // Failed to load library - build error message with all tried paths - { - const char* err = LIB_ERROR(); - int offset = snprintf(mlx_error_buffer, sizeof(mlx_error_buffer), - "MLX: Failed to load libmlxc library. Tried: "); - for (int i = 0; i < num_tried && offset < (int)sizeof(mlx_error_buffer) - 50; i++) { - offset += snprintf(mlx_error_buffer + offset, sizeof(mlx_error_buffer) - offset, - "%s%s", i > 0 ? ", " : "", tried_paths[i]); - } - if (err) { - snprintf(mlx_error_buffer + offset, sizeof(mlx_error_buffer) - offset, - ". Last error: %s", err); - } - } - return -1; - -success: - mlx_initialized = 1; + const char* err = LIB_ERROR(); snprintf(mlx_error_buffer, sizeof(mlx_error_buffer), - "MLX: Successfully loaded %s", lib_path ? lib_path : "library"); - return 0; + "MLX: Failed to load %s: %s", path ? path : "(null)", err ? err : "unknown error"); + return -1; } // Get the last error message @@ -124,21 +86,8 @@ const char* mlx_dynamic_error(void) { return mlx_error_buffer; } -// Check if MLX is initialized -int mlx_dynamic_is_initialized(void) { - return mlx_initialized; -} - // Get the library handle (for use by generated wrappers) void* mlx_get_handle(void) { return mlx_handle; } -// Cleanup (optional, called at program exit) -void mlx_dynamic_cleanup(void) { - if (mlx_handle != NULL) { - CLOSE_LIB(mlx_handle); - mlx_handle = NULL; - mlx_initialized = 0; - } -} diff --git a/x/imagegen/mlx/mlx_dynamic.h b/x/imagegen/mlx/mlx_dynamic.h index 9ca1473f9..3f4d0fd74 100644 --- a/x/imagegen/mlx/mlx_dynamic.h +++ b/x/imagegen/mlx/mlx_dynamic.h @@ -6,22 +6,16 @@ extern "C" { #endif -// Initialize the MLX dynamic library +// Initialize the MLX dynamic library from a specific path // Returns 0 on success, -1 on failure -int mlx_dynamic_init(void); +int mlx_dynamic_init_path(const char* path); // Get the last error message from dynamic loading const char* mlx_dynamic_error(void); -// Check if MLX is initialized -int mlx_dynamic_is_initialized(void); - // Get the library handle (for use by generated wrappers) void* mlx_get_handle(void); -// Cleanup resources (optional, for clean shutdown) -void mlx_dynamic_cleanup(void); - #ifdef __cplusplus } #endif diff --git a/x/imagegen/mlx/mlx_test.go b/x/imagegen/mlx/mlx_test.go index 37b3ac63b..82221cab7 100644 --- a/x/imagegen/mlx/mlx_test.go +++ b/x/imagegen/mlx/mlx_test.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx import ( diff --git a/x/imagegen/models/flux2/flux2.go b/x/imagegen/models/flux2/flux2.go index 894af41f8..908de3c87 100644 --- a/x/imagegen/models/flux2/flux2.go +++ b/x/imagegen/models/flux2/flux2.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package flux2 implements the FLUX.2 Klein diffusion transformer model. // Klein is a 4B parameter distilled model that supports sub-second inference. package flux2 diff --git a/x/imagegen/models/flux2/rope.go b/x/imagegen/models/flux2/rope.go index c349e7010..bc245b7d8 100644 --- a/x/imagegen/models/flux2/rope.go +++ b/x/imagegen/models/flux2/rope.go @@ -1,5 +1,3 @@ -//go:build mlx - package flux2 import ( diff --git a/x/imagegen/models/flux2/scheduler.go b/x/imagegen/models/flux2/scheduler.go index aba3c871f..033ee8c3c 100644 --- a/x/imagegen/models/flux2/scheduler.go +++ b/x/imagegen/models/flux2/scheduler.go @@ -1,5 +1,3 @@ -//go:build mlx - package flux2 import ( diff --git a/x/imagegen/models/flux2/transformer.go b/x/imagegen/models/flux2/transformer.go index 93771a661..3a48f27b7 100644 --- a/x/imagegen/models/flux2/transformer.go +++ b/x/imagegen/models/flux2/transformer.go @@ -1,5 +1,3 @@ -//go:build mlx - package flux2 import ( diff --git a/x/imagegen/models/flux2/vae.go b/x/imagegen/models/flux2/vae.go index 4b09b1ba4..057523d17 100644 --- a/x/imagegen/models/flux2/vae.go +++ b/x/imagegen/models/flux2/vae.go @@ -1,5 +1,3 @@ -//go:build mlx - package flux2 import ( diff --git a/x/imagegen/models/qwen3/text_encoder.go b/x/imagegen/models/qwen3/text_encoder.go index de32bd347..46b5a5e29 100644 --- a/x/imagegen/models/qwen3/text_encoder.go +++ b/x/imagegen/models/qwen3/text_encoder.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package qwen3 provides a shared Qwen3 text encoder used by multiple image generation models. package qwen3 diff --git a/x/imagegen/models/zimage/scheduler.go b/x/imagegen/models/zimage/scheduler.go index 6c474f5a4..f5e1ccc46 100644 --- a/x/imagegen/models/zimage/scheduler.go +++ b/x/imagegen/models/zimage/scheduler.go @@ -1,5 +1,3 @@ -//go:build mlx - package zimage import ( diff --git a/x/imagegen/models/zimage/text_encoder.go b/x/imagegen/models/zimage/text_encoder.go index 2c2688c31..65d9ab596 100644 --- a/x/imagegen/models/zimage/text_encoder.go +++ b/x/imagegen/models/zimage/text_encoder.go @@ -1,5 +1,3 @@ -//go:build mlx - package zimage import ( diff --git a/x/imagegen/models/zimage/transformer.go b/x/imagegen/models/zimage/transformer.go index 2c42d8c25..b27064507 100644 --- a/x/imagegen/models/zimage/transformer.go +++ b/x/imagegen/models/zimage/transformer.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package zimage implements the Z-Image diffusion transformer model. package zimage diff --git a/x/imagegen/models/zimage/vae.go b/x/imagegen/models/zimage/vae.go index aca2b1bfc..a31ec210b 100644 --- a/x/imagegen/models/zimage/vae.go +++ b/x/imagegen/models/zimage/vae.go @@ -1,5 +1,3 @@ -//go:build mlx - package zimage import ( diff --git a/x/imagegen/models/zimage/zimage.go b/x/imagegen/models/zimage/zimage.go index e7ce8436d..4058819c7 100644 --- a/x/imagegen/models/zimage/zimage.go +++ b/x/imagegen/models/zimage/zimage.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package zimage implements the Z-Image diffusion transformer model. package zimage diff --git a/x/imagegen/nn/nn.go b/x/imagegen/nn/nn.go index d72474358..0a08f05be 100644 --- a/x/imagegen/nn/nn.go +++ b/x/imagegen/nn/nn.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package nn provides neural network layer types. package nn diff --git a/x/imagegen/nn/nn_test.go b/x/imagegen/nn/nn_test.go index 00e69ccb0..dbc5e2b30 100644 --- a/x/imagegen/nn/nn_test.go +++ b/x/imagegen/nn/nn_test.go @@ -1,5 +1,3 @@ -//go:build mlx - package nn import ( diff --git a/x/imagegen/runner.go b/x/imagegen/runner.go index 0409c4bf7..d92b59059 100644 --- a/x/imagegen/runner.go +++ b/x/imagegen/runner.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package imagegen provides a unified MLX runner for both LLM and image generation models. package imagegen diff --git a/x/imagegen/runner_stub.go b/x/imagegen/runner_stub.go deleted file mode 100644 index 866a4408c..000000000 --- a/x/imagegen/runner_stub.go +++ /dev/null @@ -1,10 +0,0 @@ -//go:build !mlx - -package imagegen - -import "errors" - -// Execute returns an error when not built with MLX support. -func Execute(args []string) error { - return errors.New("MLX runner not available: build with mlx tag") -} diff --git a/x/imagegen/safetensors/loader.go b/x/imagegen/safetensors/loader.go index fbd443e05..e6e74929b 100644 --- a/x/imagegen/safetensors/loader.go +++ b/x/imagegen/safetensors/loader.go @@ -1,5 +1,3 @@ -//go:build mlx - package safetensors import ( diff --git a/x/imagegen/safetensors/safetensors.go b/x/imagegen/safetensors/safetensors.go index 4dbcf59a3..df7b52465 100644 --- a/x/imagegen/safetensors/safetensors.go +++ b/x/imagegen/safetensors/safetensors.go @@ -1,5 +1,3 @@ -//go:build mlx - package safetensors import ( diff --git a/x/imagegen/safetensors/safetensors_test.go b/x/imagegen/safetensors/safetensors_test.go index f00268751..5f3e10d71 100644 --- a/x/imagegen/safetensors/safetensors_test.go +++ b/x/imagegen/safetensors/safetensors_test.go @@ -1,5 +1,3 @@ -//go:build mlx - package safetensors import ( diff --git a/x/imagegen/tokenizer/tokenizer.go b/x/imagegen/tokenizer/tokenizer.go index bf8ff63af..d2f1aac18 100644 --- a/x/imagegen/tokenizer/tokenizer.go +++ b/x/imagegen/tokenizer/tokenizer.go @@ -1,5 +1,3 @@ -//go:build mlx - // tokenizer.go - BPE and SentencePiece tokenizer for HuggingFace models // // Based on standard BPE algorithm (Sennrich et al. 2015) with: diff --git a/x/imagegen/tokenizer/tokenizer_test.go b/x/imagegen/tokenizer/tokenizer_test.go index 2ac79ab1e..a72e447c6 100644 --- a/x/imagegen/tokenizer/tokenizer_test.go +++ b/x/imagegen/tokenizer/tokenizer_test.go @@ -1,5 +1,3 @@ -//go:build mlx - package tokenizer import ( diff --git a/x/imagegen/vae/tiling.go b/x/imagegen/vae/tiling.go index 1babfef98..fcb5af701 100644 --- a/x/imagegen/vae/tiling.go +++ b/x/imagegen/vae/tiling.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package vae provides shared utilities for VAE (Variational Autoencoder) operations. package vae diff --git a/x/mlxrunner/cache.go b/x/mlxrunner/cache.go index a9ff8904c..0216ffeaa 100644 --- a/x/mlxrunner/cache.go +++ b/x/mlxrunner/cache.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlxrunner import ( diff --git a/x/mlxrunner/cache/cache.go b/x/mlxrunner/cache/cache.go index 7d0d0b060..a452fbcb2 100644 --- a/x/mlxrunner/cache/cache.go +++ b/x/mlxrunner/cache/cache.go @@ -1,5 +1,3 @@ -//go:build mlx - package cache import ( diff --git a/x/mlxrunner/cache/recurrent.go b/x/mlxrunner/cache/recurrent.go index 0cbbc01e2..86c592be5 100644 --- a/x/mlxrunner/cache/recurrent.go +++ b/x/mlxrunner/cache/recurrent.go @@ -1,5 +1,3 @@ -//go:build mlx - package cache import "github.com/ollama/ollama/x/mlxrunner/mlx" diff --git a/x/mlxrunner/client.go b/x/mlxrunner/client.go index f1a0e4cca..2f18105af 100644 --- a/x/mlxrunner/client.go +++ b/x/mlxrunner/client.go @@ -72,14 +72,23 @@ func NewClient(modelName string) (*Client, error) { cmd := exec.Command(exe, "runner", "--mlx-engine", "--model", modelName, "--port", strconv.Itoa(port)) cmd.Env = os.Environ() - // On Linux, set LD_LIBRARY_PATH to include MLX library directories - if runtime.GOOS == "linux" { + // Set library path environment variable for MLX libraries + // Linux: LD_LIBRARY_PATH, Windows: PATH + var libPathEnvVar string + switch runtime.GOOS { + case "linux": + libPathEnvVar = "LD_LIBRARY_PATH" + case "windows": + libPathEnvVar = "PATH" + } + + if libPathEnvVar != "" { libraryPaths := []string{ml.LibOllamaPath} if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil { libraryPaths = append(libraryPaths, mlxDirs...) } - if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok { + if existingPath, ok := os.LookupEnv(libPathEnvVar); ok { libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...) } @@ -87,16 +96,20 @@ func NewClient(modelName string) (*Client, error) { found := false for i := range cmd.Env { - if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") { - cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal + envName := cmd.Env[i] + if runtime.GOOS == "windows" { + envName = strings.ToUpper(envName) + } + if strings.HasPrefix(envName, libPathEnvVar+"=") { + cmd.Env[i] = libPathEnvVar + "=" + pathEnvVal found = true break } } if !found { - cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal) + cmd.Env = append(cmd.Env, libPathEnvVar+"="+pathEnvVal) } - slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal) + slog.Debug("mlx subprocess library path", libPathEnvVar, pathEnvVal) } c := &Client{ diff --git a/x/mlxrunner/imports.go b/x/mlxrunner/imports.go index 9202dac34..6b6394d60 100644 --- a/x/mlxrunner/imports.go +++ b/x/mlxrunner/imports.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlxrunner import ( diff --git a/x/mlxrunner/mlx/CMakeLists.txt b/x/mlxrunner/mlx/CMakeLists.txt index 1ca13bdaf..9825c441b 100644 --- a/x/mlxrunner/mlx/CMakeLists.txt +++ b/x/mlxrunner/mlx/CMakeLists.txt @@ -24,3 +24,7 @@ FetchContent_Declare( ) FetchContent_MakeAvailable(mlx-c) + +# Sync vendored headers with fetched version +file(GLOB _mlx_c_hdrs "${mlx-c_SOURCE_DIR}/mlx/c/*.h") +file(COPY ${_mlx_c_hdrs} DESTINATION "${CMAKE_CURRENT_SOURCE_DIR}/include/mlx/c/") diff --git a/x/mlxrunner/mlx/act.go b/x/mlxrunner/mlx/act.go index 3134a127a..ce0e48eda 100644 --- a/x/mlxrunner/mlx/act.go +++ b/x/mlxrunner/mlx/act.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx // #include "generated.h" diff --git a/x/mlxrunner/mlx/array.go b/x/mlxrunner/mlx/array.go index 6047aacec..de91813fc 100644 --- a/x/mlxrunner/mlx/array.go +++ b/x/mlxrunner/mlx/array.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx // #include "generated.h" diff --git a/x/mlxrunner/mlx/array_test.go b/x/mlxrunner/mlx/array_test.go index aab5db7ba..bc6a4ca4a 100644 --- a/x/mlxrunner/mlx/array_test.go +++ b/x/mlxrunner/mlx/array_test.go @@ -1,10 +1,16 @@ -//go:build mlx - package mlx import "testing" +func skipIfNoMLX(t *testing.T) { + t.Helper() + if err := CheckInit(); err != nil { + t.Skipf("MLX not available: %v", err) + } +} + func TestFromValue(t *testing.T) { + skipIfNoMLX(t) for got, want := range map[*Array]DType{ FromValue(true): DTypeBool, FromValue(false): DTypeBool, @@ -22,6 +28,7 @@ func TestFromValue(t *testing.T) { } func TestFromValues(t *testing.T) { + skipIfNoMLX(t) for got, want := range map[*Array]DType{ FromValues([]bool{true, false, true}, 3): DTypeBool, FromValues([]uint8{1, 2, 3}, 3): DTypeUint8, diff --git a/x/mlxrunner/mlx/dtype.go b/x/mlxrunner/mlx/dtype.go index 95237c792..b0a0ce6c1 100644 --- a/x/mlxrunner/mlx/dtype.go +++ b/x/mlxrunner/mlx/dtype.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx // #include "generated.h" diff --git a/x/mlxrunner/mlx/dynamic.go b/x/mlxrunner/mlx/dynamic.go index a1286da59..38f825d24 100644 --- a/x/mlxrunner/mlx/dynamic.go +++ b/x/mlxrunner/mlx/dynamic.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx // #include "dynamic.h" @@ -24,10 +22,16 @@ func CheckInit() error { return initError } -// tryLoadFromDir searches a directory for libmlxc.* and tries to load it. +// tryLoadFromDir searches a directory for the mlxc shared library and tries to load it. // Returns true if the library was successfully loaded. func tryLoadFromDir(dir string) bool { - matches, err := fs.Glob(os.DirFS(dir), "libmlxc.*") + // On Windows, MSVC produces mlxc.dll (no lib prefix) + // On Unix, it's libmlxc.so or libmlxc.dylib + pattern := "libmlxc.*" + if runtime.GOOS == "windows" { + pattern = "mlxc.*" + } + matches, err := fs.Glob(os.DirFS(dir), pattern) if err != nil || len(matches) == 0 { return false } @@ -60,7 +64,10 @@ func tryLoadFromDir(dir string) bool { // Returns true if the library was successfully loaded. func tryLoadByName() bool { libraryName := "libmlxc.dylib" - if runtime.GOOS == "linux" { + switch runtime.GOOS { + case "windows": + libraryName = "mlxc.dll" + case "linux": libraryName = "libmlxc.so" } @@ -81,19 +88,25 @@ func tryLoadByName() bool { func init() { switch runtime.GOOS { - case "darwin": + case "darwin", "linux", "windows": - case "windows": default: return } - // Try OLLAMA_LIBRARY_PATH first + // Try OLLAMA_LIBRARY_PATH first, including mlx_* subdirectories if paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH"); ok { for _, dir := range filepath.SplitList(paths) { if tryLoadFromDir(dir) { return } + if mlxDirs, err := filepath.Glob(filepath.Join(dir, "mlx_*")); err == nil { + for _, mlxDir := range mlxDirs { + if tryLoadFromDir(mlxDir) { + return + } + } + } } } @@ -115,12 +128,21 @@ func init() { searchDirs = append(searchDirs, filepath.Join(cwd, "build", "lib", "ollama")) } + // Also scan mlx_* subdirectories within each search dir + var expanded []string for _, dir := range searchDirs { + expanded = append(expanded, dir) + if mlxDirs, err := filepath.Glob(filepath.Join(dir, "mlx_*")); err == nil { + expanded = append(expanded, mlxDirs...) + } + } + + for _, dir := range expanded { if tryLoadFromDir(dir) { return } } initError = fmt.Errorf("failed to load MLX dynamic library (searched: %v)", searchDirs) - slog.Warn("MLX dynamic library not available", "error", initError) + slog.Debug("MLX dynamic library not available", "error", initError) } diff --git a/x/mlxrunner/mlx/dynamic.h b/x/mlxrunner/mlx/dynamic.h index f93d8fab7..f29825ce6 100644 --- a/x/mlxrunner/mlx/dynamic.h +++ b/x/mlxrunner/mlx/dynamic.h @@ -3,7 +3,7 @@ #ifdef _WIN32 #include -#define DLSYM(handle, symbol) GetProcAddress((HMODULE)(handle), symbol) +#define DLSYM(handle, symbol) (void*)GetProcAddress((HMODULE)(handle.ctx), symbol) #else #include #define DLSYM(handle, symbol) dlsym(handle.ctx, symbol) @@ -23,9 +23,15 @@ typedef uint16_t float16_t; typedef uint16_t bfloat16_t; #endif -#define ERROR(fmt, ...) fprintf(stderr, "%s %s - ERROR - %s:%d - " fmt "\n", __DATE__, __TIME__, __FILE__, __LINE__, ##__VA_ARGS__); return 1 -#define CHECK(x) if (!(x)) { ERROR("CHECK failed: " #x); } -#define CHECK_LOAD(handle, x) x##_ = DLSYM(handle, #x); CHECK(x##_) +// Undef ERROR to avoid conflict with wingdi.h on Windows +#ifdef ERROR +#undef ERROR +#endif +#define MLX_ERROR(fmt, ...) fprintf(stderr, "%s %s - ERROR - %s:%d - " fmt "\n", __DATE__, __TIME__, __FILE__, __LINE__, ##__VA_ARGS__); return 1 +#define CHECK(x) if (!(x)) { MLX_ERROR("CHECK failed: " #x); } +#define CHECK_LOAD(handle, x) *(void**)(&x##_) = DLSYM(handle, #x); CHECK(x##_) +// OPTIONAL_LOAD: load symbol if available, leave function pointer NULL otherwise +#define OPTIONAL_LOAD(handle, x) *(void**)(&x##_) = DLSYM(handle, #x) typedef struct { void* ctx; diff --git a/x/mlxrunner/mlx/fast.go b/x/mlxrunner/mlx/fast.go index 0570840d6..7feca3b1e 100644 --- a/x/mlxrunner/mlx/fast.go +++ b/x/mlxrunner/mlx/fast.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx // #include "generated.h" diff --git a/x/mlxrunner/mlx/gated_delta.go b/x/mlxrunner/mlx/gated_delta.go index 7ace1f6d3..31550cef1 100644 --- a/x/mlxrunner/mlx/gated_delta.go +++ b/x/mlxrunner/mlx/gated_delta.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx // #include diff --git a/x/mlxrunner/mlx/generated.c b/x/mlxrunner/mlx/generated.c index 29d1330af..ecf9d30c8 100644 --- a/x/mlxrunner/mlx/generated.c +++ b/x/mlxrunner/mlx/generated.c @@ -2299,8 +2299,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_array_item_float32); CHECK_LOAD(handle, mlx_array_item_float64); CHECK_LOAD(handle, mlx_array_item_complex64); - CHECK_LOAD(handle, mlx_array_item_float16); - CHECK_LOAD(handle, mlx_array_item_bfloat16); + OPTIONAL_LOAD(handle, mlx_array_item_float16); + OPTIONAL_LOAD(handle, mlx_array_item_bfloat16); CHECK_LOAD(handle, mlx_array_data_bool); CHECK_LOAD(handle, mlx_array_data_uint8); CHECK_LOAD(handle, mlx_array_data_uint16); @@ -2313,8 +2313,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_array_data_float32); CHECK_LOAD(handle, mlx_array_data_float64); CHECK_LOAD(handle, mlx_array_data_complex64); - CHECK_LOAD(handle, mlx_array_data_float16); - CHECK_LOAD(handle, mlx_array_data_bfloat16); + OPTIONAL_LOAD(handle, mlx_array_data_float16); + OPTIONAL_LOAD(handle, mlx_array_data_bfloat16); CHECK_LOAD(handle, _mlx_array_is_available); CHECK_LOAD(handle, _mlx_array_wait); CHECK_LOAD(handle, _mlx_array_is_contiguous); diff --git a/x/mlxrunner/mlx/generator/generated.c.gotmpl b/x/mlxrunner/mlx/generator/generated.c.gotmpl index c31b34a76..227589aa8 100644 --- a/x/mlxrunner/mlx/generator/generated.c.gotmpl +++ b/x/mlxrunner/mlx/generator/generated.c.gotmpl @@ -11,7 +11,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { {{- range .Functions }} - CHECK_LOAD(handle, {{ .Name }}); + {{ if .Optional }}OPTIONAL_LOAD{{ else }}CHECK_LOAD{{ end }}(handle, {{ .Name }}); {{- end }} return 0; } diff --git a/x/mlxrunner/mlx/generator/main.go b/x/mlxrunner/mlx/generator/main.go index a98046a2f..d1203add4 100644 --- a/x/mlxrunner/mlx/generator/main.go +++ b/x/mlxrunner/mlx/generator/main.go @@ -17,11 +17,21 @@ import ( //go:embed *.gotmpl var fsys embed.FS +// optionalSymbols lists symbols that may not be present in all builds +// (e.g., float16/bfloat16 are unavailable in CUDA builds of MLX). +var optionalSymbols = map[string]bool{ + "mlx_array_item_float16": true, + "mlx_array_item_bfloat16": true, + "mlx_array_data_float16": true, + "mlx_array_data_bfloat16": true, +} + type Function struct { Type, Name, Parameters, Args string + Optional bool } func ParseFunction(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) Function { @@ -104,7 +114,9 @@ func main() { matches := qc.Matches(query, tree.RootNode(), bts) for match := matches.Next(); match != nil; match = matches.Next() { for _, capture := range match.Captures { - funs = append(funs, ParseFunction(&capture.Node, tc, bts)) + fn := ParseFunction(&capture.Node, tc, bts) + fn.Optional = optionalSymbols[fn.Name] + funs = append(funs, fn) } } } diff --git a/x/mlxrunner/mlx/include/mlx/c/README.md b/x/mlxrunner/mlx/include/mlx/c/README.md new file mode 100644 index 000000000..905ca451c --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/README.md @@ -0,0 +1,12 @@ +# Vendored MLX-C Headers + +These header files are vendored from [mlx-c](https://github.com/ml-explore/mlx-c). +The pinned version is in `MLX_VERSION` at the repo root. + +Headers are automatically refreshed when you run a CMake build: + +```shell +cmake --preset 'MLX CUDA 13' +``` + +See the [MLX Engine](../../../../../../../docs/development.md#mlx-engine-optional) section of the development docs for full build instructions. diff --git a/x/mlxrunner/mlx/include/mlx/c/array.h b/x/mlxrunner/mlx/include/mlx/c/array.h new file mode 100644 index 000000000..a3b382bb2 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/array.h @@ -0,0 +1,420 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_ARRAY_H +#define MLX_ARRAY_H + +#include "mlx/c/string.h" + +#include +#include +#include +#include + +// Complex number support +#ifdef _MSC_VER +#define _CRT_USE_C_COMPLEX_H +#include +typedef _Fcomplex mlx_complex64_t; +#else +#include +typedef float _Complex mlx_complex64_t; +#endif + +#include "half.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_array Array + * MLX N-dimensional array object. + */ +/**@{*/ + +/** + * A N-dimensional array object. + */ +typedef struct mlx_array_ { + void* ctx; +} mlx_array; + +static mlx_array mlx_array_empty; + +/** + * Array element type. + */ +typedef enum mlx_dtype_ { + MLX_BOOL, + MLX_UINT8, + MLX_UINT16, + MLX_UINT32, + MLX_UINT64, + MLX_INT8, + MLX_INT16, + MLX_INT32, + MLX_INT64, + MLX_FLOAT16, + MLX_FLOAT32, + MLX_FLOAT64, + MLX_BFLOAT16, + MLX_COMPLEX64, +} mlx_dtype; + +/** + * Size of given mlx_dtype datatype in bytes. + */ +size_t mlx_dtype_size(mlx_dtype dtype); + +/** + * Get array description. + */ +int mlx_array_tostring(mlx_string* str, const mlx_array arr); + +/** + * New empty array. + */ +mlx_array mlx_array_new(void); + +/** + * Free an array. + */ +int mlx_array_free(mlx_array arr); + +/** + * New array from a bool scalar. + */ +mlx_array mlx_array_new_bool(bool val); +/** + * New array from a int scalar. + */ +mlx_array mlx_array_new_int(int val); +/** + * New array from a float32 scalar. + */ +mlx_array mlx_array_new_float32(float val); +/** + * New array from a float scalar. + * Same as float32. + */ +mlx_array mlx_array_new_float(float val); +/** + * New array from a float64 scalar. + */ +mlx_array mlx_array_new_float64(double val); +/** + * New array from a double scalar. + * Same as float64. + */ +mlx_array mlx_array_new_double(double val); +/** + * New array from a complex scalar. + */ +mlx_array mlx_array_new_complex(float real_val, float imag_val); +/** + * New array from existing buffer. + * @param data A buffer which will be copied. + * @param shape Shape of the array. + * @param dim Number of dimensions (size of `shape`). + * @param dtype Type of array elements. + */ +mlx_array mlx_array_new_data( + const void* data, + const int* shape, + int dim, + mlx_dtype dtype); +/** + * New array from existing buffer. + * @param data A buffer which will be copied. + * @param shape Shape of the array. + * @param dim Number of dimensions (size of `shape`). + * @param dtype Type of array elements. + * @param dtor Callback for when the buffer is no longer needed. + */ +mlx_array mlx_array_new_data_managed( + void* data, + const int* shape, + int dim, + mlx_dtype dtype, + void (*dtor)(void*)); +/** + * New array from existing buffer. + * @param data A buffer which will be copied. + * @param shape Shape of the array. + * @param dim Number of dimensions (size of `shape`). + * @param dtype Type of array elements. + * @param payload Payload pointer passed to the `dtor` callback instead of + * `data`. + * @param dtor Callback for when the buffer is no longer needed. + */ +mlx_array mlx_array_new_data_managed_payload( + void* data, + const int* shape, + int dim, + mlx_dtype dtype, + void* payload, + void (*dtor)(void*)); +/** + * Set array to provided src array. + */ +int mlx_array_set(mlx_array* arr, const mlx_array src); +/** + * Set array to a bool scalar. + */ +int mlx_array_set_bool(mlx_array* arr, bool val); +/** + * Set array to a int scalar. + */ +int mlx_array_set_int(mlx_array* arr, int val); +/** + * Set array to a float32 scalar. + */ +int mlx_array_set_float32(mlx_array* arr, float val); +/** + * Set array to a float scalar. + */ +int mlx_array_set_float(mlx_array* arr, float val); +/** + * Set array to a float64 scalar. + */ +int mlx_array_set_float64(mlx_array* arr, double val); +/** + * Set array to a double scalar. + */ +int mlx_array_set_double(mlx_array* arr, double val); +/** + * Set array to a complex scalar. + */ +int mlx_array_set_complex(mlx_array* arr, float real_val, float imag_val); +/** + * Set array to specified data and shape. + * @param arr Destination array. + * @param data A buffer which will be copied. + * @param shape Shape of the array. + * @param dim Number of dimensions (size of `shape`). + * @param dtype Type of array elements. + */ +int mlx_array_set_data( + mlx_array* arr, + const void* data, + const int* shape, + int dim, + mlx_dtype dtype); + +/** + * The size of the array's datatype in bytes. + */ +size_t mlx_array_itemsize(const mlx_array arr); +/** + * Number of elements in the array. + */ +size_t mlx_array_size(const mlx_array arr); +/** + * The number of bytes in the array. + */ +size_t mlx_array_nbytes(const mlx_array arr); +/** + * The array's dimension. + */ +size_t mlx_array_ndim(const mlx_array arr); +/** + * The shape of the array. + * Returns: a pointer to the sizes of each dimension. + */ +const int* mlx_array_shape(const mlx_array arr); +/** + * The strides of the array. + * Returns: a pointer to the sizes of each dimension. + */ +const size_t* mlx_array_strides(const mlx_array arr); +/** + * The shape of the array in a particular dimension. + */ +int mlx_array_dim(const mlx_array arr, int dim); +/** + * The array element type. + */ +mlx_dtype mlx_array_dtype(const mlx_array arr); + +/** + * Evaluate the array. + */ +int mlx_array_eval(mlx_array arr); + +/** + * Access the value of a scalar array. + */ +int mlx_array_item_bool(bool* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_uint8(uint8_t* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_uint16(uint16_t* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_uint32(uint32_t* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_uint64(uint64_t* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_int8(int8_t* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_int16(int16_t* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_int32(int32_t* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_int64(int64_t* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_float32(float* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_float64(double* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr); + +#ifdef HAS_FLOAT16 +/** + * Access the value of a scalar array. + */ +int mlx_array_item_float16(float16_t* res, const mlx_array arr); +#endif + +#ifdef HAS_BFLOAT16 +/** + * Access the value of a scalar array. + */ +int mlx_array_item_bfloat16(bfloat16_t* res, const mlx_array arr); +#endif + +/** + * Returns a pointer to the array data, cast to `bool*`. + * Array must be evaluated, otherwise returns NULL. + */ +const bool* mlx_array_data_bool(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `uint8_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const uint8_t* mlx_array_data_uint8(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `uint16_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const uint16_t* mlx_array_data_uint16(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `uint32_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const uint32_t* mlx_array_data_uint32(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `uint64_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const uint64_t* mlx_array_data_uint64(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `int8_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const int8_t* mlx_array_data_int8(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `int16_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const int16_t* mlx_array_data_int16(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `int32_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const int32_t* mlx_array_data_int32(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `int64_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const int64_t* mlx_array_data_int64(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `float32*`. + * Array must be evaluated, otherwise returns NULL. + */ +const float* mlx_array_data_float32(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `float64*`. + * Array must be evaluated, otherwise returns NULL. + */ +const double* mlx_array_data_float64(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `_Complex*`. + * Array must be evaluated, otherwise returns NULL. + */ +const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr); + +#ifdef HAS_FLOAT16 +/** + * Returns a pointer to the array data, cast to `float16_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const float16_t* mlx_array_data_float16(const mlx_array arr); +#endif + +#ifdef HAS_BFLOAT16 +/** + * Returns a pointer to the array data, cast to `bfloat16_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const bfloat16_t* mlx_array_data_bfloat16(const mlx_array arr); +#endif + +/** + * Check if the array is available. + * Internal function: use at your own risk. + */ +int _mlx_array_is_available(bool* res, const mlx_array arr); + +/** + * Wait on the array to be available. After this `_mlx_array_is_available` + * returns `true`. Internal function: use at your own risk. + */ +int _mlx_array_wait(const mlx_array arr); + +/** + * Whether the array is contiguous in memory. + * Internal function: use at your own risk. + */ +int _mlx_array_is_contiguous(bool* res, const mlx_array arr); + +/** + * Whether the array's rows are contiguous in memory. + * Internal function: use at your own risk. + */ +int _mlx_array_is_row_contiguous(bool* res, const mlx_array arr); + +/** + * Whether the array's columns are contiguous in memory. + * Internal function: use at your own risk. + */ +int _mlx_array_is_col_contiguous(bool* res, const mlx_array arr); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/closure.h b/x/mlxrunner/mlx/include/mlx/c/closure.h new file mode 100644 index 000000000..33f711572 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/closure.h @@ -0,0 +1,197 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_CLOSURE_H +#define MLX_CLOSURE_H + +#include "mlx/c/array.h" +#include "mlx/c/map.h" +#include "mlx/c/optional.h" +#include "mlx/c/stream.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_closure Closures + * MLX closure objects. + */ +/**@{*/ + +typedef struct mlx_closure_ { + void* ctx; +} mlx_closure; +mlx_closure mlx_closure_new(void); +int mlx_closure_free(mlx_closure cls); +mlx_closure mlx_closure_new_func( + int (*fun)(mlx_vector_array*, const mlx_vector_array)); +mlx_closure mlx_closure_new_func_payload( + int (*fun)(mlx_vector_array*, const mlx_vector_array, void*), + void* payload, + void (*dtor)(void*)); +int mlx_closure_set(mlx_closure* cls, const mlx_closure src); +int mlx_closure_apply( + mlx_vector_array* res, + mlx_closure cls, + const mlx_vector_array input); + +mlx_closure mlx_closure_new_unary(int (*fun)(mlx_array*, const mlx_array)); + +typedef struct mlx_closure_kwargs_ { + void* ctx; +} mlx_closure_kwargs; +mlx_closure_kwargs mlx_closure_kwargs_new(void); +int mlx_closure_kwargs_free(mlx_closure_kwargs cls); +mlx_closure_kwargs mlx_closure_kwargs_new_func( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_map_string_to_array)); +mlx_closure_kwargs mlx_closure_kwargs_new_func_payload( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_map_string_to_array, + void*), + void* payload, + void (*dtor)(void*)); +int mlx_closure_kwargs_set( + mlx_closure_kwargs* cls, + const mlx_closure_kwargs src); +int mlx_closure_kwargs_apply( + mlx_vector_array* res, + mlx_closure_kwargs cls, + const mlx_vector_array input_0, + const mlx_map_string_to_array input_1); + +typedef struct mlx_closure_value_and_grad_ { + void* ctx; +} mlx_closure_value_and_grad; +mlx_closure_value_and_grad mlx_closure_value_and_grad_new(void); +int mlx_closure_value_and_grad_free(mlx_closure_value_and_grad cls); +mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func( + int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array)); +mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func_payload( + int (*fun)( + mlx_vector_array*, + mlx_vector_array*, + const mlx_vector_array, + void*), + void* payload, + void (*dtor)(void*)); +int mlx_closure_value_and_grad_set( + mlx_closure_value_and_grad* cls, + const mlx_closure_value_and_grad src); +int mlx_closure_value_and_grad_apply( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + mlx_closure_value_and_grad cls, + const mlx_vector_array input); + +typedef struct mlx_closure_custom_ { + void* ctx; +} mlx_closure_custom; +mlx_closure_custom mlx_closure_custom_new(void); +int mlx_closure_custom_free(mlx_closure_custom cls); +mlx_closure_custom mlx_closure_custom_new_func( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const mlx_vector_array)); +mlx_closure_custom mlx_closure_custom_new_func_payload( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const mlx_vector_array, + void*), + void* payload, + void (*dtor)(void*)); +int mlx_closure_custom_set( + mlx_closure_custom* cls, + const mlx_closure_custom src); +int mlx_closure_custom_apply( + mlx_vector_array* res, + mlx_closure_custom cls, + const mlx_vector_array input_0, + const mlx_vector_array input_1, + const mlx_vector_array input_2); + +typedef struct mlx_closure_custom_jvp_ { + void* ctx; +} mlx_closure_custom_jvp; +mlx_closure_custom_jvp mlx_closure_custom_jvp_new(void); +int mlx_closure_custom_jvp_free(mlx_closure_custom_jvp cls); +mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const int*, + size_t _num)); +mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func_payload( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const int*, + size_t _num, + void*), + void* payload, + void (*dtor)(void*)); +int mlx_closure_custom_jvp_set( + mlx_closure_custom_jvp* cls, + const mlx_closure_custom_jvp src); +int mlx_closure_custom_jvp_apply( + mlx_vector_array* res, + mlx_closure_custom_jvp cls, + const mlx_vector_array input_0, + const mlx_vector_array input_1, + const int* input_2, + size_t input_2_num); + +typedef struct mlx_closure_custom_vmap_ { + void* ctx; +} mlx_closure_custom_vmap; +mlx_closure_custom_vmap mlx_closure_custom_vmap_new(void); +int mlx_closure_custom_vmap_free(mlx_closure_custom_vmap cls); +mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func( + int (*fun)( + mlx_vector_array*, + mlx_vector_int*, + const mlx_vector_array, + const int*, + size_t _num)); +mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func_payload( + int (*fun)( + mlx_vector_array*, + mlx_vector_int*, + const mlx_vector_array, + const int*, + size_t _num, + void*), + void* payload, + void (*dtor)(void*)); +int mlx_closure_custom_vmap_set( + mlx_closure_custom_vmap* cls, + const mlx_closure_custom_vmap src); +int mlx_closure_custom_vmap_apply( + mlx_vector_array* res_0, + mlx_vector_int* res_1, + mlx_closure_custom_vmap cls, + const mlx_vector_array input_0, + const int* input_1, + size_t input_1_num); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/compile.h b/x/mlxrunner/mlx/include/mlx/c/compile.h new file mode 100644 index 000000000..04567fb3a --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/compile.h @@ -0,0 +1,57 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_COMPILE_H +#define MLX_COMPILE_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup compile Compilation operations + */ +/**@{*/ + +typedef enum mlx_compile_mode_ { + MLX_COMPILE_MODE_DISABLED, + MLX_COMPILE_MODE_NO_SIMPLIFY, + MLX_COMPILE_MODE_NO_FUSE, + MLX_COMPILE_MODE_ENABLED +} mlx_compile_mode; +int mlx_compile(mlx_closure* res, const mlx_closure fun, bool shapeless); +int mlx_detail_compile( + mlx_closure* res, + const mlx_closure fun, + uintptr_t fun_id, + bool shapeless, + const uint64_t* constants, + size_t constants_num); +int mlx_detail_compile_clear_cache(void); +int mlx_detail_compile_erase(uintptr_t fun_id); +int mlx_disable_compile(void); +int mlx_enable_compile(void); +int mlx_set_compile_mode(mlx_compile_mode mode); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/cuda.h b/x/mlxrunner/mlx/include/mlx/c/cuda.h new file mode 100644 index 000000000..4734f8c51 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/cuda.h @@ -0,0 +1,39 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_CUDA_H +#define MLX_CUDA_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup cuda Cuda specific operations + */ +/**@{*/ + +int mlx_cuda_is_available(bool* res); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/device.h b/x/mlxrunner/mlx/include/mlx/c/device.h new file mode 100644 index 000000000..4b74e39d3 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/device.h @@ -0,0 +1,154 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_DEVICE_H +#define MLX_DEVICE_H + +#include +#include + +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_device Device + * MLX device object. + */ +/**@{*/ + +/** + * A MLX device object. + */ +typedef struct mlx_device_ { + void* ctx; +} mlx_device; + +/** + * Device type. + */ +typedef enum mlx_device_type_ { MLX_CPU, MLX_GPU } mlx_device_type; + +/** + * Returns a new empty device. + */ +mlx_device mlx_device_new(void); + +/** + * Returns a new device of specified `type`, with specified `index`. + */ +mlx_device mlx_device_new_type(mlx_device_type type, int index); +/** + * Free a device. + */ +int mlx_device_free(mlx_device dev); +/** + * Set device to provided src device. + */ +int mlx_device_set(mlx_device* dev, const mlx_device src); +/** + * Get device description. + */ +int mlx_device_tostring(mlx_string* str, mlx_device dev); +/** + * Check if devices are the same. + */ +bool mlx_device_equal(mlx_device lhs, mlx_device rhs); +/** + * Returns the index of the device. + */ +int mlx_device_get_index(int* index, mlx_device dev); +/** + * Returns the type of the device. + */ +int mlx_device_get_type(mlx_device_type* type, mlx_device dev); +/** + * Returns the default MLX device. + */ +int mlx_get_default_device(mlx_device* dev); +/** + * Set the default MLX device. + */ +int mlx_set_default_device(mlx_device dev); +/** + * Check if device is available. + */ +int mlx_device_is_available(bool* avail, mlx_device dev); +/** + * Get the number of available devices for a device type. + */ +int mlx_device_count(int* count, mlx_device_type type); + +/** + * A MLX device info object. + * Contains key-value pairs with device properties. + * Keys vary by backend but common keys include: + * - device_name (string): Device name + * - architecture (string): Architecture identifier + * Additional keys may be present depending on the backend. + */ +typedef struct mlx_device_info_ { + void* ctx; +} mlx_device_info; + +/** + * Returns a new empty device info object. + */ +mlx_device_info mlx_device_info_new(void); +/** + * Get device information for a device. + */ +int mlx_device_info_get(mlx_device_info* info, mlx_device dev); +/** + * Free a device info object. + */ +int mlx_device_info_free(mlx_device_info info); +/** + * Check if a key exists in the device info. + * Returns 0 on success, 1 on error. + * Sets *exists to true if the key exists, false otherwise. + */ +int mlx_device_info_has_key( + bool* exists, + mlx_device_info info, + const char* key); +/** + * Check if a value is a string type. + * Returns 0 on success, 1 on error. + * Sets *is_string to true if the value is a string, false if it's a size_t. + */ +int mlx_device_info_is_string( + bool* is_string, + mlx_device_info info, + const char* key); +/** + * Get a string value from device info. + * Returns 0 on success, 1 on error, 2 if key not found or wrong type. + */ +int mlx_device_info_get_string( + const char** value, + mlx_device_info info, + const char* key); +/** + * Get a size_t value from device info. + * Returns 0 on success, 1 on error, 2 if key not found or wrong type. + */ +int mlx_device_info_get_size( + size_t* value, + mlx_device_info info, + const char* key); +/** + * Get all keys from device info. + * Returns 0 on success, 1 on error. + */ +int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/distributed.h b/x/mlxrunner/mlx/include/mlx/c/distributed.h new file mode 100644 index 000000000..c3b0baeee --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/distributed.h @@ -0,0 +1,83 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_DISTRIBUTED_H +#define MLX_DISTRIBUTED_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup distributed Distributed collectives + */ +/**@{*/ + +int mlx_distributed_all_gather( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream S); +int mlx_distributed_all_max( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); +int mlx_distributed_all_min( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); +int mlx_distributed_all_sum( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); +int mlx_distributed_recv( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + int src, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); +int mlx_distributed_recv_like( + mlx_array* res, + const mlx_array x, + int src, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); +int mlx_distributed_send( + mlx_array* res, + const mlx_array x, + int dst, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); +int mlx_distributed_sum_scatter( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/distributed_group.h b/x/mlxrunner/mlx/include/mlx/c/distributed_group.h new file mode 100644 index 000000000..3cfccc806 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/distributed_group.h @@ -0,0 +1,58 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_DISTRIBUTED_GROUP_H +#define MLX_DISTRIBUTED_GROUP_H + +#include + +#include "mlx/c/stream.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_distributed_group MLX distributed + */ +/**@{*/ + +/** + * A MLX distributed group object. + */ +typedef struct mlx_distributed_group_ { + void* ctx; +} mlx_distributed_group; + +/** + * Get the rank. + */ +int mlx_distributed_group_rank(mlx_distributed_group group); + +/** + * Get the group size. + */ +int mlx_distributed_group_size(mlx_distributed_group group); + +/** + * Split the group. + */ +mlx_distributed_group +mlx_distributed_group_split(mlx_distributed_group group, int color, int key); + +/** + * Check if distributed is available. + */ +bool mlx_distributed_is_available(void); + +/** + * Initialize distributed. + */ +mlx_distributed_group mlx_distributed_init(bool strict); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/error.h b/x/mlxrunner/mlx/include/mlx/c/error.h new file mode 100644 index 000000000..8c063a403 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/error.h @@ -0,0 +1,41 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_ERROR_H +#define MLX_ERROR_H + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_error Error management + */ +/**@{*/ + +typedef void (*mlx_error_handler_func)(const char* msg, void* data); + +/** + * Set the error handler. + */ +void mlx_set_error_handler( + mlx_error_handler_func handler, + void* data, + void (*dtor)(void*)); + +/** + * Throw an error. + */ +void _mlx_error(const char* file, const int line, const char* fmt, ...); + +/** + * Throw an error. Macro which passes file name and line number to _mlx_error(). + */ +#define mlx_error(...) _mlx_error(__FILE__, __LINE__, __VA_ARGS__) + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/export.h b/x/mlxrunner/mlx/include/mlx/c/export.h new file mode 100644 index 000000000..52cb2835c --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/export.h @@ -0,0 +1,75 @@ +/* Copyright © 2023-2025 Apple Inc. */ + +#ifndef MLX_EXPORT_H +#define MLX_EXPORT_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup export Function serialization + */ +/**@{*/ +int mlx_export_function( + const char* file, + const mlx_closure fun, + const mlx_vector_array args, + bool shapeless); +int mlx_export_function_kwargs( + const char* file, + const mlx_closure_kwargs fun, + const mlx_vector_array args, + const mlx_map_string_to_array kwargs, + bool shapeless); + +typedef struct mlx_function_exporter_ { + void* ctx; +} mlx_function_exporter; +mlx_function_exporter mlx_function_exporter_new( + const char* file, + const mlx_closure fun, + bool shapeless); +int mlx_function_exporter_free(mlx_function_exporter xfunc); +int mlx_function_exporter_apply( + const mlx_function_exporter xfunc, + const mlx_vector_array args); +int mlx_function_exporter_apply_kwargs( + const mlx_function_exporter xfunc, + const mlx_vector_array args, + const mlx_map_string_to_array kwargs); + +typedef struct mlx_imported_function_ { + void* ctx; +} mlx_imported_function; +mlx_imported_function mlx_imported_function_new(const char* file); +int mlx_imported_function_free(mlx_imported_function xfunc); +int mlx_imported_function_apply( + mlx_vector_array* res, + const mlx_imported_function xfunc, + const mlx_vector_array args); +int mlx_imported_function_apply_kwargs( + mlx_vector_array* res, + const mlx_imported_function xfunc, + const mlx_vector_array args, + const mlx_map_string_to_array kwargs); +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/fast.h b/x/mlxrunner/mlx/include/mlx/c/fast.h new file mode 100644 index 000000000..c825d00e5 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/fast.h @@ -0,0 +1,206 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_FAST_H +#define MLX_FAST_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup fast Fast custom operations + */ +/**@{*/ + +typedef struct mlx_fast_cuda_kernel_config_ { + void* ctx; +} mlx_fast_cuda_kernel_config; +mlx_fast_cuda_kernel_config mlx_fast_cuda_kernel_config_new(void); +void mlx_fast_cuda_kernel_config_free(mlx_fast_cuda_kernel_config cls); + +int mlx_fast_cuda_kernel_config_add_output_arg( + mlx_fast_cuda_kernel_config cls, + const int* shape, + size_t size, + mlx_dtype dtype); +int mlx_fast_cuda_kernel_config_set_grid( + mlx_fast_cuda_kernel_config cls, + int grid1, + int grid2, + int grid3); +int mlx_fast_cuda_kernel_config_set_thread_group( + mlx_fast_cuda_kernel_config cls, + int thread1, + int thread2, + int thread3); +int mlx_fast_cuda_kernel_config_set_init_value( + mlx_fast_cuda_kernel_config cls, + float value); +int mlx_fast_cuda_kernel_config_set_verbose( + mlx_fast_cuda_kernel_config cls, + bool verbose); +int mlx_fast_cuda_kernel_config_add_template_arg_dtype( + mlx_fast_cuda_kernel_config cls, + const char* name, + mlx_dtype dtype); +int mlx_fast_cuda_kernel_config_add_template_arg_int( + mlx_fast_cuda_kernel_config cls, + const char* name, + int value); +int mlx_fast_cuda_kernel_config_add_template_arg_bool( + mlx_fast_cuda_kernel_config cls, + const char* name, + bool value); + +typedef struct mlx_fast_cuda_kernel_ { + void* ctx; +} mlx_fast_cuda_kernel; + +mlx_fast_cuda_kernel mlx_fast_cuda_kernel_new( + const char* name, + const mlx_vector_string input_names, + const mlx_vector_string output_names, + const char* source, + const char* header, + bool ensure_row_contiguous, + int shared_memory); + +void mlx_fast_cuda_kernel_free(mlx_fast_cuda_kernel cls); + +int mlx_fast_cuda_kernel_apply( + mlx_vector_array* outputs, + mlx_fast_cuda_kernel cls, + const mlx_vector_array inputs, + const mlx_fast_cuda_kernel_config config, + const mlx_stream stream); + +int mlx_fast_layer_norm( + mlx_array* res, + const mlx_array x, + const mlx_array weight /* may be null */, + const mlx_array bias /* may be null */, + float eps, + const mlx_stream s); + +typedef struct mlx_fast_metal_kernel_config_ { + void* ctx; +} mlx_fast_metal_kernel_config; +mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new(void); +void mlx_fast_metal_kernel_config_free(mlx_fast_metal_kernel_config cls); + +int mlx_fast_metal_kernel_config_add_output_arg( + mlx_fast_metal_kernel_config cls, + const int* shape, + size_t size, + mlx_dtype dtype); +int mlx_fast_metal_kernel_config_set_grid( + mlx_fast_metal_kernel_config cls, + int grid1, + int grid2, + int grid3); +int mlx_fast_metal_kernel_config_set_thread_group( + mlx_fast_metal_kernel_config cls, + int thread1, + int thread2, + int thread3); +int mlx_fast_metal_kernel_config_set_init_value( + mlx_fast_metal_kernel_config cls, + float value); +int mlx_fast_metal_kernel_config_set_verbose( + mlx_fast_metal_kernel_config cls, + bool verbose); +int mlx_fast_metal_kernel_config_add_template_arg_dtype( + mlx_fast_metal_kernel_config cls, + const char* name, + mlx_dtype dtype); +int mlx_fast_metal_kernel_config_add_template_arg_int( + mlx_fast_metal_kernel_config cls, + const char* name, + int value); +int mlx_fast_metal_kernel_config_add_template_arg_bool( + mlx_fast_metal_kernel_config cls, + const char* name, + bool value); + +typedef struct mlx_fast_metal_kernel_ { + void* ctx; +} mlx_fast_metal_kernel; + +mlx_fast_metal_kernel mlx_fast_metal_kernel_new( + const char* name, + const mlx_vector_string input_names, + const mlx_vector_string output_names, + const char* source, + const char* header, + bool ensure_row_contiguous, + bool atomic_outputs); + +void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls); + +int mlx_fast_metal_kernel_apply( + mlx_vector_array* outputs, + mlx_fast_metal_kernel cls, + const mlx_vector_array inputs, + const mlx_fast_metal_kernel_config config, + const mlx_stream stream); + +int mlx_fast_rms_norm( + mlx_array* res, + const mlx_array x, + const mlx_array weight /* may be null */, + float eps, + const mlx_stream s); +int mlx_fast_rope( + mlx_array* res, + const mlx_array x, + int dims, + bool traditional, + mlx_optional_float base, + float scale, + int offset, + const mlx_array freqs /* may be null */, + const mlx_stream s); +int mlx_fast_rope_dynamic( + mlx_array* res, + const mlx_array x, + int dims, + bool traditional, + mlx_optional_float base, + float scale, + const mlx_array offset, + const mlx_array freqs /* may be null */, + const mlx_stream s); +int mlx_fast_scaled_dot_product_attention( + mlx_array* res, + const mlx_array queries, + const mlx_array keys, + const mlx_array values, + float scale, + const char* mask_mode, + const mlx_array mask_arr /* may be null */, + const mlx_array sinks /* may be null */, + const mlx_stream s); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/fft.h b/x/mlxrunner/mlx/include/mlx/c/fft.h new file mode 100644 index 000000000..779803e9b --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/fft.h @@ -0,0 +1,138 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_FFT_H +#define MLX_FFT_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup fft FFT operations + */ +/**@{*/ + +int mlx_fft_fft( + mlx_array* res, + const mlx_array a, + int n, + int axis, + const mlx_stream s); +int mlx_fft_fft2( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_fftn( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_fftshift( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_ifft( + mlx_array* res, + const mlx_array a, + int n, + int axis, + const mlx_stream s); +int mlx_fft_ifft2( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_ifftn( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_ifftshift( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_irfft( + mlx_array* res, + const mlx_array a, + int n, + int axis, + const mlx_stream s); +int mlx_fft_irfft2( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_irfftn( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_rfft( + mlx_array* res, + const mlx_array a, + int n, + int axis, + const mlx_stream s); +int mlx_fft_rfft2( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_rfftn( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/half.h b/x/mlxrunner/mlx/include/mlx/c/half.h new file mode 100644 index 000000000..958d555f5 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/half.h @@ -0,0 +1,26 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_HALF_H +#define MLX_HALF_H + +#ifdef __cplusplus +extern "C" { +#endif + +#if defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) || defined(__aarch64__) +#define HAS_FLOAT16 +#include +typedef __fp16 float16_t; +#endif + +#if defined(__ARM_FEATURE_BF16) || defined(__aarch64__) +#define HAS_BFLOAT16 +#include +typedef __bf16 bfloat16_t; +#endif + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/io.h b/x/mlxrunner/mlx/include/mlx/c/io.h new file mode 100644 index 000000000..6eb205c9a --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/io.h @@ -0,0 +1,63 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_IO_H +#define MLX_IO_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup io IO operations + */ +/**@{*/ + +int mlx_load_reader( + mlx_array* res, + mlx_io_reader in_stream, + const mlx_stream s); +int mlx_load(mlx_array* res, const char* file, const mlx_stream s); +int mlx_load_safetensors_reader( + mlx_map_string_to_array* res_0, + mlx_map_string_to_string* res_1, + mlx_io_reader in_stream, + const mlx_stream s); +int mlx_load_safetensors( + mlx_map_string_to_array* res_0, + mlx_map_string_to_string* res_1, + const char* file, + const mlx_stream s); +int mlx_save_writer(mlx_io_writer out_stream, const mlx_array a); +int mlx_save(const char* file, const mlx_array a); +int mlx_save_safetensors_writer( + mlx_io_writer in_stream, + const mlx_map_string_to_array param, + const mlx_map_string_to_string metadata); +int mlx_save_safetensors( + const char* file, + const mlx_map_string_to_array param, + const mlx_map_string_to_string metadata); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/io_types.h b/x/mlxrunner/mlx/include/mlx/c/io_types.h new file mode 100644 index 000000000..88349b57c --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/io_types.h @@ -0,0 +1,104 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_IO_TYPES_H +#define MLX_IO_TYPES_H + +#include + +#include "mlx/c/string.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_io_types IO Types + * MLX IO type objects. + */ +/**@{*/ + +/** + * A MLX IO reader object. + */ +typedef struct mlx_io_reader_ { + void* ctx; +} mlx_io_reader; +/** + * A MLX IO writer object. + */ +typedef struct mlx_io_writer_ { + void* ctx; +} mlx_io_writer; + +/** + * Virtual table for custom IO reader and writer objects. + */ +typedef struct mlx_io_vtable_ { + bool (*is_open)(void*); + bool (*good)(void*); + size_t (*tell)(void*); + void (*seek)(void*, int64_t off, int whence); + void (*read)(void*, char* data, size_t n); + void (*read_at_offset)(void*, char* data, size_t n, size_t off); + void (*write)(void*, const char* data, size_t n); + const char* (*label)(void*); + void (*free)(void*); +} mlx_io_vtable; + +/** + * Returns a new custom IO reader. + * `vtable` operates on user descriptor `desc`. + */ +mlx_io_reader mlx_io_reader_new(void* desc, mlx_io_vtable vtable); + +/** + * Get IO reader user descriptor. + */ +int mlx_io_reader_descriptor(void** desc_, mlx_io_reader io); + +/** + * Get IO reader description. + */ +int mlx_io_reader_tostring(mlx_string* str_, mlx_io_reader io); + +/** + * Free IO reader. + * + * Note that MLX arrays are lazily evaluated, so the underlying object may + * be not freed right away. The ``free()`` callback from ``mlx_io_vtable`` + * will be called when the underlying object is actually freed. + */ +int mlx_io_reader_free(mlx_io_reader io); + +/** + * Returns a new custom IO writer. + * `vtable` operates on user descriptor `desc`. + */ +mlx_io_writer mlx_io_writer_new(void* desc, mlx_io_vtable vtable); + +/** + * Get IO writer user descriptor. + */ +int mlx_io_writer_descriptor(void** desc_, mlx_io_writer io); + +/** + * Get IO writer description. + */ +int mlx_io_writer_tostring(mlx_string* str_, mlx_io_writer io); + +/** + * Free IO writer. + * + * Note that MLX arrays are lazily evaluated, so the underlying object may + * be not freed right away. The ``free()`` callback from ``mlx_io_vtable`` + * will be called when the underlying object is actually freed. + */ +int mlx_io_writer_free(mlx_io_writer io); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/linalg.h b/x/mlxrunner/mlx/include/mlx/c/linalg.h new file mode 100644 index 000000000..91d5d661e --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/linalg.h @@ -0,0 +1,128 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_LINALG_H +#define MLX_LINALG_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup linalg Linear algebra operations + */ +/**@{*/ + +int mlx_linalg_cholesky( + mlx_array* res, + const mlx_array a, + bool upper, + const mlx_stream s); +int mlx_linalg_cholesky_inv( + mlx_array* res, + const mlx_array a, + bool upper, + const mlx_stream s); +int mlx_linalg_cross( + mlx_array* res, + const mlx_array a, + const mlx_array b, + int axis, + const mlx_stream s); +int mlx_linalg_eig( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array a, + const mlx_stream s); +int mlx_linalg_eigh( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array a, + const char* UPLO, + const mlx_stream s); +int mlx_linalg_eigvals(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_linalg_eigvalsh( + mlx_array* res, + const mlx_array a, + const char* UPLO, + const mlx_stream s); +int mlx_linalg_inv(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_linalg_lu(mlx_vector_array* res, const mlx_array a, const mlx_stream s); +int mlx_linalg_lu_factor( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array a, + const mlx_stream s); +int mlx_linalg_norm( + mlx_array* res, + const mlx_array a, + double ord, + const int* axis /* may be null */, + size_t axis_num, + bool keepdims, + const mlx_stream s); +int mlx_linalg_norm_matrix( + mlx_array* res, + const mlx_array a, + const char* ord, + const int* axis /* may be null */, + size_t axis_num, + bool keepdims, + const mlx_stream s); +int mlx_linalg_norm_l2( + mlx_array* res, + const mlx_array a, + const int* axis /* may be null */, + size_t axis_num, + bool keepdims, + const mlx_stream s); +int mlx_linalg_pinv(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_linalg_qr( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array a, + const mlx_stream s); +int mlx_linalg_solve( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_linalg_solve_triangular( + mlx_array* res, + const mlx_array a, + const mlx_array b, + bool upper, + const mlx_stream s); +int mlx_linalg_svd( + mlx_vector_array* res, + const mlx_array a, + bool compute_uv, + const mlx_stream s); +int mlx_linalg_tri_inv( + mlx_array* res, + const mlx_array a, + bool upper, + const mlx_stream s); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/map.h b/x/mlxrunner/mlx/include/mlx/c/map.h new file mode 100644 index 000000000..56abe84f1 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/map.h @@ -0,0 +1,149 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_MAP_H +#define MLX_MAP_H + +#include "mlx/c/array.h" +#include "mlx/c/string.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_map Maps + * MLX map objects. + */ +/**@{*/ + +/** + * A string-to-array map + */ +typedef struct mlx_map_string_to_array_ { + void* ctx; +} mlx_map_string_to_array; + +/** + * Returns a new empty string-to-array map. + */ +mlx_map_string_to_array mlx_map_string_to_array_new(void); +/** + * Set map to provided src map. + */ +int mlx_map_string_to_array_set( + mlx_map_string_to_array* map, + const mlx_map_string_to_array src); +/** + * Free a string-to-array map. + */ +int mlx_map_string_to_array_free(mlx_map_string_to_array map); +/** + * Insert a new `value` at the specified `key` in the map. + */ +int mlx_map_string_to_array_insert( + mlx_map_string_to_array map, + const char* key, + const mlx_array value); +/** + * Returns the value indexed at the specified `key` in the map. + */ +int mlx_map_string_to_array_get( + mlx_array* value, + const mlx_map_string_to_array map, + const char* key); + +/** + * An iterator over a string-to-array map. + */ +typedef struct mlx_map_string_to_array_iterator_ { + void* ctx; + void* map_ctx; +} mlx_map_string_to_array_iterator; +/** + * Returns a new iterator over the given map. + */ +mlx_map_string_to_array_iterator mlx_map_string_to_array_iterator_new( + mlx_map_string_to_array map); +/** + * Free iterator. + */ +int mlx_map_string_to_array_iterator_free(mlx_map_string_to_array_iterator it); +/** + * Increment iterator. + */ +int mlx_map_string_to_array_iterator_next( + const char** key, + mlx_array* value, + mlx_map_string_to_array_iterator it); + +/** + * A string-to-string map + */ +typedef struct mlx_map_string_to_string_ { + void* ctx; +} mlx_map_string_to_string; + +/** + * Returns a new empty string-to-string map. + */ +mlx_map_string_to_string mlx_map_string_to_string_new(void); +/** + * Set map to provided src map. + */ +int mlx_map_string_to_string_set( + mlx_map_string_to_string* map, + const mlx_map_string_to_string src); +/** + * Free a string-to-string map. + */ +int mlx_map_string_to_string_free(mlx_map_string_to_string map); +/** + * Insert a new `value` at the specified `key` in the map. + */ +int mlx_map_string_to_string_insert( + mlx_map_string_to_string map, + const char* key, + const char* value); +/** + * Returns the value indexed at the specified `key` in the map. + */ +int mlx_map_string_to_string_get( + const char** value, + const mlx_map_string_to_string map, + const char* key); + +/** + * An iterator over a string-to-string map. + */ +typedef struct mlx_map_string_to_string_iterator_ { + void* ctx; + void* map_ctx; +} mlx_map_string_to_string_iterator; +/** + * Returns a new iterator over the given map. + */ +mlx_map_string_to_string_iterator mlx_map_string_to_string_iterator_new( + mlx_map_string_to_string map); +/** + * Free iterator. + */ +int mlx_map_string_to_string_iterator_free( + mlx_map_string_to_string_iterator it); +/** + * Increment iterator. + */ +int mlx_map_string_to_string_iterator_next( + const char** key, + const char** value, + mlx_map_string_to_string_iterator it); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/memory.h b/x/mlxrunner/mlx/include/mlx/c/memory.h new file mode 100644 index 000000000..bae9e08ec --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/memory.h @@ -0,0 +1,47 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_MEMORY_H +#define MLX_MEMORY_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup memory Memory operations + */ +/**@{*/ + +int mlx_clear_cache(void); +int mlx_get_active_memory(size_t* res); +int mlx_get_cache_memory(size_t* res); +int mlx_get_memory_limit(size_t* res); +int mlx_get_peak_memory(size_t* res); +int mlx_reset_peak_memory(void); +int mlx_set_cache_limit(size_t* res, size_t limit); +int mlx_set_memory_limit(size_t* res, size_t limit); +int mlx_set_wired_limit(size_t* res, size_t limit); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/metal.h b/x/mlxrunner/mlx/include/mlx/c/metal.h new file mode 100644 index 000000000..5877b224b --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/metal.h @@ -0,0 +1,41 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_METAL_H +#define MLX_METAL_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup metal Metal specific operations + */ +/**@{*/ + +int mlx_metal_is_available(bool* res); +int mlx_metal_start_capture(const char* path); +int mlx_metal_stop_capture(void); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/mlx.h b/x/mlxrunner/mlx/include/mlx/c/mlx.h new file mode 100644 index 000000000..ffadac89a --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/mlx.h @@ -0,0 +1,34 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_ALL_H +#define MLX_ALL_H + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/compile.h" +#include "mlx/c/cuda.h" +#include "mlx/c/device.h" +#include "mlx/c/distributed.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/error.h" +#include "mlx/c/export.h" +#include "mlx/c/fast.h" +#include "mlx/c/fft.h" +#include "mlx/c/half.h" +#include "mlx/c/io.h" +#include "mlx/c/io_types.h" +#include "mlx/c/linalg.h" +#include "mlx/c/map.h" +#include "mlx/c/memory.h" +#include "mlx/c/metal.h" +#include "mlx/c/ops.h" +#include "mlx/c/optional.h" +#include "mlx/c/random.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/transforms.h" +#include "mlx/c/transforms_impl.h" +#include "mlx/c/vector.h" +#include "mlx/c/version.h" + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/ops.h b/x/mlxrunner/mlx/include/mlx/c/ops.h new file mode 100644 index 000000000..a1446fb9e --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/ops.h @@ -0,0 +1,1235 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_OPS_H +#define MLX_OPS_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup ops Core array operations + */ +/**@{*/ + +int mlx_abs(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_add( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_addmm( + mlx_array* res, + const mlx_array c, + const mlx_array a, + const mlx_array b, + float alpha, + float beta, + const mlx_stream s); +int mlx_all_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_all_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_all( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_allclose( + mlx_array* res, + const mlx_array a, + const mlx_array b, + double rtol, + double atol, + bool equal_nan, + const mlx_stream s); +int mlx_any_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_any_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_any( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_arange( + mlx_array* res, + double start, + double stop, + double step, + mlx_dtype dtype, + const mlx_stream s); +int mlx_arccos(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_arccosh(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_arcsin(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_arcsinh(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_arctan(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_arctan2( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_arctanh(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_argmax_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_argmax( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_argmin_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_argmin( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_argpartition_axis( + mlx_array* res, + const mlx_array a, + int kth, + int axis, + const mlx_stream s); +int mlx_argpartition( + mlx_array* res, + const mlx_array a, + int kth, + const mlx_stream s); +int mlx_argsort_axis( + mlx_array* res, + const mlx_array a, + int axis, + const mlx_stream s); +int mlx_argsort(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_array_equal( + mlx_array* res, + const mlx_array a, + const mlx_array b, + bool equal_nan, + const mlx_stream s); +int mlx_as_strided( + mlx_array* res, + const mlx_array a, + const int* shape, + size_t shape_num, + const int64_t* strides, + size_t strides_num, + size_t offset, + const mlx_stream s); +int mlx_astype( + mlx_array* res, + const mlx_array a, + mlx_dtype dtype, + const mlx_stream s); +int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_bitwise_and( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_bitwise_or( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_bitwise_xor( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_block_masked_mm( + mlx_array* res, + const mlx_array a, + const mlx_array b, + int block_size, + const mlx_array mask_out /* may be null */, + const mlx_array mask_lhs /* may be null */, + const mlx_array mask_rhs /* may be null */, + const mlx_stream s); +int mlx_broadcast_arrays( + mlx_vector_array* res, + const mlx_vector_array inputs, + const mlx_stream s); +int mlx_broadcast_to( + mlx_array* res, + const mlx_array a, + const int* shape, + size_t shape_num, + const mlx_stream s); +int mlx_ceil(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_clip( + mlx_array* res, + const mlx_array a, + const mlx_array a_min /* may be null */, + const mlx_array a_max /* may be null */, + const mlx_stream s); +int mlx_concatenate_axis( + mlx_array* res, + const mlx_vector_array arrays, + int axis, + const mlx_stream s); +int mlx_concatenate( + mlx_array* res, + const mlx_vector_array arrays, + const mlx_stream s); +int mlx_conjugate(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_contiguous( + mlx_array* res, + const mlx_array a, + bool allow_col_major, + const mlx_stream s); +int mlx_conv1d( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride, + int padding, + int dilation, + int groups, + const mlx_stream s); +int mlx_conv2d( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride_0, + int stride_1, + int padding_0, + int padding_1, + int dilation_0, + int dilation_1, + int groups, + const mlx_stream s); +int mlx_conv3d( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride_0, + int stride_1, + int stride_2, + int padding_0, + int padding_1, + int padding_2, + int dilation_0, + int dilation_1, + int dilation_2, + int groups, + const mlx_stream s); +int mlx_conv_general( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + const int* stride, + size_t stride_num, + const int* padding_lo, + size_t padding_lo_num, + const int* padding_hi, + size_t padding_hi_num, + const int* kernel_dilation, + size_t kernel_dilation_num, + const int* input_dilation, + size_t input_dilation_num, + int groups, + bool flip, + const mlx_stream s); +int mlx_conv_transpose1d( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride, + int padding, + int dilation, + int output_padding, + int groups, + const mlx_stream s); +int mlx_conv_transpose2d( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride_0, + int stride_1, + int padding_0, + int padding_1, + int dilation_0, + int dilation_1, + int output_padding_0, + int output_padding_1, + int groups, + const mlx_stream s); +int mlx_conv_transpose3d( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride_0, + int stride_1, + int stride_2, + int padding_0, + int padding_1, + int padding_2, + int dilation_0, + int dilation_1, + int dilation_2, + int output_padding_0, + int output_padding_1, + int output_padding_2, + int groups, + const mlx_stream s); +int mlx_copy(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_cos(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_cosh(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_cummax( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s); +int mlx_cummin( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s); +int mlx_cumprod( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s); +int mlx_cumsum( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s); +int mlx_degrees(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_depends( + mlx_vector_array* res, + const mlx_vector_array inputs, + const mlx_vector_array dependencies); +int mlx_dequantize( + mlx_array* res, + const mlx_array w, + const mlx_array scales, + const mlx_array biases /* may be null */, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + mlx_optional_dtype dtype, + const mlx_stream s); +int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s); +int mlx_diagonal( + mlx_array* res, + const mlx_array a, + int offset, + int axis1, + int axis2, + const mlx_stream s); +int mlx_divide( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_divmod( + mlx_vector_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_einsum( + mlx_array* res, + const char* subscripts, + const mlx_vector_array operands, + const mlx_stream s); +int mlx_equal( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_erf(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_erfinv(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_exp(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_expand_dims_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_expand_dims( + mlx_array* res, + const mlx_array a, + int axis, + const mlx_stream s); +int mlx_expm1(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_eye( + mlx_array* res, + int n, + int m, + int k, + mlx_dtype dtype, + const mlx_stream s); +int mlx_flatten( + mlx_array* res, + const mlx_array a, + int start_axis, + int end_axis, + const mlx_stream s); +int mlx_floor(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_floor_divide( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_from_fp8( + mlx_array* res, + const mlx_array x, + mlx_dtype dtype, + const mlx_stream s); +int mlx_full( + mlx_array* res, + const int* shape, + size_t shape_num, + const mlx_array vals, + mlx_dtype dtype, + const mlx_stream s); +int mlx_full_like( + mlx_array* res, + const mlx_array a, + const mlx_array vals, + mlx_dtype dtype, + const mlx_stream s); +int mlx_gather( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const int* axes, + size_t axes_num, + const int* slice_sizes, + size_t slice_sizes_num, + const mlx_stream s); +int mlx_gather_single( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + int axis, + const int* slice_sizes, + size_t slice_sizes_num, + const mlx_stream s); +int mlx_gather_mm( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_array lhs_indices /* may be null */, + const mlx_array rhs_indices /* may be null */, + bool sorted_indices, + const mlx_stream s); +int mlx_gather_qmm( + mlx_array* res, + const mlx_array x, + const mlx_array w, + const mlx_array scales, + const mlx_array biases /* may be null */, + const mlx_array lhs_indices /* may be null */, + const mlx_array rhs_indices /* may be null */, + bool transpose, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + bool sorted_indices, + const mlx_stream s); +int mlx_greater( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_greater_equal( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_hadamard_transform( + mlx_array* res, + const mlx_array a, + mlx_optional_float scale, + const mlx_stream s); +int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s); +int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_inner( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_isclose( + mlx_array* res, + const mlx_array a, + const mlx_array b, + double rtol, + double atol, + bool equal_nan, + const mlx_stream s); +int mlx_isfinite(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_isinf(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_isnan(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_isneginf(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_isposinf(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_kron( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_left_shift( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_less( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_less_equal( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_linspace( + mlx_array* res, + double start, + double stop, + int num, + mlx_dtype dtype, + const mlx_stream s); +int mlx_log(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_log10(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_log1p(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_log2(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_logaddexp( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_logcumsumexp( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s); +int mlx_logical_and( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_logical_not(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_logical_or( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_logsumexp_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_logsumexp_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_logsumexp( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_masked_scatter( + mlx_array* res, + const mlx_array a, + const mlx_array mask, + const mlx_array src, + const mlx_stream s); +int mlx_matmul( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_max_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_max_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_max( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_maximum( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_mean_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_mean_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_mean( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_median( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_meshgrid( + mlx_vector_array* res, + const mlx_vector_array arrays, + bool sparse, + const char* indexing, + const mlx_stream s); +int mlx_min_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_min_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_min( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_minimum( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_moveaxis( + mlx_array* res, + const mlx_array a, + int source, + int destination, + const mlx_stream s); +int mlx_multiply( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_nan_to_num( + mlx_array* res, + const mlx_array a, + float nan, + mlx_optional_float posinf, + mlx_optional_float neginf, + const mlx_stream s); +int mlx_negative(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_not_equal( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_number_of_elements( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool inverted, + mlx_dtype dtype, + const mlx_stream s); +int mlx_ones( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_stream s); +int mlx_ones_like(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_outer( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_pad( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const int* low_pad_size, + size_t low_pad_size_num, + const int* high_pad_size, + size_t high_pad_size_num, + const mlx_array pad_value, + const char* mode, + const mlx_stream s); +int mlx_pad_symmetric( + mlx_array* res, + const mlx_array a, + int pad_width, + const mlx_array pad_value, + const char* mode, + const mlx_stream s); +int mlx_partition_axis( + mlx_array* res, + const mlx_array a, + int kth, + int axis, + const mlx_stream s); +int mlx_partition( + mlx_array* res, + const mlx_array a, + int kth, + const mlx_stream s); +int mlx_power( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_prod_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_prod_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_prod( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_put_along_axis( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array values, + int axis, + const mlx_stream s); +int mlx_qqmm( + mlx_array* res, + const mlx_array x, + const mlx_array w, + const mlx_array w_scales /* may be null */, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + const mlx_stream s); +int mlx_quantize( + mlx_vector_array* res, + const mlx_array w, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + const mlx_stream s); +int mlx_quantized_matmul( + mlx_array* res, + const mlx_array x, + const mlx_array w, + const mlx_array scales, + const mlx_array biases /* may be null */, + bool transpose, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + const mlx_stream s); +int mlx_radians(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_real(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_reciprocal(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_remainder( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_repeat_axis( + mlx_array* res, + const mlx_array arr, + int repeats, + int axis, + const mlx_stream s); +int mlx_repeat( + mlx_array* res, + const mlx_array arr, + int repeats, + const mlx_stream s); +int mlx_reshape( + mlx_array* res, + const mlx_array a, + const int* shape, + size_t shape_num, + const mlx_stream s); +int mlx_right_shift( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_roll_axis( + mlx_array* res, + const mlx_array a, + const int* shift, + size_t shift_num, + int axis, + const mlx_stream s); +int mlx_roll_axes( + mlx_array* res, + const mlx_array a, + const int* shift, + size_t shift_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_roll( + mlx_array* res, + const mlx_array a, + const int* shift, + size_t shift_num, + const mlx_stream s); +int mlx_round( + mlx_array* res, + const mlx_array a, + int decimals, + const mlx_stream s); +int mlx_rsqrt(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_scatter( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_scatter_single( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s); +int mlx_scatter_add( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_scatter_add_single( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s); +int mlx_scatter_add_axis( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array values, + int axis, + const mlx_stream s); +int mlx_scatter_max( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_scatter_max_single( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s); +int mlx_scatter_min( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_scatter_min_single( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s); +int mlx_scatter_prod( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_scatter_prod_single( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s); +int mlx_segmented_mm( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_array segments, + const mlx_stream s); +int mlx_sigmoid(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_sign(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_sin(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_sinh(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_slice( + mlx_array* res, + const mlx_array a, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s); +int mlx_slice_dynamic( + mlx_array* res, + const mlx_array a, + const mlx_array start, + const int* axes, + size_t axes_num, + const int* slice_size, + size_t slice_size_num, + const mlx_stream s); +int mlx_slice_update( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s); +int mlx_slice_update_dynamic( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const mlx_array start, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_softmax_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool precise, + const mlx_stream s); +int mlx_softmax_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool precise, + const mlx_stream s); +int mlx_softmax( + mlx_array* res, + const mlx_array a, + bool precise, + const mlx_stream s); +int mlx_sort_axis( + mlx_array* res, + const mlx_array a, + int axis, + const mlx_stream s); +int mlx_sort(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_split( + mlx_vector_array* res, + const mlx_array a, + int num_splits, + int axis, + const mlx_stream s); +int mlx_split_sections( + mlx_vector_array* res, + const mlx_array a, + const int* indices, + size_t indices_num, + int axis, + const mlx_stream s); +int mlx_sqrt(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_square(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_squeeze_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_squeeze_axis( + mlx_array* res, + const mlx_array a, + int axis, + const mlx_stream s); +int mlx_squeeze(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_stack_axis( + mlx_array* res, + const mlx_vector_array arrays, + int axis, + const mlx_stream s); +int mlx_stack( + mlx_array* res, + const mlx_vector_array arrays, + const mlx_stream s); +int mlx_std_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + int ddof, + const mlx_stream s); +int mlx_std_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + int ddof, + const mlx_stream s); +int mlx_std( + mlx_array* res, + const mlx_array a, + bool keepdims, + int ddof, + const mlx_stream s); +int mlx_stop_gradient(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_subtract( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_sum_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_sum_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_sum( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_swapaxes( + mlx_array* res, + const mlx_array a, + int axis1, + int axis2, + const mlx_stream s); +int mlx_take_axis( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + int axis, + const mlx_stream s); +int mlx_take( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_stream s); +int mlx_take_along_axis( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + int axis, + const mlx_stream s); +int mlx_tan(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_tanh(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_tensordot( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const int* axes_a, + size_t axes_a_num, + const int* axes_b, + size_t axes_b_num, + const mlx_stream s); +int mlx_tensordot_axis( + mlx_array* res, + const mlx_array a, + const mlx_array b, + int axis, + const mlx_stream s); +int mlx_tile( + mlx_array* res, + const mlx_array arr, + const int* reps, + size_t reps_num, + const mlx_stream s); +int mlx_to_fp8(mlx_array* res, const mlx_array x, const mlx_stream s); +int mlx_topk_axis( + mlx_array* res, + const mlx_array a, + int k, + int axis, + const mlx_stream s); +int mlx_topk(mlx_array* res, const mlx_array a, int k, const mlx_stream s); +int mlx_trace( + mlx_array* res, + const mlx_array a, + int offset, + int axis1, + int axis2, + mlx_dtype dtype, + const mlx_stream s); +int mlx_transpose_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_transpose(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_tri( + mlx_array* res, + int n, + int m, + int k, + mlx_dtype type, + const mlx_stream s); +int mlx_tril(mlx_array* res, const mlx_array x, int k, const mlx_stream s); +int mlx_triu(mlx_array* res, const mlx_array x, int k, const mlx_stream s); +int mlx_unflatten( + mlx_array* res, + const mlx_array a, + int axis, + const int* shape, + size_t shape_num, + const mlx_stream s); +int mlx_var_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + int ddof, + const mlx_stream s); +int mlx_var_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + int ddof, + const mlx_stream s); +int mlx_var( + mlx_array* res, + const mlx_array a, + bool keepdims, + int ddof, + const mlx_stream s); +int mlx_view( + mlx_array* res, + const mlx_array a, + mlx_dtype dtype, + const mlx_stream s); +int mlx_where( + mlx_array* res, + const mlx_array condition, + const mlx_array x, + const mlx_array y, + const mlx_stream s); +int mlx_zeros( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_stream s); +int mlx_zeros_like(mlx_array* res, const mlx_array a, const mlx_stream s); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/optional.h b/x/mlxrunner/mlx/include/mlx/c/optional.h new file mode 100644 index 000000000..ff9ea14e5 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/optional.h @@ -0,0 +1,51 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_OPTIONAL_H +#define MLX_OPTIONAL_H + +#include + +#include "mlx/c/array.h" +#include "mlx/c/string.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_optional Optionals + * MLX optional scalars. + */ +/**@{*/ + +/** + * A int optional. + */ +typedef struct mlx_optional_int_ { + int value; + bool has_value; +} mlx_optional_int; + +/** + * A float optional. + */ +typedef struct mlx_optional_float_ { + float value; + bool has_value; +} mlx_optional_float; + +/** + * A dtype optional. + */ +typedef struct mlx_optional_dtype_ { + mlx_dtype value; + bool has_value; +} mlx_optional_dtype; + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/random.h b/x/mlxrunner/mlx/include/mlx/c/random.h new file mode 100644 index 000000000..dbce0be37 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/random.h @@ -0,0 +1,166 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_RANDOM_H +#define MLX_RANDOM_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup random Random number operations + */ +/**@{*/ + +int mlx_random_bernoulli( + mlx_array* res, + const mlx_array p, + const int* shape, + size_t shape_num, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_bits( + mlx_array* res, + const int* shape, + size_t shape_num, + int width, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_categorical_shape( + mlx_array* res, + const mlx_array logits, + int axis, + const int* shape, + size_t shape_num, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_categorical_num_samples( + mlx_array* res, + const mlx_array logits_, + int axis, + int num_samples, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_categorical( + mlx_array* res, + const mlx_array logits, + int axis, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_gumbel( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_key(mlx_array* res, uint64_t seed); +int mlx_random_laplace( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + float loc, + float scale, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_multivariate_normal( + mlx_array* res, + const mlx_array mean, + const mlx_array cov, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_normal_broadcast( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array loc /* may be null */, + const mlx_array scale /* may be null */, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_normal( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + float loc, + float scale, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_permutation( + mlx_array* res, + const mlx_array x, + int axis, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_permutation_arange( + mlx_array* res, + int x, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_randint( + mlx_array* res, + const mlx_array low, + const mlx_array high, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_seed(uint64_t seed); +int mlx_random_split_num( + mlx_array* res, + const mlx_array key, + int num, + const mlx_stream s); +int mlx_random_split( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array key, + const mlx_stream s); +int mlx_random_truncated_normal( + mlx_array* res, + const mlx_array lower, + const mlx_array upper, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_uniform( + mlx_array* res, + const mlx_array low, + const mlx_array high, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/stream.h b/x/mlxrunner/mlx/include/mlx/c/stream.h new file mode 100644 index 000000000..d5865b806 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/stream.h @@ -0,0 +1,88 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_STREAM_H +#define MLX_STREAM_H + +#include + +#include "mlx/c/device.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_stream Stream + * MLX stream object. + */ +/**@{*/ + +/** + * A MLX stream object. + */ +typedef struct mlx_stream_ { + void* ctx; +} mlx_stream; + +/** + * Returns a new empty stream. + */ +mlx_stream mlx_stream_new(void); + +/** + * Returns a new stream on a device. + */ +mlx_stream mlx_stream_new_device(mlx_device dev); +/** + * Set stream to provided src stream. + */ +int mlx_stream_set(mlx_stream* stream, const mlx_stream src); +/** + * Free a stream. + */ +int mlx_stream_free(mlx_stream stream); +/** + * Get stream description. + */ +int mlx_stream_tostring(mlx_string* str, mlx_stream stream); +/** + * Check if streams are the same. + */ +bool mlx_stream_equal(mlx_stream lhs, mlx_stream rhs); +/** + * Return the device of the stream. + */ +int mlx_stream_get_device(mlx_device* dev, mlx_stream stream); +/** + * Return the index of the stream. + */ +int mlx_stream_get_index(int* index, mlx_stream stream); +/** + * Synchronize with the provided stream. + */ +int mlx_synchronize(mlx_stream stream); +/** + * Returns the default stream on the given device. + */ +int mlx_get_default_stream(mlx_stream* stream, mlx_device dev); +/** + * Set default stream. + */ +int mlx_set_default_stream(mlx_stream stream); +/** + * Returns the current default CPU stream. + */ +mlx_stream mlx_default_cpu_stream_new(void); + +/** + * Returns the current default GPU stream. + */ +mlx_stream mlx_default_gpu_stream_new(void); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/string.h b/x/mlxrunner/mlx/include/mlx/c/string.h new file mode 100644 index 000000000..0d2a356ba --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/string.h @@ -0,0 +1,55 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_STRING_H +#define MLX_STRING_H + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_string String + * MLX string object. + */ +/**@{*/ + +/** + * A MLX string object. + */ +typedef struct mlx_string_ { + void* ctx; +} mlx_string; + +/** + * Returns a new empty string. + */ +mlx_string mlx_string_new(void); + +/** + * Returns a new string, copying contents from `str`, which must end with `\0`. + */ +mlx_string mlx_string_new_data(const char* str); + +/** + * Set string to src string. + */ +int mlx_string_set(mlx_string* str, const mlx_string src); + +/** + * Returns a pointer to the string contents. + * The pointer is valid for the life duration of the string. + */ +const char* mlx_string_data(mlx_string str); + +/** + * Free string. + */ +int mlx_string_free(mlx_string str); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/transforms.h b/x/mlxrunner/mlx/include/mlx/c/transforms.h new file mode 100644 index 000000000..b2434619b --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/transforms.h @@ -0,0 +1,68 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_TRANSFORMS_H +#define MLX_TRANSFORMS_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup transforms Transform operations + */ +/**@{*/ + +int mlx_async_eval(const mlx_vector_array outputs); +int mlx_checkpoint(mlx_closure* res, const mlx_closure fun); +int mlx_custom_function( + mlx_closure* res, + const mlx_closure fun, + const mlx_closure_custom fun_vjp /* may be null */, + const mlx_closure_custom_jvp fun_jvp /* may be null */, + const mlx_closure_custom_vmap fun_vmap /* may be null */); +int mlx_custom_vjp( + mlx_closure* res, + const mlx_closure fun, + const mlx_closure_custom fun_vjp); +int mlx_eval(const mlx_vector_array outputs); +int mlx_jvp( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + const mlx_closure fun, + const mlx_vector_array primals, + const mlx_vector_array tangents); +int mlx_value_and_grad( + mlx_closure_value_and_grad* res, + const mlx_closure fun, + const int* argnums, + size_t argnums_num); +int mlx_vjp( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + const mlx_closure fun, + const mlx_vector_array primals, + const mlx_vector_array cotangents); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/transforms_impl.h b/x/mlxrunner/mlx/include/mlx/c/transforms_impl.h new file mode 100644 index 000000000..2b1356ebc --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/transforms_impl.h @@ -0,0 +1,54 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_TRANSFORMS_IMPL_H +#define MLX_TRANSFORMS_IMPL_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup transforms_impl Implementation detail operations + */ +/**@{*/ + +int mlx_detail_vmap_replace( + mlx_vector_array* res, + const mlx_vector_array inputs, + const mlx_vector_array s_inputs, + const mlx_vector_array s_outputs, + const int* in_axes, + size_t in_axes_num, + const int* out_axes, + size_t out_axes_num); +int mlx_detail_vmap_trace( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + const mlx_closure fun, + const mlx_vector_array inputs, + const int* in_axes, + size_t in_axes_num); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/vector.h b/x/mlxrunner/mlx/include/mlx/c/vector.h new file mode 100644 index 000000000..81bcf7495 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/vector.h @@ -0,0 +1,133 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_VECTOR_H +#define MLX_VECTOR_H + +#include "mlx/c/array.h" +#include "mlx/c/string.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_vector Vectors + * MLX vector objects. + */ +/**@{*/ + +/** + * A vector of array. + */ +typedef struct mlx_vector_array_ { + void* ctx; +} mlx_vector_array; +mlx_vector_array mlx_vector_array_new(void); +int mlx_vector_array_set(mlx_vector_array* vec, const mlx_vector_array src); +int mlx_vector_array_free(mlx_vector_array vec); +mlx_vector_array mlx_vector_array_new_data(const mlx_array* data, size_t size); +mlx_vector_array mlx_vector_array_new_value(const mlx_array val); +int mlx_vector_array_set_data( + mlx_vector_array* vec, + const mlx_array* data, + size_t size); +int mlx_vector_array_set_value(mlx_vector_array* vec, const mlx_array val); +int mlx_vector_array_append_data( + mlx_vector_array vec, + const mlx_array* data, + size_t size); +int mlx_vector_array_append_value(mlx_vector_array vec, const mlx_array val); +size_t mlx_vector_array_size(mlx_vector_array vec); +int mlx_vector_array_get( + mlx_array* res, + const mlx_vector_array vec, + size_t idx); + +/** + * A vector of vector_array. + */ +typedef struct mlx_vector_vector_array_ { + void* ctx; +} mlx_vector_vector_array; +mlx_vector_vector_array mlx_vector_vector_array_new(void); +int mlx_vector_vector_array_set( + mlx_vector_vector_array* vec, + const mlx_vector_vector_array src); +int mlx_vector_vector_array_free(mlx_vector_vector_array vec); +mlx_vector_vector_array mlx_vector_vector_array_new_data( + const mlx_vector_array* data, + size_t size); +mlx_vector_vector_array mlx_vector_vector_array_new_value( + const mlx_vector_array val); +int mlx_vector_vector_array_set_data( + mlx_vector_vector_array* vec, + const mlx_vector_array* data, + size_t size); +int mlx_vector_vector_array_set_value( + mlx_vector_vector_array* vec, + const mlx_vector_array val); +int mlx_vector_vector_array_append_data( + mlx_vector_vector_array vec, + const mlx_vector_array* data, + size_t size); +int mlx_vector_vector_array_append_value( + mlx_vector_vector_array vec, + const mlx_vector_array val); +size_t mlx_vector_vector_array_size(mlx_vector_vector_array vec); +int mlx_vector_vector_array_get( + mlx_vector_array* res, + const mlx_vector_vector_array vec, + size_t idx); + +/** + * A vector of int. + */ +typedef struct mlx_vector_int_ { + void* ctx; +} mlx_vector_int; +mlx_vector_int mlx_vector_int_new(void); +int mlx_vector_int_set(mlx_vector_int* vec, const mlx_vector_int src); +int mlx_vector_int_free(mlx_vector_int vec); +mlx_vector_int mlx_vector_int_new_data(int* data, size_t size); +mlx_vector_int mlx_vector_int_new_value(int val); +int mlx_vector_int_set_data(mlx_vector_int* vec, int* data, size_t size); +int mlx_vector_int_set_value(mlx_vector_int* vec, int val); +int mlx_vector_int_append_data(mlx_vector_int vec, int* data, size_t size); +int mlx_vector_int_append_value(mlx_vector_int vec, int val); +size_t mlx_vector_int_size(mlx_vector_int vec); +int mlx_vector_int_get(int* res, const mlx_vector_int vec, size_t idx); + +/** + * A vector of string. + */ +typedef struct mlx_vector_string_ { + void* ctx; +} mlx_vector_string; +mlx_vector_string mlx_vector_string_new(void); +int mlx_vector_string_set(mlx_vector_string* vec, const mlx_vector_string src); +int mlx_vector_string_free(mlx_vector_string vec); +mlx_vector_string mlx_vector_string_new_data(const char** data, size_t size); +mlx_vector_string mlx_vector_string_new_value(const char* val); +int mlx_vector_string_set_data( + mlx_vector_string* vec, + const char** data, + size_t size); +int mlx_vector_string_set_value(mlx_vector_string* vec, const char* val); +int mlx_vector_string_append_data( + mlx_vector_string vec, + const char** data, + size_t size); +int mlx_vector_string_append_value(mlx_vector_string vec, const char* val); +size_t mlx_vector_string_size(mlx_vector_string vec); +int mlx_vector_string_get(char** res, const mlx_vector_string vec, size_t idx); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/version.h b/x/mlxrunner/mlx/include/mlx/c/version.h new file mode 100644 index 000000000..96dd23877 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/version.h @@ -0,0 +1,18 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_VERSION_H +#define MLX_VERSION_H + +#include "mlx/c/string.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int mlx_version(mlx_string* str_); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/io.go b/x/mlxrunner/mlx/io.go index 84868e005..0ddbd3b59 100644 --- a/x/mlxrunner/mlx/io.go +++ b/x/mlxrunner/mlx/io.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx // #include "generated.h" @@ -7,6 +5,7 @@ import "C" import ( "iter" + "runtime" "unsafe" ) @@ -21,10 +20,17 @@ func Load(path string) iter.Seq2[string, *Array] { cPath := C.CString(path) defer C.free(unsafe.Pointer(cPath)) - cpu := C.mlx_default_cpu_stream_new() - defer C.mlx_stream_free(cpu) + // Use GPU stream so tensors load directly to GPU memory (CUDA has Load::eval_gpu). + // macOS Metal doesn't implement eval_gpu for Load, so fall back to CPU stream. + var stream C.mlx_stream + if runtime.GOOS == "darwin" { + stream = C.mlx_default_cpu_stream_new() + } else { + stream = C.mlx_default_gpu_stream_new() + } + defer C.mlx_stream_free(stream) - C.mlx_load_safetensors(&string2array, &string2string, cPath, cpu) + C.mlx_load_safetensors(&string2array, &string2string, cPath, stream) it := C.mlx_map_string_to_array_iterator_new(string2array) defer C.mlx_map_string_to_array_iterator_free(it) diff --git a/x/mlxrunner/mlx/memory.go b/x/mlxrunner/mlx/memory.go index cf36c304c..a243b72c0 100644 --- a/x/mlxrunner/mlx/memory.go +++ b/x/mlxrunner/mlx/memory.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx // #include "generated.h" diff --git a/x/mlxrunner/mlx/mlx.go b/x/mlxrunner/mlx/mlx.go index 962d192ae..d64a6b0ee 100644 --- a/x/mlxrunner/mlx/mlx.go +++ b/x/mlxrunner/mlx/mlx.go @@ -1,19 +1,22 @@ -//go:build mlx - package mlx -//go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release -//go:generate cmake --build build --parallel -//go:generate cmake --install build -//go:generate sh -c "go run generator/main.go -output=. ./dist/include/mlx/c/*.h" +//go:generate sh -c "go run generator/main.go -output=. ./include/mlx/c/*.h" // #cgo CXXFLAGS: -std=c++17 -// #cgo CPPFLAGS: -I${SRCDIR}/dist/include -// #cgo LDFLAGS: -L${SRCDIR}/dist/lib -lstdc++ +// #cgo CPPFLAGS: -I${SRCDIR}/include +// #cgo LDFLAGS: -lstdc++ // #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate // #include "generated.h" import "C" +// Version returns the MLX core library version string. +func Version() string { + str := C.mlx_string_new() + defer C.mlx_string_free(str) + C.mlx_version(&str) + return C.GoString(C.mlx_string_data(str)) +} + func doEval(outputs []*Array, async bool) { vector := C.mlx_vector_array_new() defer C.mlx_vector_array_free(vector) diff --git a/x/mlxrunner/mlx/nn.go b/x/mlxrunner/mlx/nn.go index 3d5691368..d3a99a6cd 100644 --- a/x/mlxrunner/mlx/nn.go +++ b/x/mlxrunner/mlx/nn.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx type Linear struct { diff --git a/x/mlxrunner/mlx/ops.go b/x/mlxrunner/mlx/ops.go index 2f97ba8d2..3d8ec7dec 100644 --- a/x/mlxrunner/mlx/ops.go +++ b/x/mlxrunner/mlx/ops.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx // #include "generated.h" diff --git a/x/mlxrunner/mlx/ops_extra.go b/x/mlxrunner/mlx/ops_extra.go index 283a2141c..a2e25d68b 100644 --- a/x/mlxrunner/mlx/ops_extra.go +++ b/x/mlxrunner/mlx/ops_extra.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx // #include "generated.h" diff --git a/x/mlxrunner/mlx/random.go b/x/mlxrunner/mlx/random.go index 6afdbbab4..82c3d6785 100644 --- a/x/mlxrunner/mlx/random.go +++ b/x/mlxrunner/mlx/random.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx // #include "generated.h" diff --git a/x/mlxrunner/mlx/slice.go b/x/mlxrunner/mlx/slice.go index ab1324774..ea642ebf7 100644 --- a/x/mlxrunner/mlx/slice.go +++ b/x/mlxrunner/mlx/slice.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx // #include "generated.h" diff --git a/x/mlxrunner/mlx/stream.go b/x/mlxrunner/mlx/stream.go index 83a3eeffd..9b01b4a85 100644 --- a/x/mlxrunner/mlx/stream.go +++ b/x/mlxrunner/mlx/stream.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlx // #include "generated.h" @@ -27,6 +25,22 @@ var DefaultDevice = sync.OnceValue(func() Device { return Device{d} }) +// GPUIsAvailable returns true if a GPU device is available. +func GPUIsAvailable() bool { + dev := C.mlx_device_new_type(C.MLX_GPU, 0) + defer C.mlx_device_free(dev) + var avail C.bool + C.mlx_device_is_available(&avail, dev) + return bool(avail) +} + +// SetDefaultDeviceGPU sets the default MLX device to GPU. +func SetDefaultDeviceGPU() { + dev := C.mlx_device_new_type(C.MLX_GPU, 0) + C.mlx_set_default_device(dev) + C.mlx_device_free(dev) +} + type Stream struct { ctx C.mlx_stream } diff --git a/x/mlxrunner/model/base/base.go b/x/mlxrunner/model/base/base.go index 4cdf6df33..3a85b6eb0 100644 --- a/x/mlxrunner/model/base/base.go +++ b/x/mlxrunner/model/base/base.go @@ -1,5 +1,3 @@ -//go:build mlx - package base import ( diff --git a/x/mlxrunner/model/base/base_stub.go b/x/mlxrunner/model/base/base_stub.go deleted file mode 100644 index 318d8f911..000000000 --- a/x/mlxrunner/model/base/base_stub.go +++ /dev/null @@ -1,3 +0,0 @@ -//go:build !mlx - -package base diff --git a/x/mlxrunner/model/linear.go b/x/mlxrunner/model/linear.go index fffdbdb29..788e4e3f0 100644 --- a/x/mlxrunner/model/linear.go +++ b/x/mlxrunner/model/linear.go @@ -1,5 +1,3 @@ -//go:build mlx - package model import ( diff --git a/x/mlxrunner/model/quant.go b/x/mlxrunner/model/quant.go index 10896e4b4..d4a56c35c 100644 --- a/x/mlxrunner/model/quant.go +++ b/x/mlxrunner/model/quant.go @@ -1,5 +1,3 @@ -//go:build mlx - package model import ( diff --git a/x/mlxrunner/model/root.go b/x/mlxrunner/model/root.go index c912f7f4c..1c05ee6a8 100644 --- a/x/mlxrunner/model/root.go +++ b/x/mlxrunner/model/root.go @@ -1,5 +1,3 @@ -//go:build mlx - package model import ( diff --git a/x/mlxrunner/model/root_stub.go b/x/mlxrunner/model/root_stub.go deleted file mode 100644 index 3fcda9c25..000000000 --- a/x/mlxrunner/model/root_stub.go +++ /dev/null @@ -1,3 +0,0 @@ -//go:build !mlx - -package model diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index 852b04dcc..3ce148c02 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlxrunner import ( diff --git a/x/mlxrunner/runner.go b/x/mlxrunner/runner.go index acaef79bf..08a376d43 100644 --- a/x/mlxrunner/runner.go +++ b/x/mlxrunner/runner.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlxrunner import ( diff --git a/x/mlxrunner/sample/sample.go b/x/mlxrunner/sample/sample.go index a25b23d03..df9da7a99 100644 --- a/x/mlxrunner/sample/sample.go +++ b/x/mlxrunner/sample/sample.go @@ -1,5 +1,3 @@ -//go:build mlx - package sample import ( diff --git a/x/mlxrunner/server.go b/x/mlxrunner/server.go index 9c7d7e775..a9972bfdc 100644 --- a/x/mlxrunner/server.go +++ b/x/mlxrunner/server.go @@ -1,5 +1,3 @@ -//go:build mlx - package mlxrunner import ( @@ -29,6 +27,13 @@ func Execute(args []string) error { return fmt.Errorf("MLX not available: %w", err) } + if mlx.GPUIsAvailable() { + mlx.SetDefaultDeviceGPU() + slog.Info("MLX engine initialized", "MLX version", mlx.Version(), "device", "gpu") + } else { + slog.Info("MLX engine initialized", "MLX version", mlx.Version(), "device", "cpu") + } + var ( modelName string port int diff --git a/x/mlxrunner/server_stub.go b/x/mlxrunner/server_stub.go deleted file mode 100644 index 3b0f35500..000000000 --- a/x/mlxrunner/server_stub.go +++ /dev/null @@ -1,10 +0,0 @@ -//go:build !mlx - -package mlxrunner - -import "errors" - -// Execute returns an error when not built with MLX support. -func Execute(args []string) error { - return errors.New("MLX runner not available: build with mlx tag") -} diff --git a/x/models/gemma3/gemma3.go b/x/models/gemma3/gemma3.go index edf66657c..01f0559a0 100644 --- a/x/models/gemma3/gemma3.go +++ b/x/models/gemma3/gemma3.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package gemma3 provides the Gemma 3 text model implementation for MLX. package gemma3 diff --git a/x/models/glm4_moe_lite/glm4_moe_lite.go b/x/models/glm4_moe_lite/glm4_moe_lite.go index fb9c4af6f..2e8365580 100644 --- a/x/models/glm4_moe_lite/glm4_moe_lite.go +++ b/x/models/glm4_moe_lite/glm4_moe_lite.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package glm4_moe_lite provides the GLM4-MoE-Lite implementation for MLX. // This model uses Multi-head Latent Attention (MLA) and Mixture of Experts (MoE). package glm4_moe_lite diff --git a/x/models/glm4_moe_lite/parser.go b/x/models/glm4_moe_lite/parser.go index de1b2cc17..9a9985f68 100644 --- a/x/models/glm4_moe_lite/parser.go +++ b/x/models/glm4_moe_lite/parser.go @@ -1,5 +1,3 @@ -//go:build mlx - package glm4_moe_lite import ( diff --git a/x/models/glm4_moe_lite/parser_test.go b/x/models/glm4_moe_lite/parser_test.go index 0ce382709..d15b4e803 100644 --- a/x/models/glm4_moe_lite/parser_test.go +++ b/x/models/glm4_moe_lite/parser_test.go @@ -1,5 +1,3 @@ -//go:build mlx - package glm4_moe_lite import ( diff --git a/x/models/glm4_moe_lite/render.go b/x/models/glm4_moe_lite/render.go index 4998604bf..d15a99e51 100644 --- a/x/models/glm4_moe_lite/render.go +++ b/x/models/glm4_moe_lite/render.go @@ -1,5 +1,3 @@ -//go:build mlx - package glm4_moe_lite import ( diff --git a/x/models/glm4_moe_lite/render_test.go b/x/models/glm4_moe_lite/render_test.go index f0d576bec..91b871acc 100644 --- a/x/models/glm4_moe_lite/render_test.go +++ b/x/models/glm4_moe_lite/render_test.go @@ -1,5 +1,3 @@ -//go:build mlx - package glm4_moe_lite import ( diff --git a/x/models/llama/llama.go b/x/models/llama/llama.go index fc7f34488..18e39bc0a 100644 --- a/x/models/llama/llama.go +++ b/x/models/llama/llama.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package llama provides a Llama-style decoder-only transformer for MLX. package llama diff --git a/x/models/nn/nn.go b/x/models/nn/nn.go index 78f1b92b6..07047024d 100644 --- a/x/models/nn/nn.go +++ b/x/models/nn/nn.go @@ -1,5 +1,3 @@ -//go:build mlx - package nn import "github.com/ollama/ollama/x/mlxrunner/mlx" diff --git a/x/models/qwen3/qwen3.go b/x/models/qwen3/qwen3.go index 85d427f58..71596af98 100644 --- a/x/models/qwen3/qwen3.go +++ b/x/models/qwen3/qwen3.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package qwen3 provides the Qwen3 text model implementation for MLX. package qwen3 diff --git a/x/models/qwen3_5/qwen3_5.go b/x/models/qwen3_5/qwen3_5.go index fbee82b59..642ea1bba 100644 --- a/x/models/qwen3_5/qwen3_5.go +++ b/x/models/qwen3_5/qwen3_5.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package qwen3_5 provides the Qwen 3.5 text and MoE implementation for MLX. package qwen3_5 diff --git a/x/models/qwen3_5/qwen3_5_test.go b/x/models/qwen3_5/qwen3_5_test.go index 0a70da189..8165cd484 100644 --- a/x/models/qwen3_5/qwen3_5_test.go +++ b/x/models/qwen3_5/qwen3_5_test.go @@ -1,5 +1,3 @@ -//go:build mlx - package qwen3_5 import ( diff --git a/x/models/qwen3_5_moe/qwen3_5_moe.go b/x/models/qwen3_5_moe/qwen3_5_moe.go index 9e0be26be..a505b458e 100644 --- a/x/models/qwen3_5_moe/qwen3_5_moe.go +++ b/x/models/qwen3_5_moe/qwen3_5_moe.go @@ -1,5 +1,3 @@ -//go:build mlx - // Package qwen3_5_moe registers Qwen 3.5 MoE architecture aliases. package qwen3_5_moe diff --git a/x/tokenizer/tokenizer.go b/x/tokenizer/tokenizer.go index 301e51aea..a1ce5e8ee 100644 --- a/x/tokenizer/tokenizer.go +++ b/x/tokenizer/tokenizer.go @@ -1,5 +1,3 @@ -//go:build mlx - // tokenizer.go - BPE and SentencePiece tokenizer for HuggingFace models // // Based on standard BPE algorithm (Sennrich et al. 2015) with: diff --git a/x/tokenizer/tokenizer_benchmark_test.go b/x/tokenizer/tokenizer_benchmark_test.go index e65a59786..9f3d2a10e 100644 --- a/x/tokenizer/tokenizer_benchmark_test.go +++ b/x/tokenizer/tokenizer_benchmark_test.go @@ -1,5 +1,3 @@ -//go:build mlx - package tokenizer import ( diff --git a/x/tokenizer/tokenizer_bpe.go b/x/tokenizer/tokenizer_bpe.go index 1e625c20a..9037f15bc 100644 --- a/x/tokenizer/tokenizer_bpe.go +++ b/x/tokenizer/tokenizer_bpe.go @@ -1,5 +1,3 @@ -//go:build mlx - package tokenizer import "container/heap" diff --git a/x/tokenizer/tokenizer_correctness_test.go b/x/tokenizer/tokenizer_correctness_test.go index 2fe94d279..91adc167d 100644 --- a/x/tokenizer/tokenizer_correctness_test.go +++ b/x/tokenizer/tokenizer_correctness_test.go @@ -1,5 +1,3 @@ -//go:build mlx - package tokenizer import ( diff --git a/x/tokenizer/tokenizer_decode.go b/x/tokenizer/tokenizer_decode.go index e02d2a88b..0056948a2 100644 --- a/x/tokenizer/tokenizer_decode.go +++ b/x/tokenizer/tokenizer_decode.go @@ -1,5 +1,3 @@ -//go:build mlx - package tokenizer import ( diff --git a/x/tokenizer/tokenizer_encode.go b/x/tokenizer/tokenizer_encode.go index 1b71ea6d3..3eb629e56 100644 --- a/x/tokenizer/tokenizer_encode.go +++ b/x/tokenizer/tokenizer_encode.go @@ -1,5 +1,3 @@ -//go:build mlx - package tokenizer import ( diff --git a/x/tokenizer/tokenizer_ggml_parity_test.go b/x/tokenizer/tokenizer_ggml_parity_test.go index 4cef3d3dd..ee9b68f38 100644 --- a/x/tokenizer/tokenizer_ggml_parity_test.go +++ b/x/tokenizer/tokenizer_ggml_parity_test.go @@ -1,5 +1,3 @@ -//go:build mlx - package tokenizer import ( diff --git a/x/tokenizer/tokenizer_load.go b/x/tokenizer/tokenizer_load.go index d2a253e17..efd086628 100644 --- a/x/tokenizer/tokenizer_load.go +++ b/x/tokenizer/tokenizer_load.go @@ -1,5 +1,3 @@ -//go:build mlx - package tokenizer import ( diff --git a/x/tokenizer/tokenizer_load_test.go b/x/tokenizer/tokenizer_load_test.go index 136399c2e..caf2b0d35 100644 --- a/x/tokenizer/tokenizer_load_test.go +++ b/x/tokenizer/tokenizer_load_test.go @@ -1,5 +1,3 @@ -//go:build mlx - package tokenizer import (