Compare commits

...

7 Commits

Author SHA1 Message Date
ParthSareen
00af64a0ae docs: add more info 2026-03-24 13:00:18 -07:00
ParthSareen
21f0db0d37 docs: update claude code docs heading 2026-03-24 12:17:03 -07:00
Jesse Gross
95ee7fbd29 mlxrunner: panic on double unpin 2026-03-23 17:44:19 -07:00
Jesse Gross
ec55536734 mlxrunner: show time since last used in cache dump tree 2026-03-23 17:44:19 -07:00
Jesse Gross
77491439c2 mlxrunner: support partial match on pure transformer caches
Previously, a partial match within a node's edge would truncate the path
to the parent snapshot - effectively making all cache types behave as
recurrent caches. Caches with only transformer layers can rewind to
arbitrary boundary so this restores this capability to improve cache
hits
2026-03-23 17:44:19 -07:00
Parth Sareen
b166b36cd2 docs: update Claude Code with Telegram guide (#15026) 2026-03-23 16:31:21 -07:00
Daniel Hiltgen
c2b0bb7a52 mlx: update as of 3/23 (#14789)
* mlx: update to HEAD on 3/23

Also fixes a few misc vendoring bugs uncovered with this first update.
This also renames the version files to make them clearer.

* CUDA Fast Gated Delta kernel

* mlx: detect eval errors and panic

On model errors or missing kernels, don't mask the error, bubble it up.
2026-03-23 11:28:44 -07:00
25 changed files with 673 additions and 167 deletions

View File

@@ -157,7 +157,7 @@ COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/imagegen/mlx x/imagegen/mlx COPY x/imagegen/mlx x/imagegen/mlx
COPY go.mod go.sum . COPY go.mod go.sum .
COPY MLX_VERSION MLX_CORE_VERSION . COPY MLX_VERSION MLX_C_VERSION .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download RUN go mod download

View File

@@ -1 +0,0 @@
v0.30.6

1
MLX_C_VERSION Normal file
View File

@@ -0,0 +1 @@
0726ca922fc902c4c61ef9c27d94132be418e945

View File

@@ -1 +1 @@
v0.5.0 38ad257088fb2193ad47e527cf6534a689f30943

View File

@@ -96,6 +96,33 @@ The `/loop` command runs a prompt or slash command on a recurring schedule insid
/loop 1h Remind me to review the deploy status /loop 1h Remind me to review the deploy status
``` ```
## Channels
Chat with Claude Code from Telegram by connecting a bot to your session. Create a bot via [@BotFather](https://t.me/BotFather).
Install the telegram plugin:
```shell
/plugin install telegram@claude-plugins-official
```
Configure the token:
```shell
/telegram:configure 123456789:ABCdEF...
```
Launch with Ollama:
```shell
ollama launch claude -- --channels plugin:telegram@claude-plugins-official
```
See the [plugin README](https://github.com/anthropics/claude-plugins-official/tree/main/external_plugins/telegram) for full setup instructions including pairing and access control.
Claude Code will prompt for permission on most actions. To allow the bot to work autonomously, configure [permission rules](https://code.claude.com/docs/en/permissions) or pass `--dangerously-skip-permissions` in isolated environments.
Other channels may also be added by following the [Claude Code docs](https://code.claude.com/docs/en/channels-reference).
## Manual setup ## Manual setup
Claude Code connects to Ollama using the Anthropic-compatible API. Claude Code connects to Ollama using the Anthropic-compatible API.

View File

@@ -1,11 +1,11 @@
include(FetchContent) include(FetchContent)
# Read MLX version from top-level file (shared with Dockerfile) # Read MLX-C version from top-level file (shared with Dockerfile)
file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_C_GIT_TAG) file(READ "${CMAKE_SOURCE_DIR}/MLX_C_VERSION" MLX_C_GIT_TAG)
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG) string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
# Read MLX core version from top-level file # Read MLX version from top-level file
file(READ "${CMAKE_SOURCE_DIR}/MLX_CORE_VERSION" MLX_GIT_TAG) file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_GIT_TAG)
string(STRIP "${MLX_GIT_TAG}" MLX_GIT_TAG) string(STRIP "${MLX_GIT_TAG}" MLX_GIT_TAG)
set(MLX_C_BUILD_EXAMPLES OFF) set(MLX_C_BUILD_EXAMPLES OFF)
@@ -98,6 +98,15 @@ FetchContent_MakeAvailable(mlx-c)
file(GLOB _mlx_c_hdrs "${mlx-c_SOURCE_DIR}/mlx/c/*.h") file(GLOB _mlx_c_hdrs "${mlx-c_SOURCE_DIR}/mlx/c/*.h")
file(COPY ${_mlx_c_hdrs} DESTINATION "${CMAKE_SOURCE_DIR}/x/mlxrunner/mlx/include/mlx/c/") file(COPY ${_mlx_c_hdrs} DESTINATION "${CMAKE_SOURCE_DIR}/x/mlxrunner/mlx/include/mlx/c/")
# Regenerate Go/C shim wrappers from the (possibly updated) headers.
find_program(GO_EXECUTABLE go REQUIRED)
message(STATUS "Regenerating MLX Go wrappers")
execute_process(
COMMAND ${GO_EXECUTABLE} generate ./x/...
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
COMMAND_ERROR_IS_FATAL ANY
)
# For local dev builds, override MLX_VERSION with git describe output # For local dev builds, override MLX_VERSION with git describe output
if(TARGET mlx_version AND DEFINED FETCHCONTENT_SOURCE_DIR_MLX) if(TARGET mlx_version AND DEFINED FETCHCONTENT_SOURCE_DIR_MLX)
execute_process( execute_process(

View File

@@ -165,8 +165,8 @@ int (*mlx_distributed_sum_scatter_ptr)(mlx_array* res, const mlx_array x, const
int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group) = NULL; int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group) = NULL;
int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group) = NULL; int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group) = NULL;
mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key) = NULL; mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key) = NULL;
bool (*mlx_distributed_is_available_ptr)(void) = NULL; bool (*mlx_distributed_is_available_ptr)(const char* bk) = NULL;
mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict) = NULL; mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict, const char* bk) = NULL;
void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) = NULL; void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) = NULL;
void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...) = NULL; void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...) = NULL;
int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless) = NULL; int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless) = NULL;
@@ -319,10 +319,12 @@ int (*mlx_astype_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype, const
int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_bartlett_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_blackman_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) = NULL; int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) = NULL;
int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s) = NULL; int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s) = NULL;
int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) = NULL; int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) = NULL;
@@ -348,7 +350,7 @@ int (*mlx_cumprod_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse
int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL; int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL;
int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies) = NULL; int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies) = NULL;
int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s) = NULL; int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s) = NULL;
int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL; int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s) = NULL; int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s) = NULL;
int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
@@ -375,6 +377,8 @@ int (*mlx_gather_qmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w,
int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s) = NULL; int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s) = NULL;
int (*mlx_hamming_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
int (*mlx_hanning_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL; int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL;
int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
@@ -434,8 +438,8 @@ int (*mlx_prod_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, siz
int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL; int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL; int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) = NULL; int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) = NULL;
int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL; int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s) = NULL;
int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL; int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s) = NULL;
int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL; int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
@@ -2101,6 +2105,11 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_3d\n"); fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_3d\n");
return -1; return -1;
} }
mlx_bartlett_ptr = GET_SYM(handle, "mlx_bartlett");
if (mlx_bartlett_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_bartlett\n");
return -1;
}
mlx_bitwise_and_ptr = GET_SYM(handle, "mlx_bitwise_and"); mlx_bitwise_and_ptr = GET_SYM(handle, "mlx_bitwise_and");
if (mlx_bitwise_and_ptr == NULL) { if (mlx_bitwise_and_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_and\n"); fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_and\n");
@@ -2121,6 +2130,11 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_xor\n"); fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_xor\n");
return -1; return -1;
} }
mlx_blackman_ptr = GET_SYM(handle, "mlx_blackman");
if (mlx_blackman_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_blackman\n");
return -1;
}
mlx_block_masked_mm_ptr = GET_SYM(handle, "mlx_block_masked_mm"); mlx_block_masked_mm_ptr = GET_SYM(handle, "mlx_block_masked_mm");
if (mlx_block_masked_mm_ptr == NULL) { if (mlx_block_masked_mm_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_block_masked_mm\n"); fprintf(stderr, "MLX: Failed to load symbol: mlx_block_masked_mm\n");
@@ -2381,6 +2395,16 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_hadamard_transform\n"); fprintf(stderr, "MLX: Failed to load symbol: mlx_hadamard_transform\n");
return -1; return -1;
} }
mlx_hamming_ptr = GET_SYM(handle, "mlx_hamming");
if (mlx_hamming_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_hamming\n");
return -1;
}
mlx_hanning_ptr = GET_SYM(handle, "mlx_hanning");
if (mlx_hanning_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_hanning\n");
return -1;
}
mlx_identity_ptr = GET_SYM(handle, "mlx_identity"); mlx_identity_ptr = GET_SYM(handle, "mlx_identity");
if (mlx_identity_ptr == NULL) { if (mlx_identity_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_identity\n"); fprintf(stderr, "MLX: Failed to load symbol: mlx_identity\n");
@@ -4132,12 +4156,12 @@ mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, i
return mlx_distributed_group_split_ptr(group, color, key); return mlx_distributed_group_split_ptr(group, color, key);
} }
bool mlx_distributed_is_available(void) { bool mlx_distributed_is_available(const char* bk) {
return mlx_distributed_is_available_ptr(); return mlx_distributed_is_available_ptr(bk);
} }
mlx_distributed_group mlx_distributed_init(bool strict) { mlx_distributed_group mlx_distributed_init(bool strict, const char* bk) {
return mlx_distributed_init_ptr(strict); return mlx_distributed_init_ptr(strict, bk);
} }
void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) { void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) {
@@ -4748,6 +4772,10 @@ int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_atleast_3d_ptr(res, a, s); return mlx_atleast_3d_ptr(res, a, s);
} }
int mlx_bartlett(mlx_array* res, int M, const mlx_stream s) {
return mlx_bartlett_ptr(res, M, s);
}
int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_bitwise_and_ptr(res, a, b, s); return mlx_bitwise_and_ptr(res, a, b, s);
} }
@@ -4764,6 +4792,10 @@ int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const
return mlx_bitwise_xor_ptr(res, a, b, s); return mlx_bitwise_xor_ptr(res, a, b, s);
} }
int mlx_blackman(mlx_array* res, int M, const mlx_stream s) {
return mlx_blackman_ptr(res, M, s);
}
int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) { int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) {
return mlx_block_masked_mm_ptr(res, a, b, block_size, mask_out, mask_lhs, mask_rhs, s); return mlx_block_masked_mm_ptr(res, a, b, block_size, mask_out, mask_lhs, mask_rhs, s);
} }
@@ -4864,8 +4896,8 @@ int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_
return mlx_depends_ptr(res, inputs, dependencies); return mlx_depends_ptr(res, inputs, dependencies);
} }
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s) { int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s) {
return mlx_dequantize_ptr(res, w, scales, biases, group_size, bits, mode, dtype, s); return mlx_dequantize_ptr(res, w, scales, biases, group_size, bits, mode, global_scale, dtype, s);
} }
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) { int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
@@ -4972,6 +5004,14 @@ int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float
return mlx_hadamard_transform_ptr(res, a, scale, s); return mlx_hadamard_transform_ptr(res, a, scale, s);
} }
int mlx_hamming(mlx_array* res, int M, const mlx_stream s) {
return mlx_hamming_ptr(res, M, s);
}
int mlx_hanning(mlx_array* res, int M, const mlx_stream s) {
return mlx_hanning_ptr(res, M, s);
}
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) { int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) {
return mlx_identity_ptr(res, n, dtype, s); return mlx_identity_ptr(res, n, dtype, s);
} }
@@ -5208,12 +5248,12 @@ int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indice
return mlx_put_along_axis_ptr(res, a, indices, values, axis, s); return mlx_put_along_axis_ptr(res, a, indices, values, axis, s);
} }
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) { int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s) {
return mlx_qqmm_ptr(res, x, w, w_scales, group_size, bits, mode, s); return mlx_qqmm_ptr(res, x, w, w_scales, group_size, bits, mode, global_scale_x, global_scale_w, s);
} }
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) { int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s) {
return mlx_quantize_ptr(res, w, group_size, bits, mode, s); return mlx_quantize_ptr(res, w, group_size, bits, mode, global_scale, s);
} }
int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) { int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {

View File

@@ -2125,7 +2125,8 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true} optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true} optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
res := C.mlx_vector_array_new() res := C.mlx_vector_array_new()
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, C.default_stream()) var globalScale C.mlx_array
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, globalScale, C.default_stream())
// Result is a vector of arrays: [weights, scales, biases?] // Result is a vector of arrays: [weights, scales, biases?]
// mxfp8 mode returns only 2 elements (no biases) // mxfp8 mode returns only 2 elements (no biases)
@@ -2161,7 +2162,8 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr
} }
res := C.mlx_array_new() res := C.mlx_array_new()
C.mlx_dequantize(&res, w.c, scales.c, b, optGroupSize, optBits, cMode, optDtype, C.default_stream()) var globalScale C.mlx_array
C.mlx_dequantize(&res, w.c, scales.c, b, optGroupSize, optBits, cMode, globalScale, optDtype, C.default_stream())
return newArray(res) return newArray(res)
} }

View File

@@ -309,10 +309,12 @@
#undef mlx_atleast_1d #undef mlx_atleast_1d
#undef mlx_atleast_2d #undef mlx_atleast_2d
#undef mlx_atleast_3d #undef mlx_atleast_3d
#undef mlx_bartlett
#undef mlx_bitwise_and #undef mlx_bitwise_and
#undef mlx_bitwise_invert #undef mlx_bitwise_invert
#undef mlx_bitwise_or #undef mlx_bitwise_or
#undef mlx_bitwise_xor #undef mlx_bitwise_xor
#undef mlx_blackman
#undef mlx_block_masked_mm #undef mlx_block_masked_mm
#undef mlx_broadcast_arrays #undef mlx_broadcast_arrays
#undef mlx_broadcast_to #undef mlx_broadcast_to
@@ -365,6 +367,8 @@
#undef mlx_greater #undef mlx_greater
#undef mlx_greater_equal #undef mlx_greater_equal
#undef mlx_hadamard_transform #undef mlx_hadamard_transform
#undef mlx_hamming
#undef mlx_hanning
#undef mlx_identity #undef mlx_identity
#undef mlx_imag #undef mlx_imag
#undef mlx_inner #undef mlx_inner
@@ -751,8 +755,8 @@ extern int (*mlx_distributed_sum_scatter_ptr)(mlx_array* res, const mlx_array x,
extern int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group); extern int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group);
extern int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group); extern int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group);
extern mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key); extern mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key);
extern bool (*mlx_distributed_is_available_ptr)(void); extern bool (*mlx_distributed_is_available_ptr)(const char* bk);
extern mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict); extern mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict, const char* bk);
extern void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*)); extern void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*));
extern void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...); extern void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...);
extern int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless); extern int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless);
@@ -905,10 +909,12 @@ extern int (*mlx_astype_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype,
extern int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_bartlett_ptr)(mlx_array* res, int M, const mlx_stream s);
extern int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
extern int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
extern int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
extern int (*mlx_blackman_ptr)(mlx_array* res, int M, const mlx_stream s);
extern int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s); extern int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s);
extern int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s); extern int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s);
extern int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s); extern int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s);
@@ -934,7 +940,7 @@ extern int (*mlx_cumprod_ptr)(mlx_array* res, const mlx_array a, int axis, bool
extern int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s); extern int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s);
extern int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies); extern int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies);
extern int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s); extern int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s);
extern int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s); extern int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
extern int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s); extern int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s);
extern int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
@@ -961,6 +967,8 @@ extern int (*mlx_gather_qmm_ptr)(mlx_array* res, const mlx_array x, const mlx_ar
extern int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
extern int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
extern int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s); extern int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s);
extern int (*mlx_hamming_ptr)(mlx_array* res, int M, const mlx_stream s);
extern int (*mlx_hanning_ptr)(mlx_array* res, int M, const mlx_stream s);
extern int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s); extern int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
extern int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
@@ -1020,8 +1028,8 @@ extern int (*mlx_prod_axes_ptr)(mlx_array* res, const mlx_array a, const int* ax
extern int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); extern int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s);
extern int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); extern int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s);
extern int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s); extern int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s);
extern int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s); extern int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s);
extern int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s); extern int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s);
extern int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s); extern int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
extern int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
@@ -1492,9 +1500,9 @@ int mlx_distributed_group_size(mlx_distributed_group group);
mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key); mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key);
bool mlx_distributed_is_available(void); bool mlx_distributed_is_available(const char* bk);
mlx_distributed_group mlx_distributed_init(bool strict); mlx_distributed_group mlx_distributed_init(bool strict, const char* bk);
void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*)); void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*));
@@ -1800,6 +1808,8 @@ int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s); int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_bartlett(mlx_array* res, int M, const mlx_stream s);
int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s); int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s);
@@ -1808,6 +1818,8 @@ int mlx_bitwise_or(mlx_array* res, const mlx_array a, const mlx_array b, const m
int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
int mlx_blackman(mlx_array* res, int M, const mlx_stream s);
int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s); int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s);
int mlx_broadcast_arrays(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s); int mlx_broadcast_arrays(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s);
@@ -1858,7 +1870,7 @@ int mlx_degrees(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies); int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies);
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s); int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s);
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s); int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
@@ -1912,6 +1924,10 @@ int mlx_greater_equal(mlx_array* res, const mlx_array a, const mlx_array b, cons
int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s); int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s);
int mlx_hamming(mlx_array* res, int M, const mlx_stream s);
int mlx_hanning(mlx_array* res, int M, const mlx_stream s);
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s); int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s); int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s);
@@ -2030,9 +2046,9 @@ int mlx_prod(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream
int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s); int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s);
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s); int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s);
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s); int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s);
int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s); int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);

View File

@@ -93,21 +93,8 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
matchPath, matched = findBestMatch(c.root, inputs[:len(inputs)-1]) matchPath, matched = findBestMatch(c.root, inputs[:len(inputs)-1])
} }
// Check for partial match within a node's edge — truncate path
// to the parent boundary. snapshot() will split the node and
// create the branch point during prefill when caches are ready.
partialMatch := false
if len(matchPath) > 1 {
lastNode := matchPath[len(matchPath)-1]
matchedInEdge := matched - lastNode.startOffset()
if matchedInEdge > 0 && matchedInEdge < len(lastNode.tokens) {
matchPath = matchPath[:len(matchPath)-1]
partialMatch = true
}
}
// Switch to the matched path, paging in/out as needed. // Switch to the matched path, paging in/out as needed.
c.switchToPath(matchPath) c.switchToPath(matchPath, matched)
// switchToPath aligns caches to a common offset // switchToPath aligns caches to a common offset
prefix := c.minCacheOffset() prefix := c.minCacheOffset()
@@ -116,7 +103,7 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
// Schedule a snapshot at the branch point during prefill so future // Schedule a snapshot at the branch point during prefill so future
// requests diverging here can restore instead of re-evaluating. // requests diverging here can restore instead of re-evaluating.
var snapshotAt int var snapshotAt int
if partialMatch || (prefix == 0 && matched > 0) { if prefix < matched {
snapshotAt = matched snapshotAt = matched
} }
@@ -142,7 +129,7 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
// switchToPath transitions from the current active path to a new path, // switchToPath transitions from the current active path to a new path,
// paging out diverging segments and paging in the new path. // paging out diverging segments and paging in the new path.
func (c *kvCache) switchToPath(newPath []*trieNode) { func (c *kvCache) switchToPath(newPath []*trieNode, matched int) {
defer c.enforceEvictionPolicy() defer c.enforceEvictionPolicy()
// Find common ancestor index. // Find common ancestor index.
@@ -167,7 +154,10 @@ func (c *kvCache) switchToPath(newPath []*trieNode) {
// non-leaf nodes here would produce wrong results for non-rewindable // non-leaf nodes here would produce wrong results for non-rewindable
// caches (e.g. RecurrentCache) whose state reflects the leaf, not // caches (e.g. RecurrentCache) whose state reflects the leaf, not
// the intermediate boundary. // the intermediate boundary.
if leaf := len(c.activePath) - 1; leaf >= commonLen { leaf := len(c.activePath) - 1
leafDiverges := leaf >= commonLen
leafNeedsRewind := matched < c.activePath[leaf].endOffset
if leafDiverges || leafNeedsRewind {
node := c.activePath[leaf] node := c.activePath[leaf]
if !node.hasAllSnapshots() { if !node.hasAllSnapshots() {
fromOffset := node.startOffset() fromOffset := node.startOffset()
@@ -184,14 +174,16 @@ func (c *kvCache) switchToPath(newPath []*trieNode) {
} }
} }
// Rewind each cache to the ancestor offset or free it. Freed // Rewind each cache to the target offset or free it. When matched
// caches (e.g. RecurrentCache that can't rewind) will be restored // falls within the ancestor's range (same-path case), we rewind
// from snapshots during page-in. // directly to the match point. Otherwise we rewind to the ancestor
// and let page-in bring us forward to matched.
rewindTarget := min(ancestorOffset, matched)
for _, kv := range c.caches { for _, kv := range c.caches {
if kv == nil { if kv == nil {
continue continue
} }
if !kv.Restore(nil, ancestorOffset) { if !kv.Restore(nil, rewindTarget) {
kv.Free() kv.Free()
} }
} }
@@ -199,10 +191,12 @@ func (c *kvCache) switchToPath(newPath []*trieNode) {
// Page in — walk the full new path, restoring from snapshots. // Page in — walk the full new path, restoring from snapshots.
// Freed caches naturally pick up the first available snapshot. // Freed caches naturally pick up the first available snapshot.
// Caches already past a node skip it via offset check. // Caches already past a node skip it via offset check.
pageIn:
for _, node := range newPath { for _, node := range newPath {
if len(node.snapshots) == 0 { if !node.hasSnapshots() {
continue continue
} }
nodeTarget := min(node.endOffset, matched)
for j, kv := range c.caches { for j, kv := range c.caches {
if kv == nil { if kv == nil {
continue continue
@@ -210,19 +204,18 @@ func (c *kvCache) switchToPath(newPath []*trieNode) {
if j >= len(node.snapshots) || node.snapshots[j] == nil { if j >= len(node.snapshots) || node.snapshots[j] == nil {
continue continue
} }
if kv.Offset() >= node.endOffset { if kv.Offset() >= nodeTarget {
continue continue
} }
if !kv.Restore(node.snapshots[j], node.endOffset) { if !kv.Restore(node.snapshots[j], nodeTarget) {
slog.Warn("cache restore failure during page-in, freeing all caches", "layer", j, "offset", node.startOffset()) // Restore failed — stop page-in and let alignment
c.freeAll() // bring all caches to a consistent offset.
c.activePath = []*trieNode{c.root} break pageIn
return
} }
} }
if node.endOffset > ancestorOffset { if node.endOffset > ancestorOffset {
pageInCount++ pageInCount++
logutil.Trace(fmt.Sprintf("page in: [%d, %d)", node.startOffset(), node.endOffset)) logutil.Trace(fmt.Sprintf("page in: [%d, %d)", node.startOffset(), nodeTarget))
} }
} }
@@ -536,6 +529,9 @@ func (c *kvCache) dumpTree() {
if nodeBytes > 0 { if nodeBytes > 0 {
label += " " + mlx.PrettyBytes(int(nodeBytes)).String() label += " " + mlx.PrettyBytes(int(nodeBytes)).String()
} }
if !n.lastUsed.IsZero() {
label += fmt.Sprintf(" %s ago", time.Since(n.lastUsed).Truncate(time.Millisecond))
}
var flags []string var flags []string
if n.user { if n.user {
flags = append(flags, "user") flags = append(flags, "user")

View File

@@ -17,7 +17,8 @@ type Cache interface {
Snapshot(fromOffset int) Snapshot Snapshot(fromOffset int) Snapshot
// Restore brings the cache to target. If snapshot is nil, rewinds // Restore brings the cache to target. If snapshot is nil, rewinds
// using the cache's own live state. // using the cache's own live state. Returns false if the target is
// unreachable (e.g. target > current offset, or negative).
Restore(snapshot Snapshot, target int) bool Restore(snapshot Snapshot, target int) bool
// Merge combines two sequential snapshots [a,b) and [b,c) into [a,c). // Merge combines two sequential snapshots [a,b) and [b,c) into [a,c).
@@ -122,17 +123,21 @@ func (c *KVCache) Snapshot(fromOffset int) Snapshot {
} }
func (c *KVCache) Restore(snapshot Snapshot, target int) bool { func (c *KVCache) Restore(snapshot Snapshot, target int) bool {
if target < 0 {
return false
}
if snapshot == nil { if snapshot == nil {
// Rewind using live state — just clamp offset. if target > c.offset {
target = max(0, min(target, c.offset)) return false
}
c.offset = target c.offset = target
return true return true
} }
snap := snapshot.(*kvSnapshot) snap := snapshot.(*kvSnapshot)
// Check that the cache has data up to the snapshot's starting point. if target > snap.toOffset || c.offset < snap.fromOffset {
if c.offset < snap.fromOffset {
return false return false
} }
@@ -354,7 +359,14 @@ func (c *RotatingKVCache) Snapshot(fromOffset int) Snapshot {
} }
func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool { func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
if target < 0 {
return false
}
if snapshot == nil { if snapshot == nil {
if target >= c.offset {
return target == c.offset
}
// Live rewind is only safe when the buffer hasn't filled yet // Live rewind is only safe when the buffer hasn't filled yet
// (offset <= maxSize). Once the window has shifted, rewinding // (offset <= maxSize). Once the window has shifted, rewinding
// leaves fewer than maxSize trailing tokens to attend to — // leaves fewer than maxSize trailing tokens to attend to —
@@ -362,7 +374,6 @@ func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
if c.offset > c.maxSize { if c.offset > c.maxSize {
return false return false
} }
target = max(0, min(target, c.offset))
c.offset = target c.offset = target
c.idx = target c.idx = target
return true return true
@@ -370,6 +381,10 @@ func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
snap := snapshot.(*rotatingSnapshot) snap := snapshot.(*rotatingSnapshot)
if target > snap.toOffset {
return false
}
// Reject if clamping would leave an incomplete window. // Reject if clamping would leave an incomplete window.
if target < snap.toOffset && snap.toOffset > c.maxSize { if target < snap.toOffset && snap.toOffset > c.maxSize {
return false return false
@@ -388,7 +403,6 @@ func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
// Clamp to target if needed. // Clamp to target if needed.
if target < c.offset { if target < c.offset {
target = max(0, target)
c.offset = target c.offset = target
c.idx = target c.idx = target
} }

View File

@@ -22,14 +22,9 @@ func (c *RecurrentCache) setStateRaw(old, v *mlx.Array) *mlx.Array {
if v == nil || !v.Valid() { if v == nil || !v.Valid() {
return old return old
} }
if old == v {
return old
}
mlx.Pin(v) mlx.Pin(v)
if old != nil && old != v { mlx.Unpin(old)
mlx.Unpin(old)
}
return v return v
} }
@@ -38,9 +33,6 @@ func (c *RecurrentCache) setStateDetached(old, v *mlx.Array, ensureContiguous bo
if v == nil || !v.Valid() { if v == nil || !v.Valid() {
return old return old
} }
if old == v {
return old
}
root := v root := v
if ensureContiguous { if ensureContiguous {
@@ -49,9 +41,7 @@ func (c *RecurrentCache) setStateDetached(old, v *mlx.Array, ensureContiguous bo
detached := root.Clone() detached := root.Clone()
mlx.Pin(detached) mlx.Pin(detached)
if old != nil && old != detached { mlx.Unpin(old)
mlx.Unpin(old)
}
return detached return detached
} }
@@ -150,10 +140,10 @@ func (c *RecurrentCache) Restore(snapshot Snapshot, target int) bool {
snap := snapshot.(*recurrentSnapshot) snap := snapshot.(*recurrentSnapshot)
// Recurrent state encodes all tokens up to snap.offset. Restoring // Recurrent snapshots encode cumulative state up to exactly
// to a target before that would leave stale state from tokens // snap.offset. Target must match — rewinding would leave stale
// [target, snap.offset) baked in. Only allow restoring forward. // state, and advancing isn't possible without feeding tokens.
if target < snap.offset { if target != snap.offset {
return false return false
} }

View File

@@ -6,39 +6,35 @@ import (
"github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/mlx"
) )
// TestRecurrentCacheRestoreDirectionality verifies that RecurrentCache only // TestRecurrentCacheRestoreExactOffset verifies that RecurrentCache restore
// allows restoring forward (target >= snapshot offset), never backward. // only succeeds when target exactly matches the snapshot's offset. Recurrent
func TestRecurrentCacheRestoreDirectionality(t *testing.T) { // state is cumulative, so it can't be rewound or fast-forwarded.
func TestRecurrentCacheRestoreExactOffset(t *testing.T) {
skipIfNoMLX(t) skipIfNoMLX(t)
c := NewRecurrentCache(3, 12, 4, 8, 8) c := NewRecurrentCache(3, 12, 4, 8, 8)
_ = c.ConvState(1, mlx.DTypeFloat16) _ = c.ConvState(1, mlx.DTypeFloat16)
_ = c.DeltaState(1, mlx.DTypeFloat16) _ = c.DeltaState(1, mlx.DTypeFloat16)
c.Advance(10) c.Advance(10)
snap := c.Snapshot(0) snap := c.Snapshot(0) // snap.offset == 10
c.Advance(5) // now at 15 c.Advance(5) // cache now at 15
// Restore backward should fail. // target < snap.offset: fails (can't rewind past snapshot)
if c.Restore(snap, 5) { if c.Restore(snap, 5) {
t.Fatal("Restore(snap, 5) should fail — target < snap.offset") t.Fatal("Restore(snap, 5) should fail — target != snap.offset")
} }
// Restore to exact snap offset should succeed. // target > snap.offset: fails (can't advance without feeding tokens)
if c.Restore(snap, 15) {
t.Fatal("Restore(snap, 15) should fail — target != snap.offset")
}
// target == snap.offset: succeeds
if !c.Restore(snap, 10) { if !c.Restore(snap, 10) {
t.Fatal("Restore(snap, 10) should succeed") t.Fatal("Restore(snap, 10) should succeed — target == snap.offset")
} }
if c.Offset() != 10 { if c.Offset() != 10 {
t.Fatalf("offset = %d, want 10", c.Offset()) t.Fatalf("offset = %d, want 10", c.Offset())
} }
// Restore forward (target > snap offset) should succeed, offset = snap.offset.
snap2 := c.Snapshot(0)
if !c.Restore(snap2, 15) {
t.Fatal("Restore(snap, 15) should succeed")
}
// Recurrent state is at snap.offset (10), not target (15).
if c.Offset() != 10 {
t.Fatalf("offset = %d, want 10 (snap offset)", c.Offset())
}
} }

View File

@@ -79,20 +79,20 @@ func (c *fakeRewindableCache) Snapshot(fromOffset int) cache.Snapshot {
} }
func (c *fakeRewindableCache) Restore(snapshot cache.Snapshot, target int) bool { func (c *fakeRewindableCache) Restore(snapshot cache.Snapshot, target int) bool {
if target < 0 {
return false
}
if snapshot == nil { if snapshot == nil {
// Rewind live state.
if target < 0 {
target = 0
}
if target > len(c.tokens) { if target > len(c.tokens) {
target = len(c.tokens) return false
} }
c.tokens = c.tokens[:target] c.tokens = c.tokens[:target]
return true return true
} }
s := snapshot.(*fakeSnapshot) s := snapshot.(*fakeSnapshot)
if len(c.tokens) < s.from { if target > s.to || len(c.tokens) < s.from {
return false // don't have base data up to snapshot start return false
} }
c.tokens = append(c.tokens[:s.from], s.tokens...) c.tokens = append(c.tokens[:s.from], s.tokens...)
if target < len(c.tokens) { if target < len(c.tokens) {
@@ -196,9 +196,13 @@ func (c *fakeSlidingWindowCache) Snapshot(fromOffset int) cache.Snapshot {
} }
func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bool { func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bool {
if target < 0 {
return false
}
if snapshot == nil { if snapshot == nil {
if target == len(c.tokens) { if target >= len(c.tokens) {
return true return target == len(c.tokens)
} }
// Live rewind only works when buffer hasn't filled (offset <= maxSize). // Live rewind only works when buffer hasn't filled (offset <= maxSize).
if len(c.tokens) > c.maxSize { if len(c.tokens) > c.maxSize {
@@ -208,6 +212,14 @@ func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bo
return true return true
} }
s := snapshot.(*fakeSnapshot) s := snapshot.(*fakeSnapshot)
if target > s.to {
return false
}
// Reject if clamping would leave an incomplete window
// (matches RotatingKVCache behavior).
if target < s.to && s.to > c.maxSize {
return false
}
c.tokens = slices.Clone(s.tokens) c.tokens = slices.Clone(s.tokens)
if target < len(c.tokens) { if target < len(c.tokens) {
c.tokens = c.tokens[:target] c.tokens = c.tokens[:target]
@@ -268,8 +280,8 @@ func (c *fakeRecurrentCache) Restore(snapshot cache.Snapshot, target int) bool {
return target == len(c.tokens) // can only no-op return target == len(c.tokens) // can only no-op
} }
s := snapshot.(*fakeSnapshot) s := snapshot.(*fakeSnapshot)
if target < s.to { if target != s.to {
return false // can't go backward return false // cumulative state requires exact match
} }
c.tokens = slices.Clone(s.tokens) c.tokens = slices.Clone(s.tokens)
return true return true
@@ -294,9 +306,10 @@ type feedableCache interface {
// testEnv encapsulates a kvCache and its fake caches for a test scenario. // testEnv encapsulates a kvCache and its fake caches for a test scenario.
type testEnv struct { type testEnv struct {
kvc *kvCache kvc *kvCache
caches []cache.Cache // typed references for assertions caches []cache.Cache // typed references for assertions
tracker *snapshotTracker tracker *snapshotTracker
rewindable bool // true when all caches support arbitrary Restore(nil, target)
} }
// newTransformerEnv creates a test environment with a single rewindable cache // newTransformerEnv creates a test environment with a single rewindable cache
@@ -305,23 +318,28 @@ func newTransformerEnv() *testEnv {
tracker := &snapshotTracker{} tracker := &snapshotTracker{}
caches := []cache.Cache{&fakeRewindableCache{tracker: tracker}} caches := []cache.Cache{&fakeRewindableCache{tracker: tracker}}
return &testEnv{ return &testEnv{
kvc: &kvCache{caches: caches}, kvc: &kvCache{caches: caches},
caches: caches, caches: caches,
tracker: tracker, tracker: tracker,
rewindable: true,
} }
} }
// newSlidingWindowEnv creates a test environment with one rewindable cache and // newSlidingWindowEnv creates a test environment with one rewindable cache and
// one sliding window cache (Mistral-style architecture). // one sliding window cache (Mistral-style architecture). The sliding window
// maxSize is set small enough that test sequences fill it, making
// Restore(nil, target) fail — the same behavior as production models where
// the window fills after a few turns.
func newSlidingWindowEnv() *testEnv { func newSlidingWindowEnv() *testEnv {
tr := &snapshotTracker{} tr := &snapshotTracker{}
rc := &fakeRewindableCache{tracker: tr} rc := &fakeRewindableCache{tracker: tr}
sw := &fakeSlidingWindowCache{maxSize: 32, tracker: tr} sw := &fakeSlidingWindowCache{maxSize: 4, tracker: tr}
caches := []cache.Cache{rc, sw} caches := []cache.Cache{rc, sw}
return &testEnv{ return &testEnv{
kvc: &kvCache{caches: caches}, kvc: &kvCache{caches: caches},
caches: caches, caches: caches,
tracker: tr, tracker: tr,
rewindable: false,
} }
} }
@@ -333,9 +351,10 @@ func newRecurrentEnv() *testEnv {
nrc := &fakeRecurrentCache{tracker: tr} nrc := &fakeRecurrentCache{tracker: tr}
caches := []cache.Cache{rc, nrc} caches := []cache.Cache{rc, nrc}
return &testEnv{ return &testEnv{
kvc: &kvCache{caches: caches}, kvc: &kvCache{caches: caches},
caches: caches, caches: caches,
tracker: tr, tracker: tr,
rewindable: false,
} }
} }
@@ -590,15 +609,24 @@ func TestBranchCreationAndReuse(t *testing.T) {
} }
// Request B: [1,2,3,4,5,10,11,12] — shares 5-token prefix with A. // Request B: [1,2,3,4,5,10,11,12] — shares 5-token prefix with A.
// Partial match in A's edge triggers snapshotOffset. // For rewindable caches, switchToPath rewinds to the match point
// so only the non-matching suffix needs evaluation. For non-rewindable
// caches (RecurrentCache), the rewind fails and freeAll fires.
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12}, []int32{30, 31}) resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12}, []int32{30, 31})
if resB.snapshotOffset != 5 { if env.rewindable {
t.Fatalf("B: snapshotOffset = %d, want 5", resB.snapshotOffset) if resB.snapshotOffset != 0 {
} t.Fatalf("B: snapshotOffset = %d, want 0 (rewind succeeded)", resB.snapshotOffset)
// Cache was rewound to 0 (partial match truncates path to root), }
// so all tokens were re-evaluated. if len(resB.remaining) != 3 {
if len(resB.remaining) != 8 { t.Fatalf("B: remaining = %d, want 3 (rewind to match point)", len(resB.remaining))
t.Fatalf("B: remaining = %d, want 8", len(resB.remaining)) }
} else {
if resB.snapshotOffset != 5 {
t.Fatalf("B: snapshotOffset = %d, want 5", resB.snapshotOffset)
}
if len(resB.remaining) != 8 {
t.Fatalf("B: remaining = %d, want 8 (freeAll fallback)", len(resB.remaining))
}
} }
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31}) env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31})
@@ -635,14 +663,24 @@ func TestExactMatchSeedBehavior(t *testing.T) {
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11}) simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11})
// Request B: identical prompt. Holdback means matched=4, partial in // Request B: identical prompt. Holdback means matched=4, partial in
// the 5-token edge, so path truncates to root and all tokens are // the 5-token edge. For rewindable caches, switchToPath rewinds to
// re-evaluated. snapshotOffset should be set at the holdback point. // offset 4, so only the held-back token needs re-evaluation. For
// non-rewindable caches, the rewind fails and freeAll fires.
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{20, 21}) resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{20, 21})
if len(resB.remaining) != 5 { if env.rewindable {
t.Fatalf("B: remaining = %d, want 5 (full re-eval due to holdback)", len(resB.remaining)) if len(resB.remaining) != 1 {
} t.Fatalf("B: remaining = %d, want 1 (rewind to holdback point)", len(resB.remaining))
if resB.snapshotOffset != 4 { }
t.Fatalf("B: snapshotOffset = %d, want 4", resB.snapshotOffset) if resB.snapshotOffset != 0 {
t.Fatalf("B: snapshotOffset = %d, want 0 (rewind succeeded)", resB.snapshotOffset)
}
} else {
if len(resB.remaining) != 5 {
t.Fatalf("B: remaining = %d, want 5 (freeAll fallback)", len(resB.remaining))
}
if resB.snapshotOffset != 4 {
t.Fatalf("B: snapshotOffset = %d, want 4", resB.snapshotOffset)
}
} }
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 20, 21}) env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 20, 21})

View File

@@ -230,6 +230,9 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
resp, err := c.client.Do(httpReq) resp, err := c.client.Do(httpReq)
if err != nil { if err != nil {
if errMsg := c.status.getLastErr(); errMsg != "" {
return fmt.Errorf("mlx runner failed: %s", errMsg)
}
return err return err
} }
defer resp.Body.Close() defer resp.Body.Close()
@@ -267,7 +270,13 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
} }
} }
return scanner.Err() if err := scanner.Err(); err != nil {
if errMsg := c.status.getLastErr(); errMsg != "" {
return fmt.Errorf("mlx runner failed: %s", errMsg)
}
return err
}
return nil
} }
func (c *Client) ContextLength() int { func (c *Client) ContextLength() int {

View File

@@ -15,7 +15,9 @@ set(CMAKE_INSTALL_RPATH "@loader_path")
include(FetchContent) include(FetchContent)
set(MLX_C_GIT_TAG "v0.5.0" CACHE STRING "") # Read MLX-C version from top-level file (shared with imagegen CMakeLists)
file(READ "${CMAKE_SOURCE_DIR}/MLX_C_VERSION" MLX_C_GIT_TAG)
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
FetchContent_Declare( FetchContent_Declare(
mlx-c mlx-c

View File

@@ -137,6 +137,9 @@ func Unpin(s ...*Array) {
for _, t := range s { for _, t := range s {
if t != nil { if t != nil {
t.pinned-- t.pinned--
if t.pinned < 0 {
panic(fmt.Sprintf("mlx.Unpin: negative pin count on array %q", t.name))
}
} }
} }
} }
@@ -261,7 +264,7 @@ func LogArrays() {
for _, t := range arrays { for _, t := range arrays {
nb := t.NumBytes() nb := t.NumBytes()
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s %v", t.name, t.DType(), PrettyBytes(nb), t.Dims())) logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned, t.Dims()))
} }
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory()))) logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory())))
} }

View File

@@ -13,6 +13,10 @@ var (
gatedDeltaMetalKernelOnce sync.Once gatedDeltaMetalKernelOnce sync.Once
gatedDeltaMetalKernel C.mlx_fast_metal_kernel gatedDeltaMetalKernel C.mlx_fast_metal_kernel
gatedDeltaMetalDisabled bool gatedDeltaMetalDisabled bool
gatedDeltaCUDAKernelOnce sync.Once
gatedDeltaCUDAKernel C.mlx_fast_cuda_kernel
gatedDeltaCUDADisabled bool
) )
const gatedDeltaMetalKernelSource = ` const gatedDeltaMetalKernelSource = `
@@ -83,6 +87,86 @@ for (int i = 0; i < n_per_t; ++i) {
} }
` `
const gatedDeltaCUDAKernelSource = `
auto tid_x = threadIdx.x;
auto tid_y = threadIdx.y;
auto grid_y = blockIdx.y * blockDim.y + tid_y;
auto grid_z = blockIdx.z;
int T_val = static_cast<int>(*T);
auto n = grid_z;
auto b_idx = n / Hv;
auto hv_idx = n % Hv;
auto hk_idx = hv_idx / (Hv / Hk);
constexpr int n_per_t = Dk / 32;
// q, k: [B, T, Hk, Dk]
auto q_ = q + b_idx * T_val * Hk * Dk + hk_idx * Dk;
auto k_ = k + b_idx * T_val * Hk * Dk + hk_idx * Dk;
// v, y: [B, T, Hv, Dv]
auto dv_idx = grid_y;
auto v_ = v + b_idx * T_val * Hv * Dv + hv_idx * Dv;
y += b_idx * T_val * Hv * Dv + hv_idx * Dv;
auto dk_idx = tid_x;
// state_in, state_out: [B, Hv, Dv, Dk]
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
float state[n_per_t];
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = static_cast<float>(i_state[s_idx]);
}
// g: [B, T, Hv]
auto g_ = g + b_idx * T_val * Hv;
auto beta_ = beta + b_idx * T_val * Hv;
for (int t = 0; t < T_val; ++t) {
float kv_mem = 0.0f;
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = state[i] * static_cast<float>(g_[hv_idx]);
kv_mem += state[i] * static_cast<float>(k_[s_idx]);
}
// Warp reduction (full warp, 32 threads in x)
for (int offset = 16; offset > 0; offset >>= 1)
kv_mem += __shfl_down_sync(0xffffffff, kv_mem, offset);
kv_mem = __shfl_sync(0xffffffff, kv_mem, 0);
auto delta = (static_cast<float>(v_[dv_idx]) - kv_mem) * static_cast<float>(beta_[hv_idx]);
float out = 0.0f;
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = state[i] + static_cast<float>(k_[s_idx]) * delta;
out += state[i] * static_cast<float>(q_[s_idx]);
}
// Warp reduction
for (int offset = 16; offset > 0; offset >>= 1)
out += __shfl_down_sync(0xffffffff, out, offset);
if (tid_x == 0) {
y[dv_idx] = static_cast<InT>(out);
}
q_ += Hk * Dk;
k_ += Hk * Dk;
v_ += Hv * Dv;
y += Hv * Dv;
g_ += Hv;
beta_ += Hv;
}
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
o_state[s_idx] = static_cast<InT>(state[i]);
}
`
func cStringVector(values []string) (C.mlx_vector_string, func(), bool) { func cStringVector(values []string) (C.mlx_vector_string, func(), bool) {
vec := C.mlx_vector_string_new() vec := C.mlx_vector_string_new()
ok := true ok := true
@@ -352,11 +436,184 @@ func gatedDeltaFallback(q, k, v, g, beta, state *Array) (y, nextState *Array) {
return Concatenate(outs, 1), nextState return Concatenate(outs, 1), nextState
} }
func initGatedDeltaCUDAKernel() {
var cudaAvail C.bool
if C.mlx_cuda_is_available(&cudaAvail) != 0 || !bool(cudaAvail) {
gatedDeltaCUDADisabled = true
return
}
inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"})
if !ok {
gatedDeltaCUDADisabled = true
freeInputs()
return
}
defer freeInputs()
outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"})
if !ok {
gatedDeltaCUDADisabled = true
freeOutputs()
return
}
defer freeOutputs()
cName := C.CString("gated_delta_step")
defer C.free(unsafe.Pointer(cName))
cSource := C.CString(gatedDeltaCUDAKernelSource)
defer C.free(unsafe.Pointer(cSource))
cHeader := C.CString("")
defer C.free(unsafe.Pointer(cHeader))
gatedDeltaCUDAKernel = C.mlx_fast_cuda_kernel_new(
cName,
inputs,
outputs,
cSource,
cHeader,
C.bool(true),
C.int(0),
)
}
func gatedDeltaCUDAKernelApply(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) {
if gatedDeltaCUDADisabled {
return nil, nil, false
}
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
return nil, nil, false
}
qd := q.Dims()
kd := k.Dims()
vd := v.Dims()
gd := g.Dims()
bd := beta.Dims()
sd := state.Dims()
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
return nil, nil, false
}
B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3]
if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 {
return nil, nil, false
}
if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk {
return nil, nil, false
}
Hv, Dv := vd[2], vd[3]
if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
return nil, nil, false
}
if gd[0] != B || gd[1] != T || gd[2] != Hv {
return nil, nil, false
}
if bd[0] != B || bd[1] != T || bd[2] != Hv {
return nil, nil, false
}
if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk {
return nil, nil, false
}
dtype := q.DType()
if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype {
return nil, nil, false
}
gatedDeltaCUDAKernelOnce.Do(initGatedDeltaCUDAKernel)
if gatedDeltaCUDADisabled {
return nil, nil, false
}
cfg := C.mlx_fast_cuda_kernel_config_new()
defer C.mlx_fast_cuda_kernel_config_free(cfg)
cInT := C.CString("InT")
defer C.free(unsafe.Pointer(cInT))
if C.mlx_fast_cuda_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(dtype)) != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
for _, tpl := range []struct {
name string
value int
}{
{name: "Dk", value: Dk},
{name: "Dv", value: Dv},
{name: "Hk", value: Hk},
{name: "Hv", value: Hv},
} {
cn := C.CString(tpl.name)
rc := C.mlx_fast_cuda_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value))
C.free(unsafe.Pointer(cn))
if rc != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
}
yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)}
stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)}
if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(dtype)) != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
if C.mlx_fast_cuda_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
threadY := Dv
if threadY > 4 {
threadY = 4
}
if C.mlx_fast_cuda_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
tScalar := FromValue(T)
inputs := []C.mlx_array{
q.ctx,
k.ctx,
v.ctx,
g.ctx,
beta.ctx,
state.ctx,
tScalar.ctx,
}
inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs)))
defer C.mlx_vector_array_free(inVec)
outVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outVec)
if C.mlx_fast_cuda_kernel_apply(&outVec, gatedDeltaCUDAKernel, inVec, cfg, DefaultStream().ctx) != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
if int(C.mlx_vector_array_size(outVec)) < 2 {
return nil, nil, false
}
y = New("GATED_DELTA_CUDA_Y")
nextState = New("GATED_DELTA_CUDA_STATE")
C.mlx_vector_array_get(&y.ctx, outVec, 0)
C.mlx_vector_array_get(&nextState.ctx, outVec, 1)
return y, nextState, true
}
// GatedDelta runs the recurrent update operation. // GatedDelta runs the recurrent update operation.
// //
// It uses the fused Metal kernel when available and otherwise falls back to a // It tries the fused CUDA kernel first, then Metal, then falls back to a
// backend-agnostic MLX implementation with identical inputs/outputs. // backend-agnostic MLX implementation with identical inputs/outputs.
func GatedDelta(q, k, v, g, beta, state *Array) (y, nextState *Array) { func GatedDelta(q, k, v, g, beta, state *Array) (y, nextState *Array) {
if y, nextState, ok := gatedDeltaCUDAKernelApply(q, k, v, g, beta, state); ok {
return y, nextState
}
if y, nextState, ok := gatedDeltaKernel(q, k, v, g, beta, state); ok { if y, nextState, ok := gatedDeltaKernel(q, k, v, g, beta, state); ok {
return y, nextState return y, nextState
} }

View File

@@ -326,8 +326,10 @@ int (*mlx_distributed_sum_scatter_)(
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL; int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL; int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL; mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
bool (*mlx_distributed_is_available_)(void) = NULL; bool (*mlx_distributed_is_available_)(const char* bk /* may be null */) = NULL;
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL; mlx_distributed_group (*mlx_distributed_init_)(
bool strict,
const char* bk /* may be null */) = NULL;
void (*mlx_set_error_handler_)( void (*mlx_set_error_handler_)(
mlx_error_handler_func handler, mlx_error_handler_func handler,
void* data, void* data,
@@ -924,6 +926,7 @@ int (*mlx_astype_)(
int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_bartlett_)(mlx_array* res, int M, const mlx_stream s) = NULL;
int (*mlx_bitwise_and_)( int (*mlx_bitwise_and_)(
mlx_array* res, mlx_array* res,
const mlx_array a, const mlx_array a,
@@ -940,6 +943,7 @@ int (*mlx_bitwise_xor_)(
const mlx_array a, const mlx_array a,
const mlx_array b, const mlx_array b,
const mlx_stream s) = NULL; const mlx_stream s) = NULL;
int (*mlx_blackman_)(mlx_array* res, int M, const mlx_stream s) = NULL;
int (*mlx_block_masked_mm_)( int (*mlx_block_masked_mm_)(
mlx_array* res, mlx_array* res,
const mlx_array a, const mlx_array a,
@@ -1120,6 +1124,7 @@ int (*mlx_dequantize_)(
mlx_optional_int group_size, mlx_optional_int group_size,
mlx_optional_int bits, mlx_optional_int bits,
const char* mode, const char* mode,
const mlx_array global_scale /* may be null */,
mlx_optional_dtype dtype, mlx_optional_dtype dtype,
const mlx_stream s) = NULL; const mlx_stream s) = NULL;
int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL; int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
@@ -1256,6 +1261,8 @@ int (*mlx_hadamard_transform_)(
const mlx_array a, const mlx_array a,
mlx_optional_float scale, mlx_optional_float scale,
const mlx_stream s) = NULL; const mlx_stream s) = NULL;
int (*mlx_hamming_)(mlx_array* res, int M, const mlx_stream s) = NULL;
int (*mlx_hanning_)(mlx_array* res, int M, const mlx_stream s) = NULL;
int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL; int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL;
int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_inner_)( int (*mlx_inner_)(
@@ -1548,6 +1555,8 @@ int (*mlx_qqmm_)(
mlx_optional_int group_size, mlx_optional_int group_size,
mlx_optional_int bits, mlx_optional_int bits,
const char* mode, const char* mode,
const mlx_array global_scale_x /* may be null */,
const mlx_array global_scale_w /* may be null */,
const mlx_stream s) = NULL; const mlx_stream s) = NULL;
int (*mlx_quantize_)( int (*mlx_quantize_)(
mlx_vector_array* res, mlx_vector_array* res,
@@ -1555,6 +1564,7 @@ int (*mlx_quantize_)(
mlx_optional_int group_size, mlx_optional_int group_size,
mlx_optional_int bits, mlx_optional_int bits,
const char* mode, const char* mode,
const mlx_array global_scale /* may be null */,
const mlx_stream s) = NULL; const mlx_stream s) = NULL;
int (*mlx_quantized_matmul_)( int (*mlx_quantized_matmul_)(
mlx_array* res, mlx_array* res,
@@ -2550,10 +2560,12 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_atleast_1d); CHECK_LOAD(handle, mlx_atleast_1d);
CHECK_LOAD(handle, mlx_atleast_2d); CHECK_LOAD(handle, mlx_atleast_2d);
CHECK_LOAD(handle, mlx_atleast_3d); CHECK_LOAD(handle, mlx_atleast_3d);
CHECK_LOAD(handle, mlx_bartlett);
CHECK_LOAD(handle, mlx_bitwise_and); CHECK_LOAD(handle, mlx_bitwise_and);
CHECK_LOAD(handle, mlx_bitwise_invert); CHECK_LOAD(handle, mlx_bitwise_invert);
CHECK_LOAD(handle, mlx_bitwise_or); CHECK_LOAD(handle, mlx_bitwise_or);
CHECK_LOAD(handle, mlx_bitwise_xor); CHECK_LOAD(handle, mlx_bitwise_xor);
CHECK_LOAD(handle, mlx_blackman);
CHECK_LOAD(handle, mlx_block_masked_mm); CHECK_LOAD(handle, mlx_block_masked_mm);
CHECK_LOAD(handle, mlx_broadcast_arrays); CHECK_LOAD(handle, mlx_broadcast_arrays);
CHECK_LOAD(handle, mlx_broadcast_to); CHECK_LOAD(handle, mlx_broadcast_to);
@@ -2606,6 +2618,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_greater); CHECK_LOAD(handle, mlx_greater);
CHECK_LOAD(handle, mlx_greater_equal); CHECK_LOAD(handle, mlx_greater_equal);
CHECK_LOAD(handle, mlx_hadamard_transform); CHECK_LOAD(handle, mlx_hadamard_transform);
CHECK_LOAD(handle, mlx_hamming);
CHECK_LOAD(handle, mlx_hanning);
CHECK_LOAD(handle, mlx_identity); CHECK_LOAD(handle, mlx_identity);
CHECK_LOAD(handle, mlx_imag); CHECK_LOAD(handle, mlx_imag);
CHECK_LOAD(handle, mlx_inner); CHECK_LOAD(handle, mlx_inner);

View File

@@ -300,10 +300,12 @@
#define mlx_atleast_1d mlx_atleast_1d_mlx_gen_orig_ #define mlx_atleast_1d mlx_atleast_1d_mlx_gen_orig_
#define mlx_atleast_2d mlx_atleast_2d_mlx_gen_orig_ #define mlx_atleast_2d mlx_atleast_2d_mlx_gen_orig_
#define mlx_atleast_3d mlx_atleast_3d_mlx_gen_orig_ #define mlx_atleast_3d mlx_atleast_3d_mlx_gen_orig_
#define mlx_bartlett mlx_bartlett_mlx_gen_orig_
#define mlx_bitwise_and mlx_bitwise_and_mlx_gen_orig_ #define mlx_bitwise_and mlx_bitwise_and_mlx_gen_orig_
#define mlx_bitwise_invert mlx_bitwise_invert_mlx_gen_orig_ #define mlx_bitwise_invert mlx_bitwise_invert_mlx_gen_orig_
#define mlx_bitwise_or mlx_bitwise_or_mlx_gen_orig_ #define mlx_bitwise_or mlx_bitwise_or_mlx_gen_orig_
#define mlx_bitwise_xor mlx_bitwise_xor_mlx_gen_orig_ #define mlx_bitwise_xor mlx_bitwise_xor_mlx_gen_orig_
#define mlx_blackman mlx_blackman_mlx_gen_orig_
#define mlx_block_masked_mm mlx_block_masked_mm_mlx_gen_orig_ #define mlx_block_masked_mm mlx_block_masked_mm_mlx_gen_orig_
#define mlx_broadcast_arrays mlx_broadcast_arrays_mlx_gen_orig_ #define mlx_broadcast_arrays mlx_broadcast_arrays_mlx_gen_orig_
#define mlx_broadcast_to mlx_broadcast_to_mlx_gen_orig_ #define mlx_broadcast_to mlx_broadcast_to_mlx_gen_orig_
@@ -356,6 +358,8 @@
#define mlx_greater mlx_greater_mlx_gen_orig_ #define mlx_greater mlx_greater_mlx_gen_orig_
#define mlx_greater_equal mlx_greater_equal_mlx_gen_orig_ #define mlx_greater_equal mlx_greater_equal_mlx_gen_orig_
#define mlx_hadamard_transform mlx_hadamard_transform_mlx_gen_orig_ #define mlx_hadamard_transform mlx_hadamard_transform_mlx_gen_orig_
#define mlx_hamming mlx_hamming_mlx_gen_orig_
#define mlx_hanning mlx_hanning_mlx_gen_orig_
#define mlx_identity mlx_identity_mlx_gen_orig_ #define mlx_identity mlx_identity_mlx_gen_orig_
#define mlx_imag mlx_imag_mlx_gen_orig_ #define mlx_imag mlx_imag_mlx_gen_orig_
#define mlx_inner mlx_inner_mlx_gen_orig_ #define mlx_inner mlx_inner_mlx_gen_orig_
@@ -889,10 +893,12 @@
#undef mlx_atleast_1d #undef mlx_atleast_1d
#undef mlx_atleast_2d #undef mlx_atleast_2d
#undef mlx_atleast_3d #undef mlx_atleast_3d
#undef mlx_bartlett
#undef mlx_bitwise_and #undef mlx_bitwise_and
#undef mlx_bitwise_invert #undef mlx_bitwise_invert
#undef mlx_bitwise_or #undef mlx_bitwise_or
#undef mlx_bitwise_xor #undef mlx_bitwise_xor
#undef mlx_blackman
#undef mlx_block_masked_mm #undef mlx_block_masked_mm
#undef mlx_broadcast_arrays #undef mlx_broadcast_arrays
#undef mlx_broadcast_to #undef mlx_broadcast_to
@@ -945,6 +951,8 @@
#undef mlx_greater #undef mlx_greater
#undef mlx_greater_equal #undef mlx_greater_equal
#undef mlx_hadamard_transform #undef mlx_hadamard_transform
#undef mlx_hamming
#undef mlx_hanning
#undef mlx_identity #undef mlx_identity
#undef mlx_imag #undef mlx_imag
#undef mlx_inner #undef mlx_inner
@@ -1501,8 +1509,10 @@ extern int (*mlx_distributed_sum_scatter_)(
extern int (*mlx_distributed_group_rank_)(mlx_distributed_group group); extern int (*mlx_distributed_group_rank_)(mlx_distributed_group group);
extern int (*mlx_distributed_group_size_)(mlx_distributed_group group); extern int (*mlx_distributed_group_size_)(mlx_distributed_group group);
extern mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key); extern mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key);
extern bool (*mlx_distributed_is_available_)(void); extern bool (*mlx_distributed_is_available_)(const char* bk /* may be null */);
extern mlx_distributed_group (*mlx_distributed_init_)(bool strict); extern mlx_distributed_group (*mlx_distributed_init_)(
bool strict,
const char* bk /* may be null */);
extern void (*mlx_set_error_handler_)( extern void (*mlx_set_error_handler_)(
mlx_error_handler_func handler, mlx_error_handler_func handler,
void* data, void* data,
@@ -2099,6 +2109,7 @@ extern int (*mlx_astype_)(
extern int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_bartlett_)(mlx_array* res, int M, const mlx_stream s);
extern int (*mlx_bitwise_and_)( extern int (*mlx_bitwise_and_)(
mlx_array* res, mlx_array* res,
const mlx_array a, const mlx_array a,
@@ -2115,6 +2126,7 @@ extern int (*mlx_bitwise_xor_)(
const mlx_array a, const mlx_array a,
const mlx_array b, const mlx_array b,
const mlx_stream s); const mlx_stream s);
extern int (*mlx_blackman_)(mlx_array* res, int M, const mlx_stream s);
extern int (*mlx_block_masked_mm_)( extern int (*mlx_block_masked_mm_)(
mlx_array* res, mlx_array* res,
const mlx_array a, const mlx_array a,
@@ -2295,6 +2307,7 @@ extern int (*mlx_dequantize_)(
mlx_optional_int group_size, mlx_optional_int group_size,
mlx_optional_int bits, mlx_optional_int bits,
const char* mode, const char* mode,
const mlx_array global_scale /* may be null */,
mlx_optional_dtype dtype, mlx_optional_dtype dtype,
const mlx_stream s); const mlx_stream s);
extern int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s); extern int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
@@ -2431,6 +2444,8 @@ extern int (*mlx_hadamard_transform_)(
const mlx_array a, const mlx_array a,
mlx_optional_float scale, mlx_optional_float scale,
const mlx_stream s); const mlx_stream s);
extern int (*mlx_hamming_)(mlx_array* res, int M, const mlx_stream s);
extern int (*mlx_hanning_)(mlx_array* res, int M, const mlx_stream s);
extern int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s); extern int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
extern int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_inner_)( extern int (*mlx_inner_)(
@@ -2723,6 +2738,8 @@ extern int (*mlx_qqmm_)(
mlx_optional_int group_size, mlx_optional_int group_size,
mlx_optional_int bits, mlx_optional_int bits,
const char* mode, const char* mode,
const mlx_array global_scale_x /* may be null */,
const mlx_array global_scale_w /* may be null */,
const mlx_stream s); const mlx_stream s);
extern int (*mlx_quantize_)( extern int (*mlx_quantize_)(
mlx_vector_array* res, mlx_vector_array* res,
@@ -2730,6 +2747,7 @@ extern int (*mlx_quantize_)(
mlx_optional_int group_size, mlx_optional_int group_size,
mlx_optional_int bits, mlx_optional_int bits,
const char* mode, const char* mode,
const mlx_array global_scale /* may be null */,
const mlx_stream s); const mlx_stream s);
extern int (*mlx_quantized_matmul_)( extern int (*mlx_quantized_matmul_)(
mlx_array* res, mlx_array* res,
@@ -4033,11 +4051,13 @@ static inline int mlx_distributed_group_size(mlx_distributed_group group) {
static inline mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) { static inline mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) {
return mlx_distributed_group_split_(group, color, key); return mlx_distributed_group_split_(group, color, key);
} }
static inline bool mlx_distributed_is_available(void) { static inline bool mlx_distributed_is_available(const char* bk /* may be null */) {
return mlx_distributed_is_available_(); return mlx_distributed_is_available_(bk);
} }
static inline mlx_distributed_group mlx_distributed_init(bool strict) { static inline mlx_distributed_group mlx_distributed_init(
return mlx_distributed_init_(strict); bool strict,
const char* bk /* may be null */) {
return mlx_distributed_init_(strict, bk);
} }
static inline void mlx_set_error_handler( static inline void mlx_set_error_handler(
mlx_error_handler_func handler, mlx_error_handler_func handler,
@@ -4939,6 +4959,9 @@ static inline int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_st
static inline int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) { static inline int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_atleast_3d_(res, a, s); return mlx_atleast_3d_(res, a, s);
} }
static inline int mlx_bartlett(mlx_array* res, int M, const mlx_stream s) {
return mlx_bartlett_(res, M, s);
}
static inline int mlx_bitwise_and( static inline int mlx_bitwise_and(
mlx_array* res, mlx_array* res,
const mlx_array a, const mlx_array a,
@@ -4963,6 +4986,9 @@ static inline int mlx_bitwise_xor(
const mlx_stream s) { const mlx_stream s) {
return mlx_bitwise_xor_(res, a, b, s); return mlx_bitwise_xor_(res, a, b, s);
} }
static inline int mlx_blackman(mlx_array* res, int M, const mlx_stream s) {
return mlx_blackman_(res, M, s);
}
static inline int mlx_block_masked_mm( static inline int mlx_block_masked_mm(
mlx_array* res, mlx_array* res,
const mlx_array a, const mlx_array a,
@@ -5193,9 +5219,10 @@ static inline int mlx_dequantize(
mlx_optional_int group_size, mlx_optional_int group_size,
mlx_optional_int bits, mlx_optional_int bits,
const char* mode, const char* mode,
const mlx_array global_scale /* may be null */,
mlx_optional_dtype dtype, mlx_optional_dtype dtype,
const mlx_stream s) { const mlx_stream s) {
return mlx_dequantize_(res, w, scales, biases, group_size, bits, mode, dtype, s); return mlx_dequantize_(res, w, scales, biases, group_size, bits, mode, global_scale, dtype, s);
} }
static inline int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) { static inline int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
return mlx_diag_(res, a, k, s); return mlx_diag_(res, a, k, s);
@@ -5383,6 +5410,12 @@ static inline int mlx_hadamard_transform(
const mlx_stream s) { const mlx_stream s) {
return mlx_hadamard_transform_(res, a, scale, s); return mlx_hadamard_transform_(res, a, scale, s);
} }
static inline int mlx_hamming(mlx_array* res, int M, const mlx_stream s) {
return mlx_hamming_(res, M, s);
}
static inline int mlx_hanning(mlx_array* res, int M, const mlx_stream s) {
return mlx_hanning_(res, M, s);
}
static inline int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) { static inline int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) {
return mlx_identity_(res, n, dtype, s); return mlx_identity_(res, n, dtype, s);
} }
@@ -5793,8 +5826,10 @@ static inline int mlx_qqmm(
mlx_optional_int group_size, mlx_optional_int group_size,
mlx_optional_int bits, mlx_optional_int bits,
const char* mode, const char* mode,
const mlx_array global_scale_x /* may be null */,
const mlx_array global_scale_w /* may be null */,
const mlx_stream s) { const mlx_stream s) {
return mlx_qqmm_(res, x, w, w_scales, group_size, bits, mode, s); return mlx_qqmm_(res, x, w, w_scales, group_size, bits, mode, global_scale_x, global_scale_w, s);
} }
static inline int mlx_quantize( static inline int mlx_quantize(
mlx_vector_array* res, mlx_vector_array* res,
@@ -5802,8 +5837,9 @@ static inline int mlx_quantize(
mlx_optional_int group_size, mlx_optional_int group_size,
mlx_optional_int bits, mlx_optional_int bits,
const char* mode, const char* mode,
const mlx_array global_scale /* may be null */,
const mlx_stream s) { const mlx_stream s) {
return mlx_quantize_(res, w, group_size, bits, mode, s); return mlx_quantize_(res, w, group_size, bits, mode, global_scale, s);
} }
static inline int mlx_quantized_matmul( static inline int mlx_quantized_matmul(
mlx_array* res, mlx_array* res,

View File

@@ -1,7 +1,7 @@
# Vendored MLX-C Headers # Vendored MLX-C Headers
These header files are vendored from [mlx-c](https://github.com/ml-explore/mlx-c). These header files are vendored from [mlx-c](https://github.com/ml-explore/mlx-c).
The pinned version is in `MLX_VERSION` at the repo root. The pinned version is in `MLX_C_VERSION` at the repo root.
Headers are automatically refreshed when you run a CMake build: Headers are automatically refreshed when you run a CMake build:

View File

@@ -42,12 +42,14 @@ mlx_distributed_group_split(mlx_distributed_group group, int color, int key);
/** /**
* Check if distributed is available. * Check if distributed is available.
*/ */
bool mlx_distributed_is_available(void); bool mlx_distributed_is_available(const char* bk /* may be null */);
/** /**
* Initialize distributed. * Initialize distributed.
*/ */
mlx_distributed_group mlx_distributed_init(bool strict); mlx_distributed_group mlx_distributed_init(
bool strict,
const char* bk /* may be null */);
/**@}*/ /**@}*/

View File

@@ -166,6 +166,7 @@ int mlx_astype(
int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s); int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s); int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s); int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_bartlett(mlx_array* res, int M, const mlx_stream s);
int mlx_bitwise_and( int mlx_bitwise_and(
mlx_array* res, mlx_array* res,
const mlx_array a, const mlx_array a,
@@ -182,6 +183,7 @@ int mlx_bitwise_xor(
const mlx_array a, const mlx_array a,
const mlx_array b, const mlx_array b,
const mlx_stream s); const mlx_stream s);
int mlx_blackman(mlx_array* res, int M, const mlx_stream s);
int mlx_block_masked_mm( int mlx_block_masked_mm(
mlx_array* res, mlx_array* res,
const mlx_array a, const mlx_array a,
@@ -362,6 +364,7 @@ int mlx_dequantize(
mlx_optional_int group_size, mlx_optional_int group_size,
mlx_optional_int bits, mlx_optional_int bits,
const char* mode, const char* mode,
const mlx_array global_scale /* may be null */,
mlx_optional_dtype dtype, mlx_optional_dtype dtype,
const mlx_stream s); const mlx_stream s);
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s); int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
@@ -498,6 +501,8 @@ int mlx_hadamard_transform(
const mlx_array a, const mlx_array a,
mlx_optional_float scale, mlx_optional_float scale,
const mlx_stream s); const mlx_stream s);
int mlx_hamming(mlx_array* res, int M, const mlx_stream s);
int mlx_hanning(mlx_array* res, int M, const mlx_stream s);
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s); int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s); int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_inner( int mlx_inner(
@@ -790,6 +795,8 @@ int mlx_qqmm(
mlx_optional_int group_size, mlx_optional_int group_size,
mlx_optional_int bits, mlx_optional_int bits,
const char* mode, const char* mode,
const mlx_array global_scale_x /* may be null */,
const mlx_array global_scale_w /* may be null */,
const mlx_stream s); const mlx_stream s);
int mlx_quantize( int mlx_quantize(
mlx_vector_array* res, mlx_vector_array* res,
@@ -797,6 +804,7 @@ int mlx_quantize(
mlx_optional_int group_size, mlx_optional_int group_size,
mlx_optional_int bits, mlx_optional_int bits,
const char* mode, const char* mode,
const mlx_array global_scale /* may be null */,
const mlx_stream s); const mlx_stream s);
int mlx_quantized_matmul( int mlx_quantized_matmul(
mlx_array* res, mlx_array* res,

View File

@@ -7,8 +7,44 @@ package mlx
// #cgo LDFLAGS: -lstdc++ // #cgo LDFLAGS: -lstdc++
// #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate // #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
// #include "generated.h" // #include "generated.h"
// #include <string.h>
//
// static char _mlx_last_error_msg[1024] = {0};
// static int _mlx_last_error_flag = 0;
//
// static void _mlx_capture_error_handler(const char* msg, void* data) {
// (void)data;
// strncpy(_mlx_last_error_msg, msg, sizeof(_mlx_last_error_msg) - 1);
// _mlx_last_error_msg[sizeof(_mlx_last_error_msg) - 1] = '\0';
// _mlx_last_error_flag = 1;
// }
//
// static void mlx_install_capture_handler(void) {
// if (mlx_set_error_handler_) {
// mlx_set_error_handler_(_mlx_capture_error_handler, NULL, NULL);
// }
// }
//
// static void mlx_clear_last_error(void) {
// _mlx_last_error_flag = 0;
// _mlx_last_error_msg[0] = '\0';
// }
//
// static int mlx_had_last_error(void) {
// return _mlx_last_error_flag;
// }
//
// static const char* mlx_get_last_error(void) {
// return _mlx_last_error_flag ? _mlx_last_error_msg : NULL;
// }
import "C" import "C"
func init() {
// Replace the default exit(-1) error handler with one that captures
// the error message so we can surface it in Go.
C.mlx_install_capture_handler()
}
// Version returns the MLX core library version string. // Version returns the MLX core library version string.
func Version() string { func Version() string {
str := C.mlx_string_new() str := C.mlx_string_new()
@@ -31,10 +67,19 @@ func doEval(outputs []*Array, async bool) {
} }
} }
C.mlx_clear_last_error()
var rc C.int
if async { if async {
C.mlx_async_eval(vector) rc = C.mlx_async_eval(vector)
} else { } else {
C.mlx_eval(vector) rc = C.mlx_eval(vector)
}
if rc != 0 {
msg := "mlx eval failed"
if C.mlx_had_last_error() != 0 {
msg = C.GoString(C.mlx_get_last_error())
}
panic("mlx: " + msg)
} }
} }

View File

@@ -17,7 +17,8 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true} optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
res := C.mlx_vector_array_new() res := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(res) defer C.mlx_vector_array_free(res)
C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, DefaultStream().ctx) var globalScale C.mlx_array
C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, globalScale, DefaultStream().ctx)
vecSize := int(C.mlx_vector_array_size(res)) vecSize := int(C.mlx_vector_array_size(res))
w0 := New("QUANTIZE_W") w0 := New("QUANTIZE_W")
@@ -45,7 +46,8 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr
} }
out := New("DEQUANTIZE") out := New("DEQUANTIZE")
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, optDtype, DefaultStream().ctx) var globalScale C.mlx_array
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, globalScale, optDtype, DefaultStream().ctx)
return out return out
} }