diff --git a/Dockerfile b/Dockerfile index bf55cf9f5..43f511465 100644 --- a/Dockerfile +++ b/Dockerfile @@ -9,15 +9,10 @@ ARG JETPACK6VERSION=r36.4.0 ARG CMAKEVERSION=3.31.2 ARG VULKANVERSION=1.4.321.1 -# We require gcc v10 minimum. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64 -RUN yum install -y yum-utils \ - && yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \ - && rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \ - && dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \ - && dnf install -y ccache \ +RUN dnf install -y yum-utils ccache gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ gcc-toolset-11-binutils \ && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo -ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH +ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH ARG VULKANVERSION RUN wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \ && tar xvf /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \ diff --git a/MLX_VERSION b/MLX_VERSION index 5aff472dd..b043aa648 100644 --- a/MLX_VERSION +++ b/MLX_VERSION @@ -1 +1 @@ -v0.4.1 +v0.5.0 diff --git a/x/imagegen/mlx/generate_wrappers.go b/x/imagegen/mlx/generate_wrappers.go index a55def02b..8aa5bd0c8 100644 --- a/x/imagegen/mlx/generate_wrappers.go +++ b/x/imagegen/mlx/generate_wrappers.go @@ -16,10 +16,10 @@ import ( ) type Function struct { - Name string - ReturnType string - Params string - ParamNames []string + Name string + ReturnType string + Params string + ParamNames []string NeedsARM64Guard bool } @@ -29,6 +29,11 @@ func findHeaders(directory string) ([]string, error) { if err != nil { return err } + // Private headers contain C++ implementation helpers and are not part of + // the C API surface; parsing them can produce invalid wrapper signatures. + if d.IsDir() && d.Name() == "private" { + return fs.SkipDir + } if !d.IsDir() && strings.HasSuffix(path, ".h") { headers = append(headers, path) } @@ -194,10 +199,10 @@ func parseFunctions(content string) []Function { needsGuard := needsARM64Guard(funcName, returnType, params) functions = append(functions, Function{ - Name: funcName, - ReturnType: returnType, - Params: params, - ParamNames: paramNames, + Name: funcName, + ReturnType: returnType, + Params: params, + ParamNames: paramNames, NeedsARM64Guard: needsGuard, }) } diff --git a/x/imagegen/mlx/mlx.c b/x/imagegen/mlx/mlx.c index 564076f30..770b60922 100644 --- a/x/imagegen/mlx/mlx.c +++ b/x/imagegen/mlx/mlx.c @@ -20,6 +20,8 @@ mlx_array (*mlx_array_new_float64_ptr)(double val) = NULL; mlx_array (*mlx_array_new_double_ptr)(double val) = NULL; mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val) = NULL; mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype) = NULL; +mlx_array (*mlx_array_new_data_managed_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) = NULL; +mlx_array (*mlx_array_new_data_managed_payload_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) = NULL; int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src) = NULL; int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val) = NULL; int (*mlx_array_set_int_ptr)(mlx_array* arr, int val) = NULL; @@ -49,7 +51,7 @@ int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr) = NULL; int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr) = NULL; int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr) = NULL; int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr) = NULL; -int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr) = NULL; +int (*mlx_array_item_complex64_ptr)(mlx_complex64_t* res, const mlx_array arr) = NULL; #if defined(__aarch64__) || defined(_M_ARM64) int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr) = NULL; #endif @@ -67,7 +69,7 @@ const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr) = NULL; const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr) = NULL; const float* (*mlx_array_data_float32_ptr)(const mlx_array arr) = NULL; const double* (*mlx_array_data_float64_ptr)(const mlx_array arr) = NULL; -const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL; +const mlx_complex64_t* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL; #if defined(__aarch64__) || defined(_M_ARM64) const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr) = NULL; #endif @@ -123,6 +125,7 @@ int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id) = NULL; int (*mlx_disable_compile_ptr)(void) = NULL; int (*mlx_enable_compile_ptr)(void) = NULL; int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode) = NULL; +int (*mlx_cuda_is_available_ptr)(bool* res) = NULL; mlx_device (*mlx_device_new_ptr)(void) = NULL; mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index) = NULL; int (*mlx_device_free_ptr)(mlx_device dev) = NULL; @@ -133,6 +136,16 @@ int (*mlx_device_get_index_ptr)(int* index, mlx_device dev) = NULL; int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev) = NULL; int (*mlx_get_default_device_ptr)(mlx_device* dev) = NULL; int (*mlx_set_default_device_ptr)(mlx_device dev) = NULL; +int (*mlx_device_is_available_ptr)(bool* avail, mlx_device dev) = NULL; +int (*mlx_device_count_ptr)(int* count, mlx_device_type type) = NULL; +mlx_device_info (*mlx_device_info_new_ptr)(void) = NULL; +int (*mlx_device_info_get_ptr)(mlx_device_info* info, mlx_device dev) = NULL; +int (*mlx_device_info_free_ptr)(mlx_device_info info) = NULL; +int (*mlx_device_info_has_key_ptr)(bool* exists, mlx_device_info info, const char* key) = NULL; +int (*mlx_device_info_is_string_ptr)(bool* is_string, mlx_device_info info, const char* key) = NULL; +int (*mlx_device_info_get_string_ptr)(const char** value, mlx_device_info info, const char* key) = NULL; +int (*mlx_device_info_get_size_ptr)(size_t* value, mlx_device_info info, const char* key) = NULL; +int (*mlx_device_info_get_keys_ptr)(mlx_vector_string* keys, mlx_device_info info) = NULL; int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) = NULL; int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL; int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL; @@ -263,7 +276,6 @@ int (*mlx_reset_peak_memory_ptr)(void) = NULL; int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit) = NULL; int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit) = NULL; int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit) = NULL; -mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void) = NULL; int (*mlx_metal_is_available_ptr)(bool* res) = NULL; int (*mlx_metal_start_capture_ptr)(const char* path) = NULL; int (*mlx_metal_stop_capture_ptr)(void) = NULL; @@ -658,6 +670,16 @@ int mlx_load_functions(void* handle) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data\n"); return -1; } + mlx_array_new_data_managed_ptr = dlsym(handle, "mlx_array_new_data_managed"); + if (mlx_array_new_data_managed_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed\n"); + return -1; + } + mlx_array_new_data_managed_payload_ptr = dlsym(handle, "mlx_array_new_data_managed_payload"); + if (mlx_array_new_data_managed_payload_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed_payload\n"); + return -1; + } mlx_array_set_ptr = dlsym(handle, "mlx_array_set"); if (mlx_array_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set\n"); @@ -1141,6 +1163,11 @@ int mlx_load_functions(void* handle) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_compile_mode\n"); return -1; } + mlx_cuda_is_available_ptr = dlsym(handle, "mlx_cuda_is_available"); + if (mlx_cuda_is_available_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_cuda_is_available\n"); + return -1; + } mlx_device_new_ptr = dlsym(handle, "mlx_device_new"); if (mlx_device_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_new\n"); @@ -1191,6 +1218,56 @@ int mlx_load_functions(void* handle) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_default_device\n"); return -1; } + mlx_device_is_available_ptr = dlsym(handle, "mlx_device_is_available"); + if (mlx_device_is_available_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_device_is_available\n"); + return -1; + } + mlx_device_count_ptr = dlsym(handle, "mlx_device_count"); + if (mlx_device_count_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_device_count\n"); + return -1; + } + mlx_device_info_new_ptr = dlsym(handle, "mlx_device_info_new"); + if (mlx_device_info_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_new\n"); + return -1; + } + mlx_device_info_get_ptr = dlsym(handle, "mlx_device_info_get"); + if (mlx_device_info_get_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get\n"); + return -1; + } + mlx_device_info_free_ptr = dlsym(handle, "mlx_device_info_free"); + if (mlx_device_info_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_free\n"); + return -1; + } + mlx_device_info_has_key_ptr = dlsym(handle, "mlx_device_info_has_key"); + if (mlx_device_info_has_key_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_has_key\n"); + return -1; + } + mlx_device_info_is_string_ptr = dlsym(handle, "mlx_device_info_is_string"); + if (mlx_device_info_is_string_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_is_string\n"); + return -1; + } + mlx_device_info_get_string_ptr = dlsym(handle, "mlx_device_info_get_string"); + if (mlx_device_info_get_string_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_string\n"); + return -1; + } + mlx_device_info_get_size_ptr = dlsym(handle, "mlx_device_info_get_size"); + if (mlx_device_info_get_size_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_size\n"); + return -1; + } + mlx_device_info_get_keys_ptr = dlsym(handle, "mlx_device_info_get_keys"); + if (mlx_device_info_get_keys_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_keys\n"); + return -1; + } mlx_distributed_all_gather_ptr = dlsym(handle, "mlx_distributed_all_gather"); if (mlx_distributed_all_gather_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_gather\n"); @@ -1841,11 +1918,6 @@ int mlx_load_functions(void* handle) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_wired_limit\n"); return -1; } - mlx_metal_device_info_ptr = dlsym(handle, "mlx_metal_device_info"); - if (mlx_metal_device_info_ptr == NULL) { - fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_device_info\n"); - return -1; - } mlx_metal_is_available_ptr = dlsym(handle, "mlx_metal_is_available"); if (mlx_metal_is_available_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_is_available\n"); @@ -3528,6 +3600,14 @@ mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dt return mlx_array_new_data_ptr(data, shape, dim, dtype); } +mlx_array mlx_array_new_data_managed(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) { + return mlx_array_new_data_managed_ptr(data, shape, dim, dtype, dtor); +} + +mlx_array mlx_array_new_data_managed_payload(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) { + return mlx_array_new_data_managed_payload_ptr(data, shape, dim, dtype, payload, dtor); +} + int mlx_array_set(mlx_array* arr, const mlx_array src) { return mlx_array_set_ptr(arr, src); } @@ -3644,7 +3724,7 @@ int mlx_array_item_float64(double* res, const mlx_array arr) { return mlx_array_item_float64_ptr(res, arr); } -int mlx_array_item_complex64(float _Complex* res, const mlx_array arr) { +int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr) { return mlx_array_item_complex64_ptr(res, arr); } @@ -3704,7 +3784,7 @@ const double* mlx_array_data_float64(const mlx_array arr) { return mlx_array_data_float64_ptr(arr); } -const float _Complex* mlx_array_data_complex64(const mlx_array arr) { +const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr) { return mlx_array_data_complex64_ptr(arr); } @@ -3916,6 +3996,10 @@ int mlx_set_compile_mode(mlx_compile_mode mode) { return mlx_set_compile_mode_ptr(mode); } +int mlx_cuda_is_available(bool* res) { + return mlx_cuda_is_available_ptr(res); +} + mlx_device mlx_device_new(void) { return mlx_device_new_ptr(); } @@ -3956,6 +4040,46 @@ int mlx_set_default_device(mlx_device dev) { return mlx_set_default_device_ptr(dev); } +int mlx_device_is_available(bool* avail, mlx_device dev) { + return mlx_device_is_available_ptr(avail, dev); +} + +int mlx_device_count(int* count, mlx_device_type type) { + return mlx_device_count_ptr(count, type); +} + +mlx_device_info mlx_device_info_new(void) { + return mlx_device_info_new_ptr(); +} + +int mlx_device_info_get(mlx_device_info* info, mlx_device dev) { + return mlx_device_info_get_ptr(info, dev); +} + +int mlx_device_info_free(mlx_device_info info) { + return mlx_device_info_free_ptr(info); +} + +int mlx_device_info_has_key(bool* exists, mlx_device_info info, const char* key) { + return mlx_device_info_has_key_ptr(exists, info, key); +} + +int mlx_device_info_is_string(bool* is_string, mlx_device_info info, const char* key) { + return mlx_device_info_is_string_ptr(is_string, info, key); +} + +int mlx_device_info_get_string(const char** value, mlx_device_info info, const char* key) { + return mlx_device_info_get_string_ptr(value, info, key); +} + +int mlx_device_info_get_size(size_t* value, mlx_device_info info, const char* key) { + return mlx_device_info_get_size_ptr(value, info, key); +} + +int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info) { + return mlx_device_info_get_keys_ptr(keys, info); +} + int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) { return mlx_distributed_all_gather_ptr(res, x, group, S); } @@ -4476,10 +4600,6 @@ int mlx_set_wired_limit(size_t* res, size_t limit) { return mlx_set_wired_limit_ptr(res, limit); } -mlx_metal_device_info_t mlx_metal_device_info(void) { - return mlx_metal_device_info_ptr(); -} - int mlx_metal_is_available(bool* res) { return mlx_metal_is_available_ptr(res); } diff --git a/x/imagegen/mlx/mlx.h b/x/imagegen/mlx/mlx.h index d4ed1a905..34829d732 100644 --- a/x/imagegen/mlx/mlx.h +++ b/x/imagegen/mlx/mlx.h @@ -26,6 +26,8 @@ #undef mlx_array_new_double #undef mlx_array_new_complex #undef mlx_array_new_data +#undef mlx_array_new_data_managed +#undef mlx_array_new_data_managed_payload #undef mlx_array_set #undef mlx_array_set_bool #undef mlx_array_set_int @@ -121,6 +123,7 @@ #undef mlx_disable_compile #undef mlx_enable_compile #undef mlx_set_compile_mode +#undef mlx_cuda_is_available #undef mlx_device_new #undef mlx_device_new_type #undef mlx_device_free @@ -131,6 +134,16 @@ #undef mlx_device_get_type #undef mlx_get_default_device #undef mlx_set_default_device +#undef mlx_device_is_available +#undef mlx_device_count +#undef mlx_device_info_new +#undef mlx_device_info_get +#undef mlx_device_info_free +#undef mlx_device_info_has_key +#undef mlx_device_info_is_string +#undef mlx_device_info_get_string +#undef mlx_device_info_get_size +#undef mlx_device_info_get_keys #undef mlx_distributed_all_gather #undef mlx_distributed_all_max #undef mlx_distributed_all_min @@ -261,7 +274,6 @@ #undef mlx_set_cache_limit #undef mlx_set_memory_limit #undef mlx_set_wired_limit -#undef mlx_metal_device_info #undef mlx_metal_is_available #undef mlx_metal_start_capture #undef mlx_metal_stop_capture @@ -602,6 +614,8 @@ extern mlx_array (*mlx_array_new_float64_ptr)(double val); extern mlx_array (*mlx_array_new_double_ptr)(double val); extern mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val); extern mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype); +extern mlx_array (*mlx_array_new_data_managed_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)); +extern mlx_array (*mlx_array_new_data_managed_payload_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)); extern int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src); extern int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val); extern int (*mlx_array_set_int_ptr)(mlx_array* arr, int val); @@ -631,7 +645,7 @@ extern int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr); extern int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr); extern int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr); extern int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr); -extern int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr); +extern int (*mlx_array_item_complex64_ptr)(mlx_complex64_t* res, const mlx_array arr); #if defined(__aarch64__) || defined(_M_ARM64) extern int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr); #endif @@ -649,7 +663,7 @@ extern const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr); extern const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr); extern const float* (*mlx_array_data_float32_ptr)(const mlx_array arr); extern const double* (*mlx_array_data_float64_ptr)(const mlx_array arr); -extern const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr); +extern const mlx_complex64_t* (*mlx_array_data_complex64_ptr)(const mlx_array arr); #if defined(__aarch64__) || defined(_M_ARM64) extern const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr); #endif @@ -705,6 +719,7 @@ extern int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id); extern int (*mlx_disable_compile_ptr)(void); extern int (*mlx_enable_compile_ptr)(void); extern int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode); +extern int (*mlx_cuda_is_available_ptr)(bool* res); extern mlx_device (*mlx_device_new_ptr)(void); extern mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index); extern int (*mlx_device_free_ptr)(mlx_device dev); @@ -715,6 +730,16 @@ extern int (*mlx_device_get_index_ptr)(int* index, mlx_device dev); extern int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev); extern int (*mlx_get_default_device_ptr)(mlx_device* dev); extern int (*mlx_set_default_device_ptr)(mlx_device dev); +extern int (*mlx_device_is_available_ptr)(bool* avail, mlx_device dev); +extern int (*mlx_device_count_ptr)(int* count, mlx_device_type type); +extern mlx_device_info (*mlx_device_info_new_ptr)(void); +extern int (*mlx_device_info_get_ptr)(mlx_device_info* info, mlx_device dev); +extern int (*mlx_device_info_free_ptr)(mlx_device_info info); +extern int (*mlx_device_info_has_key_ptr)(bool* exists, mlx_device_info info, const char* key); +extern int (*mlx_device_info_is_string_ptr)(bool* is_string, mlx_device_info info, const char* key); +extern int (*mlx_device_info_get_string_ptr)(const char** value, mlx_device_info info, const char* key); +extern int (*mlx_device_info_get_size_ptr)(size_t* value, mlx_device_info info, const char* key); +extern int (*mlx_device_info_get_keys_ptr)(mlx_vector_string* keys, mlx_device_info info); extern int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S); extern int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s); extern int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s); @@ -845,7 +870,6 @@ extern int (*mlx_reset_peak_memory_ptr)(void); extern int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit); extern int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit); extern int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit); -extern mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void); extern int (*mlx_metal_is_available_ptr)(bool* res); extern int (*mlx_metal_start_capture_ptr)(const char* path); extern int (*mlx_metal_stop_capture_ptr)(void); @@ -1202,6 +1226,10 @@ mlx_array mlx_array_new_complex(float real_val, float imag_val); mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dtype dtype); +mlx_array mlx_array_new_data_managed(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)); + +mlx_array mlx_array_new_data_managed_payload(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)); + int mlx_array_set(mlx_array* arr, const mlx_array src); int mlx_array_set_bool(mlx_array* arr, bool val); @@ -1260,7 +1288,7 @@ int mlx_array_item_float32(float* res, const mlx_array arr); int mlx_array_item_float64(double* res, const mlx_array arr); -int mlx_array_item_complex64(float _Complex* res, const mlx_array arr); +int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr); #if defined(__aarch64__) || defined(_M_ARM64) int mlx_array_item_float16(float16_t* res, const mlx_array arr); @@ -1292,7 +1320,7 @@ const float* mlx_array_data_float32(const mlx_array arr); const double* mlx_array_data_float64(const mlx_array arr); -const float _Complex* mlx_array_data_complex64(const mlx_array arr); +const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr); #if defined(__aarch64__) || defined(_M_ARM64) const float16_t* mlx_array_data_float16(const mlx_array arr); @@ -1400,6 +1428,8 @@ int mlx_enable_compile(void); int mlx_set_compile_mode(mlx_compile_mode mode); +int mlx_cuda_is_available(bool* res); + mlx_device mlx_device_new(void); mlx_device mlx_device_new_type(mlx_device_type type, int index); @@ -1420,6 +1450,26 @@ int mlx_get_default_device(mlx_device* dev); int mlx_set_default_device(mlx_device dev); +int mlx_device_is_available(bool* avail, mlx_device dev); + +int mlx_device_count(int* count, mlx_device_type type); + +mlx_device_info mlx_device_info_new(void); + +int mlx_device_info_get(mlx_device_info* info, mlx_device dev); + +int mlx_device_info_free(mlx_device_info info); + +int mlx_device_info_has_key(bool* exists, mlx_device_info info, const char* key); + +int mlx_device_info_is_string(bool* is_string, mlx_device_info info, const char* key); + +int mlx_device_info_get_string(const char** value, mlx_device_info info, const char* key); + +int mlx_device_info_get_size(size_t* value, mlx_device_info info, const char* key); + +int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info); + int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S); int mlx_distributed_all_max(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s); @@ -1680,8 +1730,6 @@ int mlx_set_memory_limit(size_t* res, size_t limit); int mlx_set_wired_limit(size_t* res, size_t limit); -mlx_metal_device_info_t mlx_metal_device_info(void); - int mlx_metal_is_available(bool* res); int mlx_metal_start_capture(const char* path); diff --git a/x/mlxrunner/mlx/CMakeLists.txt b/x/mlxrunner/mlx/CMakeLists.txt index c41ce46f7..1ca13bdaf 100644 --- a/x/mlxrunner/mlx/CMakeLists.txt +++ b/x/mlxrunner/mlx/CMakeLists.txt @@ -15,7 +15,7 @@ set(CMAKE_INSTALL_RPATH "@loader_path") include(FetchContent) -set(MLX_C_GIT_TAG "v0.4.1" CACHE STRING "") +set(MLX_C_GIT_TAG "v0.5.0" CACHE STRING "") FetchContent_Declare( mlx-c diff --git a/x/mlxrunner/mlx/generated.c b/x/mlxrunner/mlx/generated.c index af99b631e..29d1330af 100644 --- a/x/mlxrunner/mlx/generated.c +++ b/x/mlxrunner/mlx/generated.c @@ -22,6 +22,19 @@ mlx_array (*mlx_array_new_data_)( const int* shape, int dim, mlx_dtype dtype) = NULL; +mlx_array (*mlx_array_new_data_managed_)( + void* data, + const int* shape, + int dim, + mlx_dtype dtype, + void (*dtor)(void*)) = NULL; +mlx_array (*mlx_array_new_data_managed_payload_)( + void* data, + const int* shape, + int dim, + mlx_dtype dtype, + void* payload, + void (*dtor)(void*)) = NULL; int (*mlx_array_set_)(mlx_array* arr, const mlx_array src) = NULL; int (*mlx_array_set_bool_)(mlx_array* arr, bool val) = NULL; int (*mlx_array_set_int_)(mlx_array* arr, int val) = NULL; @@ -56,7 +69,7 @@ int (*mlx_array_item_int32_)(int32_t* res, const mlx_array arr) = NULL; int (*mlx_array_item_int64_)(int64_t* res, const mlx_array arr) = NULL; int (*mlx_array_item_float32_)(float* res, const mlx_array arr) = NULL; int (*mlx_array_item_float64_)(double* res, const mlx_array arr) = NULL; -int (*mlx_array_item_complex64_)(float _Complex* res, const mlx_array arr) = NULL; +int (*mlx_array_item_complex64_)(mlx_complex64_t* res, const mlx_array arr) = NULL; int (*mlx_array_item_float16_)(float16_t* res, const mlx_array arr) = NULL; int (*mlx_array_item_bfloat16_)(bfloat16_t* res, const mlx_array arr) = NULL; const bool * (*mlx_array_data_bool_)(const mlx_array arr) = NULL; @@ -70,7 +83,7 @@ const int32_t * (*mlx_array_data_int32_)(const mlx_array arr) = NULL; const int64_t * (*mlx_array_data_int64_)(const mlx_array arr) = NULL; const float * (*mlx_array_data_float32_)(const mlx_array arr) = NULL; const double * (*mlx_array_data_float64_)(const mlx_array arr) = NULL; -const float _Complex * (*mlx_array_data_complex64_)(const mlx_array arr) = NULL; +const mlx_complex64_t * (*mlx_array_data_complex64_)(const mlx_array arr) = NULL; const float16_t * (*mlx_array_data_float16_)(const mlx_array arr) = NULL; const bfloat16_t * (*mlx_array_data_bfloat16_)(const mlx_array arr) = NULL; int (*_mlx_array_is_available_)(bool* res, const mlx_array arr) = NULL; @@ -94,10 +107,11 @@ int (*mlx_closure_apply_)( mlx_closure (*mlx_closure_new_unary_)(int (*fun)(mlx_array*, const mlx_array)) = NULL; mlx_closure_kwargs (*mlx_closure_kwargs_new_)(void) = NULL; int (*mlx_closure_kwargs_free_)(mlx_closure_kwargs cls) = NULL; -mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(int (*fun)( - mlx_vector_array*, - const mlx_vector_array, - const mlx_map_string_to_array)) = NULL; +mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_map_string_to_array)) = NULL; mlx_closure_kwargs (*mlx_closure_kwargs_new_func_payload_)( int (*fun)( mlx_vector_array*, @@ -136,11 +150,12 @@ int (*mlx_closure_value_and_grad_apply_)( const mlx_vector_array input) = NULL; mlx_closure_custom (*mlx_closure_custom_new_)(void) = NULL; int (*mlx_closure_custom_free_)(mlx_closure_custom cls) = NULL; -mlx_closure_custom (*mlx_closure_custom_new_func_)(int (*fun)( - mlx_vector_array*, - const mlx_vector_array, - const mlx_vector_array, - const mlx_vector_array)) = NULL; +mlx_closure_custom (*mlx_closure_custom_new_func_)( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const mlx_vector_array)) = NULL; mlx_closure_custom (*mlx_closure_custom_new_func_payload_)( int (*fun)( mlx_vector_array*, @@ -161,12 +176,13 @@ int (*mlx_closure_custom_apply_)( const mlx_vector_array input_2) = NULL; mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_)(void) = NULL; int (*mlx_closure_custom_jvp_free_)(mlx_closure_custom_jvp cls) = NULL; -mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(int (*fun)( - mlx_vector_array*, - const mlx_vector_array, - const mlx_vector_array, - const int*, - size_t _num)) = NULL; +mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const int*, + size_t _num)) = NULL; mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_payload_)( int (*fun)( mlx_vector_array*, @@ -189,12 +205,13 @@ int (*mlx_closure_custom_jvp_apply_)( size_t input_2_num) = NULL; mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_)(void) = NULL; int (*mlx_closure_custom_vmap_free_)(mlx_closure_custom_vmap cls) = NULL; -mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(int (*fun)( - mlx_vector_array*, - mlx_vector_int*, - const mlx_vector_array, - const int*, - size_t _num)) = NULL; +mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)( + int (*fun)( + mlx_vector_array*, + mlx_vector_int*, + const mlx_vector_array, + const int*, + size_t _num)) = NULL; mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_payload_)( int (*fun)( mlx_vector_array*, @@ -228,6 +245,7 @@ int (*mlx_detail_compile_erase_)(uintptr_t fun_id) = NULL; int (*mlx_disable_compile_)(void) = NULL; int (*mlx_enable_compile_)(void) = NULL; int (*mlx_set_compile_mode_)(mlx_compile_mode mode) = NULL; +int (*mlx_cuda_is_available_)(bool* res) = NULL; mlx_device (*mlx_device_new_)(void) = NULL; mlx_device (*mlx_device_new_type_)(mlx_device_type type, int index) = NULL; int (*mlx_device_free_)(mlx_device dev) = NULL; @@ -238,11 +256,28 @@ int (*mlx_device_get_index_)(int* index, mlx_device dev) = NULL; int (*mlx_device_get_type_)(mlx_device_type* type, mlx_device dev) = NULL; int (*mlx_get_default_device_)(mlx_device* dev) = NULL; int (*mlx_set_default_device_)(mlx_device dev) = NULL; -int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL; -int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL; -mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL; -bool (*mlx_distributed_is_available_)(void) = NULL; -mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL; +int (*mlx_device_is_available_)(bool* avail, mlx_device dev) = NULL; +int (*mlx_device_count_)(int* count, mlx_device_type type) = NULL; +mlx_device_info (*mlx_device_info_new_)(void) = NULL; +int (*mlx_device_info_get_)(mlx_device_info* info, mlx_device dev) = NULL; +int (*mlx_device_info_free_)(mlx_device_info info) = NULL; +int (*mlx_device_info_has_key_)( + bool* exists, + mlx_device_info info, + const char* key) = NULL; +int (*mlx_device_info_is_string_)( + bool* is_string, + mlx_device_info info, + const char* key) = NULL; +int (*mlx_device_info_get_string_)( + const char** value, + mlx_device_info info, + const char* key) = NULL; +int (*mlx_device_info_get_size_)( + size_t* value, + mlx_device_info info, + const char* key) = NULL; +int (*mlx_device_info_get_keys_)(mlx_vector_string* keys, mlx_device_info info) = NULL; int (*mlx_distributed_all_gather_)( mlx_array* res, const mlx_array x, @@ -288,6 +323,11 @@ int (*mlx_distributed_sum_scatter_)( const mlx_array x, const mlx_distributed_group group /* may be null */, const mlx_stream s) = NULL; +int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL; +int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL; +mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL; +bool (*mlx_distributed_is_available_)(void) = NULL; +mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL; void (*mlx_set_error_handler_)( mlx_error_handler_func handler, void* data, @@ -450,6 +490,16 @@ int (*mlx_fast_rope_)( int offset, const mlx_array freqs /* may be null */, const mlx_stream s) = NULL; +int (*mlx_fast_rope_dynamic_)( + mlx_array* res, + const mlx_array x, + int dims, + bool traditional, + mlx_optional_float base, + float scale, + const mlx_array offset, + const mlx_array freqs /* may be null */, + const mlx_stream s) = NULL; int (*mlx_fast_scaled_dot_product_attention_)( mlx_array* res, const mlx_array queries, @@ -560,14 +610,6 @@ int (*mlx_fft_rfftn_)( const int* axes, size_t axes_num, const mlx_stream s) = NULL; -mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable) = NULL; -int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io) = NULL; -int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io) = NULL; -int (*mlx_io_reader_free_)(mlx_io_reader io) = NULL; -mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL; -int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL; -int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL; -int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL; int (*mlx_load_reader_)( mlx_array* res, mlx_io_reader in_stream, @@ -593,6 +635,14 @@ int (*mlx_save_safetensors_)( const char* file, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) = NULL; +mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable) = NULL; +int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io) = NULL; +int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io) = NULL; +int (*mlx_io_reader_free_)(mlx_io_reader io) = NULL; +mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL; +int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL; +int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL; +int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL; int (*mlx_linalg_cholesky_)( mlx_array* res, const mlx_array a, @@ -733,7 +783,6 @@ int (*mlx_reset_peak_memory_)(void) = NULL; int (*mlx_set_cache_limit_)(size_t* res, size_t limit) = NULL; int (*mlx_set_memory_limit_)(size_t* res, size_t limit) = NULL; int (*mlx_set_wired_limit_)(size_t* res, size_t limit) = NULL; -mlx_metal_device_info_t (*mlx_metal_device_info_)(void) = NULL; int (*mlx_metal_is_available_)(bool* res) = NULL; int (*mlx_metal_start_capture_)(const char* path) = NULL; int (*mlx_metal_stop_capture_)(void) = NULL; @@ -1162,6 +1211,14 @@ int (*mlx_gather_)( const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s) = NULL; +int (*mlx_gather_single_)( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + int axis, + const int* slice_sizes, + size_t slice_sizes_num, + const mlx_stream s) = NULL; int (*mlx_gather_mm_)( mlx_array* res, const mlx_array a, @@ -1483,6 +1540,15 @@ int (*mlx_put_along_axis_)( const mlx_array values, int axis, const mlx_stream s) = NULL; +int (*mlx_qqmm_)( + mlx_array* res, + const mlx_array x, + const mlx_array w, + const mlx_array w_scales /* may be null */, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + const mlx_stream s) = NULL; int (*mlx_quantize_)( mlx_vector_array* res, const mlx_array w, @@ -1566,6 +1632,13 @@ int (*mlx_scatter_)( const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_scatter_single_)( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s) = NULL; int (*mlx_scatter_add_)( mlx_array* res, const mlx_array a, @@ -1574,6 +1647,13 @@ int (*mlx_scatter_add_)( const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_scatter_add_single_)( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s) = NULL; int (*mlx_scatter_add_axis_)( mlx_array* res, const mlx_array a, @@ -1589,6 +1669,13 @@ int (*mlx_scatter_max_)( const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_scatter_max_single_)( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s) = NULL; int (*mlx_scatter_min_)( mlx_array* res, const mlx_array a, @@ -1597,6 +1684,13 @@ int (*mlx_scatter_min_)( const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_scatter_min_single_)( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s) = NULL; int (*mlx_scatter_prod_)( mlx_array* res, const mlx_array a, @@ -1605,6 +1699,13 @@ int (*mlx_scatter_prod_)( const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_scatter_prod_single_)( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s) = NULL; int (*mlx_segmented_mm_)( mlx_array* res, const mlx_array a, @@ -2028,22 +2129,6 @@ mlx_string (*mlx_string_new_data_)(const char* str) = NULL; int (*mlx_string_set_)(mlx_string* str, const mlx_string src) = NULL; const char * (*mlx_string_data_)(mlx_string str) = NULL; int (*mlx_string_free_)(mlx_string str) = NULL; -int (*mlx_detail_vmap_replace_)( - mlx_vector_array* res, - const mlx_vector_array inputs, - const mlx_vector_array s_inputs, - const mlx_vector_array s_outputs, - const int* in_axes, - size_t in_axes_num, - const int* out_axes, - size_t out_axes_num) = NULL; -int (*mlx_detail_vmap_trace_)( - mlx_vector_array* res_0, - mlx_vector_array* res_1, - const mlx_closure fun, - const mlx_vector_array inputs, - const int* in_axes, - size_t in_axes_num) = NULL; int (*mlx_async_eval_)(const mlx_vector_array outputs) = NULL; int (*mlx_checkpoint_)(mlx_closure* res, const mlx_closure fun) = NULL; int (*mlx_custom_function_)( @@ -2074,6 +2159,22 @@ int (*mlx_vjp_)( const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array cotangents) = NULL; +int (*mlx_detail_vmap_replace_)( + mlx_vector_array* res, + const mlx_vector_array inputs, + const mlx_vector_array s_inputs, + const mlx_vector_array s_outputs, + const int* in_axes, + size_t in_axes_num, + const int* out_axes, + size_t out_axes_num) = NULL; +int (*mlx_detail_vmap_trace_)( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + const mlx_closure fun, + const mlx_vector_array inputs, + const int* in_axes, + size_t in_axes_num) = NULL; mlx_vector_array (*mlx_vector_array_new_)(void) = NULL; int (*mlx_vector_array_set_)(mlx_vector_array* vec, const mlx_vector_array src) = NULL; int (*mlx_vector_array_free_)(mlx_vector_array vec) = NULL; @@ -2166,6 +2267,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_array_new_double); CHECK_LOAD(handle, mlx_array_new_complex); CHECK_LOAD(handle, mlx_array_new_data); + CHECK_LOAD(handle, mlx_array_new_data_managed); + CHECK_LOAD(handle, mlx_array_new_data_managed_payload); CHECK_LOAD(handle, mlx_array_set); CHECK_LOAD(handle, mlx_array_set_bool); CHECK_LOAD(handle, mlx_array_set_int); @@ -2261,6 +2364,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_disable_compile); CHECK_LOAD(handle, mlx_enable_compile); CHECK_LOAD(handle, mlx_set_compile_mode); + CHECK_LOAD(handle, mlx_cuda_is_available); CHECK_LOAD(handle, mlx_device_new); CHECK_LOAD(handle, mlx_device_new_type); CHECK_LOAD(handle, mlx_device_free); @@ -2271,11 +2375,16 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_device_get_type); CHECK_LOAD(handle, mlx_get_default_device); CHECK_LOAD(handle, mlx_set_default_device); - CHECK_LOAD(handle, mlx_distributed_group_rank); - CHECK_LOAD(handle, mlx_distributed_group_size); - CHECK_LOAD(handle, mlx_distributed_group_split); - CHECK_LOAD(handle, mlx_distributed_is_available); - CHECK_LOAD(handle, mlx_distributed_init); + CHECK_LOAD(handle, mlx_device_is_available); + CHECK_LOAD(handle, mlx_device_count); + CHECK_LOAD(handle, mlx_device_info_new); + CHECK_LOAD(handle, mlx_device_info_get); + CHECK_LOAD(handle, mlx_device_info_free); + CHECK_LOAD(handle, mlx_device_info_has_key); + CHECK_LOAD(handle, mlx_device_info_is_string); + CHECK_LOAD(handle, mlx_device_info_get_string); + CHECK_LOAD(handle, mlx_device_info_get_size); + CHECK_LOAD(handle, mlx_device_info_get_keys); CHECK_LOAD(handle, mlx_distributed_all_gather); CHECK_LOAD(handle, mlx_distributed_all_max); CHECK_LOAD(handle, mlx_distributed_all_min); @@ -2284,6 +2393,11 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_distributed_recv_like); CHECK_LOAD(handle, mlx_distributed_send); CHECK_LOAD(handle, mlx_distributed_sum_scatter); + CHECK_LOAD(handle, mlx_distributed_group_rank); + CHECK_LOAD(handle, mlx_distributed_group_size); + CHECK_LOAD(handle, mlx_distributed_group_split); + CHECK_LOAD(handle, mlx_distributed_is_available); + CHECK_LOAD(handle, mlx_distributed_init); CHECK_LOAD(handle, mlx_set_error_handler); CHECK_LOAD(handle, _mlx_error); CHECK_LOAD(handle, mlx_export_function); @@ -2325,6 +2439,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_fast_metal_kernel_apply); CHECK_LOAD(handle, mlx_fast_rms_norm); CHECK_LOAD(handle, mlx_fast_rope); + CHECK_LOAD(handle, mlx_fast_rope_dynamic); CHECK_LOAD(handle, mlx_fast_scaled_dot_product_attention); CHECK_LOAD(handle, mlx_fft_fft); CHECK_LOAD(handle, mlx_fft_fft2); @@ -2340,14 +2455,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_fft_rfft); CHECK_LOAD(handle, mlx_fft_rfft2); CHECK_LOAD(handle, mlx_fft_rfftn); - CHECK_LOAD(handle, mlx_io_reader_new); - CHECK_LOAD(handle, mlx_io_reader_descriptor); - CHECK_LOAD(handle, mlx_io_reader_tostring); - CHECK_LOAD(handle, mlx_io_reader_free); - CHECK_LOAD(handle, mlx_io_writer_new); - CHECK_LOAD(handle, mlx_io_writer_descriptor); - CHECK_LOAD(handle, mlx_io_writer_tostring); - CHECK_LOAD(handle, mlx_io_writer_free); CHECK_LOAD(handle, mlx_load_reader); CHECK_LOAD(handle, mlx_load); CHECK_LOAD(handle, mlx_load_safetensors_reader); @@ -2356,6 +2463,14 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_save); CHECK_LOAD(handle, mlx_save_safetensors_writer); CHECK_LOAD(handle, mlx_save_safetensors); + CHECK_LOAD(handle, mlx_io_reader_new); + CHECK_LOAD(handle, mlx_io_reader_descriptor); + CHECK_LOAD(handle, mlx_io_reader_tostring); + CHECK_LOAD(handle, mlx_io_reader_free); + CHECK_LOAD(handle, mlx_io_writer_new); + CHECK_LOAD(handle, mlx_io_writer_descriptor); + CHECK_LOAD(handle, mlx_io_writer_tostring); + CHECK_LOAD(handle, mlx_io_writer_free); CHECK_LOAD(handle, mlx_linalg_cholesky); CHECK_LOAD(handle, mlx_linalg_cholesky_inv); CHECK_LOAD(handle, mlx_linalg_cross); @@ -2400,7 +2515,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_set_cache_limit); CHECK_LOAD(handle, mlx_set_memory_limit); CHECK_LOAD(handle, mlx_set_wired_limit); - CHECK_LOAD(handle, mlx_metal_device_info); CHECK_LOAD(handle, mlx_metal_is_available); CHECK_LOAD(handle, mlx_metal_start_capture); CHECK_LOAD(handle, mlx_metal_stop_capture); @@ -2486,6 +2600,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_full); CHECK_LOAD(handle, mlx_full_like); CHECK_LOAD(handle, mlx_gather); + CHECK_LOAD(handle, mlx_gather_single); CHECK_LOAD(handle, mlx_gather_mm); CHECK_LOAD(handle, mlx_gather_qmm); CHECK_LOAD(handle, mlx_greater); @@ -2550,6 +2665,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_prod_axis); CHECK_LOAD(handle, mlx_prod); CHECK_LOAD(handle, mlx_put_along_axis); + CHECK_LOAD(handle, mlx_qqmm); CHECK_LOAD(handle, mlx_quantize); CHECK_LOAD(handle, mlx_quantized_matmul); CHECK_LOAD(handle, mlx_radians); @@ -2566,11 +2682,16 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_round); CHECK_LOAD(handle, mlx_rsqrt); CHECK_LOAD(handle, mlx_scatter); + CHECK_LOAD(handle, mlx_scatter_single); CHECK_LOAD(handle, mlx_scatter_add); + CHECK_LOAD(handle, mlx_scatter_add_single); CHECK_LOAD(handle, mlx_scatter_add_axis); CHECK_LOAD(handle, mlx_scatter_max); + CHECK_LOAD(handle, mlx_scatter_max_single); CHECK_LOAD(handle, mlx_scatter_min); + CHECK_LOAD(handle, mlx_scatter_min_single); CHECK_LOAD(handle, mlx_scatter_prod); + CHECK_LOAD(handle, mlx_scatter_prod_single); CHECK_LOAD(handle, mlx_segmented_mm); CHECK_LOAD(handle, mlx_sigmoid); CHECK_LOAD(handle, mlx_sign); @@ -2665,8 +2786,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_string_set); CHECK_LOAD(handle, mlx_string_data); CHECK_LOAD(handle, mlx_string_free); - CHECK_LOAD(handle, mlx_detail_vmap_replace); - CHECK_LOAD(handle, mlx_detail_vmap_trace); CHECK_LOAD(handle, mlx_async_eval); CHECK_LOAD(handle, mlx_checkpoint); CHECK_LOAD(handle, mlx_custom_function); @@ -2675,6 +2794,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_jvp); CHECK_LOAD(handle, mlx_value_and_grad); CHECK_LOAD(handle, mlx_vjp); + CHECK_LOAD(handle, mlx_detail_vmap_replace); + CHECK_LOAD(handle, mlx_detail_vmap_trace); CHECK_LOAD(handle, mlx_vector_array_new); CHECK_LOAD(handle, mlx_vector_array_set); CHECK_LOAD(handle, mlx_vector_array_free); diff --git a/x/mlxrunner/mlx/generated.h b/x/mlxrunner/mlx/generated.h index c88946d9f..e8dfa7b90 100644 --- a/x/mlxrunner/mlx/generated.h +++ b/x/mlxrunner/mlx/generated.h @@ -17,6 +17,8 @@ #define mlx_array_new_double mlx_array_new_double_mlx_gen_orig_ #define mlx_array_new_complex mlx_array_new_complex_mlx_gen_orig_ #define mlx_array_new_data mlx_array_new_data_mlx_gen_orig_ +#define mlx_array_new_data_managed mlx_array_new_data_managed_mlx_gen_orig_ +#define mlx_array_new_data_managed_payload mlx_array_new_data_managed_payload_mlx_gen_orig_ #define mlx_array_set mlx_array_set_mlx_gen_orig_ #define mlx_array_set_bool mlx_array_set_bool_mlx_gen_orig_ #define mlx_array_set_int mlx_array_set_int_mlx_gen_orig_ @@ -112,6 +114,7 @@ #define mlx_disable_compile mlx_disable_compile_mlx_gen_orig_ #define mlx_enable_compile mlx_enable_compile_mlx_gen_orig_ #define mlx_set_compile_mode mlx_set_compile_mode_mlx_gen_orig_ +#define mlx_cuda_is_available mlx_cuda_is_available_mlx_gen_orig_ #define mlx_device_new mlx_device_new_mlx_gen_orig_ #define mlx_device_new_type mlx_device_new_type_mlx_gen_orig_ #define mlx_device_free mlx_device_free_mlx_gen_orig_ @@ -122,11 +125,16 @@ #define mlx_device_get_type mlx_device_get_type_mlx_gen_orig_ #define mlx_get_default_device mlx_get_default_device_mlx_gen_orig_ #define mlx_set_default_device mlx_set_default_device_mlx_gen_orig_ -#define mlx_distributed_group_rank mlx_distributed_group_rank_mlx_gen_orig_ -#define mlx_distributed_group_size mlx_distributed_group_size_mlx_gen_orig_ -#define mlx_distributed_group_split mlx_distributed_group_split_mlx_gen_orig_ -#define mlx_distributed_is_available mlx_distributed_is_available_mlx_gen_orig_ -#define mlx_distributed_init mlx_distributed_init_mlx_gen_orig_ +#define mlx_device_is_available mlx_device_is_available_mlx_gen_orig_ +#define mlx_device_count mlx_device_count_mlx_gen_orig_ +#define mlx_device_info_new mlx_device_info_new_mlx_gen_orig_ +#define mlx_device_info_get mlx_device_info_get_mlx_gen_orig_ +#define mlx_device_info_free mlx_device_info_free_mlx_gen_orig_ +#define mlx_device_info_has_key mlx_device_info_has_key_mlx_gen_orig_ +#define mlx_device_info_is_string mlx_device_info_is_string_mlx_gen_orig_ +#define mlx_device_info_get_string mlx_device_info_get_string_mlx_gen_orig_ +#define mlx_device_info_get_size mlx_device_info_get_size_mlx_gen_orig_ +#define mlx_device_info_get_keys mlx_device_info_get_keys_mlx_gen_orig_ #define mlx_distributed_all_gather mlx_distributed_all_gather_mlx_gen_orig_ #define mlx_distributed_all_max mlx_distributed_all_max_mlx_gen_orig_ #define mlx_distributed_all_min mlx_distributed_all_min_mlx_gen_orig_ @@ -135,6 +143,11 @@ #define mlx_distributed_recv_like mlx_distributed_recv_like_mlx_gen_orig_ #define mlx_distributed_send mlx_distributed_send_mlx_gen_orig_ #define mlx_distributed_sum_scatter mlx_distributed_sum_scatter_mlx_gen_orig_ +#define mlx_distributed_group_rank mlx_distributed_group_rank_mlx_gen_orig_ +#define mlx_distributed_group_size mlx_distributed_group_size_mlx_gen_orig_ +#define mlx_distributed_group_split mlx_distributed_group_split_mlx_gen_orig_ +#define mlx_distributed_is_available mlx_distributed_is_available_mlx_gen_orig_ +#define mlx_distributed_init mlx_distributed_init_mlx_gen_orig_ #define mlx_set_error_handler mlx_set_error_handler_mlx_gen_orig_ #define _mlx_error _mlx_error_mlx_gen_orig_ #define mlx_export_function mlx_export_function_mlx_gen_orig_ @@ -176,6 +189,7 @@ #define mlx_fast_metal_kernel_apply mlx_fast_metal_kernel_apply_mlx_gen_orig_ #define mlx_fast_rms_norm mlx_fast_rms_norm_mlx_gen_orig_ #define mlx_fast_rope mlx_fast_rope_mlx_gen_orig_ +#define mlx_fast_rope_dynamic mlx_fast_rope_dynamic_mlx_gen_orig_ #define mlx_fast_scaled_dot_product_attention mlx_fast_scaled_dot_product_attention_mlx_gen_orig_ #define mlx_fft_fft mlx_fft_fft_mlx_gen_orig_ #define mlx_fft_fft2 mlx_fft_fft2_mlx_gen_orig_ @@ -191,14 +205,6 @@ #define mlx_fft_rfft mlx_fft_rfft_mlx_gen_orig_ #define mlx_fft_rfft2 mlx_fft_rfft2_mlx_gen_orig_ #define mlx_fft_rfftn mlx_fft_rfftn_mlx_gen_orig_ -#define mlx_io_reader_new mlx_io_reader_new_mlx_gen_orig_ -#define mlx_io_reader_descriptor mlx_io_reader_descriptor_mlx_gen_orig_ -#define mlx_io_reader_tostring mlx_io_reader_tostring_mlx_gen_orig_ -#define mlx_io_reader_free mlx_io_reader_free_mlx_gen_orig_ -#define mlx_io_writer_new mlx_io_writer_new_mlx_gen_orig_ -#define mlx_io_writer_descriptor mlx_io_writer_descriptor_mlx_gen_orig_ -#define mlx_io_writer_tostring mlx_io_writer_tostring_mlx_gen_orig_ -#define mlx_io_writer_free mlx_io_writer_free_mlx_gen_orig_ #define mlx_load_reader mlx_load_reader_mlx_gen_orig_ #define mlx_load mlx_load_mlx_gen_orig_ #define mlx_load_safetensors_reader mlx_load_safetensors_reader_mlx_gen_orig_ @@ -207,6 +213,14 @@ #define mlx_save mlx_save_mlx_gen_orig_ #define mlx_save_safetensors_writer mlx_save_safetensors_writer_mlx_gen_orig_ #define mlx_save_safetensors mlx_save_safetensors_mlx_gen_orig_ +#define mlx_io_reader_new mlx_io_reader_new_mlx_gen_orig_ +#define mlx_io_reader_descriptor mlx_io_reader_descriptor_mlx_gen_orig_ +#define mlx_io_reader_tostring mlx_io_reader_tostring_mlx_gen_orig_ +#define mlx_io_reader_free mlx_io_reader_free_mlx_gen_orig_ +#define mlx_io_writer_new mlx_io_writer_new_mlx_gen_orig_ +#define mlx_io_writer_descriptor mlx_io_writer_descriptor_mlx_gen_orig_ +#define mlx_io_writer_tostring mlx_io_writer_tostring_mlx_gen_orig_ +#define mlx_io_writer_free mlx_io_writer_free_mlx_gen_orig_ #define mlx_linalg_cholesky mlx_linalg_cholesky_mlx_gen_orig_ #define mlx_linalg_cholesky_inv mlx_linalg_cholesky_inv_mlx_gen_orig_ #define mlx_linalg_cross mlx_linalg_cross_mlx_gen_orig_ @@ -251,7 +265,6 @@ #define mlx_set_cache_limit mlx_set_cache_limit_mlx_gen_orig_ #define mlx_set_memory_limit mlx_set_memory_limit_mlx_gen_orig_ #define mlx_set_wired_limit mlx_set_wired_limit_mlx_gen_orig_ -#define mlx_metal_device_info mlx_metal_device_info_mlx_gen_orig_ #define mlx_metal_is_available mlx_metal_is_available_mlx_gen_orig_ #define mlx_metal_start_capture mlx_metal_start_capture_mlx_gen_orig_ #define mlx_metal_stop_capture mlx_metal_stop_capture_mlx_gen_orig_ @@ -337,6 +350,7 @@ #define mlx_full mlx_full_mlx_gen_orig_ #define mlx_full_like mlx_full_like_mlx_gen_orig_ #define mlx_gather mlx_gather_mlx_gen_orig_ +#define mlx_gather_single mlx_gather_single_mlx_gen_orig_ #define mlx_gather_mm mlx_gather_mm_mlx_gen_orig_ #define mlx_gather_qmm mlx_gather_qmm_mlx_gen_orig_ #define mlx_greater mlx_greater_mlx_gen_orig_ @@ -401,6 +415,7 @@ #define mlx_prod_axis mlx_prod_axis_mlx_gen_orig_ #define mlx_prod mlx_prod_mlx_gen_orig_ #define mlx_put_along_axis mlx_put_along_axis_mlx_gen_orig_ +#define mlx_qqmm mlx_qqmm_mlx_gen_orig_ #define mlx_quantize mlx_quantize_mlx_gen_orig_ #define mlx_quantized_matmul mlx_quantized_matmul_mlx_gen_orig_ #define mlx_radians mlx_radians_mlx_gen_orig_ @@ -417,11 +432,16 @@ #define mlx_round mlx_round_mlx_gen_orig_ #define mlx_rsqrt mlx_rsqrt_mlx_gen_orig_ #define mlx_scatter mlx_scatter_mlx_gen_orig_ +#define mlx_scatter_single mlx_scatter_single_mlx_gen_orig_ #define mlx_scatter_add mlx_scatter_add_mlx_gen_orig_ +#define mlx_scatter_add_single mlx_scatter_add_single_mlx_gen_orig_ #define mlx_scatter_add_axis mlx_scatter_add_axis_mlx_gen_orig_ #define mlx_scatter_max mlx_scatter_max_mlx_gen_orig_ +#define mlx_scatter_max_single mlx_scatter_max_single_mlx_gen_orig_ #define mlx_scatter_min mlx_scatter_min_mlx_gen_orig_ +#define mlx_scatter_min_single mlx_scatter_min_single_mlx_gen_orig_ #define mlx_scatter_prod mlx_scatter_prod_mlx_gen_orig_ +#define mlx_scatter_prod_single mlx_scatter_prod_single_mlx_gen_orig_ #define mlx_segmented_mm mlx_segmented_mm_mlx_gen_orig_ #define mlx_sigmoid mlx_sigmoid_mlx_gen_orig_ #define mlx_sign mlx_sign_mlx_gen_orig_ @@ -516,8 +536,6 @@ #define mlx_string_set mlx_string_set_mlx_gen_orig_ #define mlx_string_data mlx_string_data_mlx_gen_orig_ #define mlx_string_free mlx_string_free_mlx_gen_orig_ -#define mlx_detail_vmap_replace mlx_detail_vmap_replace_mlx_gen_orig_ -#define mlx_detail_vmap_trace mlx_detail_vmap_trace_mlx_gen_orig_ #define mlx_async_eval mlx_async_eval_mlx_gen_orig_ #define mlx_checkpoint mlx_checkpoint_mlx_gen_orig_ #define mlx_custom_function mlx_custom_function_mlx_gen_orig_ @@ -526,6 +544,8 @@ #define mlx_jvp mlx_jvp_mlx_gen_orig_ #define mlx_value_and_grad mlx_value_and_grad_mlx_gen_orig_ #define mlx_vjp mlx_vjp_mlx_gen_orig_ +#define mlx_detail_vmap_replace mlx_detail_vmap_replace_mlx_gen_orig_ +#define mlx_detail_vmap_trace mlx_detail_vmap_trace_mlx_gen_orig_ #define mlx_vector_array_new mlx_vector_array_new_mlx_gen_orig_ #define mlx_vector_array_set mlx_vector_array_set_mlx_gen_orig_ #define mlx_vector_array_free mlx_vector_array_free_mlx_gen_orig_ @@ -586,6 +606,8 @@ #undef mlx_array_new_double #undef mlx_array_new_complex #undef mlx_array_new_data +#undef mlx_array_new_data_managed +#undef mlx_array_new_data_managed_payload #undef mlx_array_set #undef mlx_array_set_bool #undef mlx_array_set_int @@ -681,6 +703,7 @@ #undef mlx_disable_compile #undef mlx_enable_compile #undef mlx_set_compile_mode +#undef mlx_cuda_is_available #undef mlx_device_new #undef mlx_device_new_type #undef mlx_device_free @@ -691,11 +714,16 @@ #undef mlx_device_get_type #undef mlx_get_default_device #undef mlx_set_default_device -#undef mlx_distributed_group_rank -#undef mlx_distributed_group_size -#undef mlx_distributed_group_split -#undef mlx_distributed_is_available -#undef mlx_distributed_init +#undef mlx_device_is_available +#undef mlx_device_count +#undef mlx_device_info_new +#undef mlx_device_info_get +#undef mlx_device_info_free +#undef mlx_device_info_has_key +#undef mlx_device_info_is_string +#undef mlx_device_info_get_string +#undef mlx_device_info_get_size +#undef mlx_device_info_get_keys #undef mlx_distributed_all_gather #undef mlx_distributed_all_max #undef mlx_distributed_all_min @@ -704,6 +732,11 @@ #undef mlx_distributed_recv_like #undef mlx_distributed_send #undef mlx_distributed_sum_scatter +#undef mlx_distributed_group_rank +#undef mlx_distributed_group_size +#undef mlx_distributed_group_split +#undef mlx_distributed_is_available +#undef mlx_distributed_init #undef mlx_set_error_handler #undef _mlx_error #undef mlx_export_function @@ -745,6 +778,7 @@ #undef mlx_fast_metal_kernel_apply #undef mlx_fast_rms_norm #undef mlx_fast_rope +#undef mlx_fast_rope_dynamic #undef mlx_fast_scaled_dot_product_attention #undef mlx_fft_fft #undef mlx_fft_fft2 @@ -760,14 +794,6 @@ #undef mlx_fft_rfft #undef mlx_fft_rfft2 #undef mlx_fft_rfftn -#undef mlx_io_reader_new -#undef mlx_io_reader_descriptor -#undef mlx_io_reader_tostring -#undef mlx_io_reader_free -#undef mlx_io_writer_new -#undef mlx_io_writer_descriptor -#undef mlx_io_writer_tostring -#undef mlx_io_writer_free #undef mlx_load_reader #undef mlx_load #undef mlx_load_safetensors_reader @@ -776,6 +802,14 @@ #undef mlx_save #undef mlx_save_safetensors_writer #undef mlx_save_safetensors +#undef mlx_io_reader_new +#undef mlx_io_reader_descriptor +#undef mlx_io_reader_tostring +#undef mlx_io_reader_free +#undef mlx_io_writer_new +#undef mlx_io_writer_descriptor +#undef mlx_io_writer_tostring +#undef mlx_io_writer_free #undef mlx_linalg_cholesky #undef mlx_linalg_cholesky_inv #undef mlx_linalg_cross @@ -820,7 +854,6 @@ #undef mlx_set_cache_limit #undef mlx_set_memory_limit #undef mlx_set_wired_limit -#undef mlx_metal_device_info #undef mlx_metal_is_available #undef mlx_metal_start_capture #undef mlx_metal_stop_capture @@ -906,6 +939,7 @@ #undef mlx_full #undef mlx_full_like #undef mlx_gather +#undef mlx_gather_single #undef mlx_gather_mm #undef mlx_gather_qmm #undef mlx_greater @@ -970,6 +1004,7 @@ #undef mlx_prod_axis #undef mlx_prod #undef mlx_put_along_axis +#undef mlx_qqmm #undef mlx_quantize #undef mlx_quantized_matmul #undef mlx_radians @@ -986,11 +1021,16 @@ #undef mlx_round #undef mlx_rsqrt #undef mlx_scatter +#undef mlx_scatter_single #undef mlx_scatter_add +#undef mlx_scatter_add_single #undef mlx_scatter_add_axis #undef mlx_scatter_max +#undef mlx_scatter_max_single #undef mlx_scatter_min +#undef mlx_scatter_min_single #undef mlx_scatter_prod +#undef mlx_scatter_prod_single #undef mlx_segmented_mm #undef mlx_sigmoid #undef mlx_sign @@ -1085,8 +1125,6 @@ #undef mlx_string_set #undef mlx_string_data #undef mlx_string_free -#undef mlx_detail_vmap_replace -#undef mlx_detail_vmap_trace #undef mlx_async_eval #undef mlx_checkpoint #undef mlx_custom_function @@ -1095,6 +1133,8 @@ #undef mlx_jvp #undef mlx_value_and_grad #undef mlx_vjp +#undef mlx_detail_vmap_replace +#undef mlx_detail_vmap_trace #undef mlx_vector_array_new #undef mlx_vector_array_set #undef mlx_vector_array_free @@ -1157,6 +1197,19 @@ extern mlx_array (*mlx_array_new_data_)( const int* shape, int dim, mlx_dtype dtype); +extern mlx_array (*mlx_array_new_data_managed_)( + void* data, + const int* shape, + int dim, + mlx_dtype dtype, + void (*dtor)(void*)); +extern mlx_array (*mlx_array_new_data_managed_payload_)( + void* data, + const int* shape, + int dim, + mlx_dtype dtype, + void* payload, + void (*dtor)(void*)); extern int (*mlx_array_set_)(mlx_array* arr, const mlx_array src); extern int (*mlx_array_set_bool_)(mlx_array* arr, bool val); extern int (*mlx_array_set_int_)(mlx_array* arr, int val); @@ -1191,7 +1244,7 @@ extern int (*mlx_array_item_int32_)(int32_t* res, const mlx_array arr); extern int (*mlx_array_item_int64_)(int64_t* res, const mlx_array arr); extern int (*mlx_array_item_float32_)(float* res, const mlx_array arr); extern int (*mlx_array_item_float64_)(double* res, const mlx_array arr); -extern int (*mlx_array_item_complex64_)(float _Complex* res, const mlx_array arr); +extern int (*mlx_array_item_complex64_)(mlx_complex64_t* res, const mlx_array arr); extern int (*mlx_array_item_float16_)(float16_t* res, const mlx_array arr); extern int (*mlx_array_item_bfloat16_)(bfloat16_t* res, const mlx_array arr); extern const bool * (*mlx_array_data_bool_)(const mlx_array arr); @@ -1205,7 +1258,7 @@ extern const int32_t * (*mlx_array_data_int32_)(const mlx_array arr); extern const int64_t * (*mlx_array_data_int64_)(const mlx_array arr); extern const float * (*mlx_array_data_float32_)(const mlx_array arr); extern const double * (*mlx_array_data_float64_)(const mlx_array arr); -extern const float _Complex * (*mlx_array_data_complex64_)(const mlx_array arr); +extern const mlx_complex64_t * (*mlx_array_data_complex64_)(const mlx_array arr); extern const float16_t * (*mlx_array_data_float16_)(const mlx_array arr); extern const bfloat16_t * (*mlx_array_data_bfloat16_)(const mlx_array arr); extern int (*_mlx_array_is_available_)(bool* res, const mlx_array arr); @@ -1229,10 +1282,11 @@ extern int (*mlx_closure_apply_)( extern mlx_closure (*mlx_closure_new_unary_)(int (*fun)(mlx_array*, const mlx_array)); extern mlx_closure_kwargs (*mlx_closure_kwargs_new_)(void); extern int (*mlx_closure_kwargs_free_)(mlx_closure_kwargs cls); -extern mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(int (*fun)( - mlx_vector_array*, - const mlx_vector_array, - const mlx_map_string_to_array)); +extern mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_map_string_to_array)); extern mlx_closure_kwargs (*mlx_closure_kwargs_new_func_payload_)( int (*fun)( mlx_vector_array*, @@ -1271,11 +1325,12 @@ extern int (*mlx_closure_value_and_grad_apply_)( const mlx_vector_array input); extern mlx_closure_custom (*mlx_closure_custom_new_)(void); extern int (*mlx_closure_custom_free_)(mlx_closure_custom cls); -extern mlx_closure_custom (*mlx_closure_custom_new_func_)(int (*fun)( - mlx_vector_array*, - const mlx_vector_array, - const mlx_vector_array, - const mlx_vector_array)); +extern mlx_closure_custom (*mlx_closure_custom_new_func_)( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const mlx_vector_array)); extern mlx_closure_custom (*mlx_closure_custom_new_func_payload_)( int (*fun)( mlx_vector_array*, @@ -1296,12 +1351,13 @@ extern int (*mlx_closure_custom_apply_)( const mlx_vector_array input_2); extern mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_)(void); extern int (*mlx_closure_custom_jvp_free_)(mlx_closure_custom_jvp cls); -extern mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(int (*fun)( - mlx_vector_array*, - const mlx_vector_array, - const mlx_vector_array, - const int*, - size_t _num)); +extern mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const int*, + size_t _num)); extern mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_payload_)( int (*fun)( mlx_vector_array*, @@ -1324,12 +1380,13 @@ extern int (*mlx_closure_custom_jvp_apply_)( size_t input_2_num); extern mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_)(void); extern int (*mlx_closure_custom_vmap_free_)(mlx_closure_custom_vmap cls); -extern mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(int (*fun)( - mlx_vector_array*, - mlx_vector_int*, - const mlx_vector_array, - const int*, - size_t _num)); +extern mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)( + int (*fun)( + mlx_vector_array*, + mlx_vector_int*, + const mlx_vector_array, + const int*, + size_t _num)); extern mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_payload_)( int (*fun)( mlx_vector_array*, @@ -1363,6 +1420,7 @@ extern int (*mlx_detail_compile_erase_)(uintptr_t fun_id); extern int (*mlx_disable_compile_)(void); extern int (*mlx_enable_compile_)(void); extern int (*mlx_set_compile_mode_)(mlx_compile_mode mode); +extern int (*mlx_cuda_is_available_)(bool* res); extern mlx_device (*mlx_device_new_)(void); extern mlx_device (*mlx_device_new_type_)(mlx_device_type type, int index); extern int (*mlx_device_free_)(mlx_device dev); @@ -1373,11 +1431,28 @@ extern int (*mlx_device_get_index_)(int* index, mlx_device dev); extern int (*mlx_device_get_type_)(mlx_device_type* type, mlx_device dev); extern int (*mlx_get_default_device_)(mlx_device* dev); extern int (*mlx_set_default_device_)(mlx_device dev); -extern int (*mlx_distributed_group_rank_)(mlx_distributed_group group); -extern int (*mlx_distributed_group_size_)(mlx_distributed_group group); -extern mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key); -extern bool (*mlx_distributed_is_available_)(void); -extern mlx_distributed_group (*mlx_distributed_init_)(bool strict); +extern int (*mlx_device_is_available_)(bool* avail, mlx_device dev); +extern int (*mlx_device_count_)(int* count, mlx_device_type type); +extern mlx_device_info (*mlx_device_info_new_)(void); +extern int (*mlx_device_info_get_)(mlx_device_info* info, mlx_device dev); +extern int (*mlx_device_info_free_)(mlx_device_info info); +extern int (*mlx_device_info_has_key_)( + bool* exists, + mlx_device_info info, + const char* key); +extern int (*mlx_device_info_is_string_)( + bool* is_string, + mlx_device_info info, + const char* key); +extern int (*mlx_device_info_get_string_)( + const char** value, + mlx_device_info info, + const char* key); +extern int (*mlx_device_info_get_size_)( + size_t* value, + mlx_device_info info, + const char* key); +extern int (*mlx_device_info_get_keys_)(mlx_vector_string* keys, mlx_device_info info); extern int (*mlx_distributed_all_gather_)( mlx_array* res, const mlx_array x, @@ -1423,6 +1498,11 @@ extern int (*mlx_distributed_sum_scatter_)( const mlx_array x, const mlx_distributed_group group /* may be null */, const mlx_stream s); +extern int (*mlx_distributed_group_rank_)(mlx_distributed_group group); +extern int (*mlx_distributed_group_size_)(mlx_distributed_group group); +extern mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key); +extern bool (*mlx_distributed_is_available_)(void); +extern mlx_distributed_group (*mlx_distributed_init_)(bool strict); extern void (*mlx_set_error_handler_)( mlx_error_handler_func handler, void* data, @@ -1585,6 +1665,16 @@ extern int (*mlx_fast_rope_)( int offset, const mlx_array freqs /* may be null */, const mlx_stream s); +extern int (*mlx_fast_rope_dynamic_)( + mlx_array* res, + const mlx_array x, + int dims, + bool traditional, + mlx_optional_float base, + float scale, + const mlx_array offset, + const mlx_array freqs /* may be null */, + const mlx_stream s); extern int (*mlx_fast_scaled_dot_product_attention_)( mlx_array* res, const mlx_array queries, @@ -1695,14 +1785,6 @@ extern int (*mlx_fft_rfftn_)( const int* axes, size_t axes_num, const mlx_stream s); -extern mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable); -extern int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io); -extern int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io); -extern int (*mlx_io_reader_free_)(mlx_io_reader io); -extern mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable); -extern int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io); -extern int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io); -extern int (*mlx_io_writer_free_)(mlx_io_writer io); extern int (*mlx_load_reader_)( mlx_array* res, mlx_io_reader in_stream, @@ -1728,6 +1810,14 @@ extern int (*mlx_save_safetensors_)( const char* file, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata); +extern mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable); +extern int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io); +extern int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io); +extern int (*mlx_io_reader_free_)(mlx_io_reader io); +extern mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable); +extern int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io); +extern int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io); +extern int (*mlx_io_writer_free_)(mlx_io_writer io); extern int (*mlx_linalg_cholesky_)( mlx_array* res, const mlx_array a, @@ -1868,7 +1958,6 @@ extern int (*mlx_reset_peak_memory_)(void); extern int (*mlx_set_cache_limit_)(size_t* res, size_t limit); extern int (*mlx_set_memory_limit_)(size_t* res, size_t limit); extern int (*mlx_set_wired_limit_)(size_t* res, size_t limit); -extern mlx_metal_device_info_t (*mlx_metal_device_info_)(void); extern int (*mlx_metal_is_available_)(bool* res); extern int (*mlx_metal_start_capture_)(const char* path); extern int (*mlx_metal_stop_capture_)(void); @@ -2297,6 +2386,14 @@ extern int (*mlx_gather_)( const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s); +extern int (*mlx_gather_single_)( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + int axis, + const int* slice_sizes, + size_t slice_sizes_num, + const mlx_stream s); extern int (*mlx_gather_mm_)( mlx_array* res, const mlx_array a, @@ -2618,6 +2715,15 @@ extern int (*mlx_put_along_axis_)( const mlx_array values, int axis, const mlx_stream s); +extern int (*mlx_qqmm_)( + mlx_array* res, + const mlx_array x, + const mlx_array w, + const mlx_array w_scales /* may be null */, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + const mlx_stream s); extern int (*mlx_quantize_)( mlx_vector_array* res, const mlx_array w, @@ -2701,6 +2807,13 @@ extern int (*mlx_scatter_)( const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_scatter_single_)( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s); extern int (*mlx_scatter_add_)( mlx_array* res, const mlx_array a, @@ -2709,6 +2822,13 @@ extern int (*mlx_scatter_add_)( const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_scatter_add_single_)( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s); extern int (*mlx_scatter_add_axis_)( mlx_array* res, const mlx_array a, @@ -2724,6 +2844,13 @@ extern int (*mlx_scatter_max_)( const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_scatter_max_single_)( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s); extern int (*mlx_scatter_min_)( mlx_array* res, const mlx_array a, @@ -2732,6 +2859,13 @@ extern int (*mlx_scatter_min_)( const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_scatter_min_single_)( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s); extern int (*mlx_scatter_prod_)( mlx_array* res, const mlx_array a, @@ -2740,6 +2874,13 @@ extern int (*mlx_scatter_prod_)( const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_scatter_prod_single_)( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s); extern int (*mlx_segmented_mm_)( mlx_array* res, const mlx_array a, @@ -3163,22 +3304,6 @@ extern mlx_string (*mlx_string_new_data_)(const char* str); extern int (*mlx_string_set_)(mlx_string* str, const mlx_string src); extern const char * (*mlx_string_data_)(mlx_string str); extern int (*mlx_string_free_)(mlx_string str); -extern int (*mlx_detail_vmap_replace_)( - mlx_vector_array* res, - const mlx_vector_array inputs, - const mlx_vector_array s_inputs, - const mlx_vector_array s_outputs, - const int* in_axes, - size_t in_axes_num, - const int* out_axes, - size_t out_axes_num); -extern int (*mlx_detail_vmap_trace_)( - mlx_vector_array* res_0, - mlx_vector_array* res_1, - const mlx_closure fun, - const mlx_vector_array inputs, - const int* in_axes, - size_t in_axes_num); extern int (*mlx_async_eval_)(const mlx_vector_array outputs); extern int (*mlx_checkpoint_)(mlx_closure* res, const mlx_closure fun); extern int (*mlx_custom_function_)( @@ -3209,6 +3334,22 @@ extern int (*mlx_vjp_)( const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array cotangents); +extern int (*mlx_detail_vmap_replace_)( + mlx_vector_array* res, + const mlx_vector_array inputs, + const mlx_vector_array s_inputs, + const mlx_vector_array s_outputs, + const int* in_axes, + size_t in_axes_num, + const int* out_axes, + size_t out_axes_num); +extern int (*mlx_detail_vmap_trace_)( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + const mlx_closure fun, + const mlx_vector_array inputs, + const int* in_axes, + size_t in_axes_num); extern mlx_vector_array (*mlx_vector_array_new_)(void); extern int (*mlx_vector_array_set_)(mlx_vector_array* vec, const mlx_vector_array src); extern int (*mlx_vector_array_free_)(mlx_vector_array vec); @@ -3293,47 +3434,36 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle); static inline size_t mlx_dtype_size(mlx_dtype dtype) { return mlx_dtype_size_(dtype); } - static inline int mlx_array_tostring(mlx_string* str, const mlx_array arr) { return mlx_array_tostring_(str, arr); } - static inline mlx_array mlx_array_new(void) { return mlx_array_new_(); } - static inline int mlx_array_free(mlx_array arr) { return mlx_array_free_(arr); } - static inline mlx_array mlx_array_new_bool(bool val) { return mlx_array_new_bool_(val); } - static inline mlx_array mlx_array_new_int(int val) { return mlx_array_new_int_(val); } - static inline mlx_array mlx_array_new_float32(float val) { return mlx_array_new_float32_(val); } - static inline mlx_array mlx_array_new_float(float val) { return mlx_array_new_float_(val); } - static inline mlx_array mlx_array_new_float64(double val) { return mlx_array_new_float64_(val); } - static inline mlx_array mlx_array_new_double(double val) { return mlx_array_new_double_(val); } - static inline mlx_array mlx_array_new_complex(float real_val, float imag_val) { return mlx_array_new_complex_(real_val, imag_val); } - static inline mlx_array mlx_array_new_data( const void* data, const int* shape, @@ -3341,39 +3471,47 @@ static inline mlx_array mlx_array_new_data( mlx_dtype dtype) { return mlx_array_new_data_(data, shape, dim, dtype); } - +static inline mlx_array mlx_array_new_data_managed( + void* data, + const int* shape, + int dim, + mlx_dtype dtype, + void (*dtor)(void*)) { + return mlx_array_new_data_managed_(data, shape, dim, dtype, dtor); +} +static inline mlx_array mlx_array_new_data_managed_payload( + void* data, + const int* shape, + int dim, + mlx_dtype dtype, + void* payload, + void (*dtor)(void*)) { + return mlx_array_new_data_managed_payload_(data, shape, dim, dtype, payload, dtor); +} static inline int mlx_array_set(mlx_array* arr, const mlx_array src) { return mlx_array_set_(arr, src); } - static inline int mlx_array_set_bool(mlx_array* arr, bool val) { return mlx_array_set_bool_(arr, val); } - static inline int mlx_array_set_int(mlx_array* arr, int val) { return mlx_array_set_int_(arr, val); } - static inline int mlx_array_set_float32(mlx_array* arr, float val) { return mlx_array_set_float32_(arr, val); } - static inline int mlx_array_set_float(mlx_array* arr, float val) { return mlx_array_set_float_(arr, val); } - static inline int mlx_array_set_float64(mlx_array* arr, double val) { return mlx_array_set_float64_(arr, val); } - static inline int mlx_array_set_double(mlx_array* arr, double val) { return mlx_array_set_double_(arr, val); } - static inline int mlx_array_set_complex(mlx_array* arr, float real_val, float imag_val) { return mlx_array_set_complex_(arr, real_val, imag_val); } - static inline int mlx_array_set_data( mlx_array* arr, const void* data, @@ -3382,225 +3520,173 @@ static inline int mlx_array_set_data( mlx_dtype dtype) { return mlx_array_set_data_(arr, data, shape, dim, dtype); } - static inline size_t mlx_array_itemsize(const mlx_array arr) { return mlx_array_itemsize_(arr); } - static inline size_t mlx_array_size(const mlx_array arr) { return mlx_array_size_(arr); } - static inline size_t mlx_array_nbytes(const mlx_array arr) { return mlx_array_nbytes_(arr); } - static inline size_t mlx_array_ndim(const mlx_array arr) { return mlx_array_ndim_(arr); } - static inline const int * mlx_array_shape(const mlx_array arr) { return mlx_array_shape_(arr); } - static inline const size_t * mlx_array_strides(const mlx_array arr) { return mlx_array_strides_(arr); } - static inline int mlx_array_dim(const mlx_array arr, int dim) { return mlx_array_dim_(arr, dim); } - static inline mlx_dtype mlx_array_dtype(const mlx_array arr) { return mlx_array_dtype_(arr); } - static inline int mlx_array_eval(mlx_array arr) { return mlx_array_eval_(arr); } - static inline int mlx_array_item_bool(bool* res, const mlx_array arr) { return mlx_array_item_bool_(res, arr); } - static inline int mlx_array_item_uint8(uint8_t* res, const mlx_array arr) { return mlx_array_item_uint8_(res, arr); } - static inline int mlx_array_item_uint16(uint16_t* res, const mlx_array arr) { return mlx_array_item_uint16_(res, arr); } - static inline int mlx_array_item_uint32(uint32_t* res, const mlx_array arr) { return mlx_array_item_uint32_(res, arr); } - static inline int mlx_array_item_uint64(uint64_t* res, const mlx_array arr) { return mlx_array_item_uint64_(res, arr); } - static inline int mlx_array_item_int8(int8_t* res, const mlx_array arr) { return mlx_array_item_int8_(res, arr); } - static inline int mlx_array_item_int16(int16_t* res, const mlx_array arr) { return mlx_array_item_int16_(res, arr); } - static inline int mlx_array_item_int32(int32_t* res, const mlx_array arr) { return mlx_array_item_int32_(res, arr); } - static inline int mlx_array_item_int64(int64_t* res, const mlx_array arr) { return mlx_array_item_int64_(res, arr); } - static inline int mlx_array_item_float32(float* res, const mlx_array arr) { return mlx_array_item_float32_(res, arr); } - static inline int mlx_array_item_float64(double* res, const mlx_array arr) { return mlx_array_item_float64_(res, arr); } - -static inline int mlx_array_item_complex64(float _Complex* res, const mlx_array arr) { +static inline int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr) { return mlx_array_item_complex64_(res, arr); } - static inline int mlx_array_item_float16(float16_t* res, const mlx_array arr) { return mlx_array_item_float16_(res, arr); } - static inline int mlx_array_item_bfloat16(bfloat16_t* res, const mlx_array arr) { return mlx_array_item_bfloat16_(res, arr); } - static inline const bool * mlx_array_data_bool(const mlx_array arr) { return mlx_array_data_bool_(arr); } - static inline const uint8_t * mlx_array_data_uint8(const mlx_array arr) { return mlx_array_data_uint8_(arr); } - static inline const uint16_t * mlx_array_data_uint16(const mlx_array arr) { return mlx_array_data_uint16_(arr); } - static inline const uint32_t * mlx_array_data_uint32(const mlx_array arr) { return mlx_array_data_uint32_(arr); } - static inline const uint64_t * mlx_array_data_uint64(const mlx_array arr) { return mlx_array_data_uint64_(arr); } - static inline const int8_t * mlx_array_data_int8(const mlx_array arr) { return mlx_array_data_int8_(arr); } - static inline const int16_t * mlx_array_data_int16(const mlx_array arr) { return mlx_array_data_int16_(arr); } - static inline const int32_t * mlx_array_data_int32(const mlx_array arr) { return mlx_array_data_int32_(arr); } - static inline const int64_t * mlx_array_data_int64(const mlx_array arr) { return mlx_array_data_int64_(arr); } - static inline const float * mlx_array_data_float32(const mlx_array arr) { return mlx_array_data_float32_(arr); } - static inline const double * mlx_array_data_float64(const mlx_array arr) { return mlx_array_data_float64_(arr); } - -static inline const float _Complex * mlx_array_data_complex64(const mlx_array arr) { +static inline const mlx_complex64_t * mlx_array_data_complex64(const mlx_array arr) { return mlx_array_data_complex64_(arr); } - static inline const float16_t * mlx_array_data_float16(const mlx_array arr) { return mlx_array_data_float16_(arr); } - static inline const bfloat16_t * mlx_array_data_bfloat16(const mlx_array arr) { return mlx_array_data_bfloat16_(arr); } - static inline int _mlx_array_is_available(bool* res, const mlx_array arr) { return _mlx_array_is_available_(res, arr); } - static inline int _mlx_array_wait(const mlx_array arr) { return _mlx_array_wait_(arr); } - static inline int _mlx_array_is_contiguous(bool* res, const mlx_array arr) { return _mlx_array_is_contiguous_(res, arr); } - static inline int _mlx_array_is_row_contiguous(bool* res, const mlx_array arr) { return _mlx_array_is_row_contiguous_(res, arr); } - static inline int _mlx_array_is_col_contiguous(bool* res, const mlx_array arr) { return _mlx_array_is_col_contiguous_(res, arr); } - static inline mlx_closure mlx_closure_new(void) { return mlx_closure_new_(); } - static inline int mlx_closure_free(mlx_closure cls) { return mlx_closure_free_(cls); } - static inline mlx_closure mlx_closure_new_func( int (*fun)(mlx_vector_array*, const mlx_vector_array)) { return mlx_closure_new_func_(fun); } - static inline mlx_closure mlx_closure_new_func_payload( int (*fun)(mlx_vector_array*, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) { return mlx_closure_new_func_payload_(fun, payload, dtor); } - static inline int mlx_closure_set(mlx_closure* cls, const mlx_closure src) { return mlx_closure_set_(cls, src); } - static inline int mlx_closure_apply( mlx_vector_array* res, mlx_closure cls, const mlx_vector_array input) { return mlx_closure_apply_(res, cls, input); } - static inline mlx_closure mlx_closure_new_unary(int (*fun)(mlx_array*, const mlx_array)) { return mlx_closure_new_unary_(fun); } - static inline mlx_closure_kwargs mlx_closure_kwargs_new(void) { return mlx_closure_kwargs_new_(); } - static inline int mlx_closure_kwargs_free(mlx_closure_kwargs cls) { return mlx_closure_kwargs_free_(cls); } - -static inline mlx_closure_kwargs mlx_closure_kwargs_new_func(int (*fun)( - mlx_vector_array*, - const mlx_vector_array, - const mlx_map_string_to_array)) { +static inline mlx_closure_kwargs mlx_closure_kwargs_new_func( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_map_string_to_array)) { return mlx_closure_kwargs_new_func_(fun); } - static inline mlx_closure_kwargs mlx_closure_kwargs_new_func_payload( int (*fun)( mlx_vector_array*, @@ -3611,13 +3697,11 @@ static inline mlx_closure_kwargs mlx_closure_kwargs_new_func_payload( void (*dtor)(void*)) { return mlx_closure_kwargs_new_func_payload_(fun, payload, dtor); } - static inline int mlx_closure_kwargs_set( mlx_closure_kwargs* cls, const mlx_closure_kwargs src) { return mlx_closure_kwargs_set_(cls, src); } - static inline int mlx_closure_kwargs_apply( mlx_vector_array* res, mlx_closure_kwargs cls, @@ -3625,20 +3709,16 @@ static inline int mlx_closure_kwargs_apply( const mlx_map_string_to_array input_1) { return mlx_closure_kwargs_apply_(res, cls, input_0, input_1); } - static inline mlx_closure_value_and_grad mlx_closure_value_and_grad_new(void) { return mlx_closure_value_and_grad_new_(); } - static inline int mlx_closure_value_and_grad_free(mlx_closure_value_and_grad cls) { return mlx_closure_value_and_grad_free_(cls); } - static inline mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func( int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array)) { return mlx_closure_value_and_grad_new_func_(fun); } - static inline mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func_payload( int (*fun)( mlx_vector_array*, @@ -3649,13 +3729,11 @@ static inline mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func_pay void (*dtor)(void*)) { return mlx_closure_value_and_grad_new_func_payload_(fun, payload, dtor); } - static inline int mlx_closure_value_and_grad_set( mlx_closure_value_and_grad* cls, const mlx_closure_value_and_grad src) { return mlx_closure_value_and_grad_set_(cls, src); } - static inline int mlx_closure_value_and_grad_apply( mlx_vector_array* res_0, mlx_vector_array* res_1, @@ -3663,23 +3741,20 @@ static inline int mlx_closure_value_and_grad_apply( const mlx_vector_array input) { return mlx_closure_value_and_grad_apply_(res_0, res_1, cls, input); } - static inline mlx_closure_custom mlx_closure_custom_new(void) { return mlx_closure_custom_new_(); } - static inline int mlx_closure_custom_free(mlx_closure_custom cls) { return mlx_closure_custom_free_(cls); } - -static inline mlx_closure_custom mlx_closure_custom_new_func(int (*fun)( - mlx_vector_array*, - const mlx_vector_array, - const mlx_vector_array, - const mlx_vector_array)) { +static inline mlx_closure_custom mlx_closure_custom_new_func( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const mlx_vector_array)) { return mlx_closure_custom_new_func_(fun); } - static inline mlx_closure_custom mlx_closure_custom_new_func_payload( int (*fun)( mlx_vector_array*, @@ -3691,13 +3766,11 @@ static inline mlx_closure_custom mlx_closure_custom_new_func_payload( void (*dtor)(void*)) { return mlx_closure_custom_new_func_payload_(fun, payload, dtor); } - static inline int mlx_closure_custom_set( mlx_closure_custom* cls, const mlx_closure_custom src) { return mlx_closure_custom_set_(cls, src); } - static inline int mlx_closure_custom_apply( mlx_vector_array* res, mlx_closure_custom cls, @@ -3706,24 +3779,21 @@ static inline int mlx_closure_custom_apply( const mlx_vector_array input_2) { return mlx_closure_custom_apply_(res, cls, input_0, input_1, input_2); } - static inline mlx_closure_custom_jvp mlx_closure_custom_jvp_new(void) { return mlx_closure_custom_jvp_new_(); } - static inline int mlx_closure_custom_jvp_free(mlx_closure_custom_jvp cls) { return mlx_closure_custom_jvp_free_(cls); } - -static inline mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func(int (*fun)( - mlx_vector_array*, - const mlx_vector_array, - const mlx_vector_array, - const int*, - size_t _num)) { +static inline mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const int*, + size_t _num)) { return mlx_closure_custom_jvp_new_func_(fun); } - static inline mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func_payload( int (*fun)( mlx_vector_array*, @@ -3736,13 +3806,11 @@ static inline mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func_payload( void (*dtor)(void*)) { return mlx_closure_custom_jvp_new_func_payload_(fun, payload, dtor); } - static inline int mlx_closure_custom_jvp_set( mlx_closure_custom_jvp* cls, const mlx_closure_custom_jvp src) { return mlx_closure_custom_jvp_set_(cls, src); } - static inline int mlx_closure_custom_jvp_apply( mlx_vector_array* res, mlx_closure_custom_jvp cls, @@ -3752,24 +3820,21 @@ static inline int mlx_closure_custom_jvp_apply( size_t input_2_num) { return mlx_closure_custom_jvp_apply_(res, cls, input_0, input_1, input_2, input_2_num); } - static inline mlx_closure_custom_vmap mlx_closure_custom_vmap_new(void) { return mlx_closure_custom_vmap_new_(); } - static inline int mlx_closure_custom_vmap_free(mlx_closure_custom_vmap cls) { return mlx_closure_custom_vmap_free_(cls); } - -static inline mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func(int (*fun)( - mlx_vector_array*, - mlx_vector_int*, - const mlx_vector_array, - const int*, - size_t _num)) { +static inline mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func( + int (*fun)( + mlx_vector_array*, + mlx_vector_int*, + const mlx_vector_array, + const int*, + size_t _num)) { return mlx_closure_custom_vmap_new_func_(fun); } - static inline mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func_payload( int (*fun)( mlx_vector_array*, @@ -3782,13 +3847,11 @@ static inline mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func_payload( void (*dtor)(void*)) { return mlx_closure_custom_vmap_new_func_payload_(fun, payload, dtor); } - static inline int mlx_closure_custom_vmap_set( mlx_closure_custom_vmap* cls, const mlx_closure_custom_vmap src) { return mlx_closure_custom_vmap_set_(cls, src); } - static inline int mlx_closure_custom_vmap_apply( mlx_vector_array* res_0, mlx_vector_int* res_1, @@ -3798,11 +3861,9 @@ static inline int mlx_closure_custom_vmap_apply( size_t input_1_num) { return mlx_closure_custom_vmap_apply_(res_0, res_1, cls, input_0, input_1, input_1_num); } - static inline int mlx_compile(mlx_closure* res, const mlx_closure fun, bool shapeless) { return mlx_compile_(res, fun, shapeless); } - static inline int mlx_detail_compile( mlx_closure* res, const mlx_closure fun, @@ -3812,87 +3873,96 @@ static inline int mlx_detail_compile( size_t constants_num) { return mlx_detail_compile_(res, fun, fun_id, shapeless, constants, constants_num); } - static inline int mlx_detail_compile_clear_cache(void) { return mlx_detail_compile_clear_cache_(); } - static inline int mlx_detail_compile_erase(uintptr_t fun_id) { return mlx_detail_compile_erase_(fun_id); } - static inline int mlx_disable_compile(void) { return mlx_disable_compile_(); } - static inline int mlx_enable_compile(void) { return mlx_enable_compile_(); } - static inline int mlx_set_compile_mode(mlx_compile_mode mode) { return mlx_set_compile_mode_(mode); } - +static inline int mlx_cuda_is_available(bool* res) { + return mlx_cuda_is_available_(res); +} static inline mlx_device mlx_device_new(void) { return mlx_device_new_(); } - static inline mlx_device mlx_device_new_type(mlx_device_type type, int index) { return mlx_device_new_type_(type, index); } - static inline int mlx_device_free(mlx_device dev) { return mlx_device_free_(dev); } - static inline int mlx_device_set(mlx_device* dev, const mlx_device src) { return mlx_device_set_(dev, src); } - static inline int mlx_device_tostring(mlx_string* str, mlx_device dev) { return mlx_device_tostring_(str, dev); } - static inline bool mlx_device_equal(mlx_device lhs, mlx_device rhs) { return mlx_device_equal_(lhs, rhs); } - static inline int mlx_device_get_index(int* index, mlx_device dev) { return mlx_device_get_index_(index, dev); } - static inline int mlx_device_get_type(mlx_device_type* type, mlx_device dev) { return mlx_device_get_type_(type, dev); } - static inline int mlx_get_default_device(mlx_device* dev) { return mlx_get_default_device_(dev); } - static inline int mlx_set_default_device(mlx_device dev) { return mlx_set_default_device_(dev); } - -static inline int mlx_distributed_group_rank(mlx_distributed_group group) { - return mlx_distributed_group_rank_(group); +static inline int mlx_device_is_available(bool* avail, mlx_device dev) { + return mlx_device_is_available_(avail, dev); } - -static inline int mlx_distributed_group_size(mlx_distributed_group group) { - return mlx_distributed_group_size_(group); +static inline int mlx_device_count(int* count, mlx_device_type type) { + return mlx_device_count_(count, type); } - -static inline mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) { - return mlx_distributed_group_split_(group, color, key); +static inline mlx_device_info mlx_device_info_new(void) { + return mlx_device_info_new_(); } - -static inline bool mlx_distributed_is_available(void) { - return mlx_distributed_is_available_(); +static inline int mlx_device_info_get(mlx_device_info* info, mlx_device dev) { + return mlx_device_info_get_(info, dev); } - -static inline mlx_distributed_group mlx_distributed_init(bool strict) { - return mlx_distributed_init_(strict); +static inline int mlx_device_info_free(mlx_device_info info) { + return mlx_device_info_free_(info); +} +static inline int mlx_device_info_has_key( + bool* exists, + mlx_device_info info, + const char* key) { + return mlx_device_info_has_key_(exists, info, key); +} +static inline int mlx_device_info_is_string( + bool* is_string, + mlx_device_info info, + const char* key) { + return mlx_device_info_is_string_(is_string, info, key); +} +static inline int mlx_device_info_get_string( + const char** value, + mlx_device_info info, + const char* key) { + return mlx_device_info_get_string_(value, info, key); +} +static inline int mlx_device_info_get_size( + size_t* value, + mlx_device_info info, + const char* key) { + return mlx_device_info_get_size_(value, info, key); +} +static inline int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info) { + return mlx_device_info_get_keys_(keys, info); } - static inline int mlx_distributed_all_gather( mlx_array* res, const mlx_array x, @@ -3900,7 +3970,6 @@ static inline int mlx_distributed_all_gather( const mlx_stream S) { return mlx_distributed_all_gather_(res, x, group, S); } - static inline int mlx_distributed_all_max( mlx_array* res, const mlx_array x, @@ -3908,7 +3977,6 @@ static inline int mlx_distributed_all_max( const mlx_stream s) { return mlx_distributed_all_max_(res, x, group, s); } - static inline int mlx_distributed_all_min( mlx_array* res, const mlx_array x, @@ -3916,7 +3984,6 @@ static inline int mlx_distributed_all_min( const mlx_stream s) { return mlx_distributed_all_min_(res, x, group, s); } - static inline int mlx_distributed_all_sum( mlx_array* res, const mlx_array x, @@ -3924,7 +3991,6 @@ static inline int mlx_distributed_all_sum( const mlx_stream s) { return mlx_distributed_all_sum_(res, x, group, s); } - static inline int mlx_distributed_recv( mlx_array* res, const int* shape, @@ -3935,7 +4001,6 @@ static inline int mlx_distributed_recv( const mlx_stream s) { return mlx_distributed_recv_(res, shape, shape_num, dtype, src, group, s); } - static inline int mlx_distributed_recv_like( mlx_array* res, const mlx_array x, @@ -3944,7 +4009,6 @@ static inline int mlx_distributed_recv_like( const mlx_stream s) { return mlx_distributed_recv_like_(res, x, src, group, s); } - static inline int mlx_distributed_send( mlx_array* res, const mlx_array x, @@ -3953,7 +4017,6 @@ static inline int mlx_distributed_send( const mlx_stream s) { return mlx_distributed_send_(res, x, dst, group, s); } - static inline int mlx_distributed_sum_scatter( mlx_array* res, const mlx_array x, @@ -3961,16 +4024,30 @@ static inline int mlx_distributed_sum_scatter( const mlx_stream s) { return mlx_distributed_sum_scatter_(res, x, group, s); } - +static inline int mlx_distributed_group_rank(mlx_distributed_group group) { + return mlx_distributed_group_rank_(group); +} +static inline int mlx_distributed_group_size(mlx_distributed_group group) { + return mlx_distributed_group_size_(group); +} +static inline mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) { + return mlx_distributed_group_split_(group, color, key); +} +static inline bool mlx_distributed_is_available(void) { + return mlx_distributed_is_available_(); +} +static inline mlx_distributed_group mlx_distributed_init(bool strict) { + return mlx_distributed_init_(strict); +} static inline void mlx_set_error_handler( mlx_error_handler_func handler, void* data, void (*dtor)(void*)) { - mlx_set_error_handler_(handler, data, dtor); + return mlx_set_error_handler_(handler, data, dtor); +} +static inline void _mlx_error(const char* file, const int line, const char* fmt, ...) { + return _mlx_error_(file, line, fmt); } - -#define _mlx_error(file, line, fmt, ...) _mlx_error_(file, line, fmt, __VA_ARGS__) - static inline int mlx_export_function( const char* file, const mlx_closure fun, @@ -3978,7 +4055,6 @@ static inline int mlx_export_function( bool shapeless) { return mlx_export_function_(file, fun, args, shapeless); } - static inline int mlx_export_function_kwargs( const char* file, const mlx_closure_kwargs fun, @@ -3987,46 +4063,38 @@ static inline int mlx_export_function_kwargs( bool shapeless) { return mlx_export_function_kwargs_(file, fun, args, kwargs, shapeless); } - static inline mlx_function_exporter mlx_function_exporter_new( const char* file, const mlx_closure fun, bool shapeless) { return mlx_function_exporter_new_(file, fun, shapeless); } - static inline int mlx_function_exporter_free(mlx_function_exporter xfunc) { return mlx_function_exporter_free_(xfunc); } - static inline int mlx_function_exporter_apply( const mlx_function_exporter xfunc, const mlx_vector_array args) { return mlx_function_exporter_apply_(xfunc, args); } - static inline int mlx_function_exporter_apply_kwargs( const mlx_function_exporter xfunc, const mlx_vector_array args, const mlx_map_string_to_array kwargs) { return mlx_function_exporter_apply_kwargs_(xfunc, args, kwargs); } - static inline mlx_imported_function mlx_imported_function_new(const char* file) { return mlx_imported_function_new_(file); } - static inline int mlx_imported_function_free(mlx_imported_function xfunc) { return mlx_imported_function_free_(xfunc); } - static inline int mlx_imported_function_apply( mlx_vector_array* res, const mlx_imported_function xfunc, const mlx_vector_array args) { return mlx_imported_function_apply_(res, xfunc, args); } - static inline int mlx_imported_function_apply_kwargs( mlx_vector_array* res, const mlx_imported_function xfunc, @@ -4034,15 +4102,12 @@ static inline int mlx_imported_function_apply_kwargs( const mlx_map_string_to_array kwargs) { return mlx_imported_function_apply_kwargs_(res, xfunc, args, kwargs); } - static inline mlx_fast_cuda_kernel_config mlx_fast_cuda_kernel_config_new(void) { return mlx_fast_cuda_kernel_config_new_(); } - static inline void mlx_fast_cuda_kernel_config_free(mlx_fast_cuda_kernel_config cls) { - mlx_fast_cuda_kernel_config_free_(cls); + return mlx_fast_cuda_kernel_config_free_(cls); } - static inline int mlx_fast_cuda_kernel_config_add_output_arg( mlx_fast_cuda_kernel_config cls, const int* shape, @@ -4050,7 +4115,6 @@ static inline int mlx_fast_cuda_kernel_config_add_output_arg( mlx_dtype dtype) { return mlx_fast_cuda_kernel_config_add_output_arg_(cls, shape, size, dtype); } - static inline int mlx_fast_cuda_kernel_config_set_grid( mlx_fast_cuda_kernel_config cls, int grid1, @@ -4058,7 +4122,6 @@ static inline int mlx_fast_cuda_kernel_config_set_grid( int grid3) { return mlx_fast_cuda_kernel_config_set_grid_(cls, grid1, grid2, grid3); } - static inline int mlx_fast_cuda_kernel_config_set_thread_group( mlx_fast_cuda_kernel_config cls, int thread1, @@ -4066,40 +4129,34 @@ static inline int mlx_fast_cuda_kernel_config_set_thread_group( int thread3) { return mlx_fast_cuda_kernel_config_set_thread_group_(cls, thread1, thread2, thread3); } - static inline int mlx_fast_cuda_kernel_config_set_init_value( mlx_fast_cuda_kernel_config cls, float value) { return mlx_fast_cuda_kernel_config_set_init_value_(cls, value); } - static inline int mlx_fast_cuda_kernel_config_set_verbose( mlx_fast_cuda_kernel_config cls, bool verbose) { return mlx_fast_cuda_kernel_config_set_verbose_(cls, verbose); } - static inline int mlx_fast_cuda_kernel_config_add_template_arg_dtype( mlx_fast_cuda_kernel_config cls, const char* name, mlx_dtype dtype) { return mlx_fast_cuda_kernel_config_add_template_arg_dtype_(cls, name, dtype); } - static inline int mlx_fast_cuda_kernel_config_add_template_arg_int( mlx_fast_cuda_kernel_config cls, const char* name, int value) { return mlx_fast_cuda_kernel_config_add_template_arg_int_(cls, name, value); } - static inline int mlx_fast_cuda_kernel_config_add_template_arg_bool( mlx_fast_cuda_kernel_config cls, const char* name, bool value) { return mlx_fast_cuda_kernel_config_add_template_arg_bool_(cls, name, value); } - static inline mlx_fast_cuda_kernel mlx_fast_cuda_kernel_new( const char* name, const mlx_vector_string input_names, @@ -4110,11 +4167,9 @@ static inline mlx_fast_cuda_kernel mlx_fast_cuda_kernel_new( int shared_memory) { return mlx_fast_cuda_kernel_new_(name, input_names, output_names, source, header, ensure_row_contiguous, shared_memory); } - static inline void mlx_fast_cuda_kernel_free(mlx_fast_cuda_kernel cls) { - mlx_fast_cuda_kernel_free_(cls); + return mlx_fast_cuda_kernel_free_(cls); } - static inline int mlx_fast_cuda_kernel_apply( mlx_vector_array* outputs, mlx_fast_cuda_kernel cls, @@ -4123,7 +4178,6 @@ static inline int mlx_fast_cuda_kernel_apply( const mlx_stream stream) { return mlx_fast_cuda_kernel_apply_(outputs, cls, inputs, config, stream); } - static inline int mlx_fast_layer_norm( mlx_array* res, const mlx_array x, @@ -4133,15 +4187,12 @@ static inline int mlx_fast_layer_norm( const mlx_stream s) { return mlx_fast_layer_norm_(res, x, weight, bias, eps, s); } - static inline mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new(void) { return mlx_fast_metal_kernel_config_new_(); } - static inline void mlx_fast_metal_kernel_config_free(mlx_fast_metal_kernel_config cls) { - mlx_fast_metal_kernel_config_free_(cls); + return mlx_fast_metal_kernel_config_free_(cls); } - static inline int mlx_fast_metal_kernel_config_add_output_arg( mlx_fast_metal_kernel_config cls, const int* shape, @@ -4149,7 +4200,6 @@ static inline int mlx_fast_metal_kernel_config_add_output_arg( mlx_dtype dtype) { return mlx_fast_metal_kernel_config_add_output_arg_(cls, shape, size, dtype); } - static inline int mlx_fast_metal_kernel_config_set_grid( mlx_fast_metal_kernel_config cls, int grid1, @@ -4157,7 +4207,6 @@ static inline int mlx_fast_metal_kernel_config_set_grid( int grid3) { return mlx_fast_metal_kernel_config_set_grid_(cls, grid1, grid2, grid3); } - static inline int mlx_fast_metal_kernel_config_set_thread_group( mlx_fast_metal_kernel_config cls, int thread1, @@ -4165,40 +4214,34 @@ static inline int mlx_fast_metal_kernel_config_set_thread_group( int thread3) { return mlx_fast_metal_kernel_config_set_thread_group_(cls, thread1, thread2, thread3); } - static inline int mlx_fast_metal_kernel_config_set_init_value( mlx_fast_metal_kernel_config cls, float value) { return mlx_fast_metal_kernel_config_set_init_value_(cls, value); } - static inline int mlx_fast_metal_kernel_config_set_verbose( mlx_fast_metal_kernel_config cls, bool verbose) { return mlx_fast_metal_kernel_config_set_verbose_(cls, verbose); } - static inline int mlx_fast_metal_kernel_config_add_template_arg_dtype( mlx_fast_metal_kernel_config cls, const char* name, mlx_dtype dtype) { return mlx_fast_metal_kernel_config_add_template_arg_dtype_(cls, name, dtype); } - static inline int mlx_fast_metal_kernel_config_add_template_arg_int( mlx_fast_metal_kernel_config cls, const char* name, int value) { return mlx_fast_metal_kernel_config_add_template_arg_int_(cls, name, value); } - static inline int mlx_fast_metal_kernel_config_add_template_arg_bool( mlx_fast_metal_kernel_config cls, const char* name, bool value) { return mlx_fast_metal_kernel_config_add_template_arg_bool_(cls, name, value); } - static inline mlx_fast_metal_kernel mlx_fast_metal_kernel_new( const char* name, const mlx_vector_string input_names, @@ -4209,11 +4252,9 @@ static inline mlx_fast_metal_kernel mlx_fast_metal_kernel_new( bool atomic_outputs) { return mlx_fast_metal_kernel_new_(name, input_names, output_names, source, header, ensure_row_contiguous, atomic_outputs); } - static inline void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls) { - mlx_fast_metal_kernel_free_(cls); + return mlx_fast_metal_kernel_free_(cls); } - static inline int mlx_fast_metal_kernel_apply( mlx_vector_array* outputs, mlx_fast_metal_kernel cls, @@ -4222,7 +4263,6 @@ static inline int mlx_fast_metal_kernel_apply( const mlx_stream stream) { return mlx_fast_metal_kernel_apply_(outputs, cls, inputs, config, stream); } - static inline int mlx_fast_rms_norm( mlx_array* res, const mlx_array x, @@ -4231,7 +4271,6 @@ static inline int mlx_fast_rms_norm( const mlx_stream s) { return mlx_fast_rms_norm_(res, x, weight, eps, s); } - static inline int mlx_fast_rope( mlx_array* res, const mlx_array x, @@ -4244,7 +4283,18 @@ static inline int mlx_fast_rope( const mlx_stream s) { return mlx_fast_rope_(res, x, dims, traditional, base, scale, offset, freqs, s); } - +static inline int mlx_fast_rope_dynamic( + mlx_array* res, + const mlx_array x, + int dims, + bool traditional, + mlx_optional_float base, + float scale, + const mlx_array offset, + const mlx_array freqs /* may be null */, + const mlx_stream s) { + return mlx_fast_rope_dynamic_(res, x, dims, traditional, base, scale, offset, freqs, s); +} static inline int mlx_fast_scaled_dot_product_attention( mlx_array* res, const mlx_array queries, @@ -4257,7 +4307,6 @@ static inline int mlx_fast_scaled_dot_product_attention( const mlx_stream s) { return mlx_fast_scaled_dot_product_attention_(res, queries, keys, values, scale, mask_mode, mask_arr, sinks, s); } - static inline int mlx_fft_fft( mlx_array* res, const mlx_array a, @@ -4266,7 +4315,6 @@ static inline int mlx_fft_fft( const mlx_stream s) { return mlx_fft_fft_(res, a, n, axis, s); } - static inline int mlx_fft_fft2( mlx_array* res, const mlx_array a, @@ -4277,7 +4325,6 @@ static inline int mlx_fft_fft2( const mlx_stream s) { return mlx_fft_fft2_(res, a, n, n_num, axes, axes_num, s); } - static inline int mlx_fft_fftn( mlx_array* res, const mlx_array a, @@ -4288,7 +4335,6 @@ static inline int mlx_fft_fftn( const mlx_stream s) { return mlx_fft_fftn_(res, a, n, n_num, axes, axes_num, s); } - static inline int mlx_fft_fftshift( mlx_array* res, const mlx_array a, @@ -4297,7 +4343,6 @@ static inline int mlx_fft_fftshift( const mlx_stream s) { return mlx_fft_fftshift_(res, a, axes, axes_num, s); } - static inline int mlx_fft_ifft( mlx_array* res, const mlx_array a, @@ -4306,7 +4351,6 @@ static inline int mlx_fft_ifft( const mlx_stream s) { return mlx_fft_ifft_(res, a, n, axis, s); } - static inline int mlx_fft_ifft2( mlx_array* res, const mlx_array a, @@ -4317,7 +4361,6 @@ static inline int mlx_fft_ifft2( const mlx_stream s) { return mlx_fft_ifft2_(res, a, n, n_num, axes, axes_num, s); } - static inline int mlx_fft_ifftn( mlx_array* res, const mlx_array a, @@ -4328,7 +4371,6 @@ static inline int mlx_fft_ifftn( const mlx_stream s) { return mlx_fft_ifftn_(res, a, n, n_num, axes, axes_num, s); } - static inline int mlx_fft_ifftshift( mlx_array* res, const mlx_array a, @@ -4337,7 +4379,6 @@ static inline int mlx_fft_ifftshift( const mlx_stream s) { return mlx_fft_ifftshift_(res, a, axes, axes_num, s); } - static inline int mlx_fft_irfft( mlx_array* res, const mlx_array a, @@ -4346,7 +4387,6 @@ static inline int mlx_fft_irfft( const mlx_stream s) { return mlx_fft_irfft_(res, a, n, axis, s); } - static inline int mlx_fft_irfft2( mlx_array* res, const mlx_array a, @@ -4357,7 +4397,6 @@ static inline int mlx_fft_irfft2( const mlx_stream s) { return mlx_fft_irfft2_(res, a, n, n_num, axes, axes_num, s); } - static inline int mlx_fft_irfftn( mlx_array* res, const mlx_array a, @@ -4368,7 +4407,6 @@ static inline int mlx_fft_irfftn( const mlx_stream s) { return mlx_fft_irfftn_(res, a, n, n_num, axes, axes_num, s); } - static inline int mlx_fft_rfft( mlx_array* res, const mlx_array a, @@ -4377,7 +4415,6 @@ static inline int mlx_fft_rfft( const mlx_stream s) { return mlx_fft_rfft_(res, a, n, axis, s); } - static inline int mlx_fft_rfft2( mlx_array* res, const mlx_array a, @@ -4388,7 +4425,6 @@ static inline int mlx_fft_rfft2( const mlx_stream s) { return mlx_fft_rfft2_(res, a, n, n_num, axes, axes_num, s); } - static inline int mlx_fft_rfftn( mlx_array* res, const mlx_array a, @@ -4399,50 +4435,15 @@ static inline int mlx_fft_rfftn( const mlx_stream s) { return mlx_fft_rfftn_(res, a, n, n_num, axes, axes_num, s); } - -static inline mlx_io_reader mlx_io_reader_new(void* desc, mlx_io_vtable vtable) { - return mlx_io_reader_new_(desc, vtable); -} - -static inline int mlx_io_reader_descriptor(void** desc_, mlx_io_reader io) { - return mlx_io_reader_descriptor_(desc_, io); -} - -static inline int mlx_io_reader_tostring(mlx_string* str_, mlx_io_reader io) { - return mlx_io_reader_tostring_(str_, io); -} - -static inline int mlx_io_reader_free(mlx_io_reader io) { - return mlx_io_reader_free_(io); -} - -static inline mlx_io_writer mlx_io_writer_new(void* desc, mlx_io_vtable vtable) { - return mlx_io_writer_new_(desc, vtable); -} - -static inline int mlx_io_writer_descriptor(void** desc_, mlx_io_writer io) { - return mlx_io_writer_descriptor_(desc_, io); -} - -static inline int mlx_io_writer_tostring(mlx_string* str_, mlx_io_writer io) { - return mlx_io_writer_tostring_(str_, io); -} - -static inline int mlx_io_writer_free(mlx_io_writer io) { - return mlx_io_writer_free_(io); -} - static inline int mlx_load_reader( mlx_array* res, mlx_io_reader in_stream, const mlx_stream s) { return mlx_load_reader_(res, in_stream, s); } - static inline int mlx_load(mlx_array* res, const char* file, const mlx_stream s) { return mlx_load_(res, file, s); } - static inline int mlx_load_safetensors_reader( mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, @@ -4450,7 +4451,6 @@ static inline int mlx_load_safetensors_reader( const mlx_stream s) { return mlx_load_safetensors_reader_(res_0, res_1, in_stream, s); } - static inline int mlx_load_safetensors( mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, @@ -4458,29 +4458,48 @@ static inline int mlx_load_safetensors( const mlx_stream s) { return mlx_load_safetensors_(res_0, res_1, file, s); } - static inline int mlx_save_writer(mlx_io_writer out_stream, const mlx_array a) { return mlx_save_writer_(out_stream, a); } - static inline int mlx_save(const char* file, const mlx_array a) { return mlx_save_(file, a); } - static inline int mlx_save_safetensors_writer( mlx_io_writer in_stream, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) { return mlx_save_safetensors_writer_(in_stream, param, metadata); } - static inline int mlx_save_safetensors( const char* file, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) { return mlx_save_safetensors_(file, param, metadata); } - +static inline mlx_io_reader mlx_io_reader_new(void* desc, mlx_io_vtable vtable) { + return mlx_io_reader_new_(desc, vtable); +} +static inline int mlx_io_reader_descriptor(void** desc_, mlx_io_reader io) { + return mlx_io_reader_descriptor_(desc_, io); +} +static inline int mlx_io_reader_tostring(mlx_string* str_, mlx_io_reader io) { + return mlx_io_reader_tostring_(str_, io); +} +static inline int mlx_io_reader_free(mlx_io_reader io) { + return mlx_io_reader_free_(io); +} +static inline mlx_io_writer mlx_io_writer_new(void* desc, mlx_io_vtable vtable) { + return mlx_io_writer_new_(desc, vtable); +} +static inline int mlx_io_writer_descriptor(void** desc_, mlx_io_writer io) { + return mlx_io_writer_descriptor_(desc_, io); +} +static inline int mlx_io_writer_tostring(mlx_string* str_, mlx_io_writer io) { + return mlx_io_writer_tostring_(str_, io); +} +static inline int mlx_io_writer_free(mlx_io_writer io) { + return mlx_io_writer_free_(io); +} static inline int mlx_linalg_cholesky( mlx_array* res, const mlx_array a, @@ -4488,7 +4507,6 @@ static inline int mlx_linalg_cholesky( const mlx_stream s) { return mlx_linalg_cholesky_(res, a, upper, s); } - static inline int mlx_linalg_cholesky_inv( mlx_array* res, const mlx_array a, @@ -4496,7 +4514,6 @@ static inline int mlx_linalg_cholesky_inv( const mlx_stream s) { return mlx_linalg_cholesky_inv_(res, a, upper, s); } - static inline int mlx_linalg_cross( mlx_array* res, const mlx_array a, @@ -4505,7 +4522,6 @@ static inline int mlx_linalg_cross( const mlx_stream s) { return mlx_linalg_cross_(res, a, b, axis, s); } - static inline int mlx_linalg_eig( mlx_array* res_0, mlx_array* res_1, @@ -4513,7 +4529,6 @@ static inline int mlx_linalg_eig( const mlx_stream s) { return mlx_linalg_eig_(res_0, res_1, a, s); } - static inline int mlx_linalg_eigh( mlx_array* res_0, mlx_array* res_1, @@ -4522,11 +4537,9 @@ static inline int mlx_linalg_eigh( const mlx_stream s) { return mlx_linalg_eigh_(res_0, res_1, a, UPLO, s); } - static inline int mlx_linalg_eigvals(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_linalg_eigvals_(res, a, s); } - static inline int mlx_linalg_eigvalsh( mlx_array* res, const mlx_array a, @@ -4534,15 +4547,12 @@ static inline int mlx_linalg_eigvalsh( const mlx_stream s) { return mlx_linalg_eigvalsh_(res, a, UPLO, s); } - static inline int mlx_linalg_inv(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_linalg_inv_(res, a, s); } - static inline int mlx_linalg_lu(mlx_vector_array* res, const mlx_array a, const mlx_stream s) { return mlx_linalg_lu_(res, a, s); } - static inline int mlx_linalg_lu_factor( mlx_array* res_0, mlx_array* res_1, @@ -4550,7 +4560,6 @@ static inline int mlx_linalg_lu_factor( const mlx_stream s) { return mlx_linalg_lu_factor_(res_0, res_1, a, s); } - static inline int mlx_linalg_norm( mlx_array* res, const mlx_array a, @@ -4561,7 +4570,6 @@ static inline int mlx_linalg_norm( const mlx_stream s) { return mlx_linalg_norm_(res, a, ord, axis, axis_num, keepdims, s); } - static inline int mlx_linalg_norm_matrix( mlx_array* res, const mlx_array a, @@ -4572,7 +4580,6 @@ static inline int mlx_linalg_norm_matrix( const mlx_stream s) { return mlx_linalg_norm_matrix_(res, a, ord, axis, axis_num, keepdims, s); } - static inline int mlx_linalg_norm_l2( mlx_array* res, const mlx_array a, @@ -4582,11 +4589,9 @@ static inline int mlx_linalg_norm_l2( const mlx_stream s) { return mlx_linalg_norm_l2_(res, a, axis, axis_num, keepdims, s); } - static inline int mlx_linalg_pinv(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_linalg_pinv_(res, a, s); } - static inline int mlx_linalg_qr( mlx_array* res_0, mlx_array* res_1, @@ -4594,7 +4599,6 @@ static inline int mlx_linalg_qr( const mlx_stream s) { return mlx_linalg_qr_(res_0, res_1, a, s); } - static inline int mlx_linalg_solve( mlx_array* res, const mlx_array a, @@ -4602,7 +4606,6 @@ static inline int mlx_linalg_solve( const mlx_stream s) { return mlx_linalg_solve_(res, a, b, s); } - static inline int mlx_linalg_solve_triangular( mlx_array* res, const mlx_array a, @@ -4611,7 +4614,6 @@ static inline int mlx_linalg_solve_triangular( const mlx_stream s) { return mlx_linalg_solve_triangular_(res, a, b, upper, s); } - static inline int mlx_linalg_svd( mlx_vector_array* res, const mlx_array a, @@ -4619,7 +4621,6 @@ static inline int mlx_linalg_svd( const mlx_stream s) { return mlx_linalg_svd_(res, a, compute_uv, s); } - static inline int mlx_linalg_tri_inv( mlx_array* res, const mlx_array a, @@ -4627,152 +4628,118 @@ static inline int mlx_linalg_tri_inv( const mlx_stream s) { return mlx_linalg_tri_inv_(res, a, upper, s); } - static inline mlx_map_string_to_array mlx_map_string_to_array_new(void) { return mlx_map_string_to_array_new_(); } - static inline int mlx_map_string_to_array_set( mlx_map_string_to_array* map, const mlx_map_string_to_array src) { return mlx_map_string_to_array_set_(map, src); } - static inline int mlx_map_string_to_array_free(mlx_map_string_to_array map) { return mlx_map_string_to_array_free_(map); } - static inline int mlx_map_string_to_array_insert( mlx_map_string_to_array map, const char* key, const mlx_array value) { return mlx_map_string_to_array_insert_(map, key, value); } - static inline int mlx_map_string_to_array_get( mlx_array* value, const mlx_map_string_to_array map, const char* key) { return mlx_map_string_to_array_get_(value, map, key); } - static inline mlx_map_string_to_array_iterator mlx_map_string_to_array_iterator_new( mlx_map_string_to_array map) { return mlx_map_string_to_array_iterator_new_(map); } - static inline int mlx_map_string_to_array_iterator_free(mlx_map_string_to_array_iterator it) { return mlx_map_string_to_array_iterator_free_(it); } - static inline int mlx_map_string_to_array_iterator_next( const char** key, mlx_array* value, mlx_map_string_to_array_iterator it) { return mlx_map_string_to_array_iterator_next_(key, value, it); } - static inline mlx_map_string_to_string mlx_map_string_to_string_new(void) { return mlx_map_string_to_string_new_(); } - static inline int mlx_map_string_to_string_set( mlx_map_string_to_string* map, const mlx_map_string_to_string src) { return mlx_map_string_to_string_set_(map, src); } - static inline int mlx_map_string_to_string_free(mlx_map_string_to_string map) { return mlx_map_string_to_string_free_(map); } - static inline int mlx_map_string_to_string_insert( mlx_map_string_to_string map, const char* key, const char* value) { return mlx_map_string_to_string_insert_(map, key, value); } - static inline int mlx_map_string_to_string_get( const char** value, const mlx_map_string_to_string map, const char* key) { return mlx_map_string_to_string_get_(value, map, key); } - static inline mlx_map_string_to_string_iterator mlx_map_string_to_string_iterator_new( mlx_map_string_to_string map) { return mlx_map_string_to_string_iterator_new_(map); } - static inline int mlx_map_string_to_string_iterator_free( mlx_map_string_to_string_iterator it) { return mlx_map_string_to_string_iterator_free_(it); } - static inline int mlx_map_string_to_string_iterator_next( const char** key, const char** value, mlx_map_string_to_string_iterator it) { return mlx_map_string_to_string_iterator_next_(key, value, it); } - static inline int mlx_clear_cache(void) { return mlx_clear_cache_(); } - static inline int mlx_get_active_memory(size_t* res) { return mlx_get_active_memory_(res); } - static inline int mlx_get_cache_memory(size_t* res) { return mlx_get_cache_memory_(res); } - static inline int mlx_get_memory_limit(size_t* res) { return mlx_get_memory_limit_(res); } - static inline int mlx_get_peak_memory(size_t* res) { return mlx_get_peak_memory_(res); } - static inline int mlx_reset_peak_memory(void) { return mlx_reset_peak_memory_(); } - static inline int mlx_set_cache_limit(size_t* res, size_t limit) { return mlx_set_cache_limit_(res, limit); } - static inline int mlx_set_memory_limit(size_t* res, size_t limit) { return mlx_set_memory_limit_(res, limit); } - static inline int mlx_set_wired_limit(size_t* res, size_t limit) { return mlx_set_wired_limit_(res, limit); } - -static inline mlx_metal_device_info_t mlx_metal_device_info(void) { - return mlx_metal_device_info_(); -} - static inline int mlx_metal_is_available(bool* res) { return mlx_metal_is_available_(res); } - static inline int mlx_metal_start_capture(const char* path) { return mlx_metal_start_capture_(path); } - static inline int mlx_metal_stop_capture(void) { return mlx_metal_stop_capture_(); } - static inline int mlx_abs(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_abs_(res, a, s); } - static inline int mlx_add( mlx_array* res, const mlx_array a, @@ -4780,7 +4747,6 @@ static inline int mlx_add( const mlx_stream s) { return mlx_add_(res, a, b, s); } - static inline int mlx_addmm( mlx_array* res, const mlx_array c, @@ -4791,7 +4757,6 @@ static inline int mlx_addmm( const mlx_stream s) { return mlx_addmm_(res, c, a, b, alpha, beta, s); } - static inline int mlx_all_axes( mlx_array* res, const mlx_array a, @@ -4801,7 +4766,6 @@ static inline int mlx_all_axes( const mlx_stream s) { return mlx_all_axes_(res, a, axes, axes_num, keepdims, s); } - static inline int mlx_all_axis( mlx_array* res, const mlx_array a, @@ -4810,7 +4774,6 @@ static inline int mlx_all_axis( const mlx_stream s) { return mlx_all_axis_(res, a, axis, keepdims, s); } - static inline int mlx_all( mlx_array* res, const mlx_array a, @@ -4818,7 +4781,6 @@ static inline int mlx_all( const mlx_stream s) { return mlx_all_(res, a, keepdims, s); } - static inline int mlx_allclose( mlx_array* res, const mlx_array a, @@ -4829,7 +4791,6 @@ static inline int mlx_allclose( const mlx_stream s) { return mlx_allclose_(res, a, b, rtol, atol, equal_nan, s); } - static inline int mlx_any_axes( mlx_array* res, const mlx_array a, @@ -4839,7 +4800,6 @@ static inline int mlx_any_axes( const mlx_stream s) { return mlx_any_axes_(res, a, axes, axes_num, keepdims, s); } - static inline int mlx_any_axis( mlx_array* res, const mlx_array a, @@ -4848,7 +4808,6 @@ static inline int mlx_any_axis( const mlx_stream s) { return mlx_any_axis_(res, a, axis, keepdims, s); } - static inline int mlx_any( mlx_array* res, const mlx_array a, @@ -4856,7 +4815,6 @@ static inline int mlx_any( const mlx_stream s) { return mlx_any_(res, a, keepdims, s); } - static inline int mlx_arange( mlx_array* res, double start, @@ -4866,27 +4824,21 @@ static inline int mlx_arange( const mlx_stream s) { return mlx_arange_(res, start, stop, step, dtype, s); } - static inline int mlx_arccos(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_arccos_(res, a, s); } - static inline int mlx_arccosh(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_arccosh_(res, a, s); } - static inline int mlx_arcsin(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_arcsin_(res, a, s); } - static inline int mlx_arcsinh(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_arcsinh_(res, a, s); } - static inline int mlx_arctan(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_arctan_(res, a, s); } - static inline int mlx_arctan2( mlx_array* res, const mlx_array a, @@ -4894,11 +4846,9 @@ static inline int mlx_arctan2( const mlx_stream s) { return mlx_arctan2_(res, a, b, s); } - static inline int mlx_arctanh(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_arctanh_(res, a, s); } - static inline int mlx_argmax_axis( mlx_array* res, const mlx_array a, @@ -4907,7 +4857,6 @@ static inline int mlx_argmax_axis( const mlx_stream s) { return mlx_argmax_axis_(res, a, axis, keepdims, s); } - static inline int mlx_argmax( mlx_array* res, const mlx_array a, @@ -4915,7 +4864,6 @@ static inline int mlx_argmax( const mlx_stream s) { return mlx_argmax_(res, a, keepdims, s); } - static inline int mlx_argmin_axis( mlx_array* res, const mlx_array a, @@ -4924,7 +4872,6 @@ static inline int mlx_argmin_axis( const mlx_stream s) { return mlx_argmin_axis_(res, a, axis, keepdims, s); } - static inline int mlx_argmin( mlx_array* res, const mlx_array a, @@ -4932,7 +4879,6 @@ static inline int mlx_argmin( const mlx_stream s) { return mlx_argmin_(res, a, keepdims, s); } - static inline int mlx_argpartition_axis( mlx_array* res, const mlx_array a, @@ -4941,7 +4887,6 @@ static inline int mlx_argpartition_axis( const mlx_stream s) { return mlx_argpartition_axis_(res, a, kth, axis, s); } - static inline int mlx_argpartition( mlx_array* res, const mlx_array a, @@ -4949,7 +4894,6 @@ static inline int mlx_argpartition( const mlx_stream s) { return mlx_argpartition_(res, a, kth, s); } - static inline int mlx_argsort_axis( mlx_array* res, const mlx_array a, @@ -4957,11 +4901,9 @@ static inline int mlx_argsort_axis( const mlx_stream s) { return mlx_argsort_axis_(res, a, axis, s); } - static inline int mlx_argsort(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_argsort_(res, a, s); } - static inline int mlx_array_equal( mlx_array* res, const mlx_array a, @@ -4970,7 +4912,6 @@ static inline int mlx_array_equal( const mlx_stream s) { return mlx_array_equal_(res, a, b, equal_nan, s); } - static inline int mlx_as_strided( mlx_array* res, const mlx_array a, @@ -4982,7 +4923,6 @@ static inline int mlx_as_strided( const mlx_stream s) { return mlx_as_strided_(res, a, shape, shape_num, strides, strides_num, offset, s); } - static inline int mlx_astype( mlx_array* res, const mlx_array a, @@ -4990,19 +4930,15 @@ static inline int mlx_astype( const mlx_stream s) { return mlx_astype_(res, a, dtype, s); } - static inline int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_atleast_1d_(res, a, s); } - static inline int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_atleast_2d_(res, a, s); } - static inline int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_atleast_3d_(res, a, s); } - static inline int mlx_bitwise_and( mlx_array* res, const mlx_array a, @@ -5010,11 +4946,9 @@ static inline int mlx_bitwise_and( const mlx_stream s) { return mlx_bitwise_and_(res, a, b, s); } - static inline int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_bitwise_invert_(res, a, s); } - static inline int mlx_bitwise_or( mlx_array* res, const mlx_array a, @@ -5022,7 +4956,6 @@ static inline int mlx_bitwise_or( const mlx_stream s) { return mlx_bitwise_or_(res, a, b, s); } - static inline int mlx_bitwise_xor( mlx_array* res, const mlx_array a, @@ -5030,7 +4963,6 @@ static inline int mlx_bitwise_xor( const mlx_stream s) { return mlx_bitwise_xor_(res, a, b, s); } - static inline int mlx_block_masked_mm( mlx_array* res, const mlx_array a, @@ -5042,14 +4974,12 @@ static inline int mlx_block_masked_mm( const mlx_stream s) { return mlx_block_masked_mm_(res, a, b, block_size, mask_out, mask_lhs, mask_rhs, s); } - static inline int mlx_broadcast_arrays( mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s) { return mlx_broadcast_arrays_(res, inputs, s); } - static inline int mlx_broadcast_to( mlx_array* res, const mlx_array a, @@ -5058,11 +4988,9 @@ static inline int mlx_broadcast_to( const mlx_stream s) { return mlx_broadcast_to_(res, a, shape, shape_num, s); } - static inline int mlx_ceil(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_ceil_(res, a, s); } - static inline int mlx_clip( mlx_array* res, const mlx_array a, @@ -5071,7 +4999,6 @@ static inline int mlx_clip( const mlx_stream s) { return mlx_clip_(res, a, a_min, a_max, s); } - static inline int mlx_concatenate_axis( mlx_array* res, const mlx_vector_array arrays, @@ -5079,18 +5006,15 @@ static inline int mlx_concatenate_axis( const mlx_stream s) { return mlx_concatenate_axis_(res, arrays, axis, s); } - static inline int mlx_concatenate( mlx_array* res, const mlx_vector_array arrays, const mlx_stream s) { return mlx_concatenate_(res, arrays, s); } - static inline int mlx_conjugate(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_conjugate_(res, a, s); } - static inline int mlx_contiguous( mlx_array* res, const mlx_array a, @@ -5098,7 +5022,6 @@ static inline int mlx_contiguous( const mlx_stream s) { return mlx_contiguous_(res, a, allow_col_major, s); } - static inline int mlx_conv1d( mlx_array* res, const mlx_array input, @@ -5110,7 +5033,6 @@ static inline int mlx_conv1d( const mlx_stream s) { return mlx_conv1d_(res, input, weight, stride, padding, dilation, groups, s); } - static inline int mlx_conv2d( mlx_array* res, const mlx_array input, @@ -5125,7 +5047,6 @@ static inline int mlx_conv2d( const mlx_stream s) { return mlx_conv2d_(res, input, weight, stride_0, stride_1, padding_0, padding_1, dilation_0, dilation_1, groups, s); } - static inline int mlx_conv3d( mlx_array* res, const mlx_array input, @@ -5143,7 +5064,6 @@ static inline int mlx_conv3d( const mlx_stream s) { return mlx_conv3d_(res, input, weight, stride_0, stride_1, stride_2, padding_0, padding_1, padding_2, dilation_0, dilation_1, dilation_2, groups, s); } - static inline int mlx_conv_general( mlx_array* res, const mlx_array input, @@ -5163,7 +5083,6 @@ static inline int mlx_conv_general( const mlx_stream s) { return mlx_conv_general_(res, input, weight, stride, stride_num, padding_lo, padding_lo_num, padding_hi, padding_hi_num, kernel_dilation, kernel_dilation_num, input_dilation, input_dilation_num, groups, flip, s); } - static inline int mlx_conv_transpose1d( mlx_array* res, const mlx_array input, @@ -5176,7 +5095,6 @@ static inline int mlx_conv_transpose1d( const mlx_stream s) { return mlx_conv_transpose1d_(res, input, weight, stride, padding, dilation, output_padding, groups, s); } - static inline int mlx_conv_transpose2d( mlx_array* res, const mlx_array input, @@ -5193,7 +5111,6 @@ static inline int mlx_conv_transpose2d( const mlx_stream s) { return mlx_conv_transpose2d_(res, input, weight, stride_0, stride_1, padding_0, padding_1, dilation_0, dilation_1, output_padding_0, output_padding_1, groups, s); } - static inline int mlx_conv_transpose3d( mlx_array* res, const mlx_array input, @@ -5214,19 +5131,15 @@ static inline int mlx_conv_transpose3d( const mlx_stream s) { return mlx_conv_transpose3d_(res, input, weight, stride_0, stride_1, stride_2, padding_0, padding_1, padding_2, dilation_0, dilation_1, dilation_2, output_padding_0, output_padding_1, output_padding_2, groups, s); } - static inline int mlx_copy(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_copy_(res, a, s); } - static inline int mlx_cos(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_cos_(res, a, s); } - static inline int mlx_cosh(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_cosh_(res, a, s); } - static inline int mlx_cummax( mlx_array* res, const mlx_array a, @@ -5236,7 +5149,6 @@ static inline int mlx_cummax( const mlx_stream s) { return mlx_cummax_(res, a, axis, reverse, inclusive, s); } - static inline int mlx_cummin( mlx_array* res, const mlx_array a, @@ -5246,7 +5158,6 @@ static inline int mlx_cummin( const mlx_stream s) { return mlx_cummin_(res, a, axis, reverse, inclusive, s); } - static inline int mlx_cumprod( mlx_array* res, const mlx_array a, @@ -5256,7 +5167,6 @@ static inline int mlx_cumprod( const mlx_stream s) { return mlx_cumprod_(res, a, axis, reverse, inclusive, s); } - static inline int mlx_cumsum( mlx_array* res, const mlx_array a, @@ -5266,18 +5176,15 @@ static inline int mlx_cumsum( const mlx_stream s) { return mlx_cumsum_(res, a, axis, reverse, inclusive, s); } - static inline int mlx_degrees(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_degrees_(res, a, s); } - static inline int mlx_depends( mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies) { return mlx_depends_(res, inputs, dependencies); } - static inline int mlx_dequantize( mlx_array* res, const mlx_array w, @@ -5290,11 +5197,9 @@ static inline int mlx_dequantize( const mlx_stream s) { return mlx_dequantize_(res, w, scales, biases, group_size, bits, mode, dtype, s); } - static inline int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) { return mlx_diag_(res, a, k, s); } - static inline int mlx_diagonal( mlx_array* res, const mlx_array a, @@ -5304,7 +5209,6 @@ static inline int mlx_diagonal( const mlx_stream s) { return mlx_diagonal_(res, a, offset, axis1, axis2, s); } - static inline int mlx_divide( mlx_array* res, const mlx_array a, @@ -5312,7 +5216,6 @@ static inline int mlx_divide( const mlx_stream s) { return mlx_divide_(res, a, b, s); } - static inline int mlx_divmod( mlx_vector_array* res, const mlx_array a, @@ -5320,7 +5223,6 @@ static inline int mlx_divmod( const mlx_stream s) { return mlx_divmod_(res, a, b, s); } - static inline int mlx_einsum( mlx_array* res, const char* subscripts, @@ -5328,7 +5230,6 @@ static inline int mlx_einsum( const mlx_stream s) { return mlx_einsum_(res, subscripts, operands, s); } - static inline int mlx_equal( mlx_array* res, const mlx_array a, @@ -5336,19 +5237,15 @@ static inline int mlx_equal( const mlx_stream s) { return mlx_equal_(res, a, b, s); } - static inline int mlx_erf(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_erf_(res, a, s); } - static inline int mlx_erfinv(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_erfinv_(res, a, s); } - static inline int mlx_exp(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_exp_(res, a, s); } - static inline int mlx_expand_dims_axes( mlx_array* res, const mlx_array a, @@ -5357,7 +5254,6 @@ static inline int mlx_expand_dims_axes( const mlx_stream s) { return mlx_expand_dims_axes_(res, a, axes, axes_num, s); } - static inline int mlx_expand_dims( mlx_array* res, const mlx_array a, @@ -5365,11 +5261,9 @@ static inline int mlx_expand_dims( const mlx_stream s) { return mlx_expand_dims_(res, a, axis, s); } - static inline int mlx_expm1(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_expm1_(res, a, s); } - static inline int mlx_eye( mlx_array* res, int n, @@ -5379,7 +5273,6 @@ static inline int mlx_eye( const mlx_stream s) { return mlx_eye_(res, n, m, k, dtype, s); } - static inline int mlx_flatten( mlx_array* res, const mlx_array a, @@ -5388,11 +5281,9 @@ static inline int mlx_flatten( const mlx_stream s) { return mlx_flatten_(res, a, start_axis, end_axis, s); } - static inline int mlx_floor(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_floor_(res, a, s); } - static inline int mlx_floor_divide( mlx_array* res, const mlx_array a, @@ -5400,7 +5291,6 @@ static inline int mlx_floor_divide( const mlx_stream s) { return mlx_floor_divide_(res, a, b, s); } - static inline int mlx_from_fp8( mlx_array* res, const mlx_array x, @@ -5408,7 +5298,6 @@ static inline int mlx_from_fp8( const mlx_stream s) { return mlx_from_fp8_(res, x, dtype, s); } - static inline int mlx_full( mlx_array* res, const int* shape, @@ -5418,7 +5307,6 @@ static inline int mlx_full( const mlx_stream s) { return mlx_full_(res, shape, shape_num, vals, dtype, s); } - static inline int mlx_full_like( mlx_array* res, const mlx_array a, @@ -5427,7 +5315,6 @@ static inline int mlx_full_like( const mlx_stream s) { return mlx_full_like_(res, a, vals, dtype, s); } - static inline int mlx_gather( mlx_array* res, const mlx_array a, @@ -5439,7 +5326,16 @@ static inline int mlx_gather( const mlx_stream s) { return mlx_gather_(res, a, indices, axes, axes_num, slice_sizes, slice_sizes_num, s); } - +static inline int mlx_gather_single( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + int axis, + const int* slice_sizes, + size_t slice_sizes_num, + const mlx_stream s) { + return mlx_gather_single_(res, a, indices, axis, slice_sizes, slice_sizes_num, s); +} static inline int mlx_gather_mm( mlx_array* res, const mlx_array a, @@ -5450,7 +5346,6 @@ static inline int mlx_gather_mm( const mlx_stream s) { return mlx_gather_mm_(res, a, b, lhs_indices, rhs_indices, sorted_indices, s); } - static inline int mlx_gather_qmm( mlx_array* res, const mlx_array x, @@ -5467,7 +5362,6 @@ static inline int mlx_gather_qmm( const mlx_stream s) { return mlx_gather_qmm_(res, x, w, scales, biases, lhs_indices, rhs_indices, transpose, group_size, bits, mode, sorted_indices, s); } - static inline int mlx_greater( mlx_array* res, const mlx_array a, @@ -5475,7 +5369,6 @@ static inline int mlx_greater( const mlx_stream s) { return mlx_greater_(res, a, b, s); } - static inline int mlx_greater_equal( mlx_array* res, const mlx_array a, @@ -5483,7 +5376,6 @@ static inline int mlx_greater_equal( const mlx_stream s) { return mlx_greater_equal_(res, a, b, s); } - static inline int mlx_hadamard_transform( mlx_array* res, const mlx_array a, @@ -5491,15 +5383,12 @@ static inline int mlx_hadamard_transform( const mlx_stream s) { return mlx_hadamard_transform_(res, a, scale, s); } - static inline int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) { return mlx_identity_(res, n, dtype, s); } - static inline int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_imag_(res, a, s); } - static inline int mlx_inner( mlx_array* res, const mlx_array a, @@ -5507,7 +5396,6 @@ static inline int mlx_inner( const mlx_stream s) { return mlx_inner_(res, a, b, s); } - static inline int mlx_isclose( mlx_array* res, const mlx_array a, @@ -5518,27 +5406,21 @@ static inline int mlx_isclose( const mlx_stream s) { return mlx_isclose_(res, a, b, rtol, atol, equal_nan, s); } - static inline int mlx_isfinite(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_isfinite_(res, a, s); } - static inline int mlx_isinf(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_isinf_(res, a, s); } - static inline int mlx_isnan(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_isnan_(res, a, s); } - static inline int mlx_isneginf(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_isneginf_(res, a, s); } - static inline int mlx_isposinf(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_isposinf_(res, a, s); } - static inline int mlx_kron( mlx_array* res, const mlx_array a, @@ -5546,7 +5428,6 @@ static inline int mlx_kron( const mlx_stream s) { return mlx_kron_(res, a, b, s); } - static inline int mlx_left_shift( mlx_array* res, const mlx_array a, @@ -5554,7 +5435,6 @@ static inline int mlx_left_shift( const mlx_stream s) { return mlx_left_shift_(res, a, b, s); } - static inline int mlx_less( mlx_array* res, const mlx_array a, @@ -5562,7 +5442,6 @@ static inline int mlx_less( const mlx_stream s) { return mlx_less_(res, a, b, s); } - static inline int mlx_less_equal( mlx_array* res, const mlx_array a, @@ -5570,7 +5449,6 @@ static inline int mlx_less_equal( const mlx_stream s) { return mlx_less_equal_(res, a, b, s); } - static inline int mlx_linspace( mlx_array* res, double start, @@ -5580,23 +5458,18 @@ static inline int mlx_linspace( const mlx_stream s) { return mlx_linspace_(res, start, stop, num, dtype, s); } - static inline int mlx_log(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_log_(res, a, s); } - static inline int mlx_log10(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_log10_(res, a, s); } - static inline int mlx_log1p(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_log1p_(res, a, s); } - static inline int mlx_log2(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_log2_(res, a, s); } - static inline int mlx_logaddexp( mlx_array* res, const mlx_array a, @@ -5604,7 +5477,6 @@ static inline int mlx_logaddexp( const mlx_stream s) { return mlx_logaddexp_(res, a, b, s); } - static inline int mlx_logcumsumexp( mlx_array* res, const mlx_array a, @@ -5614,7 +5486,6 @@ static inline int mlx_logcumsumexp( const mlx_stream s) { return mlx_logcumsumexp_(res, a, axis, reverse, inclusive, s); } - static inline int mlx_logical_and( mlx_array* res, const mlx_array a, @@ -5622,11 +5493,9 @@ static inline int mlx_logical_and( const mlx_stream s) { return mlx_logical_and_(res, a, b, s); } - static inline int mlx_logical_not(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_logical_not_(res, a, s); } - static inline int mlx_logical_or( mlx_array* res, const mlx_array a, @@ -5634,7 +5503,6 @@ static inline int mlx_logical_or( const mlx_stream s) { return mlx_logical_or_(res, a, b, s); } - static inline int mlx_logsumexp_axes( mlx_array* res, const mlx_array a, @@ -5644,7 +5512,6 @@ static inline int mlx_logsumexp_axes( const mlx_stream s) { return mlx_logsumexp_axes_(res, a, axes, axes_num, keepdims, s); } - static inline int mlx_logsumexp_axis( mlx_array* res, const mlx_array a, @@ -5653,7 +5520,6 @@ static inline int mlx_logsumexp_axis( const mlx_stream s) { return mlx_logsumexp_axis_(res, a, axis, keepdims, s); } - static inline int mlx_logsumexp( mlx_array* res, const mlx_array a, @@ -5661,7 +5527,6 @@ static inline int mlx_logsumexp( const mlx_stream s) { return mlx_logsumexp_(res, a, keepdims, s); } - static inline int mlx_masked_scatter( mlx_array* res, const mlx_array a, @@ -5670,7 +5535,6 @@ static inline int mlx_masked_scatter( const mlx_stream s) { return mlx_masked_scatter_(res, a, mask, src, s); } - static inline int mlx_matmul( mlx_array* res, const mlx_array a, @@ -5678,7 +5542,6 @@ static inline int mlx_matmul( const mlx_stream s) { return mlx_matmul_(res, a, b, s); } - static inline int mlx_max_axes( mlx_array* res, const mlx_array a, @@ -5688,7 +5551,6 @@ static inline int mlx_max_axes( const mlx_stream s) { return mlx_max_axes_(res, a, axes, axes_num, keepdims, s); } - static inline int mlx_max_axis( mlx_array* res, const mlx_array a, @@ -5697,7 +5559,6 @@ static inline int mlx_max_axis( const mlx_stream s) { return mlx_max_axis_(res, a, axis, keepdims, s); } - static inline int mlx_max( mlx_array* res, const mlx_array a, @@ -5705,7 +5566,6 @@ static inline int mlx_max( const mlx_stream s) { return mlx_max_(res, a, keepdims, s); } - static inline int mlx_maximum( mlx_array* res, const mlx_array a, @@ -5713,7 +5573,6 @@ static inline int mlx_maximum( const mlx_stream s) { return mlx_maximum_(res, a, b, s); } - static inline int mlx_mean_axes( mlx_array* res, const mlx_array a, @@ -5723,7 +5582,6 @@ static inline int mlx_mean_axes( const mlx_stream s) { return mlx_mean_axes_(res, a, axes, axes_num, keepdims, s); } - static inline int mlx_mean_axis( mlx_array* res, const mlx_array a, @@ -5732,7 +5590,6 @@ static inline int mlx_mean_axis( const mlx_stream s) { return mlx_mean_axis_(res, a, axis, keepdims, s); } - static inline int mlx_mean( mlx_array* res, const mlx_array a, @@ -5740,7 +5597,6 @@ static inline int mlx_mean( const mlx_stream s) { return mlx_mean_(res, a, keepdims, s); } - static inline int mlx_median( mlx_array* res, const mlx_array a, @@ -5750,7 +5606,6 @@ static inline int mlx_median( const mlx_stream s) { return mlx_median_(res, a, axes, axes_num, keepdims, s); } - static inline int mlx_meshgrid( mlx_vector_array* res, const mlx_vector_array arrays, @@ -5759,7 +5614,6 @@ static inline int mlx_meshgrid( const mlx_stream s) { return mlx_meshgrid_(res, arrays, sparse, indexing, s); } - static inline int mlx_min_axes( mlx_array* res, const mlx_array a, @@ -5769,7 +5623,6 @@ static inline int mlx_min_axes( const mlx_stream s) { return mlx_min_axes_(res, a, axes, axes_num, keepdims, s); } - static inline int mlx_min_axis( mlx_array* res, const mlx_array a, @@ -5778,7 +5631,6 @@ static inline int mlx_min_axis( const mlx_stream s) { return mlx_min_axis_(res, a, axis, keepdims, s); } - static inline int mlx_min( mlx_array* res, const mlx_array a, @@ -5786,7 +5638,6 @@ static inline int mlx_min( const mlx_stream s) { return mlx_min_(res, a, keepdims, s); } - static inline int mlx_minimum( mlx_array* res, const mlx_array a, @@ -5794,7 +5645,6 @@ static inline int mlx_minimum( const mlx_stream s) { return mlx_minimum_(res, a, b, s); } - static inline int mlx_moveaxis( mlx_array* res, const mlx_array a, @@ -5803,7 +5653,6 @@ static inline int mlx_moveaxis( const mlx_stream s) { return mlx_moveaxis_(res, a, source, destination, s); } - static inline int mlx_multiply( mlx_array* res, const mlx_array a, @@ -5811,7 +5660,6 @@ static inline int mlx_multiply( const mlx_stream s) { return mlx_multiply_(res, a, b, s); } - static inline int mlx_nan_to_num( mlx_array* res, const mlx_array a, @@ -5821,11 +5669,9 @@ static inline int mlx_nan_to_num( const mlx_stream s) { return mlx_nan_to_num_(res, a, nan, posinf, neginf, s); } - static inline int mlx_negative(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_negative_(res, a, s); } - static inline int mlx_not_equal( mlx_array* res, const mlx_array a, @@ -5833,7 +5679,6 @@ static inline int mlx_not_equal( const mlx_stream s) { return mlx_not_equal_(res, a, b, s); } - static inline int mlx_number_of_elements( mlx_array* res, const mlx_array a, @@ -5844,7 +5689,6 @@ static inline int mlx_number_of_elements( const mlx_stream s) { return mlx_number_of_elements_(res, a, axes, axes_num, inverted, dtype, s); } - static inline int mlx_ones( mlx_array* res, const int* shape, @@ -5853,11 +5697,9 @@ static inline int mlx_ones( const mlx_stream s) { return mlx_ones_(res, shape, shape_num, dtype, s); } - static inline int mlx_ones_like(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_ones_like_(res, a, s); } - static inline int mlx_outer( mlx_array* res, const mlx_array a, @@ -5865,7 +5707,6 @@ static inline int mlx_outer( const mlx_stream s) { return mlx_outer_(res, a, b, s); } - static inline int mlx_pad( mlx_array* res, const mlx_array a, @@ -5880,7 +5721,6 @@ static inline int mlx_pad( const mlx_stream s) { return mlx_pad_(res, a, axes, axes_num, low_pad_size, low_pad_size_num, high_pad_size, high_pad_size_num, pad_value, mode, s); } - static inline int mlx_pad_symmetric( mlx_array* res, const mlx_array a, @@ -5890,7 +5730,6 @@ static inline int mlx_pad_symmetric( const mlx_stream s) { return mlx_pad_symmetric_(res, a, pad_width, pad_value, mode, s); } - static inline int mlx_partition_axis( mlx_array* res, const mlx_array a, @@ -5899,7 +5738,6 @@ static inline int mlx_partition_axis( const mlx_stream s) { return mlx_partition_axis_(res, a, kth, axis, s); } - static inline int mlx_partition( mlx_array* res, const mlx_array a, @@ -5907,7 +5745,6 @@ static inline int mlx_partition( const mlx_stream s) { return mlx_partition_(res, a, kth, s); } - static inline int mlx_power( mlx_array* res, const mlx_array a, @@ -5915,7 +5752,6 @@ static inline int mlx_power( const mlx_stream s) { return mlx_power_(res, a, b, s); } - static inline int mlx_prod_axes( mlx_array* res, const mlx_array a, @@ -5925,7 +5761,6 @@ static inline int mlx_prod_axes( const mlx_stream s) { return mlx_prod_axes_(res, a, axes, axes_num, keepdims, s); } - static inline int mlx_prod_axis( mlx_array* res, const mlx_array a, @@ -5934,7 +5769,6 @@ static inline int mlx_prod_axis( const mlx_stream s) { return mlx_prod_axis_(res, a, axis, keepdims, s); } - static inline int mlx_prod( mlx_array* res, const mlx_array a, @@ -5942,7 +5776,6 @@ static inline int mlx_prod( const mlx_stream s) { return mlx_prod_(res, a, keepdims, s); } - static inline int mlx_put_along_axis( mlx_array* res, const mlx_array a, @@ -5952,7 +5785,17 @@ static inline int mlx_put_along_axis( const mlx_stream s) { return mlx_put_along_axis_(res, a, indices, values, axis, s); } - +static inline int mlx_qqmm( + mlx_array* res, + const mlx_array x, + const mlx_array w, + const mlx_array w_scales /* may be null */, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + const mlx_stream s) { + return mlx_qqmm_(res, x, w, w_scales, group_size, bits, mode, s); +} static inline int mlx_quantize( mlx_vector_array* res, const mlx_array w, @@ -5962,7 +5805,6 @@ static inline int mlx_quantize( const mlx_stream s) { return mlx_quantize_(res, w, group_size, bits, mode, s); } - static inline int mlx_quantized_matmul( mlx_array* res, const mlx_array x, @@ -5976,19 +5818,15 @@ static inline int mlx_quantized_matmul( const mlx_stream s) { return mlx_quantized_matmul_(res, x, w, scales, biases, transpose, group_size, bits, mode, s); } - static inline int mlx_radians(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_radians_(res, a, s); } - static inline int mlx_real(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_real_(res, a, s); } - static inline int mlx_reciprocal(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_reciprocal_(res, a, s); } - static inline int mlx_remainder( mlx_array* res, const mlx_array a, @@ -5996,7 +5834,6 @@ static inline int mlx_remainder( const mlx_stream s) { return mlx_remainder_(res, a, b, s); } - static inline int mlx_repeat_axis( mlx_array* res, const mlx_array arr, @@ -6005,7 +5842,6 @@ static inline int mlx_repeat_axis( const mlx_stream s) { return mlx_repeat_axis_(res, arr, repeats, axis, s); } - static inline int mlx_repeat( mlx_array* res, const mlx_array arr, @@ -6013,7 +5849,6 @@ static inline int mlx_repeat( const mlx_stream s) { return mlx_repeat_(res, arr, repeats, s); } - static inline int mlx_reshape( mlx_array* res, const mlx_array a, @@ -6022,7 +5857,6 @@ static inline int mlx_reshape( const mlx_stream s) { return mlx_reshape_(res, a, shape, shape_num, s); } - static inline int mlx_right_shift( mlx_array* res, const mlx_array a, @@ -6030,7 +5864,6 @@ static inline int mlx_right_shift( const mlx_stream s) { return mlx_right_shift_(res, a, b, s); } - static inline int mlx_roll_axis( mlx_array* res, const mlx_array a, @@ -6040,7 +5873,6 @@ static inline int mlx_roll_axis( const mlx_stream s) { return mlx_roll_axis_(res, a, shift, shift_num, axis, s); } - static inline int mlx_roll_axes( mlx_array* res, const mlx_array a, @@ -6051,7 +5883,6 @@ static inline int mlx_roll_axes( const mlx_stream s) { return mlx_roll_axes_(res, a, shift, shift_num, axes, axes_num, s); } - static inline int mlx_roll( mlx_array* res, const mlx_array a, @@ -6060,7 +5891,6 @@ static inline int mlx_roll( const mlx_stream s) { return mlx_roll_(res, a, shift, shift_num, s); } - static inline int mlx_round( mlx_array* res, const mlx_array a, @@ -6068,11 +5898,9 @@ static inline int mlx_round( const mlx_stream s) { return mlx_round_(res, a, decimals, s); } - static inline int mlx_rsqrt(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_rsqrt_(res, a, s); } - static inline int mlx_scatter( mlx_array* res, const mlx_array a, @@ -6083,7 +5911,15 @@ static inline int mlx_scatter( const mlx_stream s) { return mlx_scatter_(res, a, indices, updates, axes, axes_num, s); } - +static inline int mlx_scatter_single( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s) { + return mlx_scatter_single_(res, a, indices, updates, axis, s); +} static inline int mlx_scatter_add( mlx_array* res, const mlx_array a, @@ -6094,7 +5930,15 @@ static inline int mlx_scatter_add( const mlx_stream s) { return mlx_scatter_add_(res, a, indices, updates, axes, axes_num, s); } - +static inline int mlx_scatter_add_single( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s) { + return mlx_scatter_add_single_(res, a, indices, updates, axis, s); +} static inline int mlx_scatter_add_axis( mlx_array* res, const mlx_array a, @@ -6104,7 +5948,6 @@ static inline int mlx_scatter_add_axis( const mlx_stream s) { return mlx_scatter_add_axis_(res, a, indices, values, axis, s); } - static inline int mlx_scatter_max( mlx_array* res, const mlx_array a, @@ -6115,7 +5958,15 @@ static inline int mlx_scatter_max( const mlx_stream s) { return mlx_scatter_max_(res, a, indices, updates, axes, axes_num, s); } - +static inline int mlx_scatter_max_single( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s) { + return mlx_scatter_max_single_(res, a, indices, updates, axis, s); +} static inline int mlx_scatter_min( mlx_array* res, const mlx_array a, @@ -6126,7 +5977,15 @@ static inline int mlx_scatter_min( const mlx_stream s) { return mlx_scatter_min_(res, a, indices, updates, axes, axes_num, s); } - +static inline int mlx_scatter_min_single( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s) { + return mlx_scatter_min_single_(res, a, indices, updates, axis, s); +} static inline int mlx_scatter_prod( mlx_array* res, const mlx_array a, @@ -6137,7 +5996,15 @@ static inline int mlx_scatter_prod( const mlx_stream s) { return mlx_scatter_prod_(res, a, indices, updates, axes, axes_num, s); } - +static inline int mlx_scatter_prod_single( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s) { + return mlx_scatter_prod_single_(res, a, indices, updates, axis, s); +} static inline int mlx_segmented_mm( mlx_array* res, const mlx_array a, @@ -6146,23 +6013,18 @@ static inline int mlx_segmented_mm( const mlx_stream s) { return mlx_segmented_mm_(res, a, b, segments, s); } - static inline int mlx_sigmoid(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_sigmoid_(res, a, s); } - static inline int mlx_sign(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_sign_(res, a, s); } - static inline int mlx_sin(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_sin_(res, a, s); } - static inline int mlx_sinh(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_sinh_(res, a, s); } - static inline int mlx_slice( mlx_array* res, const mlx_array a, @@ -6175,7 +6037,6 @@ static inline int mlx_slice( const mlx_stream s) { return mlx_slice_(res, a, start, start_num, stop, stop_num, strides, strides_num, s); } - static inline int mlx_slice_dynamic( mlx_array* res, const mlx_array a, @@ -6187,7 +6048,6 @@ static inline int mlx_slice_dynamic( const mlx_stream s) { return mlx_slice_dynamic_(res, a, start, axes, axes_num, slice_size, slice_size_num, s); } - static inline int mlx_slice_update( mlx_array* res, const mlx_array src, @@ -6201,7 +6061,6 @@ static inline int mlx_slice_update( const mlx_stream s) { return mlx_slice_update_(res, src, update, start, start_num, stop, stop_num, strides, strides_num, s); } - static inline int mlx_slice_update_dynamic( mlx_array* res, const mlx_array src, @@ -6212,7 +6071,6 @@ static inline int mlx_slice_update_dynamic( const mlx_stream s) { return mlx_slice_update_dynamic_(res, src, update, start, axes, axes_num, s); } - static inline int mlx_softmax_axes( mlx_array* res, const mlx_array a, @@ -6222,7 +6080,6 @@ static inline int mlx_softmax_axes( const mlx_stream s) { return mlx_softmax_axes_(res, a, axes, axes_num, precise, s); } - static inline int mlx_softmax_axis( mlx_array* res, const mlx_array a, @@ -6231,7 +6088,6 @@ static inline int mlx_softmax_axis( const mlx_stream s) { return mlx_softmax_axis_(res, a, axis, precise, s); } - static inline int mlx_softmax( mlx_array* res, const mlx_array a, @@ -6239,7 +6095,6 @@ static inline int mlx_softmax( const mlx_stream s) { return mlx_softmax_(res, a, precise, s); } - static inline int mlx_sort_axis( mlx_array* res, const mlx_array a, @@ -6247,11 +6102,9 @@ static inline int mlx_sort_axis( const mlx_stream s) { return mlx_sort_axis_(res, a, axis, s); } - static inline int mlx_sort(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_sort_(res, a, s); } - static inline int mlx_split( mlx_vector_array* res, const mlx_array a, @@ -6260,7 +6113,6 @@ static inline int mlx_split( const mlx_stream s) { return mlx_split_(res, a, num_splits, axis, s); } - static inline int mlx_split_sections( mlx_vector_array* res, const mlx_array a, @@ -6270,15 +6122,12 @@ static inline int mlx_split_sections( const mlx_stream s) { return mlx_split_sections_(res, a, indices, indices_num, axis, s); } - static inline int mlx_sqrt(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_sqrt_(res, a, s); } - static inline int mlx_square(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_square_(res, a, s); } - static inline int mlx_squeeze_axes( mlx_array* res, const mlx_array a, @@ -6287,7 +6136,6 @@ static inline int mlx_squeeze_axes( const mlx_stream s) { return mlx_squeeze_axes_(res, a, axes, axes_num, s); } - static inline int mlx_squeeze_axis( mlx_array* res, const mlx_array a, @@ -6295,11 +6143,9 @@ static inline int mlx_squeeze_axis( const mlx_stream s) { return mlx_squeeze_axis_(res, a, axis, s); } - static inline int mlx_squeeze(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_squeeze_(res, a, s); } - static inline int mlx_stack_axis( mlx_array* res, const mlx_vector_array arrays, @@ -6307,14 +6153,12 @@ static inline int mlx_stack_axis( const mlx_stream s) { return mlx_stack_axis_(res, arrays, axis, s); } - static inline int mlx_stack( mlx_array* res, const mlx_vector_array arrays, const mlx_stream s) { return mlx_stack_(res, arrays, s); } - static inline int mlx_std_axes( mlx_array* res, const mlx_array a, @@ -6325,7 +6169,6 @@ static inline int mlx_std_axes( const mlx_stream s) { return mlx_std_axes_(res, a, axes, axes_num, keepdims, ddof, s); } - static inline int mlx_std_axis( mlx_array* res, const mlx_array a, @@ -6335,7 +6178,6 @@ static inline int mlx_std_axis( const mlx_stream s) { return mlx_std_axis_(res, a, axis, keepdims, ddof, s); } - static inline int mlx_std( mlx_array* res, const mlx_array a, @@ -6344,11 +6186,9 @@ static inline int mlx_std( const mlx_stream s) { return mlx_std_(res, a, keepdims, ddof, s); } - static inline int mlx_stop_gradient(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_stop_gradient_(res, a, s); } - static inline int mlx_subtract( mlx_array* res, const mlx_array a, @@ -6356,7 +6196,6 @@ static inline int mlx_subtract( const mlx_stream s) { return mlx_subtract_(res, a, b, s); } - static inline int mlx_sum_axes( mlx_array* res, const mlx_array a, @@ -6366,7 +6205,6 @@ static inline int mlx_sum_axes( const mlx_stream s) { return mlx_sum_axes_(res, a, axes, axes_num, keepdims, s); } - static inline int mlx_sum_axis( mlx_array* res, const mlx_array a, @@ -6375,7 +6213,6 @@ static inline int mlx_sum_axis( const mlx_stream s) { return mlx_sum_axis_(res, a, axis, keepdims, s); } - static inline int mlx_sum( mlx_array* res, const mlx_array a, @@ -6383,7 +6220,6 @@ static inline int mlx_sum( const mlx_stream s) { return mlx_sum_(res, a, keepdims, s); } - static inline int mlx_swapaxes( mlx_array* res, const mlx_array a, @@ -6392,7 +6228,6 @@ static inline int mlx_swapaxes( const mlx_stream s) { return mlx_swapaxes_(res, a, axis1, axis2, s); } - static inline int mlx_take_axis( mlx_array* res, const mlx_array a, @@ -6401,7 +6236,6 @@ static inline int mlx_take_axis( const mlx_stream s) { return mlx_take_axis_(res, a, indices, axis, s); } - static inline int mlx_take( mlx_array* res, const mlx_array a, @@ -6409,7 +6243,6 @@ static inline int mlx_take( const mlx_stream s) { return mlx_take_(res, a, indices, s); } - static inline int mlx_take_along_axis( mlx_array* res, const mlx_array a, @@ -6418,15 +6251,12 @@ static inline int mlx_take_along_axis( const mlx_stream s) { return mlx_take_along_axis_(res, a, indices, axis, s); } - static inline int mlx_tan(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_tan_(res, a, s); } - static inline int mlx_tanh(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_tanh_(res, a, s); } - static inline int mlx_tensordot( mlx_array* res, const mlx_array a, @@ -6438,7 +6268,6 @@ static inline int mlx_tensordot( const mlx_stream s) { return mlx_tensordot_(res, a, b, axes_a, axes_a_num, axes_b, axes_b_num, s); } - static inline int mlx_tensordot_axis( mlx_array* res, const mlx_array a, @@ -6447,7 +6276,6 @@ static inline int mlx_tensordot_axis( const mlx_stream s) { return mlx_tensordot_axis_(res, a, b, axis, s); } - static inline int mlx_tile( mlx_array* res, const mlx_array arr, @@ -6456,11 +6284,9 @@ static inline int mlx_tile( const mlx_stream s) { return mlx_tile_(res, arr, reps, reps_num, s); } - static inline int mlx_to_fp8(mlx_array* res, const mlx_array x, const mlx_stream s) { return mlx_to_fp8_(res, x, s); } - static inline int mlx_topk_axis( mlx_array* res, const mlx_array a, @@ -6469,11 +6295,9 @@ static inline int mlx_topk_axis( const mlx_stream s) { return mlx_topk_axis_(res, a, k, axis, s); } - static inline int mlx_topk(mlx_array* res, const mlx_array a, int k, const mlx_stream s) { return mlx_topk_(res, a, k, s); } - static inline int mlx_trace( mlx_array* res, const mlx_array a, @@ -6484,7 +6308,6 @@ static inline int mlx_trace( const mlx_stream s) { return mlx_trace_(res, a, offset, axis1, axis2, dtype, s); } - static inline int mlx_transpose_axes( mlx_array* res, const mlx_array a, @@ -6493,11 +6316,9 @@ static inline int mlx_transpose_axes( const mlx_stream s) { return mlx_transpose_axes_(res, a, axes, axes_num, s); } - static inline int mlx_transpose(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_transpose_(res, a, s); } - static inline int mlx_tri( mlx_array* res, int n, @@ -6507,15 +6328,12 @@ static inline int mlx_tri( const mlx_stream s) { return mlx_tri_(res, n, m, k, type, s); } - static inline int mlx_tril(mlx_array* res, const mlx_array x, int k, const mlx_stream s) { return mlx_tril_(res, x, k, s); } - static inline int mlx_triu(mlx_array* res, const mlx_array x, int k, const mlx_stream s) { return mlx_triu_(res, x, k, s); } - static inline int mlx_unflatten( mlx_array* res, const mlx_array a, @@ -6525,7 +6343,6 @@ static inline int mlx_unflatten( const mlx_stream s) { return mlx_unflatten_(res, a, axis, shape, shape_num, s); } - static inline int mlx_var_axes( mlx_array* res, const mlx_array a, @@ -6536,7 +6353,6 @@ static inline int mlx_var_axes( const mlx_stream s) { return mlx_var_axes_(res, a, axes, axes_num, keepdims, ddof, s); } - static inline int mlx_var_axis( mlx_array* res, const mlx_array a, @@ -6546,7 +6362,6 @@ static inline int mlx_var_axis( const mlx_stream s) { return mlx_var_axis_(res, a, axis, keepdims, ddof, s); } - static inline int mlx_var( mlx_array* res, const mlx_array a, @@ -6555,7 +6370,6 @@ static inline int mlx_var( const mlx_stream s) { return mlx_var_(res, a, keepdims, ddof, s); } - static inline int mlx_view( mlx_array* res, const mlx_array a, @@ -6563,7 +6377,6 @@ static inline int mlx_view( const mlx_stream s) { return mlx_view_(res, a, dtype, s); } - static inline int mlx_where( mlx_array* res, const mlx_array condition, @@ -6572,7 +6385,6 @@ static inline int mlx_where( const mlx_stream s) { return mlx_where_(res, condition, x, y, s); } - static inline int mlx_zeros( mlx_array* res, const int* shape, @@ -6581,11 +6393,9 @@ static inline int mlx_zeros( const mlx_stream s) { return mlx_zeros_(res, shape, shape_num, dtype, s); } - static inline int mlx_zeros_like(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_zeros_like_(res, a, s); } - static inline int mlx_random_bernoulli( mlx_array* res, const mlx_array p, @@ -6595,7 +6405,6 @@ static inline int mlx_random_bernoulli( const mlx_stream s) { return mlx_random_bernoulli_(res, p, shape, shape_num, key, s); } - static inline int mlx_random_bits( mlx_array* res, const int* shape, @@ -6605,7 +6414,6 @@ static inline int mlx_random_bits( const mlx_stream s) { return mlx_random_bits_(res, shape, shape_num, width, key, s); } - static inline int mlx_random_categorical_shape( mlx_array* res, const mlx_array logits, @@ -6616,7 +6424,6 @@ static inline int mlx_random_categorical_shape( const mlx_stream s) { return mlx_random_categorical_shape_(res, logits, axis, shape, shape_num, key, s); } - static inline int mlx_random_categorical_num_samples( mlx_array* res, const mlx_array logits_, @@ -6626,7 +6433,6 @@ static inline int mlx_random_categorical_num_samples( const mlx_stream s) { return mlx_random_categorical_num_samples_(res, logits_, axis, num_samples, key, s); } - static inline int mlx_random_categorical( mlx_array* res, const mlx_array logits, @@ -6635,7 +6441,6 @@ static inline int mlx_random_categorical( const mlx_stream s) { return mlx_random_categorical_(res, logits, axis, key, s); } - static inline int mlx_random_gumbel( mlx_array* res, const int* shape, @@ -6645,11 +6450,9 @@ static inline int mlx_random_gumbel( const mlx_stream s) { return mlx_random_gumbel_(res, shape, shape_num, dtype, key, s); } - static inline int mlx_random_key(mlx_array* res, uint64_t seed) { return mlx_random_key_(res, seed); } - static inline int mlx_random_laplace( mlx_array* res, const int* shape, @@ -6661,7 +6464,6 @@ static inline int mlx_random_laplace( const mlx_stream s) { return mlx_random_laplace_(res, shape, shape_num, dtype, loc, scale, key, s); } - static inline int mlx_random_multivariate_normal( mlx_array* res, const mlx_array mean, @@ -6673,7 +6475,6 @@ static inline int mlx_random_multivariate_normal( const mlx_stream s) { return mlx_random_multivariate_normal_(res, mean, cov, shape, shape_num, dtype, key, s); } - static inline int mlx_random_normal_broadcast( mlx_array* res, const int* shape, @@ -6685,7 +6486,6 @@ static inline int mlx_random_normal_broadcast( const mlx_stream s) { return mlx_random_normal_broadcast_(res, shape, shape_num, dtype, loc, scale, key, s); } - static inline int mlx_random_normal( mlx_array* res, const int* shape, @@ -6697,7 +6497,6 @@ static inline int mlx_random_normal( const mlx_stream s) { return mlx_random_normal_(res, shape, shape_num, dtype, loc, scale, key, s); } - static inline int mlx_random_permutation( mlx_array* res, const mlx_array x, @@ -6706,7 +6505,6 @@ static inline int mlx_random_permutation( const mlx_stream s) { return mlx_random_permutation_(res, x, axis, key, s); } - static inline int mlx_random_permutation_arange( mlx_array* res, int x, @@ -6714,7 +6512,6 @@ static inline int mlx_random_permutation_arange( const mlx_stream s) { return mlx_random_permutation_arange_(res, x, key, s); } - static inline int mlx_random_randint( mlx_array* res, const mlx_array low, @@ -6726,11 +6523,9 @@ static inline int mlx_random_randint( const mlx_stream s) { return mlx_random_randint_(res, low, high, shape, shape_num, dtype, key, s); } - static inline int mlx_random_seed(uint64_t seed) { return mlx_random_seed_(seed); } - static inline int mlx_random_split_num( mlx_array* res, const mlx_array key, @@ -6738,7 +6533,6 @@ static inline int mlx_random_split_num( const mlx_stream s) { return mlx_random_split_num_(res, key, num, s); } - static inline int mlx_random_split( mlx_array* res_0, mlx_array* res_1, @@ -6746,7 +6540,6 @@ static inline int mlx_random_split( const mlx_stream s) { return mlx_random_split_(res_0, res_1, key, s); } - static inline int mlx_random_truncated_normal( mlx_array* res, const mlx_array lower, @@ -6758,7 +6551,6 @@ static inline int mlx_random_truncated_normal( const mlx_stream s) { return mlx_random_truncated_normal_(res, lower, upper, shape, shape_num, dtype, key, s); } - static inline int mlx_random_uniform( mlx_array* res, const mlx_array low, @@ -6770,79 +6562,106 @@ static inline int mlx_random_uniform( const mlx_stream s) { return mlx_random_uniform_(res, low, high, shape, shape_num, dtype, key, s); } - static inline mlx_stream mlx_stream_new(void) { return mlx_stream_new_(); } - static inline mlx_stream mlx_stream_new_device(mlx_device dev) { return mlx_stream_new_device_(dev); } - static inline int mlx_stream_set(mlx_stream* stream, const mlx_stream src) { return mlx_stream_set_(stream, src); } - static inline int mlx_stream_free(mlx_stream stream) { return mlx_stream_free_(stream); } - static inline int mlx_stream_tostring(mlx_string* str, mlx_stream stream) { return mlx_stream_tostring_(str, stream); } - static inline bool mlx_stream_equal(mlx_stream lhs, mlx_stream rhs) { return mlx_stream_equal_(lhs, rhs); } - static inline int mlx_stream_get_device(mlx_device* dev, mlx_stream stream) { return mlx_stream_get_device_(dev, stream); } - static inline int mlx_stream_get_index(int* index, mlx_stream stream) { return mlx_stream_get_index_(index, stream); } - static inline int mlx_synchronize(mlx_stream stream) { return mlx_synchronize_(stream); } - static inline int mlx_get_default_stream(mlx_stream* stream, mlx_device dev) { return mlx_get_default_stream_(stream, dev); } - static inline int mlx_set_default_stream(mlx_stream stream) { return mlx_set_default_stream_(stream); } - static inline mlx_stream mlx_default_cpu_stream_new(void) { return mlx_default_cpu_stream_new_(); } - static inline mlx_stream mlx_default_gpu_stream_new(void) { return mlx_default_gpu_stream_new_(); } - static inline mlx_string mlx_string_new(void) { return mlx_string_new_(); } - static inline mlx_string mlx_string_new_data(const char* str) { return mlx_string_new_data_(str); } - static inline int mlx_string_set(mlx_string* str, const mlx_string src) { return mlx_string_set_(str, src); } - static inline const char * mlx_string_data(mlx_string str) { return mlx_string_data_(str); } - static inline int mlx_string_free(mlx_string str) { return mlx_string_free_(str); } - +static inline int mlx_async_eval(const mlx_vector_array outputs) { + return mlx_async_eval_(outputs); +} +static inline int mlx_checkpoint(mlx_closure* res, const mlx_closure fun) { + return mlx_checkpoint_(res, fun); +} +static inline int mlx_custom_function( + mlx_closure* res, + const mlx_closure fun, + const mlx_closure_custom fun_vjp /* may be null */, + const mlx_closure_custom_jvp fun_jvp /* may be null */, + const mlx_closure_custom_vmap fun_vmap /* may be null */) { + return mlx_custom_function_(res, fun, fun_vjp, fun_jvp, fun_vmap); +} +static inline int mlx_custom_vjp( + mlx_closure* res, + const mlx_closure fun, + const mlx_closure_custom fun_vjp) { + return mlx_custom_vjp_(res, fun, fun_vjp); +} +static inline int mlx_eval(const mlx_vector_array outputs) { + return mlx_eval_(outputs); +} +static inline int mlx_jvp( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + const mlx_closure fun, + const mlx_vector_array primals, + const mlx_vector_array tangents) { + return mlx_jvp_(res_0, res_1, fun, primals, tangents); +} +static inline int mlx_value_and_grad( + mlx_closure_value_and_grad* res, + const mlx_closure fun, + const int* argnums, + size_t argnums_num) { + return mlx_value_and_grad_(res, fun, argnums, argnums_num); +} +static inline int mlx_vjp( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + const mlx_closure fun, + const mlx_vector_array primals, + const mlx_vector_array cotangents) { + return mlx_vjp_(res_0, res_1, fun, primals, cotangents); +} static inline int mlx_detail_vmap_replace( mlx_vector_array* res, const mlx_vector_array inputs, @@ -6854,7 +6673,6 @@ static inline int mlx_detail_vmap_replace( size_t out_axes_num) { return mlx_detail_vmap_replace_(res, inputs, s_inputs, s_outputs, in_axes, in_axes_num, out_axes, out_axes_num); } - static inline int mlx_detail_vmap_trace( mlx_vector_array* res_0, mlx_vector_array* res_1, @@ -6864,272 +6682,173 @@ static inline int mlx_detail_vmap_trace( size_t in_axes_num) { return mlx_detail_vmap_trace_(res_0, res_1, fun, inputs, in_axes, in_axes_num); } - -static inline int mlx_async_eval(const mlx_vector_array outputs) { - return mlx_async_eval_(outputs); -} - -static inline int mlx_checkpoint(mlx_closure* res, const mlx_closure fun) { - return mlx_checkpoint_(res, fun); -} - -static inline int mlx_custom_function( - mlx_closure* res, - const mlx_closure fun, - const mlx_closure_custom fun_vjp /* may be null */, - const mlx_closure_custom_jvp fun_jvp /* may be null */, - const mlx_closure_custom_vmap fun_vmap /* may be null */) { - return mlx_custom_function_(res, fun, fun_vjp, fun_jvp, fun_vmap); -} - -static inline int mlx_custom_vjp( - mlx_closure* res, - const mlx_closure fun, - const mlx_closure_custom fun_vjp) { - return mlx_custom_vjp_(res, fun, fun_vjp); -} - -static inline int mlx_eval(const mlx_vector_array outputs) { - return mlx_eval_(outputs); -} - -static inline int mlx_jvp( - mlx_vector_array* res_0, - mlx_vector_array* res_1, - const mlx_closure fun, - const mlx_vector_array primals, - const mlx_vector_array tangents) { - return mlx_jvp_(res_0, res_1, fun, primals, tangents); -} - -static inline int mlx_value_and_grad( - mlx_closure_value_and_grad* res, - const mlx_closure fun, - const int* argnums, - size_t argnums_num) { - return mlx_value_and_grad_(res, fun, argnums, argnums_num); -} - -static inline int mlx_vjp( - mlx_vector_array* res_0, - mlx_vector_array* res_1, - const mlx_closure fun, - const mlx_vector_array primals, - const mlx_vector_array cotangents) { - return mlx_vjp_(res_0, res_1, fun, primals, cotangents); -} - static inline mlx_vector_array mlx_vector_array_new(void) { return mlx_vector_array_new_(); } - static inline int mlx_vector_array_set(mlx_vector_array* vec, const mlx_vector_array src) { return mlx_vector_array_set_(vec, src); } - static inline int mlx_vector_array_free(mlx_vector_array vec) { return mlx_vector_array_free_(vec); } - static inline mlx_vector_array mlx_vector_array_new_data(const mlx_array* data, size_t size) { return mlx_vector_array_new_data_(data, size); } - static inline mlx_vector_array mlx_vector_array_new_value(const mlx_array val) { return mlx_vector_array_new_value_(val); } - static inline int mlx_vector_array_set_data( mlx_vector_array* vec, const mlx_array* data, size_t size) { return mlx_vector_array_set_data_(vec, data, size); } - static inline int mlx_vector_array_set_value(mlx_vector_array* vec, const mlx_array val) { return mlx_vector_array_set_value_(vec, val); } - static inline int mlx_vector_array_append_data( mlx_vector_array vec, const mlx_array* data, size_t size) { return mlx_vector_array_append_data_(vec, data, size); } - static inline int mlx_vector_array_append_value(mlx_vector_array vec, const mlx_array val) { return mlx_vector_array_append_value_(vec, val); } - static inline size_t mlx_vector_array_size(mlx_vector_array vec) { return mlx_vector_array_size_(vec); } - static inline int mlx_vector_array_get( mlx_array* res, const mlx_vector_array vec, size_t idx) { return mlx_vector_array_get_(res, vec, idx); } - static inline mlx_vector_vector_array mlx_vector_vector_array_new(void) { return mlx_vector_vector_array_new_(); } - static inline int mlx_vector_vector_array_set( mlx_vector_vector_array* vec, const mlx_vector_vector_array src) { return mlx_vector_vector_array_set_(vec, src); } - static inline int mlx_vector_vector_array_free(mlx_vector_vector_array vec) { return mlx_vector_vector_array_free_(vec); } - static inline mlx_vector_vector_array mlx_vector_vector_array_new_data( const mlx_vector_array* data, size_t size) { return mlx_vector_vector_array_new_data_(data, size); } - static inline mlx_vector_vector_array mlx_vector_vector_array_new_value( const mlx_vector_array val) { return mlx_vector_vector_array_new_value_(val); } - static inline int mlx_vector_vector_array_set_data( mlx_vector_vector_array* vec, const mlx_vector_array* data, size_t size) { return mlx_vector_vector_array_set_data_(vec, data, size); } - static inline int mlx_vector_vector_array_set_value( mlx_vector_vector_array* vec, const mlx_vector_array val) { return mlx_vector_vector_array_set_value_(vec, val); } - static inline int mlx_vector_vector_array_append_data( mlx_vector_vector_array vec, const mlx_vector_array* data, size_t size) { return mlx_vector_vector_array_append_data_(vec, data, size); } - static inline int mlx_vector_vector_array_append_value( mlx_vector_vector_array vec, const mlx_vector_array val) { return mlx_vector_vector_array_append_value_(vec, val); } - static inline size_t mlx_vector_vector_array_size(mlx_vector_vector_array vec) { return mlx_vector_vector_array_size_(vec); } - static inline int mlx_vector_vector_array_get( mlx_vector_array* res, const mlx_vector_vector_array vec, size_t idx) { return mlx_vector_vector_array_get_(res, vec, idx); } - static inline mlx_vector_int mlx_vector_int_new(void) { return mlx_vector_int_new_(); } - static inline int mlx_vector_int_set(mlx_vector_int* vec, const mlx_vector_int src) { return mlx_vector_int_set_(vec, src); } - static inline int mlx_vector_int_free(mlx_vector_int vec) { return mlx_vector_int_free_(vec); } - static inline mlx_vector_int mlx_vector_int_new_data(int* data, size_t size) { return mlx_vector_int_new_data_(data, size); } - static inline mlx_vector_int mlx_vector_int_new_value(int val) { return mlx_vector_int_new_value_(val); } - static inline int mlx_vector_int_set_data(mlx_vector_int* vec, int* data, size_t size) { return mlx_vector_int_set_data_(vec, data, size); } - static inline int mlx_vector_int_set_value(mlx_vector_int* vec, int val) { return mlx_vector_int_set_value_(vec, val); } - static inline int mlx_vector_int_append_data(mlx_vector_int vec, int* data, size_t size) { return mlx_vector_int_append_data_(vec, data, size); } - static inline int mlx_vector_int_append_value(mlx_vector_int vec, int val) { return mlx_vector_int_append_value_(vec, val); } - static inline size_t mlx_vector_int_size(mlx_vector_int vec) { return mlx_vector_int_size_(vec); } - static inline int mlx_vector_int_get(int* res, const mlx_vector_int vec, size_t idx) { return mlx_vector_int_get_(res, vec, idx); } - static inline mlx_vector_string mlx_vector_string_new(void) { return mlx_vector_string_new_(); } - static inline int mlx_vector_string_set(mlx_vector_string* vec, const mlx_vector_string src) { return mlx_vector_string_set_(vec, src); } - static inline int mlx_vector_string_free(mlx_vector_string vec) { return mlx_vector_string_free_(vec); } - static inline mlx_vector_string mlx_vector_string_new_data(const char** data, size_t size) { return mlx_vector_string_new_data_(data, size); } - static inline mlx_vector_string mlx_vector_string_new_value(const char* val) { return mlx_vector_string_new_value_(val); } - static inline int mlx_vector_string_set_data( mlx_vector_string* vec, const char** data, size_t size) { return mlx_vector_string_set_data_(vec, data, size); } - static inline int mlx_vector_string_set_value(mlx_vector_string* vec, const char* val) { return mlx_vector_string_set_value_(vec, val); } - static inline int mlx_vector_string_append_data( mlx_vector_string vec, const char** data, size_t size) { return mlx_vector_string_append_data_(vec, data, size); } - static inline int mlx_vector_string_append_value(mlx_vector_string vec, const char* val) { return mlx_vector_string_append_value_(vec, val); } - static inline size_t mlx_vector_string_size(mlx_vector_string vec) { return mlx_vector_string_size_(vec); } - static inline int mlx_vector_string_get(char** res, const mlx_vector_string vec, size_t idx) { return mlx_vector_string_get_(res, vec, idx); } - static inline int mlx_version(mlx_string* str_) { return mlx_version_(str_); } -#endif // MLX_GENERATED_H \ No newline at end of file +#endif // MLX_GENERATED_H diff --git a/x/mlxrunner/mlx/generator/generated.h.gotmpl b/x/mlxrunner/mlx/generator/generated.h.gotmpl index 8f043573b..594a3f3e3 100644 --- a/x/mlxrunner/mlx/generator/generated.h.gotmpl +++ b/x/mlxrunner/mlx/generator/generated.h.gotmpl @@ -4,6 +4,10 @@ #define MLX_GENERATED_H #include "dynamic.h" +{{ range .Functions }} +#define {{ .Name }} {{ .Name }}_mlx_gen_orig_ +{{- end }} + #include "mlx/c/mlx.h" {{ range .Functions }} #undef {{ .Name }}