Revert "chore: update mlx-c bindings to 0.5.0 (#14303)" (#14316)

This reverts commit f01a9a7859.
This commit is contained in:
Patrick Devine
2026-02-18 17:01:25 -08:00
committed by GitHub
parent 325b72bc31
commit 0759fface9
8 changed files with 864 additions and 881 deletions

View File

@@ -1 +1 @@
v0.5.0
v0.4.1

View File

@@ -16,10 +16,10 @@ import (
)
type Function struct {
Name string
ReturnType string
Params string
ParamNames []string
Name string
ReturnType string
Params string
ParamNames []string
NeedsARM64Guard bool
}
@@ -29,11 +29,6 @@ func findHeaders(directory string) ([]string, error) {
if err != nil {
return err
}
// Private headers contain C++ implementation helpers and are not part of
// the C API surface; parsing them can produce invalid wrapper signatures.
if d.IsDir() && d.Name() == "private" {
return fs.SkipDir
}
if !d.IsDir() && strings.HasSuffix(path, ".h") {
headers = append(headers, path)
}
@@ -199,10 +194,10 @@ func parseFunctions(content string) []Function {
needsGuard := needsARM64Guard(funcName, returnType, params)
functions = append(functions, Function{
Name: funcName,
ReturnType: returnType,
Params: params,
ParamNames: paramNames,
Name: funcName,
ReturnType: returnType,
Params: params,
ParamNames: paramNames,
NeedsARM64Guard: needsGuard,
})
}

View File

@@ -20,8 +20,6 @@ mlx_array (*mlx_array_new_float64_ptr)(double val) = NULL;
mlx_array (*mlx_array_new_double_ptr)(double val) = NULL;
mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val) = NULL;
mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype) = NULL;
mlx_array (*mlx_array_new_data_managed_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) = NULL;
mlx_array (*mlx_array_new_data_managed_payload_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) = NULL;
int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src) = NULL;
int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val) = NULL;
int (*mlx_array_set_int_ptr)(mlx_array* arr, int val) = NULL;
@@ -51,7 +49,7 @@ int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr) = NULL;
int (*mlx_array_item_complex64_ptr)(mlx_complex64_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr) = NULL;
#if defined(__aarch64__) || defined(_M_ARM64)
int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr) = NULL;
#endif
@@ -69,7 +67,7 @@ const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr) = NULL;
const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr) = NULL;
const float* (*mlx_array_data_float32_ptr)(const mlx_array arr) = NULL;
const double* (*mlx_array_data_float64_ptr)(const mlx_array arr) = NULL;
const mlx_complex64_t* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL;
const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL;
#if defined(__aarch64__) || defined(_M_ARM64)
const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr) = NULL;
#endif
@@ -125,7 +123,6 @@ int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id) = NULL;
int (*mlx_disable_compile_ptr)(void) = NULL;
int (*mlx_enable_compile_ptr)(void) = NULL;
int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode) = NULL;
int (*mlx_cuda_is_available_ptr)(bool* res) = NULL;
mlx_device (*mlx_device_new_ptr)(void) = NULL;
mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index) = NULL;
int (*mlx_device_free_ptr)(mlx_device dev) = NULL;
@@ -136,16 +133,6 @@ int (*mlx_device_get_index_ptr)(int* index, mlx_device dev) = NULL;
int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev) = NULL;
int (*mlx_get_default_device_ptr)(mlx_device* dev) = NULL;
int (*mlx_set_default_device_ptr)(mlx_device dev) = NULL;
int (*mlx_device_is_available_ptr)(bool* avail, mlx_device dev) = NULL;
int (*mlx_device_count_ptr)(int* count, mlx_device_type type) = NULL;
mlx_device_info (*mlx_device_info_new_ptr)(void) = NULL;
int (*mlx_device_info_get_ptr)(mlx_device_info* info, mlx_device dev) = NULL;
int (*mlx_device_info_free_ptr)(mlx_device_info info) = NULL;
int (*mlx_device_info_has_key_ptr)(bool* exists, mlx_device_info info, const char* key) = NULL;
int (*mlx_device_info_is_string_ptr)(bool* is_string, mlx_device_info info, const char* key) = NULL;
int (*mlx_device_info_get_string_ptr)(const char** value, mlx_device_info info, const char* key) = NULL;
int (*mlx_device_info_get_size_ptr)(size_t* value, mlx_device_info info, const char* key) = NULL;
int (*mlx_device_info_get_keys_ptr)(mlx_vector_string* keys, mlx_device_info info) = NULL;
int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) = NULL;
int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
@@ -276,6 +263,7 @@ int (*mlx_reset_peak_memory_ptr)(void) = NULL;
int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit) = NULL;
int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit) = NULL;
int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit) = NULL;
mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void) = NULL;
int (*mlx_metal_is_available_ptr)(bool* res) = NULL;
int (*mlx_metal_start_capture_ptr)(const char* path) = NULL;
int (*mlx_metal_stop_capture_ptr)(void) = NULL;
@@ -670,16 +658,6 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data\n");
return -1;
}
mlx_array_new_data_managed_ptr = dlsym(handle, "mlx_array_new_data_managed");
if (mlx_array_new_data_managed_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed\n");
return -1;
}
mlx_array_new_data_managed_payload_ptr = dlsym(handle, "mlx_array_new_data_managed_payload");
if (mlx_array_new_data_managed_payload_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed_payload\n");
return -1;
}
mlx_array_set_ptr = dlsym(handle, "mlx_array_set");
if (mlx_array_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set\n");
@@ -1163,11 +1141,6 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_compile_mode\n");
return -1;
}
mlx_cuda_is_available_ptr = dlsym(handle, "mlx_cuda_is_available");
if (mlx_cuda_is_available_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_cuda_is_available\n");
return -1;
}
mlx_device_new_ptr = dlsym(handle, "mlx_device_new");
if (mlx_device_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_new\n");
@@ -1218,56 +1191,6 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_default_device\n");
return -1;
}
mlx_device_is_available_ptr = dlsym(handle, "mlx_device_is_available");
if (mlx_device_is_available_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_is_available\n");
return -1;
}
mlx_device_count_ptr = dlsym(handle, "mlx_device_count");
if (mlx_device_count_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_count\n");
return -1;
}
mlx_device_info_new_ptr = dlsym(handle, "mlx_device_info_new");
if (mlx_device_info_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_new\n");
return -1;
}
mlx_device_info_get_ptr = dlsym(handle, "mlx_device_info_get");
if (mlx_device_info_get_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get\n");
return -1;
}
mlx_device_info_free_ptr = dlsym(handle, "mlx_device_info_free");
if (mlx_device_info_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_free\n");
return -1;
}
mlx_device_info_has_key_ptr = dlsym(handle, "mlx_device_info_has_key");
if (mlx_device_info_has_key_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_has_key\n");
return -1;
}
mlx_device_info_is_string_ptr = dlsym(handle, "mlx_device_info_is_string");
if (mlx_device_info_is_string_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_is_string\n");
return -1;
}
mlx_device_info_get_string_ptr = dlsym(handle, "mlx_device_info_get_string");
if (mlx_device_info_get_string_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_string\n");
return -1;
}
mlx_device_info_get_size_ptr = dlsym(handle, "mlx_device_info_get_size");
if (mlx_device_info_get_size_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_size\n");
return -1;
}
mlx_device_info_get_keys_ptr = dlsym(handle, "mlx_device_info_get_keys");
if (mlx_device_info_get_keys_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_keys\n");
return -1;
}
mlx_distributed_all_gather_ptr = dlsym(handle, "mlx_distributed_all_gather");
if (mlx_distributed_all_gather_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_gather\n");
@@ -1918,6 +1841,11 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_wired_limit\n");
return -1;
}
mlx_metal_device_info_ptr = dlsym(handle, "mlx_metal_device_info");
if (mlx_metal_device_info_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_device_info\n");
return -1;
}
mlx_metal_is_available_ptr = dlsym(handle, "mlx_metal_is_available");
if (mlx_metal_is_available_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_is_available\n");
@@ -3600,14 +3528,6 @@ mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dt
return mlx_array_new_data_ptr(data, shape, dim, dtype);
}
mlx_array mlx_array_new_data_managed(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) {
return mlx_array_new_data_managed_ptr(data, shape, dim, dtype, dtor);
}
mlx_array mlx_array_new_data_managed_payload(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) {
return mlx_array_new_data_managed_payload_ptr(data, shape, dim, dtype, payload, dtor);
}
int mlx_array_set(mlx_array* arr, const mlx_array src) {
return mlx_array_set_ptr(arr, src);
}
@@ -3724,7 +3644,7 @@ int mlx_array_item_float64(double* res, const mlx_array arr) {
return mlx_array_item_float64_ptr(res, arr);
}
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr) {
int mlx_array_item_complex64(float _Complex* res, const mlx_array arr) {
return mlx_array_item_complex64_ptr(res, arr);
}
@@ -3784,7 +3704,7 @@ const double* mlx_array_data_float64(const mlx_array arr) {
return mlx_array_data_float64_ptr(arr);
}
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr) {
const float _Complex* mlx_array_data_complex64(const mlx_array arr) {
return mlx_array_data_complex64_ptr(arr);
}
@@ -3996,10 +3916,6 @@ int mlx_set_compile_mode(mlx_compile_mode mode) {
return mlx_set_compile_mode_ptr(mode);
}
int mlx_cuda_is_available(bool* res) {
return mlx_cuda_is_available_ptr(res);
}
mlx_device mlx_device_new(void) {
return mlx_device_new_ptr();
}
@@ -4040,46 +3956,6 @@ int mlx_set_default_device(mlx_device dev) {
return mlx_set_default_device_ptr(dev);
}
int mlx_device_is_available(bool* avail, mlx_device dev) {
return mlx_device_is_available_ptr(avail, dev);
}
int mlx_device_count(int* count, mlx_device_type type) {
return mlx_device_count_ptr(count, type);
}
mlx_device_info mlx_device_info_new(void) {
return mlx_device_info_new_ptr();
}
int mlx_device_info_get(mlx_device_info* info, mlx_device dev) {
return mlx_device_info_get_ptr(info, dev);
}
int mlx_device_info_free(mlx_device_info info) {
return mlx_device_info_free_ptr(info);
}
int mlx_device_info_has_key(bool* exists, mlx_device_info info, const char* key) {
return mlx_device_info_has_key_ptr(exists, info, key);
}
int mlx_device_info_is_string(bool* is_string, mlx_device_info info, const char* key) {
return mlx_device_info_is_string_ptr(is_string, info, key);
}
int mlx_device_info_get_string(const char** value, mlx_device_info info, const char* key) {
return mlx_device_info_get_string_ptr(value, info, key);
}
int mlx_device_info_get_size(size_t* value, mlx_device_info info, const char* key) {
return mlx_device_info_get_size_ptr(value, info, key);
}
int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info) {
return mlx_device_info_get_keys_ptr(keys, info);
}
int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) {
return mlx_distributed_all_gather_ptr(res, x, group, S);
}
@@ -4600,6 +4476,10 @@ int mlx_set_wired_limit(size_t* res, size_t limit) {
return mlx_set_wired_limit_ptr(res, limit);
}
mlx_metal_device_info_t mlx_metal_device_info(void) {
return mlx_metal_device_info_ptr();
}
int mlx_metal_is_available(bool* res) {
return mlx_metal_is_available_ptr(res);
}

View File

@@ -26,8 +26,6 @@
#undef mlx_array_new_double
#undef mlx_array_new_complex
#undef mlx_array_new_data
#undef mlx_array_new_data_managed
#undef mlx_array_new_data_managed_payload
#undef mlx_array_set
#undef mlx_array_set_bool
#undef mlx_array_set_int
@@ -123,7 +121,6 @@
#undef mlx_disable_compile
#undef mlx_enable_compile
#undef mlx_set_compile_mode
#undef mlx_cuda_is_available
#undef mlx_device_new
#undef mlx_device_new_type
#undef mlx_device_free
@@ -134,16 +131,6 @@
#undef mlx_device_get_type
#undef mlx_get_default_device
#undef mlx_set_default_device
#undef mlx_device_is_available
#undef mlx_device_count
#undef mlx_device_info_new
#undef mlx_device_info_get
#undef mlx_device_info_free
#undef mlx_device_info_has_key
#undef mlx_device_info_is_string
#undef mlx_device_info_get_string
#undef mlx_device_info_get_size
#undef mlx_device_info_get_keys
#undef mlx_distributed_all_gather
#undef mlx_distributed_all_max
#undef mlx_distributed_all_min
@@ -274,6 +261,7 @@
#undef mlx_set_cache_limit
#undef mlx_set_memory_limit
#undef mlx_set_wired_limit
#undef mlx_metal_device_info
#undef mlx_metal_is_available
#undef mlx_metal_start_capture
#undef mlx_metal_stop_capture
@@ -614,8 +602,6 @@ extern mlx_array (*mlx_array_new_float64_ptr)(double val);
extern mlx_array (*mlx_array_new_double_ptr)(double val);
extern mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val);
extern mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype);
extern mlx_array (*mlx_array_new_data_managed_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*));
extern mlx_array (*mlx_array_new_data_managed_payload_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*));
extern int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src);
extern int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val);
extern int (*mlx_array_set_int_ptr)(mlx_array* arr, int val);
@@ -645,7 +631,7 @@ extern int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr);
extern int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr);
extern int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr);
extern int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr);
extern int (*mlx_array_item_complex64_ptr)(mlx_complex64_t* res, const mlx_array arr);
extern int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr);
#if defined(__aarch64__) || defined(_M_ARM64)
extern int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr);
#endif
@@ -663,7 +649,7 @@ extern const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr);
extern const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr);
extern const float* (*mlx_array_data_float32_ptr)(const mlx_array arr);
extern const double* (*mlx_array_data_float64_ptr)(const mlx_array arr);
extern const mlx_complex64_t* (*mlx_array_data_complex64_ptr)(const mlx_array arr);
extern const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr);
#if defined(__aarch64__) || defined(_M_ARM64)
extern const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr);
#endif
@@ -719,7 +705,6 @@ extern int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id);
extern int (*mlx_disable_compile_ptr)(void);
extern int (*mlx_enable_compile_ptr)(void);
extern int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode);
extern int (*mlx_cuda_is_available_ptr)(bool* res);
extern mlx_device (*mlx_device_new_ptr)(void);
extern mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index);
extern int (*mlx_device_free_ptr)(mlx_device dev);
@@ -730,16 +715,6 @@ extern int (*mlx_device_get_index_ptr)(int* index, mlx_device dev);
extern int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev);
extern int (*mlx_get_default_device_ptr)(mlx_device* dev);
extern int (*mlx_set_default_device_ptr)(mlx_device dev);
extern int (*mlx_device_is_available_ptr)(bool* avail, mlx_device dev);
extern int (*mlx_device_count_ptr)(int* count, mlx_device_type type);
extern mlx_device_info (*mlx_device_info_new_ptr)(void);
extern int (*mlx_device_info_get_ptr)(mlx_device_info* info, mlx_device dev);
extern int (*mlx_device_info_free_ptr)(mlx_device_info info);
extern int (*mlx_device_info_has_key_ptr)(bool* exists, mlx_device_info info, const char* key);
extern int (*mlx_device_info_is_string_ptr)(bool* is_string, mlx_device_info info, const char* key);
extern int (*mlx_device_info_get_string_ptr)(const char** value, mlx_device_info info, const char* key);
extern int (*mlx_device_info_get_size_ptr)(size_t* value, mlx_device_info info, const char* key);
extern int (*mlx_device_info_get_keys_ptr)(mlx_vector_string* keys, mlx_device_info info);
extern int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S);
extern int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
extern int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
@@ -870,6 +845,7 @@ extern int (*mlx_reset_peak_memory_ptr)(void);
extern int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit);
extern int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit);
extern int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit);
extern mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void);
extern int (*mlx_metal_is_available_ptr)(bool* res);
extern int (*mlx_metal_start_capture_ptr)(const char* path);
extern int (*mlx_metal_stop_capture_ptr)(void);
@@ -1226,10 +1202,6 @@ mlx_array mlx_array_new_complex(float real_val, float imag_val);
mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dtype dtype);
mlx_array mlx_array_new_data_managed(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*));
mlx_array mlx_array_new_data_managed_payload(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*));
int mlx_array_set(mlx_array* arr, const mlx_array src);
int mlx_array_set_bool(mlx_array* arr, bool val);
@@ -1288,7 +1260,7 @@ int mlx_array_item_float32(float* res, const mlx_array arr);
int mlx_array_item_float64(double* res, const mlx_array arr);
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr);
int mlx_array_item_complex64(float _Complex* res, const mlx_array arr);
#if defined(__aarch64__) || defined(_M_ARM64)
int mlx_array_item_float16(float16_t* res, const mlx_array arr);
@@ -1320,7 +1292,7 @@ const float* mlx_array_data_float32(const mlx_array arr);
const double* mlx_array_data_float64(const mlx_array arr);
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr);
const float _Complex* mlx_array_data_complex64(const mlx_array arr);
#if defined(__aarch64__) || defined(_M_ARM64)
const float16_t* mlx_array_data_float16(const mlx_array arr);
@@ -1428,8 +1400,6 @@ int mlx_enable_compile(void);
int mlx_set_compile_mode(mlx_compile_mode mode);
int mlx_cuda_is_available(bool* res);
mlx_device mlx_device_new(void);
mlx_device mlx_device_new_type(mlx_device_type type, int index);
@@ -1450,26 +1420,6 @@ int mlx_get_default_device(mlx_device* dev);
int mlx_set_default_device(mlx_device dev);
int mlx_device_is_available(bool* avail, mlx_device dev);
int mlx_device_count(int* count, mlx_device_type type);
mlx_device_info mlx_device_info_new(void);
int mlx_device_info_get(mlx_device_info* info, mlx_device dev);
int mlx_device_info_free(mlx_device_info info);
int mlx_device_info_has_key(bool* exists, mlx_device_info info, const char* key);
int mlx_device_info_is_string(bool* is_string, mlx_device_info info, const char* key);
int mlx_device_info_get_string(const char** value, mlx_device_info info, const char* key);
int mlx_device_info_get_size(size_t* value, mlx_device_info info, const char* key);
int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info);
int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S);
int mlx_distributed_all_max(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
@@ -1730,6 +1680,8 @@ int mlx_set_memory_limit(size_t* res, size_t limit);
int mlx_set_wired_limit(size_t* res, size_t limit);
mlx_metal_device_info_t mlx_metal_device_info(void);
int mlx_metal_is_available(bool* res);
int mlx_metal_start_capture(const char* path);

View File

@@ -15,7 +15,7 @@ set(CMAKE_INSTALL_RPATH "@loader_path")
include(FetchContent)
set(MLX_C_GIT_TAG "v0.5.0" CACHE STRING "")
set(MLX_C_GIT_TAG "v0.4.1" CACHE STRING "")
FetchContent_Declare(
mlx-c

View File

@@ -22,19 +22,6 @@ mlx_array (*mlx_array_new_data_)(
const int* shape,
int dim,
mlx_dtype dtype) = NULL;
mlx_array (*mlx_array_new_data_managed_)(
void* data,
const int* shape,
int dim,
mlx_dtype dtype,
void (*dtor)(void*)) = NULL;
mlx_array (*mlx_array_new_data_managed_payload_)(
void* data,
const int* shape,
int dim,
mlx_dtype dtype,
void* payload,
void (*dtor)(void*)) = NULL;
int (*mlx_array_set_)(mlx_array* arr, const mlx_array src) = NULL;
int (*mlx_array_set_bool_)(mlx_array* arr, bool val) = NULL;
int (*mlx_array_set_int_)(mlx_array* arr, int val) = NULL;
@@ -69,7 +56,7 @@ int (*mlx_array_item_int32_)(int32_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_int64_)(int64_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float32_)(float* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float64_)(double* res, const mlx_array arr) = NULL;
int (*mlx_array_item_complex64_)(mlx_complex64_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_complex64_)(float _Complex* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float16_)(float16_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_bfloat16_)(bfloat16_t* res, const mlx_array arr) = NULL;
const bool * (*mlx_array_data_bool_)(const mlx_array arr) = NULL;
@@ -83,7 +70,7 @@ const int32_t * (*mlx_array_data_int32_)(const mlx_array arr) = NULL;
const int64_t * (*mlx_array_data_int64_)(const mlx_array arr) = NULL;
const float * (*mlx_array_data_float32_)(const mlx_array arr) = NULL;
const double * (*mlx_array_data_float64_)(const mlx_array arr) = NULL;
const mlx_complex64_t * (*mlx_array_data_complex64_)(const mlx_array arr) = NULL;
const float _Complex * (*mlx_array_data_complex64_)(const mlx_array arr) = NULL;
const float16_t * (*mlx_array_data_float16_)(const mlx_array arr) = NULL;
const bfloat16_t * (*mlx_array_data_bfloat16_)(const mlx_array arr) = NULL;
int (*_mlx_array_is_available_)(bool* res, const mlx_array arr) = NULL;
@@ -107,11 +94,10 @@ int (*mlx_closure_apply_)(
mlx_closure (*mlx_closure_new_unary_)(int (*fun)(mlx_array*, const mlx_array)) = NULL;
mlx_closure_kwargs (*mlx_closure_kwargs_new_)(void) = NULL;
int (*mlx_closure_kwargs_free_)(mlx_closure_kwargs cls) = NULL;
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_map_string_to_array)) = NULL;
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_map_string_to_array)) = NULL;
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_payload_)(
int (*fun)(
mlx_vector_array*,
@@ -150,12 +136,11 @@ int (*mlx_closure_value_and_grad_apply_)(
const mlx_vector_array input) = NULL;
mlx_closure_custom (*mlx_closure_custom_new_)(void) = NULL;
int (*mlx_closure_custom_free_)(mlx_closure_custom cls) = NULL;
mlx_closure_custom (*mlx_closure_custom_new_func_)(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const mlx_vector_array)) = NULL;
mlx_closure_custom (*mlx_closure_custom_new_func_)(int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const mlx_vector_array)) = NULL;
mlx_closure_custom (*mlx_closure_custom_new_func_payload_)(
int (*fun)(
mlx_vector_array*,
@@ -176,13 +161,12 @@ int (*mlx_closure_custom_apply_)(
const mlx_vector_array input_2) = NULL;
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_)(void) = NULL;
int (*mlx_closure_custom_jvp_free_)(mlx_closure_custom_jvp cls) = NULL;
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const int*,
size_t _num)) = NULL;
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const int*,
size_t _num)) = NULL;
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_payload_)(
int (*fun)(
mlx_vector_array*,
@@ -205,13 +189,12 @@ int (*mlx_closure_custom_jvp_apply_)(
size_t input_2_num) = NULL;
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_)(void) = NULL;
int (*mlx_closure_custom_vmap_free_)(mlx_closure_custom_vmap cls) = NULL;
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(
int (*fun)(
mlx_vector_array*,
mlx_vector_int*,
const mlx_vector_array,
const int*,
size_t _num)) = NULL;
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(int (*fun)(
mlx_vector_array*,
mlx_vector_int*,
const mlx_vector_array,
const int*,
size_t _num)) = NULL;
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_payload_)(
int (*fun)(
mlx_vector_array*,
@@ -245,7 +228,6 @@ int (*mlx_detail_compile_erase_)(uintptr_t fun_id) = NULL;
int (*mlx_disable_compile_)(void) = NULL;
int (*mlx_enable_compile_)(void) = NULL;
int (*mlx_set_compile_mode_)(mlx_compile_mode mode) = NULL;
int (*mlx_cuda_is_available_)(bool* res) = NULL;
mlx_device (*mlx_device_new_)(void) = NULL;
mlx_device (*mlx_device_new_type_)(mlx_device_type type, int index) = NULL;
int (*mlx_device_free_)(mlx_device dev) = NULL;
@@ -256,28 +238,11 @@ int (*mlx_device_get_index_)(int* index, mlx_device dev) = NULL;
int (*mlx_device_get_type_)(mlx_device_type* type, mlx_device dev) = NULL;
int (*mlx_get_default_device_)(mlx_device* dev) = NULL;
int (*mlx_set_default_device_)(mlx_device dev) = NULL;
int (*mlx_device_is_available_)(bool* avail, mlx_device dev) = NULL;
int (*mlx_device_count_)(int* count, mlx_device_type type) = NULL;
mlx_device_info (*mlx_device_info_new_)(void) = NULL;
int (*mlx_device_info_get_)(mlx_device_info* info, mlx_device dev) = NULL;
int (*mlx_device_info_free_)(mlx_device_info info) = NULL;
int (*mlx_device_info_has_key_)(
bool* exists,
mlx_device_info info,
const char* key) = NULL;
int (*mlx_device_info_is_string_)(
bool* is_string,
mlx_device_info info,
const char* key) = NULL;
int (*mlx_device_info_get_string_)(
const char** value,
mlx_device_info info,
const char* key) = NULL;
int (*mlx_device_info_get_size_)(
size_t* value,
mlx_device_info info,
const char* key) = NULL;
int (*mlx_device_info_get_keys_)(mlx_vector_string* keys, mlx_device_info info) = NULL;
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
bool (*mlx_distributed_is_available_)(void) = NULL;
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
int (*mlx_distributed_all_gather_)(
mlx_array* res,
const mlx_array x,
@@ -323,11 +288,6 @@ int (*mlx_distributed_sum_scatter_)(
const mlx_array x,
const mlx_distributed_group group /* may be null */,
const mlx_stream s) = NULL;
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
bool (*mlx_distributed_is_available_)(void) = NULL;
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
void (*mlx_set_error_handler_)(
mlx_error_handler_func handler,
void* data,
@@ -490,16 +450,6 @@ int (*mlx_fast_rope_)(
int offset,
const mlx_array freqs /* may be null */,
const mlx_stream s) = NULL;
int (*mlx_fast_rope_dynamic_)(
mlx_array* res,
const mlx_array x,
int dims,
bool traditional,
mlx_optional_float base,
float scale,
const mlx_array offset,
const mlx_array freqs /* may be null */,
const mlx_stream s) = NULL;
int (*mlx_fast_scaled_dot_product_attention_)(
mlx_array* res,
const mlx_array queries,
@@ -610,6 +560,14 @@ int (*mlx_fft_rfftn_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable) = NULL;
int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io) = NULL;
int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io) = NULL;
int (*mlx_io_reader_free_)(mlx_io_reader io) = NULL;
mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL;
int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL;
int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL;
int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL;
int (*mlx_load_reader_)(
mlx_array* res,
mlx_io_reader in_stream,
@@ -635,14 +593,6 @@ int (*mlx_save_safetensors_)(
const char* file,
const mlx_map_string_to_array param,
const mlx_map_string_to_string metadata) = NULL;
mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable) = NULL;
int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io) = NULL;
int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io) = NULL;
int (*mlx_io_reader_free_)(mlx_io_reader io) = NULL;
mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL;
int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL;
int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL;
int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL;
int (*mlx_linalg_cholesky_)(
mlx_array* res,
const mlx_array a,
@@ -783,6 +733,7 @@ int (*mlx_reset_peak_memory_)(void) = NULL;
int (*mlx_set_cache_limit_)(size_t* res, size_t limit) = NULL;
int (*mlx_set_memory_limit_)(size_t* res, size_t limit) = NULL;
int (*mlx_set_wired_limit_)(size_t* res, size_t limit) = NULL;
mlx_metal_device_info_t (*mlx_metal_device_info_)(void) = NULL;
int (*mlx_metal_is_available_)(bool* res) = NULL;
int (*mlx_metal_start_capture_)(const char* path) = NULL;
int (*mlx_metal_stop_capture_)(void) = NULL;
@@ -1211,14 +1162,6 @@ int (*mlx_gather_)(
const int* slice_sizes,
size_t slice_sizes_num,
const mlx_stream s) = NULL;
int (*mlx_gather_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
int axis,
const int* slice_sizes,
size_t slice_sizes_num,
const mlx_stream s) = NULL;
int (*mlx_gather_mm_)(
mlx_array* res,
const mlx_array a,
@@ -1540,15 +1483,6 @@ int (*mlx_put_along_axis_)(
const mlx_array values,
int axis,
const mlx_stream s) = NULL;
int (*mlx_qqmm_)(
mlx_array* res,
const mlx_array x,
const mlx_array w,
const mlx_array w_scales /* may be null */,
mlx_optional_int group_size,
mlx_optional_int bits,
const char* mode,
const mlx_stream s) = NULL;
int (*mlx_quantize_)(
mlx_vector_array* res,
const mlx_array w,
@@ -1632,13 +1566,6 @@ int (*mlx_scatter_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
int (*mlx_scatter_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array updates,
int axis,
const mlx_stream s) = NULL;
int (*mlx_scatter_add_)(
mlx_array* res,
const mlx_array a,
@@ -1647,13 +1574,6 @@ int (*mlx_scatter_add_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
int (*mlx_scatter_add_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array updates,
int axis,
const mlx_stream s) = NULL;
int (*mlx_scatter_add_axis_)(
mlx_array* res,
const mlx_array a,
@@ -1669,13 +1589,6 @@ int (*mlx_scatter_max_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
int (*mlx_scatter_max_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array updates,
int axis,
const mlx_stream s) = NULL;
int (*mlx_scatter_min_)(
mlx_array* res,
const mlx_array a,
@@ -1684,13 +1597,6 @@ int (*mlx_scatter_min_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
int (*mlx_scatter_min_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array updates,
int axis,
const mlx_stream s) = NULL;
int (*mlx_scatter_prod_)(
mlx_array* res,
const mlx_array a,
@@ -1699,13 +1605,6 @@ int (*mlx_scatter_prod_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
int (*mlx_scatter_prod_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array updates,
int axis,
const mlx_stream s) = NULL;
int (*mlx_segmented_mm_)(
mlx_array* res,
const mlx_array a,
@@ -2129,6 +2028,22 @@ mlx_string (*mlx_string_new_data_)(const char* str) = NULL;
int (*mlx_string_set_)(mlx_string* str, const mlx_string src) = NULL;
const char * (*mlx_string_data_)(mlx_string str) = NULL;
int (*mlx_string_free_)(mlx_string str) = NULL;
int (*mlx_detail_vmap_replace_)(
mlx_vector_array* res,
const mlx_vector_array inputs,
const mlx_vector_array s_inputs,
const mlx_vector_array s_outputs,
const int* in_axes,
size_t in_axes_num,
const int* out_axes,
size_t out_axes_num) = NULL;
int (*mlx_detail_vmap_trace_)(
mlx_vector_array* res_0,
mlx_vector_array* res_1,
const mlx_closure fun,
const mlx_vector_array inputs,
const int* in_axes,
size_t in_axes_num) = NULL;
int (*mlx_async_eval_)(const mlx_vector_array outputs) = NULL;
int (*mlx_checkpoint_)(mlx_closure* res, const mlx_closure fun) = NULL;
int (*mlx_custom_function_)(
@@ -2159,22 +2074,6 @@ int (*mlx_vjp_)(
const mlx_closure fun,
const mlx_vector_array primals,
const mlx_vector_array cotangents) = NULL;
int (*mlx_detail_vmap_replace_)(
mlx_vector_array* res,
const mlx_vector_array inputs,
const mlx_vector_array s_inputs,
const mlx_vector_array s_outputs,
const int* in_axes,
size_t in_axes_num,
const int* out_axes,
size_t out_axes_num) = NULL;
int (*mlx_detail_vmap_trace_)(
mlx_vector_array* res_0,
mlx_vector_array* res_1,
const mlx_closure fun,
const mlx_vector_array inputs,
const int* in_axes,
size_t in_axes_num) = NULL;
mlx_vector_array (*mlx_vector_array_new_)(void) = NULL;
int (*mlx_vector_array_set_)(mlx_vector_array* vec, const mlx_vector_array src) = NULL;
int (*mlx_vector_array_free_)(mlx_vector_array vec) = NULL;
@@ -2267,8 +2166,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_array_new_double);
CHECK_LOAD(handle, mlx_array_new_complex);
CHECK_LOAD(handle, mlx_array_new_data);
CHECK_LOAD(handle, mlx_array_new_data_managed);
CHECK_LOAD(handle, mlx_array_new_data_managed_payload);
CHECK_LOAD(handle, mlx_array_set);
CHECK_LOAD(handle, mlx_array_set_bool);
CHECK_LOAD(handle, mlx_array_set_int);
@@ -2364,7 +2261,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_disable_compile);
CHECK_LOAD(handle, mlx_enable_compile);
CHECK_LOAD(handle, mlx_set_compile_mode);
CHECK_LOAD(handle, mlx_cuda_is_available);
CHECK_LOAD(handle, mlx_device_new);
CHECK_LOAD(handle, mlx_device_new_type);
CHECK_LOAD(handle, mlx_device_free);
@@ -2375,16 +2271,11 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_device_get_type);
CHECK_LOAD(handle, mlx_get_default_device);
CHECK_LOAD(handle, mlx_set_default_device);
CHECK_LOAD(handle, mlx_device_is_available);
CHECK_LOAD(handle, mlx_device_count);
CHECK_LOAD(handle, mlx_device_info_new);
CHECK_LOAD(handle, mlx_device_info_get);
CHECK_LOAD(handle, mlx_device_info_free);
CHECK_LOAD(handle, mlx_device_info_has_key);
CHECK_LOAD(handle, mlx_device_info_is_string);
CHECK_LOAD(handle, mlx_device_info_get_string);
CHECK_LOAD(handle, mlx_device_info_get_size);
CHECK_LOAD(handle, mlx_device_info_get_keys);
CHECK_LOAD(handle, mlx_distributed_group_rank);
CHECK_LOAD(handle, mlx_distributed_group_size);
CHECK_LOAD(handle, mlx_distributed_group_split);
CHECK_LOAD(handle, mlx_distributed_is_available);
CHECK_LOAD(handle, mlx_distributed_init);
CHECK_LOAD(handle, mlx_distributed_all_gather);
CHECK_LOAD(handle, mlx_distributed_all_max);
CHECK_LOAD(handle, mlx_distributed_all_min);
@@ -2393,11 +2284,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_distributed_recv_like);
CHECK_LOAD(handle, mlx_distributed_send);
CHECK_LOAD(handle, mlx_distributed_sum_scatter);
CHECK_LOAD(handle, mlx_distributed_group_rank);
CHECK_LOAD(handle, mlx_distributed_group_size);
CHECK_LOAD(handle, mlx_distributed_group_split);
CHECK_LOAD(handle, mlx_distributed_is_available);
CHECK_LOAD(handle, mlx_distributed_init);
CHECK_LOAD(handle, mlx_set_error_handler);
CHECK_LOAD(handle, _mlx_error);
CHECK_LOAD(handle, mlx_export_function);
@@ -2439,7 +2325,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_fast_metal_kernel_apply);
CHECK_LOAD(handle, mlx_fast_rms_norm);
CHECK_LOAD(handle, mlx_fast_rope);
CHECK_LOAD(handle, mlx_fast_rope_dynamic);
CHECK_LOAD(handle, mlx_fast_scaled_dot_product_attention);
CHECK_LOAD(handle, mlx_fft_fft);
CHECK_LOAD(handle, mlx_fft_fft2);
@@ -2455,14 +2340,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_fft_rfft);
CHECK_LOAD(handle, mlx_fft_rfft2);
CHECK_LOAD(handle, mlx_fft_rfftn);
CHECK_LOAD(handle, mlx_load_reader);
CHECK_LOAD(handle, mlx_load);
CHECK_LOAD(handle, mlx_load_safetensors_reader);
CHECK_LOAD(handle, mlx_load_safetensors);
CHECK_LOAD(handle, mlx_save_writer);
CHECK_LOAD(handle, mlx_save);
CHECK_LOAD(handle, mlx_save_safetensors_writer);
CHECK_LOAD(handle, mlx_save_safetensors);
CHECK_LOAD(handle, mlx_io_reader_new);
CHECK_LOAD(handle, mlx_io_reader_descriptor);
CHECK_LOAD(handle, mlx_io_reader_tostring);
@@ -2471,6 +2348,14 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_io_writer_descriptor);
CHECK_LOAD(handle, mlx_io_writer_tostring);
CHECK_LOAD(handle, mlx_io_writer_free);
CHECK_LOAD(handle, mlx_load_reader);
CHECK_LOAD(handle, mlx_load);
CHECK_LOAD(handle, mlx_load_safetensors_reader);
CHECK_LOAD(handle, mlx_load_safetensors);
CHECK_LOAD(handle, mlx_save_writer);
CHECK_LOAD(handle, mlx_save);
CHECK_LOAD(handle, mlx_save_safetensors_writer);
CHECK_LOAD(handle, mlx_save_safetensors);
CHECK_LOAD(handle, mlx_linalg_cholesky);
CHECK_LOAD(handle, mlx_linalg_cholesky_inv);
CHECK_LOAD(handle, mlx_linalg_cross);
@@ -2515,6 +2400,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_set_cache_limit);
CHECK_LOAD(handle, mlx_set_memory_limit);
CHECK_LOAD(handle, mlx_set_wired_limit);
CHECK_LOAD(handle, mlx_metal_device_info);
CHECK_LOAD(handle, mlx_metal_is_available);
CHECK_LOAD(handle, mlx_metal_start_capture);
CHECK_LOAD(handle, mlx_metal_stop_capture);
@@ -2600,7 +2486,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_full);
CHECK_LOAD(handle, mlx_full_like);
CHECK_LOAD(handle, mlx_gather);
CHECK_LOAD(handle, mlx_gather_single);
CHECK_LOAD(handle, mlx_gather_mm);
CHECK_LOAD(handle, mlx_gather_qmm);
CHECK_LOAD(handle, mlx_greater);
@@ -2665,7 +2550,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_prod_axis);
CHECK_LOAD(handle, mlx_prod);
CHECK_LOAD(handle, mlx_put_along_axis);
CHECK_LOAD(handle, mlx_qqmm);
CHECK_LOAD(handle, mlx_quantize);
CHECK_LOAD(handle, mlx_quantized_matmul);
CHECK_LOAD(handle, mlx_radians);
@@ -2682,16 +2566,11 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_round);
CHECK_LOAD(handle, mlx_rsqrt);
CHECK_LOAD(handle, mlx_scatter);
CHECK_LOAD(handle, mlx_scatter_single);
CHECK_LOAD(handle, mlx_scatter_add);
CHECK_LOAD(handle, mlx_scatter_add_single);
CHECK_LOAD(handle, mlx_scatter_add_axis);
CHECK_LOAD(handle, mlx_scatter_max);
CHECK_LOAD(handle, mlx_scatter_max_single);
CHECK_LOAD(handle, mlx_scatter_min);
CHECK_LOAD(handle, mlx_scatter_min_single);
CHECK_LOAD(handle, mlx_scatter_prod);
CHECK_LOAD(handle, mlx_scatter_prod_single);
CHECK_LOAD(handle, mlx_segmented_mm);
CHECK_LOAD(handle, mlx_sigmoid);
CHECK_LOAD(handle, mlx_sign);
@@ -2786,6 +2665,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_string_set);
CHECK_LOAD(handle, mlx_string_data);
CHECK_LOAD(handle, mlx_string_free);
CHECK_LOAD(handle, mlx_detail_vmap_replace);
CHECK_LOAD(handle, mlx_detail_vmap_trace);
CHECK_LOAD(handle, mlx_async_eval);
CHECK_LOAD(handle, mlx_checkpoint);
CHECK_LOAD(handle, mlx_custom_function);
@@ -2794,8 +2675,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_jvp);
CHECK_LOAD(handle, mlx_value_and_grad);
CHECK_LOAD(handle, mlx_vjp);
CHECK_LOAD(handle, mlx_detail_vmap_replace);
CHECK_LOAD(handle, mlx_detail_vmap_trace);
CHECK_LOAD(handle, mlx_vector_array_new);
CHECK_LOAD(handle, mlx_vector_array_set);
CHECK_LOAD(handle, mlx_vector_array_free);

File diff suppressed because it is too large Load Diff

View File

@@ -4,10 +4,6 @@
#define MLX_GENERATED_H
#include "dynamic.h"
{{ range .Functions }}
#define {{ .Name }} {{ .Name }}_mlx_gen_orig_
{{- end }}
#include "mlx/c/mlx.h"
{{ range .Functions }}
#undef {{ .Name }}