mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
MLX: add header vendoring and remove go build tag (#14642)
* prefer rocm v6 on windows Avoid building with v7 - more changes are needed * MLX: add header vendoring and remove go build tag This switches to using a vendoring approach for the mlx-c headers so that Go can build without requiring a cmake first. This enables building the new MLX based code by default. Every time cmake runs, the headers are refreshed, so we can easily keep them in sync when we bump mlx versions. Basic Windows and Linux support are verified. * ci: harden for flaky choco repo servers CI sometimes fails due to choco not actually installing cache. Since it just speeds up the build, we can proceed without. * review comments
This commit is contained in:
52
.github/workflows/release.yaml
vendored
52
.github/workflows/release.yaml
vendored
@@ -117,6 +117,25 @@ jobs:
|
||||
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
||||
flags: ''
|
||||
runner_dir: 'vulkan'
|
||||
- os: windows
|
||||
arch: amd64
|
||||
preset: 'MLX CUDA 13'
|
||||
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
|
||||
cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip
|
||||
cuda-components:
|
||||
- '"cudart"'
|
||||
- '"nvcc"'
|
||||
- '"cublas"'
|
||||
- '"cublas_dev"'
|
||||
- '"cufft"'
|
||||
- '"cufft_dev"'
|
||||
- '"nvrtc"'
|
||||
- '"nvrtc_dev"'
|
||||
- '"crt"'
|
||||
- '"nvvm"'
|
||||
- '"nvptxcompiler"'
|
||||
cuda-version: '13.0'
|
||||
flags: ''
|
||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||
environment: release
|
||||
env:
|
||||
@@ -125,8 +144,10 @@ jobs:
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
choco install -y --no-progress ccache ninja
|
||||
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
||||
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan')
|
||||
if (Get-Command ccache -ErrorAction SilentlyContinue) {
|
||||
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
||||
}
|
||||
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan') || startsWith(matrix.preset, 'MLX ')
|
||||
id: cache-install
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
@@ -134,8 +155,9 @@ jobs:
|
||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||
C:\Program Files\AMD\ROCm
|
||||
C:\VulkanSDK
|
||||
key: ${{ matrix.install }}
|
||||
- if: startsWith(matrix.preset, 'CUDA ')
|
||||
C:\Program Files\NVIDIA\CUDNN
|
||||
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'MLX ')
|
||||
name: Install CUDA ${{ matrix.cuda-version }}
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
@@ -179,6 +201,23 @@ jobs:
|
||||
run: |
|
||||
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
- if: startsWith(matrix.preset, 'MLX ')
|
||||
name: Install cuDNN for MLX
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
|
||||
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||
Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip"
|
||||
Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted
|
||||
$cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName
|
||||
New-Item -ItemType Directory -Force -Path $cudnnRoot
|
||||
Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse
|
||||
}
|
||||
|
||||
echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
@@ -186,7 +225,8 @@ jobs:
|
||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||
C:\Program Files\AMD\ROCm
|
||||
C:\VulkanSDK
|
||||
key: ${{ matrix.install }}
|
||||
C:\Program Files\NVIDIA\CUDNN
|
||||
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/cache@v4
|
||||
with:
|
||||
@@ -198,7 +238,7 @@ jobs:
|
||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}"
|
||||
cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}"
|
||||
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
|
||||
cmake --install build --component "${{ startsWith(matrix.preset, 'MLX ') && 'MLX' || startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
|
||||
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
|
||||
env:
|
||||
CMAKE_GENERATOR: Ninja
|
||||
|
||||
64
.github/workflows/test.yaml
vendored
64
.github/workflows/test.yaml
vendored
@@ -37,7 +37,7 @@ jobs:
|
||||
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
|
||||
}
|
||||
|
||||
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
|
||||
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*' '.github/**/*') | tee -a $GITHUB_OUTPUT
|
||||
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
|
||||
|
||||
linux:
|
||||
@@ -60,6 +60,10 @@ jobs:
|
||||
mesa-vulkan-drivers vulkan-tools
|
||||
libvulkan1 libvulkan-dev
|
||||
vulkan-sdk cmake ccache g++ make
|
||||
- preset: 'MLX CUDA 13'
|
||||
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
|
||||
extra-packages: libcudnn9-dev-cuda-13 libopenblas-dev liblapack-dev liblapacke-dev git curl
|
||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=87 -DBLAS_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu -DLAPACK_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu'
|
||||
runs-on: linux
|
||||
container: ${{ matrix.container }}
|
||||
steps:
|
||||
@@ -76,6 +80,10 @@ jobs:
|
||||
$sudo apt-get update
|
||||
fi
|
||||
$sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }}
|
||||
# MLX requires CMake 3.25+, install from official releases
|
||||
if [ "${{ matrix.preset }}" = "MLX CUDA 13" ]; then
|
||||
curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.31.2/cmake-3.31.2-linux-$(uname -m).tar.gz | $sudo tar xz -C /usr/local --strip-components 1
|
||||
fi
|
||||
# Export VULKAN_SDK if provided by LunarG package (defensive)
|
||||
if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then
|
||||
echo "VULKAN_SDK=/usr" >> $GITHUB_ENV
|
||||
@@ -87,8 +95,8 @@ jobs:
|
||||
path: /github/home/.cache/ccache
|
||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
|
||||
- run: |
|
||||
cmake --preset ${{ matrix.preset }} ${{ matrix.flags }}
|
||||
cmake --build --preset ${{ matrix.preset }} --parallel
|
||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
|
||||
cmake --build --preset "${{ matrix.preset }}" --parallel
|
||||
|
||||
windows:
|
||||
needs: [changes]
|
||||
@@ -114,12 +122,31 @@ jobs:
|
||||
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||
- preset: Vulkan
|
||||
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
||||
- preset: 'MLX CUDA 13'
|
||||
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
|
||||
cudnn-install: https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.18.1.3_cuda13-archive.zip
|
||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
|
||||
cuda-components:
|
||||
- '"cudart"'
|
||||
- '"nvcc"'
|
||||
- '"cublas"'
|
||||
- '"cublas_dev"'
|
||||
- '"cufft"'
|
||||
- '"cufft_dev"'
|
||||
- '"nvrtc"'
|
||||
- '"nvrtc_dev"'
|
||||
- '"crt"'
|
||||
- '"nvvm"'
|
||||
- '"nvptxcompiler"'
|
||||
cuda-version: '13.0'
|
||||
runs-on: windows
|
||||
steps:
|
||||
- run: |
|
||||
choco install -y --no-progress ccache ninja
|
||||
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
||||
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan'
|
||||
if (Get-Command ccache -ErrorAction SilentlyContinue) {
|
||||
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
||||
}
|
||||
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan' || matrix.preset == 'MLX CUDA 13'
|
||||
id: cache-install
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
@@ -127,8 +154,9 @@ jobs:
|
||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||
C:\Program Files\AMD\ROCm
|
||||
C:\VulkanSDK
|
||||
key: ${{ matrix.install }}
|
||||
- if: matrix.preset == 'CUDA'
|
||||
C:\Program Files\NVIDIA\CUDNN
|
||||
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||
- if: matrix.preset == 'CUDA' || matrix.preset == 'MLX CUDA 13'
|
||||
name: Install CUDA ${{ matrix.cuda-version }}
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
@@ -164,10 +192,27 @@ jobs:
|
||||
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
||||
Start-Process -FilePath .\install.exe -ArgumentList "-c","--am","--al","in" -NoNewWindow -Wait
|
||||
}
|
||||
|
||||
|
||||
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
|
||||
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
|
||||
- if: matrix.preset == 'MLX CUDA 13'
|
||||
name: Install cuDNN for MLX
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
|
||||
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||
Invoke-WebRequest -Uri "${{ matrix.cudnn-install }}" -OutFile "cudnn.zip"
|
||||
Expand-Archive -Path cudnn.zip -DestinationPath cudnn-extracted
|
||||
$cudnnDir = (Get-ChildItem -Path cudnn-extracted -Directory)[0].FullName
|
||||
New-Item -ItemType Directory -Force -Path $cudnnRoot
|
||||
Copy-Item -Path "$cudnnDir\*" -Destination "$cudnnRoot\" -Recurse
|
||||
}
|
||||
|
||||
echo "CUDNN_ROOT_DIR=$cudnnRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CUDNN_INCLUDE_PATH=$cudnnRoot\include" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CUDNN_LIBRARY_PATH=$cudnnRoot\lib\x64" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "$cudnnRoot\bin\x64" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
@@ -175,7 +220,8 @@ jobs:
|
||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||
C:\Program Files\AMD\ROCm
|
||||
C:\VulkanSDK
|
||||
key: ${{ matrix.install }}
|
||||
C:\Program Files\NVIDIA\CUDNN
|
||||
key: ${{ matrix.install }}-${{ matrix.cudnn-install }}
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/cache@v4
|
||||
with:
|
||||
|
||||
106
CMakeLists.txt
106
CMakeLists.txt
@@ -64,10 +64,15 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR})
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${OLLAMA_BUILD_DIR})
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
|
||||
# Store ggml include paths for use with target_include_directories later.
|
||||
# We avoid global include_directories() to prevent polluting the include path
|
||||
# for other projects like MLX (whose openblas dependency has its own common.h).
|
||||
set(GGML_INCLUDE_DIRS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx
|
||||
)
|
||||
|
||||
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
|
||||
|
||||
@@ -87,6 +92,14 @@ if(NOT CPU_VARIANTS)
|
||||
set(CPU_VARIANTS "ggml-cpu")
|
||||
endif()
|
||||
|
||||
# Apply ggml include directories to ggml targets only (not globally)
|
||||
target_include_directories(ggml-base PRIVATE ${GGML_INCLUDE_DIRS})
|
||||
foreach(variant ${CPU_VARIANTS})
|
||||
if(TARGET ${variant})
|
||||
target_include_directories(${variant} PRIVATE ${GGML_INCLUDE_DIRS})
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
install(TARGETS ggml-base ${CPU_VARIANTS}
|
||||
RUNTIME_DEPENDENCIES
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
@@ -103,6 +116,7 @@ if(CMAKE_CUDA_COMPILER)
|
||||
|
||||
find_package(CUDAToolkit)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
||||
target_include_directories(ggml-cuda PRIVATE ${GGML_INCLUDE_DIRS})
|
||||
install(TARGETS ggml-cuda
|
||||
RUNTIME_DEPENDENCIES
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||
@@ -134,6 +148,7 @@ if(CMAKE_HIP_COMPILER)
|
||||
if(AMDGPU_TARGETS)
|
||||
find_package(hip REQUIRED)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
||||
target_include_directories(ggml-hip PRIVATE ${GGML_INCLUDE_DIRS})
|
||||
|
||||
if (WIN32)
|
||||
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
|
||||
@@ -168,6 +183,7 @@ if(NOT APPLE)
|
||||
find_package(Vulkan)
|
||||
if(Vulkan_FOUND)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
||||
target_include_directories(ggml-vulkan PRIVATE ${GGML_INCLUDE_DIRS})
|
||||
install(TARGETS ggml-vulkan
|
||||
RUNTIME_DEPENDENCIES
|
||||
PRE_INCLUDE_REGEXES vulkan
|
||||
@@ -179,7 +195,6 @@ if(NOT APPLE)
|
||||
endif()
|
||||
|
||||
option(MLX_ENGINE "Enable MLX backend" OFF)
|
||||
|
||||
if(MLX_ENGINE)
|
||||
message(STATUS "Setting up MLX (this takes a while...)")
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/imagegen/mlx)
|
||||
@@ -187,10 +202,36 @@ if(MLX_ENGINE)
|
||||
# Find CUDA toolkit if MLX is built with CUDA support
|
||||
find_package(CUDAToolkit)
|
||||
|
||||
# Build list of directories for runtime dependency resolution
|
||||
set(MLX_RUNTIME_DIRS ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR})
|
||||
# Add cuDNN bin paths for DLLs (Windows MLX CUDA builds)
|
||||
# CUDNN_ROOT_DIR is the standard CMake variable for cuDNN location
|
||||
if(DEFINED ENV{CUDNN_ROOT_DIR})
|
||||
# cuDNN 9.x has versioned subdirectories under bin/ (e.g., bin/13.0/)
|
||||
file(GLOB CUDNN_BIN_SUBDIRS "$ENV{CUDNN_ROOT_DIR}/bin/*")
|
||||
list(APPEND MLX_RUNTIME_DIRS ${CUDNN_BIN_SUBDIRS})
|
||||
endif()
|
||||
# Add build output directory and MLX dependency build directories
|
||||
list(APPEND MLX_RUNTIME_DIRS ${OLLAMA_BUILD_DIR})
|
||||
# OpenBLAS DLL location (pre-built zip extracts into openblas-src/bin/)
|
||||
list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/openblas-src/bin)
|
||||
# NCCL: on Linux, if real NCCL is found, cmake bundles libnccl.so via the
|
||||
# regex below. If NCCL is not found, MLX links a static stub (OBJECT lib)
|
||||
# so there is no runtime dependency. This path covers the stub build dir
|
||||
# for windows so we include the DLL in our dependencies.
|
||||
list(APPEND MLX_RUNTIME_DIRS ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/distributed/nccl/nccl_stub-prefix/src/nccl_stub-build/Release)
|
||||
|
||||
# Base regexes for runtime dependencies (cross-platform)
|
||||
set(MLX_INCLUDE_REGEXES cublas cublasLt cudart cufft nvrtc nvrtc-builtins cudnn nccl openblas gfortran)
|
||||
# On Windows, also include dl.dll (dlfcn-win32 POSIX emulation layer)
|
||||
if(WIN32)
|
||||
list(APPEND MLX_INCLUDE_REGEXES "^dl\\.dll$")
|
||||
endif()
|
||||
|
||||
install(TARGETS mlx mlxc
|
||||
RUNTIME_DEPENDENCIES
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
|
||||
DIRECTORIES ${MLX_RUNTIME_DIRS}
|
||||
PRE_INCLUDE_REGEXES ${MLX_INCLUDE_REGEXES}
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
@@ -205,13 +246,54 @@ if(MLX_ENGINE)
|
||||
COMPONENT MLX)
|
||||
endif()
|
||||
|
||||
# Manually install cudart and cublas since they might not be picked up as direct dependencies
|
||||
# Install CCCL headers for NVRTC JIT compilation at runtime.
|
||||
# MLX's own install rules use the default component so they get skipped by
|
||||
# --component MLX. Headers are installed alongside libmlx in OLLAMA_INSTALL_DIR.
|
||||
# On Linux, MLX's jit_module.cpp resolves CCCL via
|
||||
# current_binary_dir().parent_path() / "include" / "cccl", so we create a
|
||||
# symlink from lib/ollama/include -> ${OLLAMA_RUNNER_DIR}/include
|
||||
# This will need refinement if we add multiple CUDA versions for MLX in the future.
|
||||
if(EXISTS ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda)
|
||||
install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/cuda
|
||||
DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl
|
||||
COMPONENT MLX)
|
||||
install(DIRECTORY ${CMAKE_BINARY_DIR}/_deps/cccl-src/include/nv
|
||||
DESTINATION ${OLLAMA_INSTALL_DIR}/include/cccl
|
||||
COMPONENT MLX)
|
||||
if(NOT WIN32 AND NOT APPLE)
|
||||
install(CODE "
|
||||
set(_link \"${CMAKE_INSTALL_PREFIX}/lib/ollama/include\")
|
||||
set(_target \"${OLLAMA_RUNNER_DIR}/include\")
|
||||
if(NOT EXISTS \${_link})
|
||||
execute_process(COMMAND \${CMAKE_COMMAND} -E create_symlink \${_target} \${_link})
|
||||
endif()
|
||||
" COMPONENT MLX)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# On Windows, explicitly install dl.dll (dlfcn-win32 POSIX dlopen emulation)
|
||||
# RUNTIME_DEPENDENCIES auto-excludes it via POST_EXCLUDE_FILES_STRICT because
|
||||
# dlfcn-win32 is a known CMake target with its own install rules (which install
|
||||
# to the wrong destination). We must install it explicitly here.
|
||||
if(WIN32)
|
||||
install(FILES ${OLLAMA_BUILD_DIR}/dl.dll
|
||||
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||
COMPONENT MLX)
|
||||
endif()
|
||||
|
||||
# Manually install CUDA runtime libraries that MLX loads via dlopen
|
||||
# (not detected by RUNTIME_DEPENDENCIES since they aren't link-time deps)
|
||||
if(CUDAToolkit_FOUND)
|
||||
file(GLOB CUDART_LIBS
|
||||
file(GLOB MLX_CUDA_LIBS
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
|
||||
if(CUDART_LIBS)
|
||||
install(FILES ${CUDART_LIBS}
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*"
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcublasLt.so*"
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libnvrtc.so*"
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libnvrtc-builtins.so*"
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcufft.so*"
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcudnn.so*")
|
||||
if(MLX_CUDA_LIBS)
|
||||
install(FILES ${MLX_CUDA_LIBS}
|
||||
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||
COMPONENT MLX)
|
||||
endif()
|
||||
|
||||
@@ -112,6 +112,7 @@
|
||||
"name": "MLX CUDA 13",
|
||||
"inherits": [ "MLX", "CUDA 13" ],
|
||||
"cacheVariables": {
|
||||
"MLX_CUDA_ARCHITECTURES": "86;89;90;90a;100;103;75-virtual;80-virtual;110-virtual;120-virtual;121-virtual",
|
||||
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
|
||||
}
|
||||
}
|
||||
|
||||
23
Dockerfile
23
Dockerfile
@@ -9,6 +9,11 @@ ARG CMAKEVERSION=3.31.2
|
||||
ARG NINJAVERSION=1.12.1
|
||||
ARG VULKANVERSION=1.4.321.1
|
||||
|
||||
# Default empty stages for local MLX source overrides.
|
||||
# Override with: docker build --build-context local-mlx=../mlx --build-context local-mlx-c=../mlx-c
|
||||
FROM scratch AS local-mlx
|
||||
FROM scratch AS local-mlx-c
|
||||
|
||||
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
|
||||
RUN dnf install -y yum-utils ccache gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ gcc-toolset-11-binutils \
|
||||
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
||||
@@ -152,12 +157,20 @@ COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
COPY x/imagegen/mlx x/imagegen/mlx
|
||||
COPY go.mod go.sum .
|
||||
COPY MLX_VERSION .
|
||||
COPY MLX_VERSION MLX_CORE_VERSION .
|
||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
RUN go mod download
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
||||
--mount=type=bind,from=local-mlx,target=/tmp/local-mlx \
|
||||
--mount=type=bind,from=local-mlx-c,target=/tmp/local-mlx-c \
|
||||
if [ -f /tmp/local-mlx/CMakeLists.txt ]; then \
|
||||
export OLLAMA_MLX_SOURCE=/tmp/local-mlx; \
|
||||
fi \
|
||||
&& if [ -f /tmp/local-mlx-c/CMakeLists.txt ]; then \
|
||||
export OLLAMA_MLX_C_SOURCE=/tmp/local-mlx-c; \
|
||||
fi \
|
||||
&& cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
||||
&& cmake --build --preset 'MLX CUDA 13' -- -l $(nproc) \
|
||||
&& cmake --install build --component MLX --strip
|
||||
|
||||
@@ -168,16 +181,14 @@ RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
# Clone mlx-c headers for CGO (version from MLX_VERSION file)
|
||||
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
ENV CGO_CFLAGS="${CGO_CFLAGS} -I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
|
||||
ENV CGO_CFLAGS="${CGO_CFLAGS}"
|
||||
ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}"
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .
|
||||
go build -trimpath -buildmode=pie -o /bin/ollama .
|
||||
|
||||
FROM --platform=linux/amd64 scratch AS amd64
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
|
||||
1
MLX_CORE_VERSION
Normal file
1
MLX_CORE_VERSION
Normal file
@@ -0,0 +1 @@
|
||||
v0.30.6
|
||||
@@ -51,6 +51,9 @@ Install prerequisites:
|
||||
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
|
||||
- (Optional) VULKAN GPU support
|
||||
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
|
||||
- (Optional) MLX engine support
|
||||
- [CUDA 13+ SDK](https://developer.nvidia.com/cuda-downloads)
|
||||
- [cuDNN 9+](https://developer.nvidia.com/cudnn)
|
||||
|
||||
Then, configure and build the project:
|
||||
|
||||
@@ -101,6 +104,10 @@ Install prerequisites:
|
||||
- (Optional) VULKAN GPU support
|
||||
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
|
||||
- Or install via package manager: `sudo apt install vulkan-sdk` (Ubuntu/Debian) or `sudo dnf install vulkan-sdk` (Fedora/CentOS)
|
||||
- (Optional) MLX engine support
|
||||
- [CUDA 13+ SDK](https://developer.nvidia.com/cuda-downloads)
|
||||
- [cuDNN 9+](https://developer.nvidia.com/cudnn)
|
||||
- OpenBLAS/LAPACK: `sudo apt install libopenblas-dev liblapack-dev liblapacke-dev` (Ubuntu/Debian)
|
||||
> [!IMPORTANT]
|
||||
> Ensure prerequisites are in `PATH` before running CMake.
|
||||
|
||||
@@ -118,6 +125,67 @@ Lastly, run Ollama:
|
||||
go run . serve
|
||||
```
|
||||
|
||||
## MLX Engine (Optional)
|
||||
|
||||
The MLX engine enables running safetensor based models. It requires building the [MLX](https://github.com/ml-explore/mlx) and [MLX-C](https://github.com/ml-explore/mlx-c) shared libraries separately via CMake. On MacOS, MLX leverages the Metal library to run on the GPU, and on Windows and Linux, runs on NVIDIA GPUs via CUDA v13.
|
||||
|
||||
### macOS (Apple Silicon)
|
||||
|
||||
Requires the Metal toolchain. Install [Xcode](https://developer.apple.com/xcode/) first, then:
|
||||
|
||||
```shell
|
||||
xcodebuild -downloadComponent MetalToolchain
|
||||
```
|
||||
|
||||
Verify it's installed correctly (should print "no input files"):
|
||||
|
||||
```shell
|
||||
xcrun metal
|
||||
```
|
||||
|
||||
Then build:
|
||||
|
||||
```shell
|
||||
cmake -B build --preset MLX
|
||||
cmake --build build --preset MLX --parallel
|
||||
cmake --install build --component MLX
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> Without the Metal toolchain, cmake will silently complete with Metal disabled. Check the cmake output for `Setting MLX_BUILD_METAL=OFF` which indicates the toolchain is missing.
|
||||
|
||||
### Windows / Linux (CUDA)
|
||||
|
||||
Requires CUDA 13+ and [cuDNN](https://developer.nvidia.com/cudnn) 9+.
|
||||
|
||||
```shell
|
||||
cmake -B build --preset "MLX CUDA 13"
|
||||
cmake --build build --target mlx --target mlxc --config Release --parallel
|
||||
cmake --install build --component MLX --strip
|
||||
```
|
||||
|
||||
### Local MLX source overrides
|
||||
|
||||
To build against a local checkout of MLX and/or MLX-C (useful for development), set environment variables before running CMake:
|
||||
|
||||
```shell
|
||||
export OLLAMA_MLX_SOURCE=/path/to/mlx
|
||||
export OLLAMA_MLX_C_SOURCE=/path/to/mlx-c
|
||||
```
|
||||
|
||||
For example, using the helper scripts with local mlx and mlx-c repos:
|
||||
```shell
|
||||
OLLAMA_MLX_SOURCE=../mlx OLLAMA_MLX_C_SOURCE=../mlx-c ./scripts/build_linux.sh
|
||||
|
||||
OLLAMA_MLX_SOURCE=../mlx OLLAMA_MLX_C_SOURCE=../mlx-c ./scripts/build_darwin.sh
|
||||
```
|
||||
|
||||
```powershell
|
||||
$env:OLLAMA_MLX_SOURCE="../mlx"
|
||||
$env:OLLAMA_MLX_C_SOURCE="../mlx-c"
|
||||
./scripts/build_darwin.ps1
|
||||
```
|
||||
|
||||
## Docker
|
||||
|
||||
```shell
|
||||
|
||||
@@ -181,6 +181,9 @@ func fileDigestMap(path string) (map[string]string, error) {
|
||||
}
|
||||
|
||||
if !filepath.IsLocal(rel) {
|
||||
if strings.Contains(rel, ".cache") {
|
||||
return nil, fmt.Errorf("insecure path: %s\n\nUse --local-dir <dir> when downloading model to disable caching", rel)
|
||||
}
|
||||
return nil, fmt.Errorf("insecure path: %s", rel)
|
||||
}
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ _build_darwin() {
|
||||
cmake --install $BUILD_DIR --component CPU
|
||||
cmake --install $BUILD_DIR --component MLX
|
||||
# Override CGO flags to point to the amd64 build directory
|
||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||
MLX_CGO_CFLAGS="-O3 -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-ldl -lc++ -framework Accelerate -mmacosx-version-min=14.0"
|
||||
else
|
||||
BUILD_DIR=build
|
||||
@@ -70,10 +70,10 @@ _build_darwin() {
|
||||
cmake --build --preset MLX --parallel
|
||||
cmake --install $BUILD_DIR --component MLX
|
||||
# Use default CGO flags from mlx.go for arm64
|
||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||
MLX_CGO_CFLAGS="-O3 -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
|
||||
fi
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX .
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -o $INSTALL_PREFIX .
|
||||
# Copy MLX libraries to same directory as executable for dlopen
|
||||
cp $INSTALL_PREFIX/lib/ollama/libmlxc.dylib $INSTALL_PREFIX/
|
||||
cp $INSTALL_PREFIX/lib/ollama/libmlx.dylib $INSTALL_PREFIX/
|
||||
|
||||
@@ -4,7 +4,10 @@
|
||||
#
|
||||
# gcloud auth application-default login
|
||||
|
||||
$ErrorActionPreference = "Stop"
|
||||
# Use "Continue" so that stderr output from native commands (e.g. CGo warnings)
|
||||
# is not promoted to a terminating exception by the try/catch block.
|
||||
# All native commands already check $LASTEXITCODE explicitly.
|
||||
$ErrorActionPreference = "Continue"
|
||||
|
||||
mkdir -Force -path .\dist | Out-Null
|
||||
|
||||
@@ -16,13 +19,13 @@ function checkEnv {
|
||||
if ($null -ne $arch) {
|
||||
$script:ARCH = ($arch.ToString().ToLower()).Replace("x64", "amd64")
|
||||
} else {
|
||||
write-host "WARNING: old powershell detected, assuming amd64 architecture - set `$env:ARCH to override"
|
||||
Write-Output "WARNING: old powershell detected, assuming amd64 architecture - set `$env:ARCH to override"
|
||||
$script:ARCH="amd64"
|
||||
}
|
||||
}
|
||||
$script:TARGET_ARCH=$script:ARCH
|
||||
Write-host "Building for ${script:TARGET_ARCH}"
|
||||
write-host "Locating required tools and paths"
|
||||
Write-Output "Locating required tools and paths"
|
||||
$script:SRC_DIR=$PWD
|
||||
|
||||
# Locate CUDA versions
|
||||
@@ -37,16 +40,17 @@ function checkEnv {
|
||||
$script:CUDA_DIRS=($cudaList | sort-object -Descending)
|
||||
}
|
||||
if ($script:CUDA_DIRS.length -gt 0) {
|
||||
write-host "Available CUDA Versions: $script:CUDA_DIRS"
|
||||
Write-Output "Available CUDA Versions: $script:CUDA_DIRS"
|
||||
} else {
|
||||
write-host "No CUDA versions detected"
|
||||
Write-Output "No CUDA versions detected"
|
||||
}
|
||||
|
||||
# Locate ROCm version
|
||||
if ($null -ne $env:HIP_PATH) {
|
||||
# Locate ROCm v6
|
||||
$rocmDir=(get-item "C:\Program Files\AMD\ROCm\6.*" -ea 'silentlycontinue' | sort-object -Descending | select-object -First 1)
|
||||
if ($null -ne $rocmDir) {
|
||||
$script:HIP_PATH=$rocmDir.FullName
|
||||
} elseif ($null -ne $env:HIP_PATH -and $env:HIP_PATH -match '[/\\]6\.') {
|
||||
$script:HIP_PATH=$env:HIP_PATH
|
||||
} else {
|
||||
$script:HIP_PATH=(get-item "C:\Program Files\AMD\ROCm\*\bin\" -ea 'silentlycontinue' | sort-object -Descending)
|
||||
}
|
||||
|
||||
$inoSetup=(get-item "C:\Program Files*\Inno Setup*\")
|
||||
@@ -78,7 +82,7 @@ function checkEnv {
|
||||
} else {
|
||||
$script:PKG_VERSION="0.0.0"
|
||||
}
|
||||
write-host "Building Ollama $script:VERSION with package version $script:PKG_VERSION"
|
||||
Write-Output "Building Ollama $script:VERSION with package version $script:PKG_VERSION"
|
||||
|
||||
# Note: Windows Kits 10 signtool crashes with GCP's plugin
|
||||
if ($null -eq $env:SIGN_TOOL) {
|
||||
@@ -87,12 +91,32 @@ function checkEnv {
|
||||
${script:SignTool}=${env:SIGN_TOOL}
|
||||
}
|
||||
if ("${env:KEY_CONTAINER}") {
|
||||
${script:OLLAMA_CERT}=$(resolve-path "${script:SRC_DIR}\ollama_inc.crt")
|
||||
Write-host "Code signing enabled"
|
||||
if (Test-Path "${script:SRC_DIR}\ollama_inc.crt") {
|
||||
${script:OLLAMA_CERT}=$(resolve-path "${script:SRC_DIR}\ollama_inc.crt")
|
||||
Write-host "Code signing enabled"
|
||||
} else {
|
||||
Write-Output "WARNING: KEY_CONTAINER is set but ollama_inc.crt not found at ${script:SRC_DIR}\ollama_inc.crt - code signing disabled"
|
||||
}
|
||||
} else {
|
||||
write-host "Code signing disabled - please set KEY_CONTAINERS to sign and copy ollama_inc.crt to the top of the source tree"
|
||||
Write-Output "Code signing disabled - please set KEY_CONTAINERS to sign and copy ollama_inc.crt to the top of the source tree"
|
||||
}
|
||||
$script:JOBS=([Environment]::ProcessorCount)
|
||||
if ($env:OLLAMA_BUILD_PARALLEL) {
|
||||
$script:JOBS=[int]$env:OLLAMA_BUILD_PARALLEL
|
||||
} else {
|
||||
# Use physical core count rather than logical processors (hyperthreads)
|
||||
# to avoid saturating the system during builds
|
||||
try {
|
||||
$cores = (Get-CimInstance Win32_Processor | Measure-Object -Property NumberOfCores -Sum).Sum
|
||||
} catch {
|
||||
$cores = 0
|
||||
}
|
||||
if ($cores -gt 0) {
|
||||
$script:JOBS = $cores
|
||||
} else {
|
||||
$script:JOBS = [Environment]::ProcessorCount
|
||||
}
|
||||
}
|
||||
Write-Output "Build parallelism: $script:JOBS (set OLLAMA_BUILD_PARALLEL to override)"
|
||||
}
|
||||
|
||||
|
||||
@@ -127,7 +151,7 @@ function cuda11 {
|
||||
}
|
||||
}
|
||||
}
|
||||
write-host "Building CUDA v$cudaMajorVer backend libraries $cuda"
|
||||
Write-Output "Building CUDA v$cudaMajorVer backend libraries $cuda"
|
||||
$env:CUDAToolkit_ROOT=$cuda
|
||||
& cmake -B build\cuda_v$cudaMajorVer --preset "CUDA $cudaMajorVer" -T cuda="$cuda" -DCMAKE_CUDA_COMPILER="$cuda\bin\nvcc.exe" -G "Visual Studio 16 2019" --install-prefix "$script:DIST_DIR"
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
@@ -136,12 +160,12 @@ function cuda11 {
|
||||
& cmake --install build\cuda_v$cudaMajorVer --component "CUDA" --strip
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
} else {
|
||||
write-host "CUDA v$cudaMajorVer not detected, skipping"
|
||||
Write-Output "CUDA v$cudaMajorVer not detected, skipping"
|
||||
}
|
||||
} else {
|
||||
write-host "not arch we wanted"
|
||||
Write-Output "not arch we wanted"
|
||||
}
|
||||
write-host "done"
|
||||
Write-Output "done"
|
||||
}
|
||||
|
||||
function cudaCommon {
|
||||
@@ -159,7 +183,7 @@ function cudaCommon {
|
||||
}
|
||||
}
|
||||
}
|
||||
write-host "Building CUDA v$cudaMajorVer backend libraries $cuda"
|
||||
Write-Output "Building CUDA v$cudaMajorVer backend libraries $cuda"
|
||||
$env:CUDAToolkit_ROOT=$cuda
|
||||
& cmake -B build\cuda_v$cudaMajorVer --preset "CUDA $cudaMajorVer" -T cuda="$cuda" --install-prefix "$script:DIST_DIR"
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
@@ -168,7 +192,7 @@ function cudaCommon {
|
||||
& cmake --install build\cuda_v$cudaMajorVer --component "CUDA" --strip
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
} else {
|
||||
write-host "CUDA v$cudaMajorVer not detected, skipping"
|
||||
Write-Output "CUDA v$cudaMajorVer not detected, skipping"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -181,11 +205,11 @@ function cuda13 {
|
||||
cudaCommon("13")
|
||||
}
|
||||
|
||||
function rocm {
|
||||
function rocm6 {
|
||||
mkdir -Force -path "${script:DIST_DIR}\" | Out-Null
|
||||
if ($script:ARCH -ne "arm64") {
|
||||
if ($script:HIP_PATH) {
|
||||
write-host "Building ROCm backend libraries $script:HIP_PATH"
|
||||
Write-Output "Building ROCm backend libraries $script:HIP_PATH"
|
||||
if (-Not (get-command -ErrorAction silent ninja)) {
|
||||
$NINJA_DIR=(gci -path (Get-CimInstance MSFT_VSInstance -Namespace root/cimv2/vs)[0].InstallLocation -r -fi ninja.exe).Directory.FullName
|
||||
$env:PATH="$NINJA_DIR;$env:PATH"
|
||||
@@ -193,9 +217,11 @@ function rocm {
|
||||
$env:HIPCXX="${script:HIP_PATH}\bin\clang++.exe"
|
||||
$env:HIP_PLATFORM="amd"
|
||||
$env:CMAKE_PREFIX_PATH="${script:HIP_PATH}"
|
||||
# Set CC/CXX via environment instead of -D flags to avoid triggering
|
||||
# spurious compiler-change reconfigures that reset CMAKE_INSTALL_PREFIX
|
||||
$env:CC="${script:HIP_PATH}\bin\clang.exe"
|
||||
$env:CXX="${script:HIP_PATH}\bin\clang++.exe"
|
||||
& cmake -B build\rocm --preset "ROCm 6" -G Ninja `
|
||||
-DCMAKE_C_COMPILER=clang `
|
||||
-DCMAKE_CXX_COMPILER=clang++ `
|
||||
-DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" `
|
||||
-DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" `
|
||||
--install-prefix $script:DIST_DIR
|
||||
@@ -203,20 +229,22 @@ function rocm {
|
||||
$env:HIPCXX=""
|
||||
$env:HIP_PLATFORM=""
|
||||
$env:CMAKE_PREFIX_PATH=""
|
||||
$env:CC=""
|
||||
$env:CXX=""
|
||||
& cmake --build build\rocm --target ggml-hip --config Release --parallel $script:JOBS
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --install build\rocm --component "HIP" --strip
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
Remove-Item -Path $script:DIST_DIR\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
|
||||
} else {
|
||||
write-host "ROCm not detected, skipping"
|
||||
Write-Output "ROCm not detected, skipping"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function vulkan {
|
||||
if ($env:VULKAN_SDK) {
|
||||
write-host "Building Vulkan backend libraries"
|
||||
Write-Output "Building Vulkan backend libraries"
|
||||
& cmake -B build\vulkan --preset Vulkan --install-prefix $script:DIST_DIR
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --build build\vulkan --target ggml-vulkan --config Release --parallel $script:JOBS
|
||||
@@ -224,33 +252,91 @@ function vulkan {
|
||||
& cmake --install build\vulkan --component Vulkan --strip
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
} else {
|
||||
write-host "Vulkan not detected, skipping"
|
||||
Write-Output "Vulkan not detected, skipping"
|
||||
}
|
||||
}
|
||||
|
||||
function mlxCuda13 {
|
||||
mkdir -Force -path "${script:DIST_DIR}\" | Out-Null
|
||||
$cudaMajorVer="13"
|
||||
if ($script:ARCH -ne "arm64") {
|
||||
if ("$script:CUDA_DIRS".Contains("v$cudaMajorVer")) {
|
||||
foreach ($d in $Script:CUDA_DIRS){
|
||||
if ($d.FullName.Contains("v$cudaMajorVer")) {
|
||||
if (test-path -literalpath (join-path -path $d -childpath "nvcc.exe" ) ) {
|
||||
$cuda=($d.FullName|split-path -parent)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Check for cuDNN - required for MLX CUDA backend
|
||||
# Supports two layouts:
|
||||
# 1. CI/zip extract: CUDNN\include\cudnn.h, lib\x64\, bin\x64\
|
||||
# 2. Official installer: CUDNN\v*\include\{cuda-ver}\cudnn.h, lib\{cuda-ver}\x64\, bin\{cuda-ver}\
|
||||
if ($env:CUDNN_INCLUDE_PATH -and $env:CUDNN_LIBRARY_PATH) {
|
||||
Write-Output "Using cuDNN from environment: $env:CUDNN_INCLUDE_PATH"
|
||||
} elseif (Test-Path "C:\Program Files\NVIDIA\CUDNN\include\cudnn.h") {
|
||||
# CI/zip layout (flat)
|
||||
$cudnnRoot = "C:\Program Files\NVIDIA\CUDNN"
|
||||
$env:CUDNN_ROOT_DIR = $cudnnRoot
|
||||
$env:CUDNN_INCLUDE_PATH = "$cudnnRoot\include"
|
||||
$env:CUDNN_LIBRARY_PATH = "$cudnnRoot\lib\x64"
|
||||
Write-Output "Found cuDNN at $cudnnRoot (flat layout)"
|
||||
} else {
|
||||
# Official installer layout (versioned)
|
||||
$cudnnRoot = $null
|
||||
$resolved = Resolve-Path -Path "C:\Program Files\NVIDIA\CUDNN\v*" -ErrorAction SilentlyContinue | Sort-Object -Descending | Select-Object -First 1
|
||||
if ($resolved -and (Test-Path "$($resolved.Path)\include\$cudaMajorVer.0\cudnn.h")) {
|
||||
$cudnnRoot = $resolved.Path
|
||||
$env:CUDNN_ROOT_DIR = $cudnnRoot
|
||||
$env:CUDNN_INCLUDE_PATH = "$cudnnRoot\include\$cudaMajorVer.0"
|
||||
$env:CUDNN_LIBRARY_PATH = "$cudnnRoot\lib\$cudaMajorVer.0\x64"
|
||||
Write-Output "Found cuDNN at $cudnnRoot (official installer, CUDA $cudaMajorVer.0)"
|
||||
} else {
|
||||
Write-Output "cuDNN not found - set CUDNN_INCLUDE_PATH and CUDNN_LIBRARY_PATH environment variables"
|
||||
Write-Output "Skipping MLX build"
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
Write-Output "Building MLX CUDA v$cudaMajorVer backend libraries $cuda"
|
||||
$env:CUDAToolkit_ROOT=$cuda
|
||||
& cmake -B build\mlx_cuda_v$cudaMajorVer --preset "MLX CUDA $cudaMajorVer" -T cuda="$cuda" --install-prefix "$script:DIST_DIR"
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --build build\mlx_cuda_v$cudaMajorVer --target mlx --target mlxc --config Release --parallel $script:JOBS -- /nodeReuse:false
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --install build\mlx_cuda_v$cudaMajorVer --component "MLX" --strip
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
} else {
|
||||
Write-Output "CUDA v$cudaMajorVer not detected, skipping MLX build"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function ollama {
|
||||
mkdir -Force -path "${script:DIST_DIR}\" | Out-Null
|
||||
write-host "Building ollama CLI"
|
||||
Write-Output "Building ollama CLI"
|
||||
& go build -trimpath -ldflags "-s -w -X=github.com/ollama/ollama/version.Version=$script:VERSION -X=github.com/ollama/ollama/server.mode=release" .
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
cp .\ollama.exe "${script:DIST_DIR}\"
|
||||
}
|
||||
|
||||
function app {
|
||||
write-host "Building Ollama App $script:VERSION with package version $script:PKG_VERSION"
|
||||
Write-Output "Building Ollama App $script:VERSION with package version $script:PKG_VERSION"
|
||||
|
||||
if (!(Get-Command npm -ErrorAction SilentlyContinue)) {
|
||||
write-host "npm is not installed. Please install Node.js and npm first:"
|
||||
write-host " Visit: https://nodejs.org/"
|
||||
Write-Output "npm is not installed. Please install Node.js and npm first:"
|
||||
Write-Output " Visit: https://nodejs.org/"
|
||||
exit 1
|
||||
}
|
||||
|
||||
if (!(Get-Command tsc -ErrorAction SilentlyContinue)) {
|
||||
write-host "Installing TypeScript compiler..."
|
||||
Write-Output "Installing TypeScript compiler..."
|
||||
npm install -g typescript
|
||||
}
|
||||
if (!(Get-Command tscriptify -ErrorAction SilentlyContinue)) {
|
||||
write-host "Installing tscriptify..."
|
||||
Write-Output "Installing tscriptify..."
|
||||
go install github.com/tkrajina/typescriptify-golang-structs/tscriptify@latest
|
||||
}
|
||||
if (!(Get-Command tscriptify -ErrorAction SilentlyContinue)) {
|
||||
@@ -260,32 +346,32 @@ function app {
|
||||
Push-Location app/ui/app
|
||||
npm install
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
write-host "ERROR: npm install failed with exit code $LASTEXITCODE"
|
||||
Write-Output "ERROR: npm install failed with exit code $LASTEXITCODE"
|
||||
exit $LASTEXITCODE
|
||||
}
|
||||
|
||||
write-host "Building React application..."
|
||||
Write-Output "Building React application..."
|
||||
npm run build
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
write-host "ERROR: npm run build failed with exit code $LASTEXITCODE"
|
||||
Write-Output "ERROR: npm run build failed with exit code $LASTEXITCODE"
|
||||
exit $LASTEXITCODE
|
||||
}
|
||||
|
||||
# Check if dist directory exists and has content
|
||||
if (!(Test-Path "dist")) {
|
||||
write-host "ERROR: dist directory was not created by npm run build"
|
||||
Write-Output "ERROR: dist directory was not created by npm run build"
|
||||
exit 1
|
||||
}
|
||||
|
||||
$distFiles = Get-ChildItem "dist" -Recurse
|
||||
if ($distFiles.Count -eq 0) {
|
||||
write-host "ERROR: dist directory is empty after npm run build"
|
||||
Write-Output "ERROR: dist directory is empty after npm run build"
|
||||
exit 1
|
||||
}
|
||||
|
||||
Pop-Location
|
||||
|
||||
write-host "Running go generate"
|
||||
Write-Output "Running go generate"
|
||||
& go generate ./...
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& go build -trimpath -ldflags "-s -w -H windowsgui -X=github.com/ollama/ollama/app/version.Version=$script:VERSION" -o .\dist\windows-ollama-app-${script:ARCH}.exe ./app/cmd/app/
|
||||
@@ -293,42 +379,42 @@ function app {
|
||||
}
|
||||
|
||||
function deps {
|
||||
write-host "Download MSVC Redistributables"
|
||||
Write-Output "Download MSVC Redistributables"
|
||||
mkdir -Force -path "${script:SRC_DIR}\dist\\windows-arm64" | Out-Null
|
||||
mkdir -Force -path "${script:SRC_DIR}\dist\\windows-amd64" | Out-Null
|
||||
invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.arm64.exe" -OutFile "${script:SRC_DIR}\dist\windows-arm64\vc_redist.arm64.exe"
|
||||
invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.x64.exe" -OutFile "${script:SRC_DIR}\dist\windows-amd64\vc_redist.x64.exe"
|
||||
write-host "Done."
|
||||
invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.arm64.exe" -OutFile "${script:SRC_DIR}\dist\windows-arm64\vc_redist.arm64.exe" -ErrorAction Stop
|
||||
invoke-webrequest -Uri "https://aka.ms/vs/17/release/vc_redist.x64.exe" -OutFile "${script:SRC_DIR}\dist\windows-amd64\vc_redist.x64.exe" -ErrorAction Stop
|
||||
Write-Output "Done."
|
||||
}
|
||||
|
||||
function sign {
|
||||
# Copy install.ps1 to dist for release packaging
|
||||
write-host "Copying install.ps1 to dist"
|
||||
Copy-Item -Path "${script:SRC_DIR}\scripts\install.ps1" -Destination "${script:SRC_DIR}\dist\install.ps1"
|
||||
Write-Output "Copying install.ps1 to dist"
|
||||
Copy-Item -Path "${script:SRC_DIR}\scripts\install.ps1" -Destination "${script:SRC_DIR}\dist\install.ps1" -ErrorAction Stop
|
||||
|
||||
if ("${env:KEY_CONTAINER}") {
|
||||
write-host "Signing Ollama executables, scripts and libraries"
|
||||
Write-Output "Signing Ollama executables, scripts and libraries"
|
||||
& "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
|
||||
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} `
|
||||
$(get-childitem -path "${script:SRC_DIR}\dist\windows-*" -r -include @('*.exe', '*.dll'))
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
|
||||
write-host "Signing install.ps1"
|
||||
Write-Output "Signing install.ps1"
|
||||
& "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
|
||||
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} `
|
||||
"${script:SRC_DIR}\dist\install.ps1"
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
} else {
|
||||
write-host "Signing not enabled"
|
||||
Write-Output "Signing not enabled"
|
||||
}
|
||||
}
|
||||
|
||||
function installer {
|
||||
if ($null -eq ${script:INNO_SETUP_DIR}) {
|
||||
write-host "ERROR: missing Inno Setup installation directory - install from https://jrsoftware.org/isdl.php"
|
||||
Write-Output "ERROR: missing Inno Setup installation directory - install from https://jrsoftware.org/isdl.php"
|
||||
exit 1
|
||||
}
|
||||
write-host "Building Ollama Installer"
|
||||
Write-Output "Building Ollama Installer"
|
||||
cd "${script:SRC_DIR}\app"
|
||||
$env:PKG_VERSION=$script:PKG_VERSION
|
||||
if ("${env:KEY_CONTAINER}") {
|
||||
@@ -342,24 +428,24 @@ function installer {
|
||||
function zip {
|
||||
if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64") {
|
||||
if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm") {
|
||||
write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64-rocm.zip"
|
||||
Write-Output "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64-rocm.zip"
|
||||
# Temporarily adjust paths so we can retain the same directory structure
|
||||
Remove-Item -ea 0 -r "${script:SRC_DIR}\dist\windows-amd64-rocm"
|
||||
mkdir -Force -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama"
|
||||
Write-Output "Extract this ROCm zip file to the same location where you extracted ollama-windows-amd64.zip" > "${script:SRC_DIR}\dist\windows-amd64-rocm\README.txt"
|
||||
Move-Item -path "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -destination "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama"
|
||||
Move-Item -path "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -destination "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama" -ErrorAction Stop
|
||||
Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-amd64-rocm\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64-rocm.zip" -Force
|
||||
}
|
||||
|
||||
write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64.zip"
|
||||
Write-Output "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64.zip"
|
||||
Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-amd64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64.zip" -Force
|
||||
if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64-rocm") {
|
||||
Move-Item -destination "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama\rocm"
|
||||
Move-Item -destination "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama\rocm" -ErrorAction Stop
|
||||
}
|
||||
}
|
||||
|
||||
if (Test-Path -Path "${script:SRC_DIR}\dist\windows-arm64") {
|
||||
write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-arm64.zip"
|
||||
Write-Output "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-arm64.zip"
|
||||
Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-arm64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-arm64.zip" -Force
|
||||
}
|
||||
}
|
||||
@@ -375,8 +461,9 @@ try {
|
||||
cpu
|
||||
cuda12
|
||||
cuda13
|
||||
rocm
|
||||
rocm6
|
||||
vulkan
|
||||
mlxCuda13
|
||||
ollama
|
||||
app
|
||||
deps
|
||||
@@ -385,13 +472,13 @@ try {
|
||||
zip
|
||||
} else {
|
||||
for ( $i = 0; $i -lt $args.count; $i++ ) {
|
||||
write-host "running build step $($args[$i])"
|
||||
Write-Output "running build step $($args[$i])"
|
||||
& $($args[$i])
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
write-host "Build Failed"
|
||||
write-host $_
|
||||
Write-Error "Build Failed: $($_.Exception.Message)"
|
||||
Write-Error "$($_.ScriptStackTrace)"
|
||||
} finally {
|
||||
set-location $script:SRC_DIR
|
||||
$env:PKG_VERSION=""
|
||||
|
||||
@@ -18,6 +18,14 @@ OLLAMA_COMMON_BUILD_ARGS="--build-arg=VERSION \
|
||||
--build-arg=GPU_RUNNER_CPU_FLAGS \
|
||||
--build-arg=AMDGPU_TARGETS"
|
||||
|
||||
# Forward local MLX source overrides as Docker build contexts
|
||||
if [ -n "${OLLAMA_MLX_SOURCE:-}" ]; then
|
||||
OLLAMA_COMMON_BUILD_ARGS="$OLLAMA_COMMON_BUILD_ARGS --build-context local-mlx=$(cd "$OLLAMA_MLX_SOURCE" && pwd)"
|
||||
fi
|
||||
if [ -n "${OLLAMA_MLX_C_SOURCE:-}" ]; then
|
||||
OLLAMA_COMMON_BUILD_ARGS="$OLLAMA_COMMON_BUILD_ARGS --build-context local-mlx-c=$(cd "$OLLAMA_MLX_C_SOURCE" && pwd)"
|
||||
fi
|
||||
|
||||
echo "Building Ollama"
|
||||
echo "VERSION=$VERSION"
|
||||
echo "PLATFORM=$PLATFORM"
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
@@ -194,9 +192,10 @@ func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
|
||||
return blobData, nil
|
||||
}
|
||||
|
||||
// QuantizeSupported returns true if quantization is supported (MLX build)
|
||||
// QuantizeSupported returns true if quantization is supported (MLX library available)
|
||||
func QuantizeSupported() bool {
|
||||
return true
|
||||
mlx.InitMLX()
|
||||
return mlx.IsMLXAvailable()
|
||||
}
|
||||
|
||||
// ensureTempDir creates the temp directory for quantization if it doesn't exist
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
//go:build !mlx
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/ollama/ollama/x/create"
|
||||
)
|
||||
|
||||
// quantizeTensor is not available without MLX
|
||||
func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) {
|
||||
return nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
|
||||
}
|
||||
|
||||
// quantizePackedGroup is not available without MLX
|
||||
func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
|
||||
return nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
|
||||
}
|
||||
|
||||
// QuantizeSupported returns false when MLX is not available
|
||||
func QuantizeSupported() bool {
|
||||
return false
|
||||
}
|
||||
2
x/imagegen/cache/cache.go
vendored
2
x/imagegen/cache/cache.go
vendored
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package cache
|
||||
|
||||
import "github.com/ollama/ollama/x/imagegen/mlx"
|
||||
|
||||
2
x/imagegen/cache/step.go
vendored
2
x/imagegen/cache/step.go
vendored
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package cache
|
||||
|
||||
import "github.com/ollama/ollama/x/imagegen/mlx"
|
||||
|
||||
2
x/imagegen/cache/teacache.go
vendored
2
x/imagegen/cache/teacache.go
vendored
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package cache provides caching mechanisms for diffusion model inference.
|
||||
package cache
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ Experimental MLX backend for running models on Apple Silicon and CUDA.
|
||||
## Build
|
||||
|
||||
```bash
|
||||
go build -tags mlx -o engine ./x/imagegen/cmd/engine
|
||||
go build -o engine ./x/imagegen/cmd/engine
|
||||
```
|
||||
|
||||
## Text Generation
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package main
|
||||
|
||||
import "github.com/ollama/ollama/x/imagegen/mlx"
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package manifest
|
||||
|
||||
import (
|
||||
|
||||
@@ -4,6 +4,10 @@ include(FetchContent)
|
||||
file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_C_GIT_TAG)
|
||||
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
|
||||
|
||||
# Read MLX core version from top-level file
|
||||
file(READ "${CMAKE_SOURCE_DIR}/MLX_CORE_VERSION" MLX_GIT_TAG)
|
||||
string(STRIP "${MLX_GIT_TAG}" MLX_GIT_TAG)
|
||||
|
||||
set(MLX_C_BUILD_EXAMPLES OFF)
|
||||
|
||||
set(MLX_BUILD_GGUF OFF)
|
||||
@@ -43,6 +47,17 @@ if(NOT MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_ARCHITECTURES)
|
||||
message(STATUS "Using CMAKE_CUDA_ARCHITECTURES for MLX: ${MLX_CUDA_ARCHITECTURES}")
|
||||
endif()
|
||||
|
||||
# Forward cuDNN environment variables to cmake variables so MLX's FindCUDNN.cmake
|
||||
# can find them via HINTS ${CUDNN_INCLUDE_PATH} / ${CUDNN_LIBRARY_PATH}.
|
||||
if(DEFINED ENV{CUDNN_INCLUDE_PATH} AND NOT CUDNN_INCLUDE_PATH)
|
||||
set(CUDNN_INCLUDE_PATH "$ENV{CUDNN_INCLUDE_PATH}" CACHE PATH "cuDNN include path")
|
||||
message(STATUS "Using CUDNN_INCLUDE_PATH from environment: ${CUDNN_INCLUDE_PATH}")
|
||||
endif()
|
||||
if(DEFINED ENV{CUDNN_LIBRARY_PATH} AND NOT CUDNN_LIBRARY_PATH)
|
||||
set(CUDNN_LIBRARY_PATH "$ENV{CUDNN_LIBRARY_PATH}" CACHE PATH "cuDNN library path")
|
||||
message(STATUS "Using CUDNN_LIBRARY_PATH from environment: ${CUDNN_LIBRARY_PATH}")
|
||||
endif()
|
||||
|
||||
# Enable CUDA backend if CUDA architectures are specified and CUDA compiler is available
|
||||
if(MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_COMPILER)
|
||||
set(MLX_BUILD_CUDA ON CACHE BOOL "Build CUDA backend for MLX" FORCE)
|
||||
@@ -51,11 +66,58 @@ elseif(MLX_CUDA_ARCHITECTURES)
|
||||
message(WARNING "MLX_CUDA_ARCHITECTURES specified but CUDA compiler not found, CUDA backend will be disabled")
|
||||
endif()
|
||||
|
||||
# Allow local source overrides via environment variables
|
||||
# Resolve to absolute paths so FetchContent doesn't break on relative dirs.
|
||||
if(DEFINED ENV{OLLAMA_MLX_SOURCE})
|
||||
get_filename_component(_mlx_src "$ENV{OLLAMA_MLX_SOURCE}" ABSOLUTE BASE_DIR ${CMAKE_SOURCE_DIR})
|
||||
set(FETCHCONTENT_SOURCE_DIR_MLX "${_mlx_src}" CACHE PATH "" FORCE)
|
||||
message(STATUS "Using local MLX source: ${_mlx_src}")
|
||||
endif()
|
||||
if(DEFINED ENV{OLLAMA_MLX_C_SOURCE})
|
||||
get_filename_component(_mlx_c_src "$ENV{OLLAMA_MLX_C_SOURCE}" ABSOLUTE BASE_DIR ${CMAKE_SOURCE_DIR})
|
||||
set(FETCHCONTENT_SOURCE_DIR_MLX-C "${_mlx_c_src}" CACHE PATH "" FORCE)
|
||||
message(STATUS "Using local MLX-C source: ${_mlx_c_src}")
|
||||
endif()
|
||||
|
||||
# Pre-declare mlx so our pinned version takes precedence over the one
|
||||
# hardcoded in mlx-c's CMakeLists.txt (first FetchContent_Declare wins).
|
||||
FetchContent_Declare(
|
||||
mlx-c
|
||||
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
|
||||
GIT_TAG ${MLX_C_GIT_TAG})
|
||||
mlx
|
||||
GIT_REPOSITORY "https://github.com/ml-explore/mlx.git"
|
||||
GIT_TAG ${MLX_GIT_TAG}
|
||||
)
|
||||
|
||||
FetchContent_Declare(
|
||||
mlx-c
|
||||
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
|
||||
GIT_TAG ${MLX_C_GIT_TAG}
|
||||
)
|
||||
FetchContent_MakeAvailable(mlx-c)
|
||||
|
||||
# Sync vendored headers with fetched version
|
||||
file(GLOB _mlx_c_hdrs "${mlx-c_SOURCE_DIR}/mlx/c/*.h")
|
||||
file(COPY ${_mlx_c_hdrs} DESTINATION "${CMAKE_SOURCE_DIR}/x/mlxrunner/mlx/include/mlx/c/")
|
||||
|
||||
# For local dev builds, override MLX_VERSION with git describe output
|
||||
if(TARGET mlx_version AND DEFINED FETCHCONTENT_SOURCE_DIR_MLX)
|
||||
execute_process(
|
||||
COMMAND git describe --tags --first-parent --abbrev=7 --long --dirty --always
|
||||
WORKING_DIRECTORY ${mlx_SOURCE_DIR}
|
||||
OUTPUT_VARIABLE _mlx_git_version
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
ERROR_QUIET
|
||||
RESULT_VARIABLE _mlx_git_result
|
||||
)
|
||||
if(_mlx_git_result EQUAL 0 AND _mlx_git_version)
|
||||
# Strip leading "v" prefix for consistency
|
||||
string(REGEX REPLACE "^v" "" _mlx_git_version "${_mlx_git_version}")
|
||||
get_target_property(_mlx_defs mlx_version COMPILE_DEFINITIONS)
|
||||
list(FILTER _mlx_defs EXCLUDE REGEX "^MLX_VERSION=")
|
||||
set_target_properties(mlx_version PROPERTIES COMPILE_DEFINITIONS "${_mlx_defs}")
|
||||
target_compile_definitions(mlx_version PRIVATE "MLX_VERSION=\"${_mlx_git_version}\"")
|
||||
message(STATUS "MLX version (local dev): ${_mlx_git_version}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set_target_output_directory(mlx)
|
||||
set_target_output_directory(mlxc)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package mlx provides Go bindings for the MLX-C library with dynamic loading support.
|
||||
//
|
||||
//go:generate go run generate_wrappers.go ../../../build/_deps/mlx-c-src/mlx/c mlx.h mlx.c
|
||||
//go:generate go run generate_wrappers.go ../../mlxrunner/mlx/include/mlx/c mlx.h mlx.c
|
||||
package mlx
|
||||
|
||||
@@ -291,8 +291,15 @@ func generateWrapperFiles(functions []Function, headerPath, implPath string) err
|
||||
|
||||
implBuf.WriteString("#include \"mlx/c/mlx.h\"\n")
|
||||
implBuf.WriteString("#include \"mlx_dynamic.h\"\n")
|
||||
implBuf.WriteString("#include <stdio.h>\n")
|
||||
implBuf.WriteString("#include <dlfcn.h>\n\n")
|
||||
implBuf.WriteString("#include <stdio.h>\n\n")
|
||||
implBuf.WriteString("// Platform-specific dynamic loading\n")
|
||||
implBuf.WriteString("#ifdef _WIN32\n")
|
||||
implBuf.WriteString("#include <windows.h>\n")
|
||||
implBuf.WriteString("#define GET_SYM(handle, name) (void*)GetProcAddress((HMODULE)(handle), name)\n")
|
||||
implBuf.WriteString("#else\n")
|
||||
implBuf.WriteString("#include <dlfcn.h>\n")
|
||||
implBuf.WriteString("#define GET_SYM(handle, name) dlsym(handle, name)\n")
|
||||
implBuf.WriteString("#endif\n\n")
|
||||
|
||||
// Function pointer definitions
|
||||
implBuf.WriteString("// Function pointer definitions\n")
|
||||
@@ -308,7 +315,7 @@ func generateWrapperFiles(functions []Function, headerPath, implPath string) err
|
||||
implBuf.WriteString("\n")
|
||||
|
||||
// Initialization function
|
||||
implBuf.WriteString("// Initialize all function pointers via dlsym\n")
|
||||
implBuf.WriteString("// Initialize all function pointers\n")
|
||||
implBuf.WriteString("int mlx_load_functions(void* handle) {\n")
|
||||
implBuf.WriteString(" if (handle == NULL) {\n")
|
||||
implBuf.WriteString(" fprintf(stderr, \"MLX: Invalid library handle\\n\");\n")
|
||||
@@ -319,7 +326,7 @@ func generateWrapperFiles(functions []Function, headerPath, implPath string) err
|
||||
if fn.NeedsARM64Guard {
|
||||
implBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
|
||||
}
|
||||
implBuf.WriteString(fmt.Sprintf(" %s_ptr = dlsym(handle, \"%s\");\n", fn.Name, fn.Name))
|
||||
implBuf.WriteString(fmt.Sprintf(" %s_ptr = GET_SYM(handle, \"%s\");\n", fn.Name, fn.Name))
|
||||
implBuf.WriteString(fmt.Sprintf(" if (%s_ptr == NULL) {\n", fn.Name))
|
||||
implBuf.WriteString(fmt.Sprintf(" fprintf(stderr, \"MLX: Failed to load symbol: %s\\n\");\n", fn.Name))
|
||||
implBuf.WriteString(" return -1;\n")
|
||||
|
||||
1182
x/imagegen/mlx/mlx.c
1182
x/imagegen/mlx/mlx.c
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,7 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -O3 -I${SRCDIR}/../../../build/_deps/mlx-c-src -I${SRCDIR}
|
||||
#cgo CFLAGS: -O3 -I${SRCDIR}/../../mlxrunner/mlx/include -I${SRCDIR}
|
||||
#cgo darwin LDFLAGS: -lc++ -framework Metal -framework Foundation -framework Accelerate
|
||||
#cgo linux LDFLAGS: -lstdc++ -ldl
|
||||
#cgo windows LDFLAGS: -lstdc++
|
||||
@@ -32,7 +30,7 @@ static inline void set_default_stream(mlx_stream s) {
|
||||
_default_stream = s;
|
||||
}
|
||||
|
||||
// CPU stream for file loading (Load primitive only runs on CPU)
|
||||
// CPU stream for operations that only support CPU evaluation
|
||||
static inline mlx_stream cpu_stream() {
|
||||
if (_cpu_stream.ctx == NULL) {
|
||||
_cpu_stream = mlx_default_cpu_stream_new();
|
||||
@@ -45,8 +43,11 @@ static inline mlx_stream cpu_stream() {
|
||||
// nocallback: function won't call back into Go
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"sync"
|
||||
@@ -1502,15 +1503,21 @@ type SafetensorsFile struct {
|
||||
metadata C.mlx_map_string_to_string
|
||||
}
|
||||
|
||||
// LoadSafetensorsNative loads a safetensors file using MLX's optimized loader
|
||||
// Note: Uses CPU stream because Load primitive only runs on CPU
|
||||
// LoadSafetensorsNative loads a safetensors file using MLX's optimized loader.
|
||||
// On CUDA, Load::eval_gpu is implemented so we use the default (GPU) stream.
|
||||
// On Metal, Load::eval_gpu is not implemented so we must use the CPU stream.
|
||||
func LoadSafetensorsNative(path string) (*SafetensorsFile, error) {
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
stream := C.default_stream()
|
||||
if runtime.GOOS == "darwin" {
|
||||
stream = C.cpu_stream()
|
||||
}
|
||||
|
||||
var arrays C.mlx_map_string_to_array
|
||||
var metadata C.mlx_map_string_to_string
|
||||
if C.mlx_load_safetensors(&arrays, &metadata, cPath, C.cpu_stream()) != 0 {
|
||||
if C.mlx_load_safetensors(&arrays, &metadata, cPath, stream) != 0 {
|
||||
return nil, fmt.Errorf("failed to load safetensors: %s", path)
|
||||
}
|
||||
return &SafetensorsFile{arrays: arrays, metadata: metadata}, nil
|
||||
@@ -1689,11 +1696,91 @@ func ArgmaxKeepArray(logits *Array) *Array {
|
||||
//
|
||||
// Thread safety: Protected by randomStateMu, mimicking Python's GIL behavior.
|
||||
// All random functions that use global state acquire this lock.
|
||||
var RandomState = []*Array{nil}
|
||||
var randomStateMu sync.Mutex
|
||||
var (
|
||||
RandomState = []*Array{nil}
|
||||
randomStateMu sync.Mutex
|
||||
)
|
||||
|
||||
var mlxInitialized bool
|
||||
var mlxInitError error
|
||||
var (
|
||||
mlxInitialized bool
|
||||
mlxInitError error
|
||||
)
|
||||
|
||||
// mlxLibName returns the platform-specific shared library filename.
|
||||
func mlxLibName() string {
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
return "mlxc.dll"
|
||||
case "darwin":
|
||||
return "libmlxc.dylib"
|
||||
default:
|
||||
return "libmlxc.so"
|
||||
}
|
||||
}
|
||||
|
||||
// findMLXLibrary searches for the MLX shared library in standard locations.
|
||||
// Returns the path to the library, or empty string if not found.
|
||||
func findMLXLibrary() string {
|
||||
libName := mlxLibName()
|
||||
|
||||
// 1. OLLAMA_LIBRARY_PATH — check each dir and mlx_* subdirs
|
||||
if paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH"); ok {
|
||||
for _, dir := range filepath.SplitList(paths) {
|
||||
candidate := filepath.Join(dir, libName)
|
||||
if _, err := os.Stat(candidate); err == nil {
|
||||
return candidate
|
||||
}
|
||||
if mlxDirs, err := filepath.Glob(filepath.Join(dir, "mlx*")); err == nil {
|
||||
for _, mlxDir := range mlxDirs {
|
||||
candidate = filepath.Join(mlxDir, libName)
|
||||
if _, err := os.Stat(candidate); err == nil {
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Executable directory and lib/ollama/mlx* subdirs
|
||||
if exe, err := os.Executable(); err == nil {
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
exeDir := filepath.Dir(exe)
|
||||
|
||||
// Check exe dir directly (macOS copies dylib here)
|
||||
candidate := filepath.Join(exeDir, libName)
|
||||
if _, err := os.Stat(candidate); err == nil {
|
||||
return candidate
|
||||
}
|
||||
|
||||
// Check exe_dir/lib/ollama/mlx* subdirectories
|
||||
// and exe_dir/../lib/ollama/mlx* (standard bin/lib sibling layout)
|
||||
for _, libOllamaDir := range []string{
|
||||
filepath.Join(exeDir, "lib", "ollama"),
|
||||
filepath.Join(exeDir, "..", "lib", "ollama"),
|
||||
} {
|
||||
if mlxDirs, err := filepath.Glob(filepath.Join(libOllamaDir, "mlx*")); err == nil {
|
||||
for _, mlxDir := range mlxDirs {
|
||||
candidate = filepath.Join(mlxDir, libName)
|
||||
if _, err := os.Stat(candidate); err == nil {
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Build directory (for tests run from repo root)
|
||||
if cwd, err := os.Getwd(); err == nil {
|
||||
candidate := filepath.Join(cwd, "build", "lib", "ollama", libName)
|
||||
if _, err := os.Stat(candidate); err == nil {
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// InitMLX initializes the MLX library by dynamically loading libmlxc.
|
||||
// This must be called before using any MLX functions.
|
||||
@@ -1703,9 +1790,16 @@ func InitMLX() error {
|
||||
return mlxInitError
|
||||
}
|
||||
|
||||
// Try to load the MLX dynamic library
|
||||
ret := C.mlx_dynamic_init()
|
||||
if ret != 0 {
|
||||
// Search for the library using Go path discovery
|
||||
libPath := findMLXLibrary()
|
||||
if libPath == "" {
|
||||
mlxInitError = fmt.Errorf("failed to initialize MLX: %s not found", mlxLibName())
|
||||
return mlxInitError
|
||||
}
|
||||
|
||||
cPath := C.CString(libPath)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
if C.mlx_dynamic_init_path(cPath) != 0 {
|
||||
errMsg := C.GoString(C.mlx_dynamic_error())
|
||||
mlxInitError = fmt.Errorf("failed to initialize MLX: %s", errMsg)
|
||||
return mlxInitError
|
||||
@@ -1713,8 +1807,7 @@ func InitMLX() error {
|
||||
|
||||
// Initialize all function pointers via dlsym
|
||||
handle := C.mlx_get_handle()
|
||||
ret = C.mlx_load_functions(handle)
|
||||
if ret != 0 {
|
||||
if C.mlx_load_functions(handle) != 0 {
|
||||
mlxInitError = fmt.Errorf("failed to load MLX function symbols")
|
||||
return mlxInitError
|
||||
}
|
||||
|
||||
@@ -9,114 +9,76 @@
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
typedef HMODULE lib_handle_t;
|
||||
#define LOAD_LIB(path) LoadLibraryA(path)
|
||||
#define GET_SYMBOL(handle, name) GetProcAddress(handle, name)
|
||||
#define CLOSE_LIB(handle) FreeLibrary(handle)
|
||||
#define LIB_ERROR() "LoadLibrary failed"
|
||||
static char win_error_buffer[256] = {0};
|
||||
static const char* get_win_error(void) {
|
||||
DWORD err = GetLastError();
|
||||
snprintf(win_error_buffer, sizeof(win_error_buffer), "error code %lu", err);
|
||||
return win_error_buffer;
|
||||
}
|
||||
#define LIB_ERROR() get_win_error()
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
typedef void* lib_handle_t;
|
||||
#define LOAD_LIB(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
|
||||
#define GET_SYMBOL(handle, name) dlsym(handle, name)
|
||||
#define CLOSE_LIB(handle) dlclose(handle)
|
||||
#define LIB_ERROR() dlerror()
|
||||
#ifdef __APPLE__
|
||||
#include <mach-o/dyld.h>
|
||||
#include <libgen.h>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
static lib_handle_t mlx_handle = NULL;
|
||||
static int mlx_initialized = 0;
|
||||
static char mlx_error_buffer[512] = {0};
|
||||
|
||||
#ifdef __APPLE__
|
||||
// Get path to library in same directory as executable
|
||||
static char* get_exe_relative_path(const char* libname) {
|
||||
static char path[1024];
|
||||
uint32_t size = sizeof(path);
|
||||
if (_NSGetExecutablePath(path, &size) != 0) {
|
||||
return NULL;
|
||||
#ifdef _WIN32
|
||||
// Windows: Load library from a path with dependency resolution.
|
||||
// Temporarily adds the library's directory to the DLL search path
|
||||
// so that dependencies (like mlx.dll) in the same directory are found.
|
||||
static int try_load_win(const char* path) {
|
||||
if (!path) return 0;
|
||||
|
||||
// Extract directory and add to DLL search path for dependency resolution
|
||||
char dir_path[MAX_PATH];
|
||||
strncpy(dir_path, path, MAX_PATH - 1);
|
||||
dir_path[MAX_PATH - 1] = '\0';
|
||||
char* last_slash = strrchr(dir_path, '\\');
|
||||
if (!last_slash) last_slash = strrchr(dir_path, '/');
|
||||
if (last_slash) {
|
||||
*last_slash = '\0';
|
||||
SetDllDirectoryA(dir_path);
|
||||
}
|
||||
// Get directory of executable
|
||||
char* dir = dirname(path);
|
||||
static char fullpath[1024];
|
||||
snprintf(fullpath, sizeof(fullpath), "%s/%s", dir, libname);
|
||||
return fullpath;
|
||||
|
||||
mlx_handle = LoadLibraryA(path);
|
||||
SetDllDirectoryA(NULL);
|
||||
return mlx_handle != NULL;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Try to load library from a specific path
|
||||
static int try_load_lib(const char* path) {
|
||||
if (!path) return 0;
|
||||
mlx_handle = LOAD_LIB(path);
|
||||
#ifdef _WIN32
|
||||
return try_load_win(path);
|
||||
#else
|
||||
mlx_handle = dlopen(path, RTLD_LAZY | RTLD_GLOBAL);
|
||||
return mlx_handle != NULL;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Initialize MLX dynamic library
|
||||
// Returns 0 on success, -1 on failure
|
||||
// On failure, call mlx_dynamic_error() to get error message
|
||||
int mlx_dynamic_init(void) {
|
||||
// Initialize the MLX dynamic library from a specific path.
|
||||
// Returns 0 on success, -1 on failure.
|
||||
int mlx_dynamic_init_path(const char* path) {
|
||||
if (mlx_initialized) {
|
||||
return 0; // Already initialized
|
||||
return 0;
|
||||
}
|
||||
|
||||
const char* lib_path = NULL;
|
||||
const char* tried_paths[8] = {0};
|
||||
int num_tried = 0;
|
||||
|
||||
#ifdef _WIN32
|
||||
// Windows: try same directory as executable
|
||||
lib_path = "libmlxc.dll";
|
||||
tried_paths[num_tried++] = lib_path;
|
||||
if (try_load_lib(lib_path)) goto success;
|
||||
#elif defined(__APPLE__)
|
||||
// macOS: try executable directory first
|
||||
lib_path = get_exe_relative_path("libmlxc.dylib");
|
||||
if (lib_path) {
|
||||
tried_paths[num_tried++] = lib_path;
|
||||
if (try_load_lib(lib_path)) goto success;
|
||||
if (try_load_lib(path)) {
|
||||
mlx_initialized = 1;
|
||||
snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
|
||||
"MLX: Successfully loaded %s", path ? path : "library");
|
||||
return 0;
|
||||
}
|
||||
// Try build directory (for tests run from repo root)
|
||||
lib_path = "./build/lib/ollama/libmlxc.dylib";
|
||||
tried_paths[num_tried++] = lib_path;
|
||||
if (try_load_lib(lib_path)) goto success;
|
||||
// Fallback to system paths
|
||||
lib_path = "libmlxc.dylib";
|
||||
tried_paths[num_tried++] = lib_path;
|
||||
if (try_load_lib(lib_path)) goto success;
|
||||
#else
|
||||
// Linux: try build directory first (for tests)
|
||||
lib_path = "./build/lib/ollama/libmlxc.so";
|
||||
tried_paths[num_tried++] = lib_path;
|
||||
if (try_load_lib(lib_path)) goto success;
|
||||
// Fallback to system paths
|
||||
lib_path = "libmlxc.so";
|
||||
tried_paths[num_tried++] = lib_path;
|
||||
if (try_load_lib(lib_path)) goto success;
|
||||
#endif
|
||||
|
||||
// Failed to load library - build error message with all tried paths
|
||||
{
|
||||
const char* err = LIB_ERROR();
|
||||
int offset = snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
|
||||
"MLX: Failed to load libmlxc library. Tried: ");
|
||||
for (int i = 0; i < num_tried && offset < (int)sizeof(mlx_error_buffer) - 50; i++) {
|
||||
offset += snprintf(mlx_error_buffer + offset, sizeof(mlx_error_buffer) - offset,
|
||||
"%s%s", i > 0 ? ", " : "", tried_paths[i]);
|
||||
}
|
||||
if (err) {
|
||||
snprintf(mlx_error_buffer + offset, sizeof(mlx_error_buffer) - offset,
|
||||
". Last error: %s", err);
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
|
||||
success:
|
||||
mlx_initialized = 1;
|
||||
const char* err = LIB_ERROR();
|
||||
snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
|
||||
"MLX: Successfully loaded %s", lib_path ? lib_path : "library");
|
||||
return 0;
|
||||
"MLX: Failed to load %s: %s", path ? path : "(null)", err ? err : "unknown error");
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Get the last error message
|
||||
@@ -124,21 +86,8 @@ const char* mlx_dynamic_error(void) {
|
||||
return mlx_error_buffer;
|
||||
}
|
||||
|
||||
// Check if MLX is initialized
|
||||
int mlx_dynamic_is_initialized(void) {
|
||||
return mlx_initialized;
|
||||
}
|
||||
|
||||
// Get the library handle (for use by generated wrappers)
|
||||
void* mlx_get_handle(void) {
|
||||
return mlx_handle;
|
||||
}
|
||||
|
||||
// Cleanup (optional, called at program exit)
|
||||
void mlx_dynamic_cleanup(void) {
|
||||
if (mlx_handle != NULL) {
|
||||
CLOSE_LIB(mlx_handle);
|
||||
mlx_handle = NULL;
|
||||
mlx_initialized = 0;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,22 +6,16 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Initialize the MLX dynamic library
|
||||
// Initialize the MLX dynamic library from a specific path
|
||||
// Returns 0 on success, -1 on failure
|
||||
int mlx_dynamic_init(void);
|
||||
int mlx_dynamic_init_path(const char* path);
|
||||
|
||||
// Get the last error message from dynamic loading
|
||||
const char* mlx_dynamic_error(void);
|
||||
|
||||
// Check if MLX is initialized
|
||||
int mlx_dynamic_is_initialized(void);
|
||||
|
||||
// Get the library handle (for use by generated wrappers)
|
||||
void* mlx_get_handle(void);
|
||||
|
||||
// Cleanup resources (optional, for clean shutdown)
|
||||
void mlx_dynamic_cleanup(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlx
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package flux2 implements the FLUX.2 Klein diffusion transformer model.
|
||||
// Klein is a 4B parameter distilled model that supports sub-second inference.
|
||||
package flux2
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package flux2
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package flux2
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package flux2
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package flux2
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package qwen3 provides a shared Qwen3 text encoder used by multiple image generation models.
|
||||
package qwen3
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package zimage
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package zimage
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package zimage implements the Z-Image diffusion transformer model.
|
||||
package zimage
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package zimage
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package zimage implements the Z-Image diffusion transformer model.
|
||||
package zimage
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package nn provides neural network layer types.
|
||||
package nn
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package nn
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package imagegen provides a unified MLX runner for both LLM and image generation models.
|
||||
package imagegen
|
||||
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
//go:build !mlx
|
||||
|
||||
package imagegen
|
||||
|
||||
import "errors"
|
||||
|
||||
// Execute returns an error when not built with MLX support.
|
||||
func Execute(args []string) error {
|
||||
return errors.New("MLX runner not available: build with mlx tag")
|
||||
}
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package safetensors
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package safetensors
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package safetensors
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
// tokenizer.go - BPE and SentencePiece tokenizer for HuggingFace models
|
||||
//
|
||||
// Based on standard BPE algorithm (Sennrich et al. 2015) with:
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package vae provides shared utilities for VAE (Variational Autoencoder) operations.
|
||||
package vae
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
|
||||
2
x/mlxrunner/cache/cache.go
vendored
2
x/mlxrunner/cache/cache.go
vendored
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
|
||||
2
x/mlxrunner/cache/recurrent.go
vendored
2
x/mlxrunner/cache/recurrent.go
vendored
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package cache
|
||||
|
||||
import "github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
|
||||
@@ -72,14 +72,23 @@ func NewClient(modelName string) (*Client, error) {
|
||||
cmd := exec.Command(exe, "runner", "--mlx-engine", "--model", modelName, "--port", strconv.Itoa(port))
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
|
||||
if runtime.GOOS == "linux" {
|
||||
// Set library path environment variable for MLX libraries
|
||||
// Linux: LD_LIBRARY_PATH, Windows: PATH
|
||||
var libPathEnvVar string
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
libPathEnvVar = "LD_LIBRARY_PATH"
|
||||
case "windows":
|
||||
libPathEnvVar = "PATH"
|
||||
}
|
||||
|
||||
if libPathEnvVar != "" {
|
||||
libraryPaths := []string{ml.LibOllamaPath}
|
||||
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
|
||||
libraryPaths = append(libraryPaths, mlxDirs...)
|
||||
}
|
||||
|
||||
if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
|
||||
if existingPath, ok := os.LookupEnv(libPathEnvVar); ok {
|
||||
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
|
||||
}
|
||||
|
||||
@@ -87,16 +96,20 @@ func NewClient(modelName string) (*Client, error) {
|
||||
|
||||
found := false
|
||||
for i := range cmd.Env {
|
||||
if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") {
|
||||
cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal
|
||||
envName := cmd.Env[i]
|
||||
if runtime.GOOS == "windows" {
|
||||
envName = strings.ToUpper(envName)
|
||||
}
|
||||
if strings.HasPrefix(envName, libPathEnvVar+"=") {
|
||||
cmd.Env[i] = libPathEnvVar + "=" + pathEnvVal
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal)
|
||||
cmd.Env = append(cmd.Env, libPathEnvVar+"="+pathEnvVal)
|
||||
}
|
||||
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
||||
slog.Debug("mlx subprocess library path", libPathEnvVar, pathEnvVal)
|
||||
}
|
||||
|
||||
c := &Client{
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
|
||||
@@ -24,3 +24,7 @@ FetchContent_Declare(
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(mlx-c)
|
||||
|
||||
# Sync vendored headers with fetched version
|
||||
file(GLOB _mlx_c_hdrs "${mlx-c_SOURCE_DIR}/mlx/c/*.h")
|
||||
file(COPY ${_mlx_c_hdrs} DESTINATION "${CMAKE_CURRENT_SOURCE_DIR}/include/mlx/c/")
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlx
|
||||
|
||||
import "testing"
|
||||
|
||||
func skipIfNoMLX(t *testing.T) {
|
||||
t.Helper()
|
||||
if err := CheckInit(); err != nil {
|
||||
t.Skipf("MLX not available: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromValue(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
for got, want := range map[*Array]DType{
|
||||
FromValue(true): DTypeBool,
|
||||
FromValue(false): DTypeBool,
|
||||
@@ -22,6 +28,7 @@ func TestFromValue(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFromValues(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
for got, want := range map[*Array]DType{
|
||||
FromValues([]bool{true, false, true}, 3): DTypeBool,
|
||||
FromValues([]uint8{1, 2, 3}, 3): DTypeUint8,
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlx
|
||||
|
||||
// #include "dynamic.h"
|
||||
@@ -24,10 +22,16 @@ func CheckInit() error {
|
||||
return initError
|
||||
}
|
||||
|
||||
// tryLoadFromDir searches a directory for libmlxc.* and tries to load it.
|
||||
// tryLoadFromDir searches a directory for the mlxc shared library and tries to load it.
|
||||
// Returns true if the library was successfully loaded.
|
||||
func tryLoadFromDir(dir string) bool {
|
||||
matches, err := fs.Glob(os.DirFS(dir), "libmlxc.*")
|
||||
// On Windows, MSVC produces mlxc.dll (no lib prefix)
|
||||
// On Unix, it's libmlxc.so or libmlxc.dylib
|
||||
pattern := "libmlxc.*"
|
||||
if runtime.GOOS == "windows" {
|
||||
pattern = "mlxc.*"
|
||||
}
|
||||
matches, err := fs.Glob(os.DirFS(dir), pattern)
|
||||
if err != nil || len(matches) == 0 {
|
||||
return false
|
||||
}
|
||||
@@ -60,7 +64,10 @@ func tryLoadFromDir(dir string) bool {
|
||||
// Returns true if the library was successfully loaded.
|
||||
func tryLoadByName() bool {
|
||||
libraryName := "libmlxc.dylib"
|
||||
if runtime.GOOS == "linux" {
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
libraryName = "mlxc.dll"
|
||||
case "linux":
|
||||
libraryName = "libmlxc.so"
|
||||
}
|
||||
|
||||
@@ -81,19 +88,25 @@ func tryLoadByName() bool {
|
||||
|
||||
func init() {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
case "darwin", "linux", "windows":
|
||||
|
||||
case "windows":
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
// Try OLLAMA_LIBRARY_PATH first
|
||||
// Try OLLAMA_LIBRARY_PATH first, including mlx_* subdirectories
|
||||
if paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH"); ok {
|
||||
for _, dir := range filepath.SplitList(paths) {
|
||||
if tryLoadFromDir(dir) {
|
||||
return
|
||||
}
|
||||
if mlxDirs, err := filepath.Glob(filepath.Join(dir, "mlx_*")); err == nil {
|
||||
for _, mlxDir := range mlxDirs {
|
||||
if tryLoadFromDir(mlxDir) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,12 +128,21 @@ func init() {
|
||||
searchDirs = append(searchDirs, filepath.Join(cwd, "build", "lib", "ollama"))
|
||||
}
|
||||
|
||||
// Also scan mlx_* subdirectories within each search dir
|
||||
var expanded []string
|
||||
for _, dir := range searchDirs {
|
||||
expanded = append(expanded, dir)
|
||||
if mlxDirs, err := filepath.Glob(filepath.Join(dir, "mlx_*")); err == nil {
|
||||
expanded = append(expanded, mlxDirs...)
|
||||
}
|
||||
}
|
||||
|
||||
for _, dir := range expanded {
|
||||
if tryLoadFromDir(dir) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
initError = fmt.Errorf("failed to load MLX dynamic library (searched: %v)", searchDirs)
|
||||
slog.Warn("MLX dynamic library not available", "error", initError)
|
||||
slog.Debug("MLX dynamic library not available", "error", initError)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
#define DLSYM(handle, symbol) GetProcAddress((HMODULE)(handle), symbol)
|
||||
#define DLSYM(handle, symbol) (void*)GetProcAddress((HMODULE)(handle.ctx), symbol)
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
#define DLSYM(handle, symbol) dlsym(handle.ctx, symbol)
|
||||
@@ -23,9 +23,15 @@ typedef uint16_t float16_t;
|
||||
typedef uint16_t bfloat16_t;
|
||||
#endif
|
||||
|
||||
#define ERROR(fmt, ...) fprintf(stderr, "%s %s - ERROR - %s:%d - " fmt "\n", __DATE__, __TIME__, __FILE__, __LINE__, ##__VA_ARGS__); return 1
|
||||
#define CHECK(x) if (!(x)) { ERROR("CHECK failed: " #x); }
|
||||
#define CHECK_LOAD(handle, x) x##_ = DLSYM(handle, #x); CHECK(x##_)
|
||||
// Undef ERROR to avoid conflict with wingdi.h on Windows
|
||||
#ifdef ERROR
|
||||
#undef ERROR
|
||||
#endif
|
||||
#define MLX_ERROR(fmt, ...) fprintf(stderr, "%s %s - ERROR - %s:%d - " fmt "\n", __DATE__, __TIME__, __FILE__, __LINE__, ##__VA_ARGS__); return 1
|
||||
#define CHECK(x) if (!(x)) { MLX_ERROR("CHECK failed: " #x); }
|
||||
#define CHECK_LOAD(handle, x) *(void**)(&x##_) = DLSYM(handle, #x); CHECK(x##_)
|
||||
// OPTIONAL_LOAD: load symbol if available, leave function pointer NULL otherwise
|
||||
#define OPTIONAL_LOAD(handle, x) *(void**)(&x##_) = DLSYM(handle, #x)
|
||||
|
||||
typedef struct {
|
||||
void* ctx;
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlx
|
||||
|
||||
// #include <stdlib.h>
|
||||
|
||||
@@ -2299,8 +2299,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_array_item_float32);
|
||||
CHECK_LOAD(handle, mlx_array_item_float64);
|
||||
CHECK_LOAD(handle, mlx_array_item_complex64);
|
||||
CHECK_LOAD(handle, mlx_array_item_float16);
|
||||
CHECK_LOAD(handle, mlx_array_item_bfloat16);
|
||||
OPTIONAL_LOAD(handle, mlx_array_item_float16);
|
||||
OPTIONAL_LOAD(handle, mlx_array_item_bfloat16);
|
||||
CHECK_LOAD(handle, mlx_array_data_bool);
|
||||
CHECK_LOAD(handle, mlx_array_data_uint8);
|
||||
CHECK_LOAD(handle, mlx_array_data_uint16);
|
||||
@@ -2313,8 +2313,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_array_data_float32);
|
||||
CHECK_LOAD(handle, mlx_array_data_float64);
|
||||
CHECK_LOAD(handle, mlx_array_data_complex64);
|
||||
CHECK_LOAD(handle, mlx_array_data_float16);
|
||||
CHECK_LOAD(handle, mlx_array_data_bfloat16);
|
||||
OPTIONAL_LOAD(handle, mlx_array_data_float16);
|
||||
OPTIONAL_LOAD(handle, mlx_array_data_bfloat16);
|
||||
CHECK_LOAD(handle, _mlx_array_is_available);
|
||||
CHECK_LOAD(handle, _mlx_array_wait);
|
||||
CHECK_LOAD(handle, _mlx_array_is_contiguous);
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
|
||||
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
{{- range .Functions }}
|
||||
CHECK_LOAD(handle, {{ .Name }});
|
||||
{{ if .Optional }}OPTIONAL_LOAD{{ else }}CHECK_LOAD{{ end }}(handle, {{ .Name }});
|
||||
{{- end }}
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -17,11 +17,21 @@ import (
|
||||
//go:embed *.gotmpl
|
||||
var fsys embed.FS
|
||||
|
||||
// optionalSymbols lists symbols that may not be present in all builds
|
||||
// (e.g., float16/bfloat16 are unavailable in CUDA builds of MLX).
|
||||
var optionalSymbols = map[string]bool{
|
||||
"mlx_array_item_float16": true,
|
||||
"mlx_array_item_bfloat16": true,
|
||||
"mlx_array_data_float16": true,
|
||||
"mlx_array_data_bfloat16": true,
|
||||
}
|
||||
|
||||
type Function struct {
|
||||
Type,
|
||||
Name,
|
||||
Parameters,
|
||||
Args string
|
||||
Optional bool
|
||||
}
|
||||
|
||||
func ParseFunction(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) Function {
|
||||
@@ -104,7 +114,9 @@ func main() {
|
||||
matches := qc.Matches(query, tree.RootNode(), bts)
|
||||
for match := matches.Next(); match != nil; match = matches.Next() {
|
||||
for _, capture := range match.Captures {
|
||||
funs = append(funs, ParseFunction(&capture.Node, tc, bts))
|
||||
fn := ParseFunction(&capture.Node, tc, bts)
|
||||
fn.Optional = optionalSymbols[fn.Name]
|
||||
funs = append(funs, fn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
12
x/mlxrunner/mlx/include/mlx/c/README.md
Normal file
12
x/mlxrunner/mlx/include/mlx/c/README.md
Normal file
@@ -0,0 +1,12 @@
|
||||
# Vendored MLX-C Headers
|
||||
|
||||
These header files are vendored from [mlx-c](https://github.com/ml-explore/mlx-c).
|
||||
The pinned version is in `MLX_VERSION` at the repo root.
|
||||
|
||||
Headers are automatically refreshed when you run a CMake build:
|
||||
|
||||
```shell
|
||||
cmake --preset 'MLX CUDA 13'
|
||||
```
|
||||
|
||||
See the [MLX Engine](../../../../../../../docs/development.md#mlx-engine-optional) section of the development docs for full build instructions.
|
||||
420
x/mlxrunner/mlx/include/mlx/c/array.h
Normal file
420
x/mlxrunner/mlx/include/mlx/c/array.h
Normal file
@@ -0,0 +1,420 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_ARRAY_H
|
||||
#define MLX_ARRAY_H
|
||||
|
||||
#include "mlx/c/string.h"
|
||||
|
||||
#include <float.h>
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
// Complex number support
|
||||
#ifdef _MSC_VER
|
||||
#define _CRT_USE_C_COMPLEX_H
|
||||
#include <complex.h>
|
||||
typedef _Fcomplex mlx_complex64_t;
|
||||
#else
|
||||
#include <complex.h>
|
||||
typedef float _Complex mlx_complex64_t;
|
||||
#endif
|
||||
|
||||
#include "half.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_array Array
|
||||
* MLX N-dimensional array object.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A N-dimensional array object.
|
||||
*/
|
||||
typedef struct mlx_array_ {
|
||||
void* ctx;
|
||||
} mlx_array;
|
||||
|
||||
static mlx_array mlx_array_empty;
|
||||
|
||||
/**
|
||||
* Array element type.
|
||||
*/
|
||||
typedef enum mlx_dtype_ {
|
||||
MLX_BOOL,
|
||||
MLX_UINT8,
|
||||
MLX_UINT16,
|
||||
MLX_UINT32,
|
||||
MLX_UINT64,
|
||||
MLX_INT8,
|
||||
MLX_INT16,
|
||||
MLX_INT32,
|
||||
MLX_INT64,
|
||||
MLX_FLOAT16,
|
||||
MLX_FLOAT32,
|
||||
MLX_FLOAT64,
|
||||
MLX_BFLOAT16,
|
||||
MLX_COMPLEX64,
|
||||
} mlx_dtype;
|
||||
|
||||
/**
|
||||
* Size of given mlx_dtype datatype in bytes.
|
||||
*/
|
||||
size_t mlx_dtype_size(mlx_dtype dtype);
|
||||
|
||||
/**
|
||||
* Get array description.
|
||||
*/
|
||||
int mlx_array_tostring(mlx_string* str, const mlx_array arr);
|
||||
|
||||
/**
|
||||
* New empty array.
|
||||
*/
|
||||
mlx_array mlx_array_new(void);
|
||||
|
||||
/**
|
||||
* Free an array.
|
||||
*/
|
||||
int mlx_array_free(mlx_array arr);
|
||||
|
||||
/**
|
||||
* New array from a bool scalar.
|
||||
*/
|
||||
mlx_array mlx_array_new_bool(bool val);
|
||||
/**
|
||||
* New array from a int scalar.
|
||||
*/
|
||||
mlx_array mlx_array_new_int(int val);
|
||||
/**
|
||||
* New array from a float32 scalar.
|
||||
*/
|
||||
mlx_array mlx_array_new_float32(float val);
|
||||
/**
|
||||
* New array from a float scalar.
|
||||
* Same as float32.
|
||||
*/
|
||||
mlx_array mlx_array_new_float(float val);
|
||||
/**
|
||||
* New array from a float64 scalar.
|
||||
*/
|
||||
mlx_array mlx_array_new_float64(double val);
|
||||
/**
|
||||
* New array from a double scalar.
|
||||
* Same as float64.
|
||||
*/
|
||||
mlx_array mlx_array_new_double(double val);
|
||||
/**
|
||||
* New array from a complex scalar.
|
||||
*/
|
||||
mlx_array mlx_array_new_complex(float real_val, float imag_val);
|
||||
/**
|
||||
* New array from existing buffer.
|
||||
* @param data A buffer which will be copied.
|
||||
* @param shape Shape of the array.
|
||||
* @param dim Number of dimensions (size of `shape`).
|
||||
* @param dtype Type of array elements.
|
||||
*/
|
||||
mlx_array mlx_array_new_data(
|
||||
const void* data,
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype);
|
||||
/**
|
||||
* New array from existing buffer.
|
||||
* @param data A buffer which will be copied.
|
||||
* @param shape Shape of the array.
|
||||
* @param dim Number of dimensions (size of `shape`).
|
||||
* @param dtype Type of array elements.
|
||||
* @param dtor Callback for when the buffer is no longer needed.
|
||||
*/
|
||||
mlx_array mlx_array_new_data_managed(
|
||||
void* data,
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype,
|
||||
void (*dtor)(void*));
|
||||
/**
|
||||
* New array from existing buffer.
|
||||
* @param data A buffer which will be copied.
|
||||
* @param shape Shape of the array.
|
||||
* @param dim Number of dimensions (size of `shape`).
|
||||
* @param dtype Type of array elements.
|
||||
* @param payload Payload pointer passed to the `dtor` callback instead of
|
||||
* `data`.
|
||||
* @param dtor Callback for when the buffer is no longer needed.
|
||||
*/
|
||||
mlx_array mlx_array_new_data_managed_payload(
|
||||
void* data,
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype,
|
||||
void* payload,
|
||||
void (*dtor)(void*));
|
||||
/**
|
||||
* Set array to provided src array.
|
||||
*/
|
||||
int mlx_array_set(mlx_array* arr, const mlx_array src);
|
||||
/**
|
||||
* Set array to a bool scalar.
|
||||
*/
|
||||
int mlx_array_set_bool(mlx_array* arr, bool val);
|
||||
/**
|
||||
* Set array to a int scalar.
|
||||
*/
|
||||
int mlx_array_set_int(mlx_array* arr, int val);
|
||||
/**
|
||||
* Set array to a float32 scalar.
|
||||
*/
|
||||
int mlx_array_set_float32(mlx_array* arr, float val);
|
||||
/**
|
||||
* Set array to a float scalar.
|
||||
*/
|
||||
int mlx_array_set_float(mlx_array* arr, float val);
|
||||
/**
|
||||
* Set array to a float64 scalar.
|
||||
*/
|
||||
int mlx_array_set_float64(mlx_array* arr, double val);
|
||||
/**
|
||||
* Set array to a double scalar.
|
||||
*/
|
||||
int mlx_array_set_double(mlx_array* arr, double val);
|
||||
/**
|
||||
* Set array to a complex scalar.
|
||||
*/
|
||||
int mlx_array_set_complex(mlx_array* arr, float real_val, float imag_val);
|
||||
/**
|
||||
* Set array to specified data and shape.
|
||||
* @param arr Destination array.
|
||||
* @param data A buffer which will be copied.
|
||||
* @param shape Shape of the array.
|
||||
* @param dim Number of dimensions (size of `shape`).
|
||||
* @param dtype Type of array elements.
|
||||
*/
|
||||
int mlx_array_set_data(
|
||||
mlx_array* arr,
|
||||
const void* data,
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype);
|
||||
|
||||
/**
|
||||
* The size of the array's datatype in bytes.
|
||||
*/
|
||||
size_t mlx_array_itemsize(const mlx_array arr);
|
||||
/**
|
||||
* Number of elements in the array.
|
||||
*/
|
||||
size_t mlx_array_size(const mlx_array arr);
|
||||
/**
|
||||
* The number of bytes in the array.
|
||||
*/
|
||||
size_t mlx_array_nbytes(const mlx_array arr);
|
||||
/**
|
||||
* The array's dimension.
|
||||
*/
|
||||
size_t mlx_array_ndim(const mlx_array arr);
|
||||
/**
|
||||
* The shape of the array.
|
||||
* Returns: a pointer to the sizes of each dimension.
|
||||
*/
|
||||
const int* mlx_array_shape(const mlx_array arr);
|
||||
/**
|
||||
* The strides of the array.
|
||||
* Returns: a pointer to the sizes of each dimension.
|
||||
*/
|
||||
const size_t* mlx_array_strides(const mlx_array arr);
|
||||
/**
|
||||
* The shape of the array in a particular dimension.
|
||||
*/
|
||||
int mlx_array_dim(const mlx_array arr, int dim);
|
||||
/**
|
||||
* The array element type.
|
||||
*/
|
||||
mlx_dtype mlx_array_dtype(const mlx_array arr);
|
||||
|
||||
/**
|
||||
* Evaluate the array.
|
||||
*/
|
||||
int mlx_array_eval(mlx_array arr);
|
||||
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_bool(bool* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_uint8(uint8_t* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_uint16(uint16_t* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_uint32(uint32_t* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_uint64(uint64_t* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_int8(int8_t* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_int16(int16_t* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_int32(int32_t* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_int64(int64_t* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_float32(float* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_float64(double* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr);
|
||||
|
||||
#ifdef HAS_FLOAT16
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_float16(float16_t* res, const mlx_array arr);
|
||||
#endif
|
||||
|
||||
#ifdef HAS_BFLOAT16
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_bfloat16(bfloat16_t* res, const mlx_array arr);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `bool*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const bool* mlx_array_data_bool(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `uint8_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const uint8_t* mlx_array_data_uint8(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `uint16_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const uint16_t* mlx_array_data_uint16(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `uint32_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const uint32_t* mlx_array_data_uint32(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `uint64_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const uint64_t* mlx_array_data_uint64(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `int8_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const int8_t* mlx_array_data_int8(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `int16_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const int16_t* mlx_array_data_int16(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `int32_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const int32_t* mlx_array_data_int32(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `int64_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const int64_t* mlx_array_data_int64(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `float32*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const float* mlx_array_data_float32(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `float64*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const double* mlx_array_data_float64(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `_Complex*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr);
|
||||
|
||||
#ifdef HAS_FLOAT16
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `float16_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const float16_t* mlx_array_data_float16(const mlx_array arr);
|
||||
#endif
|
||||
|
||||
#ifdef HAS_BFLOAT16
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `bfloat16_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const bfloat16_t* mlx_array_data_bfloat16(const mlx_array arr);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Check if the array is available.
|
||||
* Internal function: use at your own risk.
|
||||
*/
|
||||
int _mlx_array_is_available(bool* res, const mlx_array arr);
|
||||
|
||||
/**
|
||||
* Wait on the array to be available. After this `_mlx_array_is_available`
|
||||
* returns `true`. Internal function: use at your own risk.
|
||||
*/
|
||||
int _mlx_array_wait(const mlx_array arr);
|
||||
|
||||
/**
|
||||
* Whether the array is contiguous in memory.
|
||||
* Internal function: use at your own risk.
|
||||
*/
|
||||
int _mlx_array_is_contiguous(bool* res, const mlx_array arr);
|
||||
|
||||
/**
|
||||
* Whether the array's rows are contiguous in memory.
|
||||
* Internal function: use at your own risk.
|
||||
*/
|
||||
int _mlx_array_is_row_contiguous(bool* res, const mlx_array arr);
|
||||
|
||||
/**
|
||||
* Whether the array's columns are contiguous in memory.
|
||||
* Internal function: use at your own risk.
|
||||
*/
|
||||
int _mlx_array_is_col_contiguous(bool* res, const mlx_array arr);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
197
x/mlxrunner/mlx/include/mlx/c/closure.h
Normal file
197
x/mlxrunner/mlx/include/mlx/c/closure.h
Normal file
@@ -0,0 +1,197 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_CLOSURE_H
|
||||
#define MLX_CLOSURE_H
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/optional.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_closure Closures
|
||||
* MLX closure objects.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
typedef struct mlx_closure_ {
|
||||
void* ctx;
|
||||
} mlx_closure;
|
||||
mlx_closure mlx_closure_new(void);
|
||||
int mlx_closure_free(mlx_closure cls);
|
||||
mlx_closure mlx_closure_new_func(
|
||||
int (*fun)(mlx_vector_array*, const mlx_vector_array));
|
||||
mlx_closure mlx_closure_new_func_payload(
|
||||
int (*fun)(mlx_vector_array*, const mlx_vector_array, void*),
|
||||
void* payload,
|
||||
void (*dtor)(void*));
|
||||
int mlx_closure_set(mlx_closure* cls, const mlx_closure src);
|
||||
int mlx_closure_apply(
|
||||
mlx_vector_array* res,
|
||||
mlx_closure cls,
|
||||
const mlx_vector_array input);
|
||||
|
||||
mlx_closure mlx_closure_new_unary(int (*fun)(mlx_array*, const mlx_array));
|
||||
|
||||
typedef struct mlx_closure_kwargs_ {
|
||||
void* ctx;
|
||||
} mlx_closure_kwargs;
|
||||
mlx_closure_kwargs mlx_closure_kwargs_new(void);
|
||||
int mlx_closure_kwargs_free(mlx_closure_kwargs cls);
|
||||
mlx_closure_kwargs mlx_closure_kwargs_new_func(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_map_string_to_array));
|
||||
mlx_closure_kwargs mlx_closure_kwargs_new_func_payload(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_map_string_to_array,
|
||||
void*),
|
||||
void* payload,
|
||||
void (*dtor)(void*));
|
||||
int mlx_closure_kwargs_set(
|
||||
mlx_closure_kwargs* cls,
|
||||
const mlx_closure_kwargs src);
|
||||
int mlx_closure_kwargs_apply(
|
||||
mlx_vector_array* res,
|
||||
mlx_closure_kwargs cls,
|
||||
const mlx_vector_array input_0,
|
||||
const mlx_map_string_to_array input_1);
|
||||
|
||||
typedef struct mlx_closure_value_and_grad_ {
|
||||
void* ctx;
|
||||
} mlx_closure_value_and_grad;
|
||||
mlx_closure_value_and_grad mlx_closure_value_and_grad_new(void);
|
||||
int mlx_closure_value_and_grad_free(mlx_closure_value_and_grad cls);
|
||||
mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func(
|
||||
int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array));
|
||||
mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func_payload(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
void*),
|
||||
void* payload,
|
||||
void (*dtor)(void*));
|
||||
int mlx_closure_value_and_grad_set(
|
||||
mlx_closure_value_and_grad* cls,
|
||||
const mlx_closure_value_and_grad src);
|
||||
int mlx_closure_value_and_grad_apply(
|
||||
mlx_vector_array* res_0,
|
||||
mlx_vector_array* res_1,
|
||||
mlx_closure_value_and_grad cls,
|
||||
const mlx_vector_array input);
|
||||
|
||||
typedef struct mlx_closure_custom_ {
|
||||
void* ctx;
|
||||
} mlx_closure_custom;
|
||||
mlx_closure_custom mlx_closure_custom_new(void);
|
||||
int mlx_closure_custom_free(mlx_closure_custom cls);
|
||||
mlx_closure_custom mlx_closure_custom_new_func(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array));
|
||||
mlx_closure_custom mlx_closure_custom_new_func_payload(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
void*),
|
||||
void* payload,
|
||||
void (*dtor)(void*));
|
||||
int mlx_closure_custom_set(
|
||||
mlx_closure_custom* cls,
|
||||
const mlx_closure_custom src);
|
||||
int mlx_closure_custom_apply(
|
||||
mlx_vector_array* res,
|
||||
mlx_closure_custom cls,
|
||||
const mlx_vector_array input_0,
|
||||
const mlx_vector_array input_1,
|
||||
const mlx_vector_array input_2);
|
||||
|
||||
typedef struct mlx_closure_custom_jvp_ {
|
||||
void* ctx;
|
||||
} mlx_closure_custom_jvp;
|
||||
mlx_closure_custom_jvp mlx_closure_custom_jvp_new(void);
|
||||
int mlx_closure_custom_jvp_free(mlx_closure_custom_jvp cls);
|
||||
mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num));
|
||||
mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func_payload(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num,
|
||||
void*),
|
||||
void* payload,
|
||||
void (*dtor)(void*));
|
||||
int mlx_closure_custom_jvp_set(
|
||||
mlx_closure_custom_jvp* cls,
|
||||
const mlx_closure_custom_jvp src);
|
||||
int mlx_closure_custom_jvp_apply(
|
||||
mlx_vector_array* res,
|
||||
mlx_closure_custom_jvp cls,
|
||||
const mlx_vector_array input_0,
|
||||
const mlx_vector_array input_1,
|
||||
const int* input_2,
|
||||
size_t input_2_num);
|
||||
|
||||
typedef struct mlx_closure_custom_vmap_ {
|
||||
void* ctx;
|
||||
} mlx_closure_custom_vmap;
|
||||
mlx_closure_custom_vmap mlx_closure_custom_vmap_new(void);
|
||||
int mlx_closure_custom_vmap_free(mlx_closure_custom_vmap cls);
|
||||
mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
mlx_vector_int*,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num));
|
||||
mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func_payload(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
mlx_vector_int*,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num,
|
||||
void*),
|
||||
void* payload,
|
||||
void (*dtor)(void*));
|
||||
int mlx_closure_custom_vmap_set(
|
||||
mlx_closure_custom_vmap* cls,
|
||||
const mlx_closure_custom_vmap src);
|
||||
int mlx_closure_custom_vmap_apply(
|
||||
mlx_vector_array* res_0,
|
||||
mlx_vector_int* res_1,
|
||||
mlx_closure_custom_vmap cls,
|
||||
const mlx_vector_array input_0,
|
||||
const int* input_1,
|
||||
size_t input_1_num);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
57
x/mlxrunner/mlx/include/mlx/c/compile.h
Normal file
57
x/mlxrunner/mlx/include/mlx/c/compile.h
Normal file
@@ -0,0 +1,57 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_COMPILE_H
|
||||
#define MLX_COMPILE_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup compile Compilation operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
typedef enum mlx_compile_mode_ {
|
||||
MLX_COMPILE_MODE_DISABLED,
|
||||
MLX_COMPILE_MODE_NO_SIMPLIFY,
|
||||
MLX_COMPILE_MODE_NO_FUSE,
|
||||
MLX_COMPILE_MODE_ENABLED
|
||||
} mlx_compile_mode;
|
||||
int mlx_compile(mlx_closure* res, const mlx_closure fun, bool shapeless);
|
||||
int mlx_detail_compile(
|
||||
mlx_closure* res,
|
||||
const mlx_closure fun,
|
||||
uintptr_t fun_id,
|
||||
bool shapeless,
|
||||
const uint64_t* constants,
|
||||
size_t constants_num);
|
||||
int mlx_detail_compile_clear_cache(void);
|
||||
int mlx_detail_compile_erase(uintptr_t fun_id);
|
||||
int mlx_disable_compile(void);
|
||||
int mlx_enable_compile(void);
|
||||
int mlx_set_compile_mode(mlx_compile_mode mode);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
39
x/mlxrunner/mlx/include/mlx/c/cuda.h
Normal file
39
x/mlxrunner/mlx/include/mlx/c/cuda.h
Normal file
@@ -0,0 +1,39 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_CUDA_H
|
||||
#define MLX_CUDA_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup cuda Cuda specific operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_cuda_is_available(bool* res);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
154
x/mlxrunner/mlx/include/mlx/c/device.h
Normal file
154
x/mlxrunner/mlx/include/mlx/c/device.h
Normal file
@@ -0,0 +1,154 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_DEVICE_H
|
||||
#define MLX_DEVICE_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stddef.h>
|
||||
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_device Device
|
||||
* MLX device object.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A MLX device object.
|
||||
*/
|
||||
typedef struct mlx_device_ {
|
||||
void* ctx;
|
||||
} mlx_device;
|
||||
|
||||
/**
|
||||
* Device type.
|
||||
*/
|
||||
typedef enum mlx_device_type_ { MLX_CPU, MLX_GPU } mlx_device_type;
|
||||
|
||||
/**
|
||||
* Returns a new empty device.
|
||||
*/
|
||||
mlx_device mlx_device_new(void);
|
||||
|
||||
/**
|
||||
* Returns a new device of specified `type`, with specified `index`.
|
||||
*/
|
||||
mlx_device mlx_device_new_type(mlx_device_type type, int index);
|
||||
/**
|
||||
* Free a device.
|
||||
*/
|
||||
int mlx_device_free(mlx_device dev);
|
||||
/**
|
||||
* Set device to provided src device.
|
||||
*/
|
||||
int mlx_device_set(mlx_device* dev, const mlx_device src);
|
||||
/**
|
||||
* Get device description.
|
||||
*/
|
||||
int mlx_device_tostring(mlx_string* str, mlx_device dev);
|
||||
/**
|
||||
* Check if devices are the same.
|
||||
*/
|
||||
bool mlx_device_equal(mlx_device lhs, mlx_device rhs);
|
||||
/**
|
||||
* Returns the index of the device.
|
||||
*/
|
||||
int mlx_device_get_index(int* index, mlx_device dev);
|
||||
/**
|
||||
* Returns the type of the device.
|
||||
*/
|
||||
int mlx_device_get_type(mlx_device_type* type, mlx_device dev);
|
||||
/**
|
||||
* Returns the default MLX device.
|
||||
*/
|
||||
int mlx_get_default_device(mlx_device* dev);
|
||||
/**
|
||||
* Set the default MLX device.
|
||||
*/
|
||||
int mlx_set_default_device(mlx_device dev);
|
||||
/**
|
||||
* Check if device is available.
|
||||
*/
|
||||
int mlx_device_is_available(bool* avail, mlx_device dev);
|
||||
/**
|
||||
* Get the number of available devices for a device type.
|
||||
*/
|
||||
int mlx_device_count(int* count, mlx_device_type type);
|
||||
|
||||
/**
|
||||
* A MLX device info object.
|
||||
* Contains key-value pairs with device properties.
|
||||
* Keys vary by backend but common keys include:
|
||||
* - device_name (string): Device name
|
||||
* - architecture (string): Architecture identifier
|
||||
* Additional keys may be present depending on the backend.
|
||||
*/
|
||||
typedef struct mlx_device_info_ {
|
||||
void* ctx;
|
||||
} mlx_device_info;
|
||||
|
||||
/**
|
||||
* Returns a new empty device info object.
|
||||
*/
|
||||
mlx_device_info mlx_device_info_new(void);
|
||||
/**
|
||||
* Get device information for a device.
|
||||
*/
|
||||
int mlx_device_info_get(mlx_device_info* info, mlx_device dev);
|
||||
/**
|
||||
* Free a device info object.
|
||||
*/
|
||||
int mlx_device_info_free(mlx_device_info info);
|
||||
/**
|
||||
* Check if a key exists in the device info.
|
||||
* Returns 0 on success, 1 on error.
|
||||
* Sets *exists to true if the key exists, false otherwise.
|
||||
*/
|
||||
int mlx_device_info_has_key(
|
||||
bool* exists,
|
||||
mlx_device_info info,
|
||||
const char* key);
|
||||
/**
|
||||
* Check if a value is a string type.
|
||||
* Returns 0 on success, 1 on error.
|
||||
* Sets *is_string to true if the value is a string, false if it's a size_t.
|
||||
*/
|
||||
int mlx_device_info_is_string(
|
||||
bool* is_string,
|
||||
mlx_device_info info,
|
||||
const char* key);
|
||||
/**
|
||||
* Get a string value from device info.
|
||||
* Returns 0 on success, 1 on error, 2 if key not found or wrong type.
|
||||
*/
|
||||
int mlx_device_info_get_string(
|
||||
const char** value,
|
||||
mlx_device_info info,
|
||||
const char* key);
|
||||
/**
|
||||
* Get a size_t value from device info.
|
||||
* Returns 0 on success, 1 on error, 2 if key not found or wrong type.
|
||||
*/
|
||||
int mlx_device_info_get_size(
|
||||
size_t* value,
|
||||
mlx_device_info info,
|
||||
const char* key);
|
||||
/**
|
||||
* Get all keys from device info.
|
||||
* Returns 0 on success, 1 on error.
|
||||
*/
|
||||
int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
83
x/mlxrunner/mlx/include/mlx/c/distributed.h
Normal file
83
x/mlxrunner/mlx/include/mlx/c/distributed.h
Normal file
@@ -0,0 +1,83 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_DISTRIBUTED_H
|
||||
#define MLX_DISTRIBUTED_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup distributed Distributed collectives
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_distributed_all_gather(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream S);
|
||||
int mlx_distributed_all_max(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_distributed_all_min(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_distributed_all_sum(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_distributed_recv(
|
||||
mlx_array* res,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
int src,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_distributed_recv_like(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
int src,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_distributed_send(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
int dst,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_distributed_sum_scatter(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream s);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
58
x/mlxrunner/mlx/include/mlx/c/distributed_group.h
Normal file
58
x/mlxrunner/mlx/include/mlx/c/distributed_group.h
Normal file
@@ -0,0 +1,58 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_DISTRIBUTED_GROUP_H
|
||||
#define MLX_DISTRIBUTED_GROUP_H
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
#include "mlx/c/stream.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_distributed_group MLX distributed
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A MLX distributed group object.
|
||||
*/
|
||||
typedef struct mlx_distributed_group_ {
|
||||
void* ctx;
|
||||
} mlx_distributed_group;
|
||||
|
||||
/**
|
||||
* Get the rank.
|
||||
*/
|
||||
int mlx_distributed_group_rank(mlx_distributed_group group);
|
||||
|
||||
/**
|
||||
* Get the group size.
|
||||
*/
|
||||
int mlx_distributed_group_size(mlx_distributed_group group);
|
||||
|
||||
/**
|
||||
* Split the group.
|
||||
*/
|
||||
mlx_distributed_group
|
||||
mlx_distributed_group_split(mlx_distributed_group group, int color, int key);
|
||||
|
||||
/**
|
||||
* Check if distributed is available.
|
||||
*/
|
||||
bool mlx_distributed_is_available(void);
|
||||
|
||||
/**
|
||||
* Initialize distributed.
|
||||
*/
|
||||
mlx_distributed_group mlx_distributed_init(bool strict);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
41
x/mlxrunner/mlx/include/mlx/c/error.h
Normal file
41
x/mlxrunner/mlx/include/mlx/c/error.h
Normal file
@@ -0,0 +1,41 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_ERROR_H
|
||||
#define MLX_ERROR_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_error Error management
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
typedef void (*mlx_error_handler_func)(const char* msg, void* data);
|
||||
|
||||
/**
|
||||
* Set the error handler.
|
||||
*/
|
||||
void mlx_set_error_handler(
|
||||
mlx_error_handler_func handler,
|
||||
void* data,
|
||||
void (*dtor)(void*));
|
||||
|
||||
/**
|
||||
* Throw an error.
|
||||
*/
|
||||
void _mlx_error(const char* file, const int line, const char* fmt, ...);
|
||||
|
||||
/**
|
||||
* Throw an error. Macro which passes file name and line number to _mlx_error().
|
||||
*/
|
||||
#define mlx_error(...) _mlx_error(__FILE__, __LINE__, __VA_ARGS__)
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
75
x/mlxrunner/mlx/include/mlx/c/export.h
Normal file
75
x/mlxrunner/mlx/include/mlx/c/export.h
Normal file
@@ -0,0 +1,75 @@
|
||||
/* Copyright © 2023-2025 Apple Inc. */
|
||||
|
||||
#ifndef MLX_EXPORT_H
|
||||
#define MLX_EXPORT_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup export Function serialization
|
||||
*/
|
||||
/**@{*/
|
||||
int mlx_export_function(
|
||||
const char* file,
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array args,
|
||||
bool shapeless);
|
||||
int mlx_export_function_kwargs(
|
||||
const char* file,
|
||||
const mlx_closure_kwargs fun,
|
||||
const mlx_vector_array args,
|
||||
const mlx_map_string_to_array kwargs,
|
||||
bool shapeless);
|
||||
|
||||
typedef struct mlx_function_exporter_ {
|
||||
void* ctx;
|
||||
} mlx_function_exporter;
|
||||
mlx_function_exporter mlx_function_exporter_new(
|
||||
const char* file,
|
||||
const mlx_closure fun,
|
||||
bool shapeless);
|
||||
int mlx_function_exporter_free(mlx_function_exporter xfunc);
|
||||
int mlx_function_exporter_apply(
|
||||
const mlx_function_exporter xfunc,
|
||||
const mlx_vector_array args);
|
||||
int mlx_function_exporter_apply_kwargs(
|
||||
const mlx_function_exporter xfunc,
|
||||
const mlx_vector_array args,
|
||||
const mlx_map_string_to_array kwargs);
|
||||
|
||||
typedef struct mlx_imported_function_ {
|
||||
void* ctx;
|
||||
} mlx_imported_function;
|
||||
mlx_imported_function mlx_imported_function_new(const char* file);
|
||||
int mlx_imported_function_free(mlx_imported_function xfunc);
|
||||
int mlx_imported_function_apply(
|
||||
mlx_vector_array* res,
|
||||
const mlx_imported_function xfunc,
|
||||
const mlx_vector_array args);
|
||||
int mlx_imported_function_apply_kwargs(
|
||||
mlx_vector_array* res,
|
||||
const mlx_imported_function xfunc,
|
||||
const mlx_vector_array args,
|
||||
const mlx_map_string_to_array kwargs);
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
206
x/mlxrunner/mlx/include/mlx/c/fast.h
Normal file
206
x/mlxrunner/mlx/include/mlx/c/fast.h
Normal file
@@ -0,0 +1,206 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_FAST_H
|
||||
#define MLX_FAST_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup fast Fast custom operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
typedef struct mlx_fast_cuda_kernel_config_ {
|
||||
void* ctx;
|
||||
} mlx_fast_cuda_kernel_config;
|
||||
mlx_fast_cuda_kernel_config mlx_fast_cuda_kernel_config_new(void);
|
||||
void mlx_fast_cuda_kernel_config_free(mlx_fast_cuda_kernel_config cls);
|
||||
|
||||
int mlx_fast_cuda_kernel_config_add_output_arg(
|
||||
mlx_fast_cuda_kernel_config cls,
|
||||
const int* shape,
|
||||
size_t size,
|
||||
mlx_dtype dtype);
|
||||
int mlx_fast_cuda_kernel_config_set_grid(
|
||||
mlx_fast_cuda_kernel_config cls,
|
||||
int grid1,
|
||||
int grid2,
|
||||
int grid3);
|
||||
int mlx_fast_cuda_kernel_config_set_thread_group(
|
||||
mlx_fast_cuda_kernel_config cls,
|
||||
int thread1,
|
||||
int thread2,
|
||||
int thread3);
|
||||
int mlx_fast_cuda_kernel_config_set_init_value(
|
||||
mlx_fast_cuda_kernel_config cls,
|
||||
float value);
|
||||
int mlx_fast_cuda_kernel_config_set_verbose(
|
||||
mlx_fast_cuda_kernel_config cls,
|
||||
bool verbose);
|
||||
int mlx_fast_cuda_kernel_config_add_template_arg_dtype(
|
||||
mlx_fast_cuda_kernel_config cls,
|
||||
const char* name,
|
||||
mlx_dtype dtype);
|
||||
int mlx_fast_cuda_kernel_config_add_template_arg_int(
|
||||
mlx_fast_cuda_kernel_config cls,
|
||||
const char* name,
|
||||
int value);
|
||||
int mlx_fast_cuda_kernel_config_add_template_arg_bool(
|
||||
mlx_fast_cuda_kernel_config cls,
|
||||
const char* name,
|
||||
bool value);
|
||||
|
||||
typedef struct mlx_fast_cuda_kernel_ {
|
||||
void* ctx;
|
||||
} mlx_fast_cuda_kernel;
|
||||
|
||||
mlx_fast_cuda_kernel mlx_fast_cuda_kernel_new(
|
||||
const char* name,
|
||||
const mlx_vector_string input_names,
|
||||
const mlx_vector_string output_names,
|
||||
const char* source,
|
||||
const char* header,
|
||||
bool ensure_row_contiguous,
|
||||
int shared_memory);
|
||||
|
||||
void mlx_fast_cuda_kernel_free(mlx_fast_cuda_kernel cls);
|
||||
|
||||
int mlx_fast_cuda_kernel_apply(
|
||||
mlx_vector_array* outputs,
|
||||
mlx_fast_cuda_kernel cls,
|
||||
const mlx_vector_array inputs,
|
||||
const mlx_fast_cuda_kernel_config config,
|
||||
const mlx_stream stream);
|
||||
|
||||
int mlx_fast_layer_norm(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
const mlx_array weight /* may be null */,
|
||||
const mlx_array bias /* may be null */,
|
||||
float eps,
|
||||
const mlx_stream s);
|
||||
|
||||
typedef struct mlx_fast_metal_kernel_config_ {
|
||||
void* ctx;
|
||||
} mlx_fast_metal_kernel_config;
|
||||
mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new(void);
|
||||
void mlx_fast_metal_kernel_config_free(mlx_fast_metal_kernel_config cls);
|
||||
|
||||
int mlx_fast_metal_kernel_config_add_output_arg(
|
||||
mlx_fast_metal_kernel_config cls,
|
||||
const int* shape,
|
||||
size_t size,
|
||||
mlx_dtype dtype);
|
||||
int mlx_fast_metal_kernel_config_set_grid(
|
||||
mlx_fast_metal_kernel_config cls,
|
||||
int grid1,
|
||||
int grid2,
|
||||
int grid3);
|
||||
int mlx_fast_metal_kernel_config_set_thread_group(
|
||||
mlx_fast_metal_kernel_config cls,
|
||||
int thread1,
|
||||
int thread2,
|
||||
int thread3);
|
||||
int mlx_fast_metal_kernel_config_set_init_value(
|
||||
mlx_fast_metal_kernel_config cls,
|
||||
float value);
|
||||
int mlx_fast_metal_kernel_config_set_verbose(
|
||||
mlx_fast_metal_kernel_config cls,
|
||||
bool verbose);
|
||||
int mlx_fast_metal_kernel_config_add_template_arg_dtype(
|
||||
mlx_fast_metal_kernel_config cls,
|
||||
const char* name,
|
||||
mlx_dtype dtype);
|
||||
int mlx_fast_metal_kernel_config_add_template_arg_int(
|
||||
mlx_fast_metal_kernel_config cls,
|
||||
const char* name,
|
||||
int value);
|
||||
int mlx_fast_metal_kernel_config_add_template_arg_bool(
|
||||
mlx_fast_metal_kernel_config cls,
|
||||
const char* name,
|
||||
bool value);
|
||||
|
||||
typedef struct mlx_fast_metal_kernel_ {
|
||||
void* ctx;
|
||||
} mlx_fast_metal_kernel;
|
||||
|
||||
mlx_fast_metal_kernel mlx_fast_metal_kernel_new(
|
||||
const char* name,
|
||||
const mlx_vector_string input_names,
|
||||
const mlx_vector_string output_names,
|
||||
const char* source,
|
||||
const char* header,
|
||||
bool ensure_row_contiguous,
|
||||
bool atomic_outputs);
|
||||
|
||||
void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls);
|
||||
|
||||
int mlx_fast_metal_kernel_apply(
|
||||
mlx_vector_array* outputs,
|
||||
mlx_fast_metal_kernel cls,
|
||||
const mlx_vector_array inputs,
|
||||
const mlx_fast_metal_kernel_config config,
|
||||
const mlx_stream stream);
|
||||
|
||||
int mlx_fast_rms_norm(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
const mlx_array weight /* may be null */,
|
||||
float eps,
|
||||
const mlx_stream s);
|
||||
int mlx_fast_rope(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
int dims,
|
||||
bool traditional,
|
||||
mlx_optional_float base,
|
||||
float scale,
|
||||
int offset,
|
||||
const mlx_array freqs /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_fast_rope_dynamic(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
int dims,
|
||||
bool traditional,
|
||||
mlx_optional_float base,
|
||||
float scale,
|
||||
const mlx_array offset,
|
||||
const mlx_array freqs /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_fast_scaled_dot_product_attention(
|
||||
mlx_array* res,
|
||||
const mlx_array queries,
|
||||
const mlx_array keys,
|
||||
const mlx_array values,
|
||||
float scale,
|
||||
const char* mask_mode,
|
||||
const mlx_array mask_arr /* may be null */,
|
||||
const mlx_array sinks /* may be null */,
|
||||
const mlx_stream s);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
138
x/mlxrunner/mlx/include/mlx/c/fft.h
Normal file
138
x/mlxrunner/mlx/include/mlx/c/fft.h
Normal file
@@ -0,0 +1,138 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_FFT_H
|
||||
#define MLX_FFT_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup fft FFT operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_fft_fft(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
int n,
|
||||
int axis,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_fft2(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* n,
|
||||
size_t n_num,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_fftn(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* n,
|
||||
size_t n_num,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_fftshift(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_ifft(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
int n,
|
||||
int axis,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_ifft2(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* n,
|
||||
size_t n_num,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_ifftn(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* n,
|
||||
size_t n_num,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_ifftshift(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_irfft(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
int n,
|
||||
int axis,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_irfft2(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* n,
|
||||
size_t n_num,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_irfftn(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* n,
|
||||
size_t n_num,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_rfft(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
int n,
|
||||
int axis,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_rfft2(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* n,
|
||||
size_t n_num,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_rfftn(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* n,
|
||||
size_t n_num,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
26
x/mlxrunner/mlx/include/mlx/c/half.h
Normal file
26
x/mlxrunner/mlx/include/mlx/c/half.h
Normal file
@@ -0,0 +1,26 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_HALF_H
|
||||
#define MLX_HALF_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#if defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) || defined(__aarch64__)
|
||||
#define HAS_FLOAT16
|
||||
#include <arm_fp16.h>
|
||||
typedef __fp16 float16_t;
|
||||
#endif
|
||||
|
||||
#if defined(__ARM_FEATURE_BF16) || defined(__aarch64__)
|
||||
#define HAS_BFLOAT16
|
||||
#include <arm_bf16.h>
|
||||
typedef __bf16 bfloat16_t;
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
63
x/mlxrunner/mlx/include/mlx/c/io.h
Normal file
63
x/mlxrunner/mlx/include/mlx/c/io.h
Normal file
@@ -0,0 +1,63 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_IO_H
|
||||
#define MLX_IO_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup io IO operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_load_reader(
|
||||
mlx_array* res,
|
||||
mlx_io_reader in_stream,
|
||||
const mlx_stream s);
|
||||
int mlx_load(mlx_array* res, const char* file, const mlx_stream s);
|
||||
int mlx_load_safetensors_reader(
|
||||
mlx_map_string_to_array* res_0,
|
||||
mlx_map_string_to_string* res_1,
|
||||
mlx_io_reader in_stream,
|
||||
const mlx_stream s);
|
||||
int mlx_load_safetensors(
|
||||
mlx_map_string_to_array* res_0,
|
||||
mlx_map_string_to_string* res_1,
|
||||
const char* file,
|
||||
const mlx_stream s);
|
||||
int mlx_save_writer(mlx_io_writer out_stream, const mlx_array a);
|
||||
int mlx_save(const char* file, const mlx_array a);
|
||||
int mlx_save_safetensors_writer(
|
||||
mlx_io_writer in_stream,
|
||||
const mlx_map_string_to_array param,
|
||||
const mlx_map_string_to_string metadata);
|
||||
int mlx_save_safetensors(
|
||||
const char* file,
|
||||
const mlx_map_string_to_array param,
|
||||
const mlx_map_string_to_string metadata);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
104
x/mlxrunner/mlx/include/mlx/c/io_types.h
Normal file
104
x/mlxrunner/mlx/include/mlx/c/io_types.h
Normal file
@@ -0,0 +1,104 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_IO_TYPES_H
|
||||
#define MLX_IO_TYPES_H
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
#include "mlx/c/string.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_io_types IO Types
|
||||
* MLX IO type objects.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A MLX IO reader object.
|
||||
*/
|
||||
typedef struct mlx_io_reader_ {
|
||||
void* ctx;
|
||||
} mlx_io_reader;
|
||||
/**
|
||||
* A MLX IO writer object.
|
||||
*/
|
||||
typedef struct mlx_io_writer_ {
|
||||
void* ctx;
|
||||
} mlx_io_writer;
|
||||
|
||||
/**
|
||||
* Virtual table for custom IO reader and writer objects.
|
||||
*/
|
||||
typedef struct mlx_io_vtable_ {
|
||||
bool (*is_open)(void*);
|
||||
bool (*good)(void*);
|
||||
size_t (*tell)(void*);
|
||||
void (*seek)(void*, int64_t off, int whence);
|
||||
void (*read)(void*, char* data, size_t n);
|
||||
void (*read_at_offset)(void*, char* data, size_t n, size_t off);
|
||||
void (*write)(void*, const char* data, size_t n);
|
||||
const char* (*label)(void*);
|
||||
void (*free)(void*);
|
||||
} mlx_io_vtable;
|
||||
|
||||
/**
|
||||
* Returns a new custom IO reader.
|
||||
* `vtable` operates on user descriptor `desc`.
|
||||
*/
|
||||
mlx_io_reader mlx_io_reader_new(void* desc, mlx_io_vtable vtable);
|
||||
|
||||
/**
|
||||
* Get IO reader user descriptor.
|
||||
*/
|
||||
int mlx_io_reader_descriptor(void** desc_, mlx_io_reader io);
|
||||
|
||||
/**
|
||||
* Get IO reader description.
|
||||
*/
|
||||
int mlx_io_reader_tostring(mlx_string* str_, mlx_io_reader io);
|
||||
|
||||
/**
|
||||
* Free IO reader.
|
||||
*
|
||||
* Note that MLX arrays are lazily evaluated, so the underlying object may
|
||||
* be not freed right away. The ``free()`` callback from ``mlx_io_vtable``
|
||||
* will be called when the underlying object is actually freed.
|
||||
*/
|
||||
int mlx_io_reader_free(mlx_io_reader io);
|
||||
|
||||
/**
|
||||
* Returns a new custom IO writer.
|
||||
* `vtable` operates on user descriptor `desc`.
|
||||
*/
|
||||
mlx_io_writer mlx_io_writer_new(void* desc, mlx_io_vtable vtable);
|
||||
|
||||
/**
|
||||
* Get IO writer user descriptor.
|
||||
*/
|
||||
int mlx_io_writer_descriptor(void** desc_, mlx_io_writer io);
|
||||
|
||||
/**
|
||||
* Get IO writer description.
|
||||
*/
|
||||
int mlx_io_writer_tostring(mlx_string* str_, mlx_io_writer io);
|
||||
|
||||
/**
|
||||
* Free IO writer.
|
||||
*
|
||||
* Note that MLX arrays are lazily evaluated, so the underlying object may
|
||||
* be not freed right away. The ``free()`` callback from ``mlx_io_vtable``
|
||||
* will be called when the underlying object is actually freed.
|
||||
*/
|
||||
int mlx_io_writer_free(mlx_io_writer io);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
128
x/mlxrunner/mlx/include/mlx/c/linalg.h
Normal file
128
x/mlxrunner/mlx/include/mlx/c/linalg.h
Normal file
@@ -0,0 +1,128 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_LINALG_H
|
||||
#define MLX_LINALG_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup linalg Linear algebra operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_linalg_cholesky(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
bool upper,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_cholesky_inv(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
bool upper,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_cross(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array b,
|
||||
int axis,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_eig(
|
||||
mlx_array* res_0,
|
||||
mlx_array* res_1,
|
||||
const mlx_array a,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_eigh(
|
||||
mlx_array* res_0,
|
||||
mlx_array* res_1,
|
||||
const mlx_array a,
|
||||
const char* UPLO,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_eigvals(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
int mlx_linalg_eigvalsh(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const char* UPLO,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_inv(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
int mlx_linalg_lu(mlx_vector_array* res, const mlx_array a, const mlx_stream s);
|
||||
int mlx_linalg_lu_factor(
|
||||
mlx_array* res_0,
|
||||
mlx_array* res_1,
|
||||
const mlx_array a,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_norm(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
double ord,
|
||||
const int* axis /* may be null */,
|
||||
size_t axis_num,
|
||||
bool keepdims,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_norm_matrix(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const char* ord,
|
||||
const int* axis /* may be null */,
|
||||
size_t axis_num,
|
||||
bool keepdims,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_norm_l2(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* axis /* may be null */,
|
||||
size_t axis_num,
|
||||
bool keepdims,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_pinv(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
int mlx_linalg_qr(
|
||||
mlx_array* res_0,
|
||||
mlx_array* res_1,
|
||||
const mlx_array a,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_solve(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array b,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_solve_triangular(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array b,
|
||||
bool upper,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_svd(
|
||||
mlx_vector_array* res,
|
||||
const mlx_array a,
|
||||
bool compute_uv,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_tri_inv(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
bool upper,
|
||||
const mlx_stream s);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
149
x/mlxrunner/mlx/include/mlx/c/map.h
Normal file
149
x/mlxrunner/mlx/include/mlx/c/map.h
Normal file
@@ -0,0 +1,149 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_MAP_H
|
||||
#define MLX_MAP_H
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/string.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_map Maps
|
||||
* MLX map objects.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A string-to-array map
|
||||
*/
|
||||
typedef struct mlx_map_string_to_array_ {
|
||||
void* ctx;
|
||||
} mlx_map_string_to_array;
|
||||
|
||||
/**
|
||||
* Returns a new empty string-to-array map.
|
||||
*/
|
||||
mlx_map_string_to_array mlx_map_string_to_array_new(void);
|
||||
/**
|
||||
* Set map to provided src map.
|
||||
*/
|
||||
int mlx_map_string_to_array_set(
|
||||
mlx_map_string_to_array* map,
|
||||
const mlx_map_string_to_array src);
|
||||
/**
|
||||
* Free a string-to-array map.
|
||||
*/
|
||||
int mlx_map_string_to_array_free(mlx_map_string_to_array map);
|
||||
/**
|
||||
* Insert a new `value` at the specified `key` in the map.
|
||||
*/
|
||||
int mlx_map_string_to_array_insert(
|
||||
mlx_map_string_to_array map,
|
||||
const char* key,
|
||||
const mlx_array value);
|
||||
/**
|
||||
* Returns the value indexed at the specified `key` in the map.
|
||||
*/
|
||||
int mlx_map_string_to_array_get(
|
||||
mlx_array* value,
|
||||
const mlx_map_string_to_array map,
|
||||
const char* key);
|
||||
|
||||
/**
|
||||
* An iterator over a string-to-array map.
|
||||
*/
|
||||
typedef struct mlx_map_string_to_array_iterator_ {
|
||||
void* ctx;
|
||||
void* map_ctx;
|
||||
} mlx_map_string_to_array_iterator;
|
||||
/**
|
||||
* Returns a new iterator over the given map.
|
||||
*/
|
||||
mlx_map_string_to_array_iterator mlx_map_string_to_array_iterator_new(
|
||||
mlx_map_string_to_array map);
|
||||
/**
|
||||
* Free iterator.
|
||||
*/
|
||||
int mlx_map_string_to_array_iterator_free(mlx_map_string_to_array_iterator it);
|
||||
/**
|
||||
* Increment iterator.
|
||||
*/
|
||||
int mlx_map_string_to_array_iterator_next(
|
||||
const char** key,
|
||||
mlx_array* value,
|
||||
mlx_map_string_to_array_iterator it);
|
||||
|
||||
/**
|
||||
* A string-to-string map
|
||||
*/
|
||||
typedef struct mlx_map_string_to_string_ {
|
||||
void* ctx;
|
||||
} mlx_map_string_to_string;
|
||||
|
||||
/**
|
||||
* Returns a new empty string-to-string map.
|
||||
*/
|
||||
mlx_map_string_to_string mlx_map_string_to_string_new(void);
|
||||
/**
|
||||
* Set map to provided src map.
|
||||
*/
|
||||
int mlx_map_string_to_string_set(
|
||||
mlx_map_string_to_string* map,
|
||||
const mlx_map_string_to_string src);
|
||||
/**
|
||||
* Free a string-to-string map.
|
||||
*/
|
||||
int mlx_map_string_to_string_free(mlx_map_string_to_string map);
|
||||
/**
|
||||
* Insert a new `value` at the specified `key` in the map.
|
||||
*/
|
||||
int mlx_map_string_to_string_insert(
|
||||
mlx_map_string_to_string map,
|
||||
const char* key,
|
||||
const char* value);
|
||||
/**
|
||||
* Returns the value indexed at the specified `key` in the map.
|
||||
*/
|
||||
int mlx_map_string_to_string_get(
|
||||
const char** value,
|
||||
const mlx_map_string_to_string map,
|
||||
const char* key);
|
||||
|
||||
/**
|
||||
* An iterator over a string-to-string map.
|
||||
*/
|
||||
typedef struct mlx_map_string_to_string_iterator_ {
|
||||
void* ctx;
|
||||
void* map_ctx;
|
||||
} mlx_map_string_to_string_iterator;
|
||||
/**
|
||||
* Returns a new iterator over the given map.
|
||||
*/
|
||||
mlx_map_string_to_string_iterator mlx_map_string_to_string_iterator_new(
|
||||
mlx_map_string_to_string map);
|
||||
/**
|
||||
* Free iterator.
|
||||
*/
|
||||
int mlx_map_string_to_string_iterator_free(
|
||||
mlx_map_string_to_string_iterator it);
|
||||
/**
|
||||
* Increment iterator.
|
||||
*/
|
||||
int mlx_map_string_to_string_iterator_next(
|
||||
const char** key,
|
||||
const char** value,
|
||||
mlx_map_string_to_string_iterator it);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
47
x/mlxrunner/mlx/include/mlx/c/memory.h
Normal file
47
x/mlxrunner/mlx/include/mlx/c/memory.h
Normal file
@@ -0,0 +1,47 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_MEMORY_H
|
||||
#define MLX_MEMORY_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup memory Memory operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_clear_cache(void);
|
||||
int mlx_get_active_memory(size_t* res);
|
||||
int mlx_get_cache_memory(size_t* res);
|
||||
int mlx_get_memory_limit(size_t* res);
|
||||
int mlx_get_peak_memory(size_t* res);
|
||||
int mlx_reset_peak_memory(void);
|
||||
int mlx_set_cache_limit(size_t* res, size_t limit);
|
||||
int mlx_set_memory_limit(size_t* res, size_t limit);
|
||||
int mlx_set_wired_limit(size_t* res, size_t limit);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
41
x/mlxrunner/mlx/include/mlx/c/metal.h
Normal file
41
x/mlxrunner/mlx/include/mlx/c/metal.h
Normal file
@@ -0,0 +1,41 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_METAL_H
|
||||
#define MLX_METAL_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup metal Metal specific operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_metal_is_available(bool* res);
|
||||
int mlx_metal_start_capture(const char* path);
|
||||
int mlx_metal_stop_capture(void);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
34
x/mlxrunner/mlx/include/mlx/c/mlx.h
Normal file
34
x/mlxrunner/mlx/include/mlx/c/mlx.h
Normal file
@@ -0,0 +1,34 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_ALL_H
|
||||
#define MLX_ALL_H
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/compile.h"
|
||||
#include "mlx/c/cuda.h"
|
||||
#include "mlx/c/device.h"
|
||||
#include "mlx/c/distributed.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/error.h"
|
||||
#include "mlx/c/export.h"
|
||||
#include "mlx/c/fast.h"
|
||||
#include "mlx/c/fft.h"
|
||||
#include "mlx/c/half.h"
|
||||
#include "mlx/c/io.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/linalg.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/memory.h"
|
||||
#include "mlx/c/metal.h"
|
||||
#include "mlx/c/ops.h"
|
||||
#include "mlx/c/optional.h"
|
||||
#include "mlx/c/random.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/transforms.h"
|
||||
#include "mlx/c/transforms_impl.h"
|
||||
#include "mlx/c/vector.h"
|
||||
#include "mlx/c/version.h"
|
||||
|
||||
#endif
|
||||
1235
x/mlxrunner/mlx/include/mlx/c/ops.h
Normal file
1235
x/mlxrunner/mlx/include/mlx/c/ops.h
Normal file
File diff suppressed because it is too large
Load Diff
51
x/mlxrunner/mlx/include/mlx/c/optional.h
Normal file
51
x/mlxrunner/mlx/include/mlx/c/optional.h
Normal file
@@ -0,0 +1,51 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_OPTIONAL_H
|
||||
#define MLX_OPTIONAL_H
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/string.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_optional Optionals
|
||||
* MLX optional scalars.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A int optional.
|
||||
*/
|
||||
typedef struct mlx_optional_int_ {
|
||||
int value;
|
||||
bool has_value;
|
||||
} mlx_optional_int;
|
||||
|
||||
/**
|
||||
* A float optional.
|
||||
*/
|
||||
typedef struct mlx_optional_float_ {
|
||||
float value;
|
||||
bool has_value;
|
||||
} mlx_optional_float;
|
||||
|
||||
/**
|
||||
* A dtype optional.
|
||||
*/
|
||||
typedef struct mlx_optional_dtype_ {
|
||||
mlx_dtype value;
|
||||
bool has_value;
|
||||
} mlx_optional_dtype;
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
166
x/mlxrunner/mlx/include/mlx/c/random.h
Normal file
166
x/mlxrunner/mlx/include/mlx/c/random.h
Normal file
@@ -0,0 +1,166 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_RANDOM_H
|
||||
#define MLX_RANDOM_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup random Random number operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_random_bernoulli(
|
||||
mlx_array* res,
|
||||
const mlx_array p,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_bits(
|
||||
mlx_array* res,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
int width,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_categorical_shape(
|
||||
mlx_array* res,
|
||||
const mlx_array logits,
|
||||
int axis,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_categorical_num_samples(
|
||||
mlx_array* res,
|
||||
const mlx_array logits_,
|
||||
int axis,
|
||||
int num_samples,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_categorical(
|
||||
mlx_array* res,
|
||||
const mlx_array logits,
|
||||
int axis,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_gumbel(
|
||||
mlx_array* res,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_key(mlx_array* res, uint64_t seed);
|
||||
int mlx_random_laplace(
|
||||
mlx_array* res,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
float loc,
|
||||
float scale,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_multivariate_normal(
|
||||
mlx_array* res,
|
||||
const mlx_array mean,
|
||||
const mlx_array cov,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_normal_broadcast(
|
||||
mlx_array* res,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
const mlx_array loc /* may be null */,
|
||||
const mlx_array scale /* may be null */,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_normal(
|
||||
mlx_array* res,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
float loc,
|
||||
float scale,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_permutation(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
int axis,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_permutation_arange(
|
||||
mlx_array* res,
|
||||
int x,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_randint(
|
||||
mlx_array* res,
|
||||
const mlx_array low,
|
||||
const mlx_array high,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_seed(uint64_t seed);
|
||||
int mlx_random_split_num(
|
||||
mlx_array* res,
|
||||
const mlx_array key,
|
||||
int num,
|
||||
const mlx_stream s);
|
||||
int mlx_random_split(
|
||||
mlx_array* res_0,
|
||||
mlx_array* res_1,
|
||||
const mlx_array key,
|
||||
const mlx_stream s);
|
||||
int mlx_random_truncated_normal(
|
||||
mlx_array* res,
|
||||
const mlx_array lower,
|
||||
const mlx_array upper,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_uniform(
|
||||
mlx_array* res,
|
||||
const mlx_array low,
|
||||
const mlx_array high,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
88
x/mlxrunner/mlx/include/mlx/c/stream.h
Normal file
88
x/mlxrunner/mlx/include/mlx/c/stream.h
Normal file
@@ -0,0 +1,88 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_STREAM_H
|
||||
#define MLX_STREAM_H
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
#include "mlx/c/device.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_stream Stream
|
||||
* MLX stream object.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A MLX stream object.
|
||||
*/
|
||||
typedef struct mlx_stream_ {
|
||||
void* ctx;
|
||||
} mlx_stream;
|
||||
|
||||
/**
|
||||
* Returns a new empty stream.
|
||||
*/
|
||||
mlx_stream mlx_stream_new(void);
|
||||
|
||||
/**
|
||||
* Returns a new stream on a device.
|
||||
*/
|
||||
mlx_stream mlx_stream_new_device(mlx_device dev);
|
||||
/**
|
||||
* Set stream to provided src stream.
|
||||
*/
|
||||
int mlx_stream_set(mlx_stream* stream, const mlx_stream src);
|
||||
/**
|
||||
* Free a stream.
|
||||
*/
|
||||
int mlx_stream_free(mlx_stream stream);
|
||||
/**
|
||||
* Get stream description.
|
||||
*/
|
||||
int mlx_stream_tostring(mlx_string* str, mlx_stream stream);
|
||||
/**
|
||||
* Check if streams are the same.
|
||||
*/
|
||||
bool mlx_stream_equal(mlx_stream lhs, mlx_stream rhs);
|
||||
/**
|
||||
* Return the device of the stream.
|
||||
*/
|
||||
int mlx_stream_get_device(mlx_device* dev, mlx_stream stream);
|
||||
/**
|
||||
* Return the index of the stream.
|
||||
*/
|
||||
int mlx_stream_get_index(int* index, mlx_stream stream);
|
||||
/**
|
||||
* Synchronize with the provided stream.
|
||||
*/
|
||||
int mlx_synchronize(mlx_stream stream);
|
||||
/**
|
||||
* Returns the default stream on the given device.
|
||||
*/
|
||||
int mlx_get_default_stream(mlx_stream* stream, mlx_device dev);
|
||||
/**
|
||||
* Set default stream.
|
||||
*/
|
||||
int mlx_set_default_stream(mlx_stream stream);
|
||||
/**
|
||||
* Returns the current default CPU stream.
|
||||
*/
|
||||
mlx_stream mlx_default_cpu_stream_new(void);
|
||||
|
||||
/**
|
||||
* Returns the current default GPU stream.
|
||||
*/
|
||||
mlx_stream mlx_default_gpu_stream_new(void);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
55
x/mlxrunner/mlx/include/mlx/c/string.h
Normal file
55
x/mlxrunner/mlx/include/mlx/c/string.h
Normal file
@@ -0,0 +1,55 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_STRING_H
|
||||
#define MLX_STRING_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_string String
|
||||
* MLX string object.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A MLX string object.
|
||||
*/
|
||||
typedef struct mlx_string_ {
|
||||
void* ctx;
|
||||
} mlx_string;
|
||||
|
||||
/**
|
||||
* Returns a new empty string.
|
||||
*/
|
||||
mlx_string mlx_string_new(void);
|
||||
|
||||
/**
|
||||
* Returns a new string, copying contents from `str`, which must end with `\0`.
|
||||
*/
|
||||
mlx_string mlx_string_new_data(const char* str);
|
||||
|
||||
/**
|
||||
* Set string to src string.
|
||||
*/
|
||||
int mlx_string_set(mlx_string* str, const mlx_string src);
|
||||
|
||||
/**
|
||||
* Returns a pointer to the string contents.
|
||||
* The pointer is valid for the life duration of the string.
|
||||
*/
|
||||
const char* mlx_string_data(mlx_string str);
|
||||
|
||||
/**
|
||||
* Free string.
|
||||
*/
|
||||
int mlx_string_free(mlx_string str);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
68
x/mlxrunner/mlx/include/mlx/c/transforms.h
Normal file
68
x/mlxrunner/mlx/include/mlx/c/transforms.h
Normal file
@@ -0,0 +1,68 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_TRANSFORMS_H
|
||||
#define MLX_TRANSFORMS_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup transforms Transform operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_async_eval(const mlx_vector_array outputs);
|
||||
int mlx_checkpoint(mlx_closure* res, const mlx_closure fun);
|
||||
int mlx_custom_function(
|
||||
mlx_closure* res,
|
||||
const mlx_closure fun,
|
||||
const mlx_closure_custom fun_vjp /* may be null */,
|
||||
const mlx_closure_custom_jvp fun_jvp /* may be null */,
|
||||
const mlx_closure_custom_vmap fun_vmap /* may be null */);
|
||||
int mlx_custom_vjp(
|
||||
mlx_closure* res,
|
||||
const mlx_closure fun,
|
||||
const mlx_closure_custom fun_vjp);
|
||||
int mlx_eval(const mlx_vector_array outputs);
|
||||
int mlx_jvp(
|
||||
mlx_vector_array* res_0,
|
||||
mlx_vector_array* res_1,
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array primals,
|
||||
const mlx_vector_array tangents);
|
||||
int mlx_value_and_grad(
|
||||
mlx_closure_value_and_grad* res,
|
||||
const mlx_closure fun,
|
||||
const int* argnums,
|
||||
size_t argnums_num);
|
||||
int mlx_vjp(
|
||||
mlx_vector_array* res_0,
|
||||
mlx_vector_array* res_1,
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array primals,
|
||||
const mlx_vector_array cotangents);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
54
x/mlxrunner/mlx/include/mlx/c/transforms_impl.h
Normal file
54
x/mlxrunner/mlx/include/mlx/c/transforms_impl.h
Normal file
@@ -0,0 +1,54 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_TRANSFORMS_IMPL_H
|
||||
#define MLX_TRANSFORMS_IMPL_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup transforms_impl Implementation detail operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_detail_vmap_replace(
|
||||
mlx_vector_array* res,
|
||||
const mlx_vector_array inputs,
|
||||
const mlx_vector_array s_inputs,
|
||||
const mlx_vector_array s_outputs,
|
||||
const int* in_axes,
|
||||
size_t in_axes_num,
|
||||
const int* out_axes,
|
||||
size_t out_axes_num);
|
||||
int mlx_detail_vmap_trace(
|
||||
mlx_vector_array* res_0,
|
||||
mlx_vector_array* res_1,
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array inputs,
|
||||
const int* in_axes,
|
||||
size_t in_axes_num);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
133
x/mlxrunner/mlx/include/mlx/c/vector.h
Normal file
133
x/mlxrunner/mlx/include/mlx/c/vector.h
Normal file
@@ -0,0 +1,133 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_VECTOR_H
|
||||
#define MLX_VECTOR_H
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/string.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_vector Vectors
|
||||
* MLX vector objects.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A vector of array.
|
||||
*/
|
||||
typedef struct mlx_vector_array_ {
|
||||
void* ctx;
|
||||
} mlx_vector_array;
|
||||
mlx_vector_array mlx_vector_array_new(void);
|
||||
int mlx_vector_array_set(mlx_vector_array* vec, const mlx_vector_array src);
|
||||
int mlx_vector_array_free(mlx_vector_array vec);
|
||||
mlx_vector_array mlx_vector_array_new_data(const mlx_array* data, size_t size);
|
||||
mlx_vector_array mlx_vector_array_new_value(const mlx_array val);
|
||||
int mlx_vector_array_set_data(
|
||||
mlx_vector_array* vec,
|
||||
const mlx_array* data,
|
||||
size_t size);
|
||||
int mlx_vector_array_set_value(mlx_vector_array* vec, const mlx_array val);
|
||||
int mlx_vector_array_append_data(
|
||||
mlx_vector_array vec,
|
||||
const mlx_array* data,
|
||||
size_t size);
|
||||
int mlx_vector_array_append_value(mlx_vector_array vec, const mlx_array val);
|
||||
size_t mlx_vector_array_size(mlx_vector_array vec);
|
||||
int mlx_vector_array_get(
|
||||
mlx_array* res,
|
||||
const mlx_vector_array vec,
|
||||
size_t idx);
|
||||
|
||||
/**
|
||||
* A vector of vector_array.
|
||||
*/
|
||||
typedef struct mlx_vector_vector_array_ {
|
||||
void* ctx;
|
||||
} mlx_vector_vector_array;
|
||||
mlx_vector_vector_array mlx_vector_vector_array_new(void);
|
||||
int mlx_vector_vector_array_set(
|
||||
mlx_vector_vector_array* vec,
|
||||
const mlx_vector_vector_array src);
|
||||
int mlx_vector_vector_array_free(mlx_vector_vector_array vec);
|
||||
mlx_vector_vector_array mlx_vector_vector_array_new_data(
|
||||
const mlx_vector_array* data,
|
||||
size_t size);
|
||||
mlx_vector_vector_array mlx_vector_vector_array_new_value(
|
||||
const mlx_vector_array val);
|
||||
int mlx_vector_vector_array_set_data(
|
||||
mlx_vector_vector_array* vec,
|
||||
const mlx_vector_array* data,
|
||||
size_t size);
|
||||
int mlx_vector_vector_array_set_value(
|
||||
mlx_vector_vector_array* vec,
|
||||
const mlx_vector_array val);
|
||||
int mlx_vector_vector_array_append_data(
|
||||
mlx_vector_vector_array vec,
|
||||
const mlx_vector_array* data,
|
||||
size_t size);
|
||||
int mlx_vector_vector_array_append_value(
|
||||
mlx_vector_vector_array vec,
|
||||
const mlx_vector_array val);
|
||||
size_t mlx_vector_vector_array_size(mlx_vector_vector_array vec);
|
||||
int mlx_vector_vector_array_get(
|
||||
mlx_vector_array* res,
|
||||
const mlx_vector_vector_array vec,
|
||||
size_t idx);
|
||||
|
||||
/**
|
||||
* A vector of int.
|
||||
*/
|
||||
typedef struct mlx_vector_int_ {
|
||||
void* ctx;
|
||||
} mlx_vector_int;
|
||||
mlx_vector_int mlx_vector_int_new(void);
|
||||
int mlx_vector_int_set(mlx_vector_int* vec, const mlx_vector_int src);
|
||||
int mlx_vector_int_free(mlx_vector_int vec);
|
||||
mlx_vector_int mlx_vector_int_new_data(int* data, size_t size);
|
||||
mlx_vector_int mlx_vector_int_new_value(int val);
|
||||
int mlx_vector_int_set_data(mlx_vector_int* vec, int* data, size_t size);
|
||||
int mlx_vector_int_set_value(mlx_vector_int* vec, int val);
|
||||
int mlx_vector_int_append_data(mlx_vector_int vec, int* data, size_t size);
|
||||
int mlx_vector_int_append_value(mlx_vector_int vec, int val);
|
||||
size_t mlx_vector_int_size(mlx_vector_int vec);
|
||||
int mlx_vector_int_get(int* res, const mlx_vector_int vec, size_t idx);
|
||||
|
||||
/**
|
||||
* A vector of string.
|
||||
*/
|
||||
typedef struct mlx_vector_string_ {
|
||||
void* ctx;
|
||||
} mlx_vector_string;
|
||||
mlx_vector_string mlx_vector_string_new(void);
|
||||
int mlx_vector_string_set(mlx_vector_string* vec, const mlx_vector_string src);
|
||||
int mlx_vector_string_free(mlx_vector_string vec);
|
||||
mlx_vector_string mlx_vector_string_new_data(const char** data, size_t size);
|
||||
mlx_vector_string mlx_vector_string_new_value(const char* val);
|
||||
int mlx_vector_string_set_data(
|
||||
mlx_vector_string* vec,
|
||||
const char** data,
|
||||
size_t size);
|
||||
int mlx_vector_string_set_value(mlx_vector_string* vec, const char* val);
|
||||
int mlx_vector_string_append_data(
|
||||
mlx_vector_string vec,
|
||||
const char** data,
|
||||
size_t size);
|
||||
int mlx_vector_string_append_value(mlx_vector_string vec, const char* val);
|
||||
size_t mlx_vector_string_size(mlx_vector_string vec);
|
||||
int mlx_vector_string_get(char** res, const mlx_vector_string vec, size_t idx);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user