// AUTO-GENERATED by generate_wrappers.go - DO NOT EDIT // This file contains the function pointer definitions and initialization // All function pointers are in a single compilation unit to avoid duplication #include "mlx/c/mlx.h" #include "mlx_dynamic.h" #include // Platform-specific dynamic loading #ifdef _WIN32 #include #define GET_SYM(handle, name) (void*)GetProcAddress((HMODULE)(handle), name) #else #include #define GET_SYM(handle, name) dlsym(handle, name) #endif // Function pointer definitions size_t (*mlx_dtype_size_ptr)(mlx_dtype dtype) = NULL; int (*mlx_array_tostring_ptr)(mlx_string* str, const mlx_array arr) = NULL; mlx_array (*mlx_array_new_ptr)(void) = NULL; int (*mlx_array_free_ptr)(mlx_array arr) = NULL; mlx_array (*mlx_array_new_bool_ptr)(bool val) = NULL; mlx_array (*mlx_array_new_int_ptr)(int val) = NULL; mlx_array (*mlx_array_new_float32_ptr)(float val) = NULL; mlx_array (*mlx_array_new_float_ptr)(float val) = NULL; mlx_array (*mlx_array_new_float64_ptr)(double val) = NULL; mlx_array (*mlx_array_new_double_ptr)(double val) = NULL; mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val) = NULL; mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype) = NULL; mlx_array (*mlx_array_new_data_managed_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) = NULL; mlx_array (*mlx_array_new_data_managed_payload_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) = NULL; int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src) = NULL; int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val) = NULL; int (*mlx_array_set_int_ptr)(mlx_array* arr, int val) = NULL; int (*mlx_array_set_float32_ptr)(mlx_array* arr, float val) = NULL; int (*mlx_array_set_float_ptr)(mlx_array* arr, float val) = NULL; int (*mlx_array_set_float64_ptr)(mlx_array* arr, double val) = NULL; int (*mlx_array_set_double_ptr)(mlx_array* arr, double val) = NULL; int (*mlx_array_set_complex_ptr)(mlx_array* arr, float real_val, float imag_val) = NULL; int (*mlx_array_set_data_ptr)(mlx_array* arr, const void* data, const int* shape, int dim, mlx_dtype dtype) = NULL; size_t (*mlx_array_itemsize_ptr)(const mlx_array arr) = NULL; size_t (*mlx_array_size_ptr)(const mlx_array arr) = NULL; size_t (*mlx_array_nbytes_ptr)(const mlx_array arr) = NULL; size_t (*mlx_array_ndim_ptr)(const mlx_array arr) = NULL; const int* (*mlx_array_shape_ptr)(const mlx_array arr) = NULL; const size_t* (*mlx_array_strides_ptr)(const mlx_array arr) = NULL; int (*mlx_array_dim_ptr)(const mlx_array arr, int dim) = NULL; mlx_dtype (*mlx_array_dtype_ptr)(const mlx_array arr) = NULL; int (*mlx_array_eval_ptr)(mlx_array arr) = NULL; int (*mlx_array_item_bool_ptr)(bool* res, const mlx_array arr) = NULL; int (*mlx_array_item_uint8_ptr)(uint8_t* res, const mlx_array arr) = NULL; int (*mlx_array_item_uint16_ptr)(uint16_t* res, const mlx_array arr) = NULL; int (*mlx_array_item_uint32_ptr)(uint32_t* res, const mlx_array arr) = NULL; int (*mlx_array_item_uint64_ptr)(uint64_t* res, const mlx_array arr) = NULL; int (*mlx_array_item_int8_ptr)(int8_t* res, const mlx_array arr) = NULL; int (*mlx_array_item_int16_ptr)(int16_t* res, const mlx_array arr) = NULL; int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr) = NULL; int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr) = NULL; int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr) = NULL; int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr) = NULL; int (*mlx_array_item_complex64_ptr)(mlx_complex64_t* res, const mlx_array arr) = NULL; #if defined(__aarch64__) || defined(_M_ARM64) int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr) = NULL; #endif #if defined(__aarch64__) || defined(_M_ARM64) int (*mlx_array_item_bfloat16_ptr)(bfloat16_t* res, const mlx_array arr) = NULL; #endif const bool* (*mlx_array_data_bool_ptr)(const mlx_array arr) = NULL; const uint8_t* (*mlx_array_data_uint8_ptr)(const mlx_array arr) = NULL; const uint16_t* (*mlx_array_data_uint16_ptr)(const mlx_array arr) = NULL; const uint32_t* (*mlx_array_data_uint32_ptr)(const mlx_array arr) = NULL; const uint64_t* (*mlx_array_data_uint64_ptr)(const mlx_array arr) = NULL; const int8_t* (*mlx_array_data_int8_ptr)(const mlx_array arr) = NULL; const int16_t* (*mlx_array_data_int16_ptr)(const mlx_array arr) = NULL; const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr) = NULL; const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr) = NULL; const float* (*mlx_array_data_float32_ptr)(const mlx_array arr) = NULL; const double* (*mlx_array_data_float64_ptr)(const mlx_array arr) = NULL; const mlx_complex64_t* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL; #if defined(__aarch64__) || defined(_M_ARM64) const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr) = NULL; #endif #if defined(__aarch64__) || defined(_M_ARM64) const bfloat16_t* (*mlx_array_data_bfloat16_ptr)(const mlx_array arr) = NULL; #endif int (*_mlx_array_is_available_ptr)(bool* res, const mlx_array arr) = NULL; int (*_mlx_array_wait_ptr)(const mlx_array arr) = NULL; int (*_mlx_array_is_contiguous_ptr)(bool* res, const mlx_array arr) = NULL; int (*_mlx_array_is_row_contiguous_ptr)(bool* res, const mlx_array arr) = NULL; int (*_mlx_array_is_col_contiguous_ptr)(bool* res, const mlx_array arr) = NULL; mlx_closure (*mlx_closure_new_ptr)(void) = NULL; int (*mlx_closure_free_ptr)(mlx_closure cls) = NULL; mlx_closure (*mlx_closure_new_func_ptr)(int (*fun)(mlx_vector_array*, const mlx_vector_array)) = NULL; mlx_closure (*mlx_closure_new_func_payload_ptr)(int (*fun)(mlx_vector_array*, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) = NULL; int (*mlx_closure_set_ptr)(mlx_closure* cls, const mlx_closure src) = NULL; int (*mlx_closure_apply_ptr)(mlx_vector_array* res, mlx_closure cls, const mlx_vector_array input) = NULL; mlx_closure (*mlx_closure_new_unary_ptr)(int (*fun)(mlx_array*, const mlx_array)) = NULL; mlx_closure_kwargs (*mlx_closure_kwargs_new_ptr)(void) = NULL; int (*mlx_closure_kwargs_free_ptr)(mlx_closure_kwargs cls) = NULL; mlx_closure_kwargs (*mlx_closure_kwargs_new_func_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_map_string_to_array)) = NULL; mlx_closure_kwargs (*mlx_closure_kwargs_new_func_payload_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_map_string_to_array, void*), void* payload, void (*dtor)(void*)) = NULL; int (*mlx_closure_kwargs_set_ptr)(mlx_closure_kwargs* cls, const mlx_closure_kwargs src) = NULL; int (*mlx_closure_kwargs_apply_ptr)(mlx_vector_array* res, mlx_closure_kwargs cls, const mlx_vector_array input_0, const mlx_map_string_to_array input_1) = NULL; mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_ptr)(void) = NULL; int (*mlx_closure_value_and_grad_free_ptr)(mlx_closure_value_and_grad cls) = NULL; mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_func_ptr)(int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array)) = NULL; mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_func_payload_ptr)(int (*fun)( mlx_vector_array*, mlx_vector_array*, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) = NULL; int (*mlx_closure_value_and_grad_set_ptr)(mlx_closure_value_and_grad* cls, const mlx_closure_value_and_grad src) = NULL; int (*mlx_closure_value_and_grad_apply_ptr)(mlx_vector_array* res_0, mlx_vector_array* res_1, mlx_closure_value_and_grad cls, const mlx_vector_array input) = NULL; mlx_closure_custom (*mlx_closure_custom_new_ptr)(void) = NULL; int (*mlx_closure_custom_free_ptr)(mlx_closure_custom cls) = NULL; mlx_closure_custom (*mlx_closure_custom_new_func_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const mlx_vector_array)) = NULL; mlx_closure_custom (*mlx_closure_custom_new_func_payload_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) = NULL; int (*mlx_closure_custom_set_ptr)(mlx_closure_custom* cls, const mlx_closure_custom src) = NULL; int (*mlx_closure_custom_apply_ptr)(mlx_vector_array* res, mlx_closure_custom cls, const mlx_vector_array input_0, const mlx_vector_array input_1, const mlx_vector_array input_2) = NULL; mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_ptr)(void) = NULL; int (*mlx_closure_custom_jvp_free_ptr)(mlx_closure_custom_jvp cls) = NULL; mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const int*, size_t _num)) = NULL; mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_payload_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const int*, size_t _num, void*), void* payload, void (*dtor)(void*)) = NULL; int (*mlx_closure_custom_jvp_set_ptr)(mlx_closure_custom_jvp* cls, const mlx_closure_custom_jvp src) = NULL; int (*mlx_closure_custom_jvp_apply_ptr)(mlx_vector_array* res, mlx_closure_custom_jvp cls, const mlx_vector_array input_0, const mlx_vector_array input_1, const int* input_2, size_t input_2_num) = NULL; mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_ptr)(void) = NULL; int (*mlx_closure_custom_vmap_free_ptr)(mlx_closure_custom_vmap cls) = NULL; mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_ptr)(int (*fun)( mlx_vector_array*, mlx_vector_int*, const mlx_vector_array, const int*, size_t _num)) = NULL; mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_payload_ptr)(int (*fun)( mlx_vector_array*, mlx_vector_int*, const mlx_vector_array, const int*, size_t _num, void*), void* payload, void (*dtor)(void*)) = NULL; int (*mlx_closure_custom_vmap_set_ptr)(mlx_closure_custom_vmap* cls, const mlx_closure_custom_vmap src) = NULL; int (*mlx_closure_custom_vmap_apply_ptr)(mlx_vector_array* res_0, mlx_vector_int* res_1, mlx_closure_custom_vmap cls, const mlx_vector_array input_0, const int* input_1, size_t input_1_num) = NULL; int (*mlx_compile_ptr)(mlx_closure* res, const mlx_closure fun, bool shapeless) = NULL; int (*mlx_detail_compile_ptr)(mlx_closure* res, const mlx_closure fun, uintptr_t fun_id, bool shapeless, const uint64_t* constants, size_t constants_num) = NULL; int (*mlx_detail_compile_clear_cache_ptr)(void) = NULL; int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id) = NULL; int (*mlx_disable_compile_ptr)(void) = NULL; int (*mlx_enable_compile_ptr)(void) = NULL; int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode) = NULL; int (*mlx_cuda_is_available_ptr)(bool* res) = NULL; mlx_device (*mlx_device_new_ptr)(void) = NULL; mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index) = NULL; int (*mlx_device_free_ptr)(mlx_device dev) = NULL; int (*mlx_device_set_ptr)(mlx_device* dev, const mlx_device src) = NULL; int (*mlx_device_tostring_ptr)(mlx_string* str, mlx_device dev) = NULL; bool (*mlx_device_equal_ptr)(mlx_device lhs, mlx_device rhs) = NULL; int (*mlx_device_get_index_ptr)(int* index, mlx_device dev) = NULL; int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev) = NULL; int (*mlx_get_default_device_ptr)(mlx_device* dev) = NULL; int (*mlx_set_default_device_ptr)(mlx_device dev) = NULL; int (*mlx_device_is_available_ptr)(bool* avail, mlx_device dev) = NULL; int (*mlx_device_count_ptr)(int* count, mlx_device_type type) = NULL; mlx_device_info (*mlx_device_info_new_ptr)(void) = NULL; int (*mlx_device_info_get_ptr)(mlx_device_info* info, mlx_device dev) = NULL; int (*mlx_device_info_free_ptr)(mlx_device_info info) = NULL; int (*mlx_device_info_has_key_ptr)(bool* exists, mlx_device_info info, const char* key) = NULL; int (*mlx_device_info_is_string_ptr)(bool* is_string, mlx_device_info info, const char* key) = NULL; int (*mlx_device_info_get_string_ptr)(const char** value, mlx_device_info info, const char* key) = NULL; int (*mlx_device_info_get_size_ptr)(size_t* value, mlx_device_info info, const char* key) = NULL; int (*mlx_device_info_get_keys_ptr)(mlx_vector_string* keys, mlx_device_info info) = NULL; int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) = NULL; int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL; int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL; int (*mlx_distributed_all_sum_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL; int (*mlx_distributed_recv_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, int src, const mlx_distributed_group group , const mlx_stream s) = NULL; int (*mlx_distributed_recv_like_ptr)(mlx_array* res, const mlx_array x, int src, const mlx_distributed_group group , const mlx_stream s) = NULL; int (*mlx_distributed_send_ptr)(mlx_array* res, const mlx_array x, int dst, const mlx_distributed_group group , const mlx_stream s) = NULL; int (*mlx_distributed_sum_scatter_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL; int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group) = NULL; int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group) = NULL; mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key) = NULL; bool (*mlx_distributed_is_available_ptr)(const char* bk) = NULL; mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict, const char* bk) = NULL; void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) = NULL; void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...) = NULL; int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless) = NULL; int (*mlx_export_function_kwargs_ptr)(const char* file, const mlx_closure_kwargs fun, const mlx_vector_array args, const mlx_map_string_to_array kwargs, bool shapeless) = NULL; mlx_function_exporter (*mlx_function_exporter_new_ptr)(const char* file, const mlx_closure fun, bool shapeless) = NULL; int (*mlx_function_exporter_free_ptr)(mlx_function_exporter xfunc) = NULL; int (*mlx_function_exporter_apply_ptr)(const mlx_function_exporter xfunc, const mlx_vector_array args) = NULL; int (*mlx_function_exporter_apply_kwargs_ptr)(const mlx_function_exporter xfunc, const mlx_vector_array args, const mlx_map_string_to_array kwargs) = NULL; mlx_imported_function (*mlx_imported_function_new_ptr)(const char* file) = NULL; int (*mlx_imported_function_free_ptr)(mlx_imported_function xfunc) = NULL; int (*mlx_imported_function_apply_ptr)(mlx_vector_array* res, const mlx_imported_function xfunc, const mlx_vector_array args) = NULL; int (*mlx_imported_function_apply_kwargs_ptr)(mlx_vector_array* res, const mlx_imported_function xfunc, const mlx_vector_array args, const mlx_map_string_to_array kwargs) = NULL; mlx_fast_cuda_kernel_config (*mlx_fast_cuda_kernel_config_new_ptr)(void) = NULL; void (*mlx_fast_cuda_kernel_config_free_ptr)(mlx_fast_cuda_kernel_config cls) = NULL; int (*mlx_fast_cuda_kernel_config_add_output_arg_ptr)(mlx_fast_cuda_kernel_config cls, const int* shape, size_t size, mlx_dtype dtype) = NULL; int (*mlx_fast_cuda_kernel_config_set_grid_ptr)(mlx_fast_cuda_kernel_config cls, int grid1, int grid2, int grid3) = NULL; int (*mlx_fast_cuda_kernel_config_set_thread_group_ptr)(mlx_fast_cuda_kernel_config cls, int thread1, int thread2, int thread3) = NULL; int (*mlx_fast_cuda_kernel_config_set_init_value_ptr)(mlx_fast_cuda_kernel_config cls, float value) = NULL; int (*mlx_fast_cuda_kernel_config_set_verbose_ptr)(mlx_fast_cuda_kernel_config cls, bool verbose) = NULL; int (*mlx_fast_cuda_kernel_config_add_template_arg_dtype_ptr)(mlx_fast_cuda_kernel_config cls, const char* name, mlx_dtype dtype) = NULL; int (*mlx_fast_cuda_kernel_config_add_template_arg_int_ptr)(mlx_fast_cuda_kernel_config cls, const char* name, int value) = NULL; int (*mlx_fast_cuda_kernel_config_add_template_arg_bool_ptr)(mlx_fast_cuda_kernel_config cls, const char* name, bool value) = NULL; mlx_fast_cuda_kernel (*mlx_fast_cuda_kernel_new_ptr)(const char* name, const mlx_vector_string input_names, const mlx_vector_string output_names, const char* source, const char* header, bool ensure_row_contiguous, int shared_memory) = NULL; void (*mlx_fast_cuda_kernel_free_ptr)(mlx_fast_cuda_kernel cls) = NULL; int (*mlx_fast_cuda_kernel_apply_ptr)(mlx_vector_array* outputs, mlx_fast_cuda_kernel cls, const mlx_vector_array inputs, const mlx_fast_cuda_kernel_config config, const mlx_stream stream) = NULL; int (*mlx_fast_layer_norm_ptr)(mlx_array* res, const mlx_array x, const mlx_array weight , const mlx_array bias , float eps, const mlx_stream s) = NULL; mlx_fast_metal_kernel_config (*mlx_fast_metal_kernel_config_new_ptr)(void) = NULL; void (*mlx_fast_metal_kernel_config_free_ptr)(mlx_fast_metal_kernel_config cls) = NULL; int (*mlx_fast_metal_kernel_config_add_output_arg_ptr)(mlx_fast_metal_kernel_config cls, const int* shape, size_t size, mlx_dtype dtype) = NULL; int (*mlx_fast_metal_kernel_config_set_grid_ptr)(mlx_fast_metal_kernel_config cls, int grid1, int grid2, int grid3) = NULL; int (*mlx_fast_metal_kernel_config_set_thread_group_ptr)(mlx_fast_metal_kernel_config cls, int thread1, int thread2, int thread3) = NULL; int (*mlx_fast_metal_kernel_config_set_init_value_ptr)(mlx_fast_metal_kernel_config cls, float value) = NULL; int (*mlx_fast_metal_kernel_config_set_verbose_ptr)(mlx_fast_metal_kernel_config cls, bool verbose) = NULL; int (*mlx_fast_metal_kernel_config_add_template_arg_dtype_ptr)(mlx_fast_metal_kernel_config cls, const char* name, mlx_dtype dtype) = NULL; int (*mlx_fast_metal_kernel_config_add_template_arg_int_ptr)(mlx_fast_metal_kernel_config cls, const char* name, int value) = NULL; int (*mlx_fast_metal_kernel_config_add_template_arg_bool_ptr)(mlx_fast_metal_kernel_config cls, const char* name, bool value) = NULL; mlx_fast_metal_kernel (*mlx_fast_metal_kernel_new_ptr)(const char* name, const mlx_vector_string input_names, const mlx_vector_string output_names, const char* source, const char* header, bool ensure_row_contiguous, bool atomic_outputs) = NULL; void (*mlx_fast_metal_kernel_free_ptr)(mlx_fast_metal_kernel cls) = NULL; int (*mlx_fast_metal_kernel_apply_ptr)(mlx_vector_array* outputs, mlx_fast_metal_kernel cls, const mlx_vector_array inputs, const mlx_fast_metal_kernel_config config, const mlx_stream stream) = NULL; int (*mlx_fast_rms_norm_ptr)(mlx_array* res, const mlx_array x, const mlx_array weight , float eps, const mlx_stream s) = NULL; int (*mlx_fast_rope_ptr)(mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, int offset, const mlx_array freqs , const mlx_stream s) = NULL; int (*mlx_fast_rope_dynamic_ptr)(mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, const mlx_array offset, const mlx_array freqs , const mlx_stream s) = NULL; int (*mlx_fast_scaled_dot_product_attention_ptr)(mlx_array* res, const mlx_array queries, const mlx_array keys, const mlx_array values, float scale, const char* mask_mode, const mlx_array mask_arr , const mlx_array sinks , const mlx_stream s) = NULL; int (*mlx_fft_fft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) = NULL; int (*mlx_fft_fft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL; int (*mlx_fft_fftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL; int (*mlx_fft_fftshift_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) = NULL; int (*mlx_fft_ifft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) = NULL; int (*mlx_fft_ifft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL; int (*mlx_fft_ifftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL; int (*mlx_fft_ifftshift_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) = NULL; int (*mlx_fft_irfft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) = NULL; int (*mlx_fft_irfft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL; int (*mlx_fft_irfftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL; int (*mlx_fft_rfft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) = NULL; int (*mlx_fft_rfft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL; int (*mlx_fft_rfftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL; int (*mlx_load_reader_ptr)(mlx_array* res, mlx_io_reader in_stream, const mlx_stream s) = NULL; int (*mlx_load_ptr)(mlx_array* res, const char* file, const mlx_stream s) = NULL; int (*mlx_load_safetensors_reader_ptr)(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, mlx_io_reader in_stream, const mlx_stream s) = NULL; int (*mlx_load_safetensors_ptr)(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, const char* file, const mlx_stream s) = NULL; int (*mlx_save_writer_ptr)(mlx_io_writer out_stream, const mlx_array a) = NULL; int (*mlx_save_ptr)(const char* file, const mlx_array a) = NULL; int (*mlx_save_safetensors_writer_ptr)(mlx_io_writer in_stream, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) = NULL; int (*mlx_save_safetensors_ptr)(const char* file, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) = NULL; mlx_io_reader (*mlx_io_reader_new_ptr)(void* desc, mlx_io_vtable vtable) = NULL; int (*mlx_io_reader_descriptor_ptr)(void** desc_, mlx_io_reader io) = NULL; int (*mlx_io_reader_tostring_ptr)(mlx_string* str_, mlx_io_reader io) = NULL; int (*mlx_io_reader_free_ptr)(mlx_io_reader io) = NULL; mlx_io_writer (*mlx_io_writer_new_ptr)(void* desc, mlx_io_vtable vtable) = NULL; int (*mlx_io_writer_descriptor_ptr)(void** desc_, mlx_io_writer io) = NULL; int (*mlx_io_writer_tostring_ptr)(mlx_string* str_, mlx_io_writer io) = NULL; int (*mlx_io_writer_free_ptr)(mlx_io_writer io) = NULL; int (*mlx_linalg_cholesky_ptr)(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) = NULL; int (*mlx_linalg_cholesky_inv_ptr)(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) = NULL; int (*mlx_linalg_cross_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s) = NULL; int (*mlx_linalg_eig_ptr)(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_linalg_eigh_ptr)(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const char* UPLO, const mlx_stream s) = NULL; int (*mlx_linalg_eigvals_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_linalg_eigvalsh_ptr)(mlx_array* res, const mlx_array a, const char* UPLO, const mlx_stream s) = NULL; int (*mlx_linalg_inv_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_linalg_lu_ptr)(mlx_vector_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_linalg_lu_factor_ptr)(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_linalg_norm_ptr)(mlx_array* res, const mlx_array a, double ord, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s) = NULL; int (*mlx_linalg_norm_matrix_ptr)(mlx_array* res, const mlx_array a, const char* ord, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s) = NULL; int (*mlx_linalg_norm_l2_ptr)(mlx_array* res, const mlx_array a, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s) = NULL; int (*mlx_linalg_pinv_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_linalg_qr_ptr)(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_linalg_solve_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_linalg_solve_triangular_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, bool upper, const mlx_stream s) = NULL; int (*mlx_linalg_svd_ptr)(mlx_vector_array* res, const mlx_array a, bool compute_uv, const mlx_stream s) = NULL; int (*mlx_linalg_tri_inv_ptr)(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) = NULL; mlx_map_string_to_array (*mlx_map_string_to_array_new_ptr)(void) = NULL; int (*mlx_map_string_to_array_set_ptr)(mlx_map_string_to_array* map, const mlx_map_string_to_array src) = NULL; int (*mlx_map_string_to_array_free_ptr)(mlx_map_string_to_array map) = NULL; int (*mlx_map_string_to_array_insert_ptr)(mlx_map_string_to_array map, const char* key, const mlx_array value) = NULL; int (*mlx_map_string_to_array_get_ptr)(mlx_array* value, const mlx_map_string_to_array map, const char* key) = NULL; mlx_map_string_to_array_iterator (*mlx_map_string_to_array_iterator_new_ptr)(mlx_map_string_to_array map) = NULL; int (*mlx_map_string_to_array_iterator_free_ptr)(mlx_map_string_to_array_iterator it) = NULL; int (*mlx_map_string_to_array_iterator_next_ptr)(const char** key, mlx_array* value, mlx_map_string_to_array_iterator it) = NULL; mlx_map_string_to_string (*mlx_map_string_to_string_new_ptr)(void) = NULL; int (*mlx_map_string_to_string_set_ptr)(mlx_map_string_to_string* map, const mlx_map_string_to_string src) = NULL; int (*mlx_map_string_to_string_free_ptr)(mlx_map_string_to_string map) = NULL; int (*mlx_map_string_to_string_insert_ptr)(mlx_map_string_to_string map, const char* key, const char* value) = NULL; int (*mlx_map_string_to_string_get_ptr)(const char** value, const mlx_map_string_to_string map, const char* key) = NULL; mlx_map_string_to_string_iterator (*mlx_map_string_to_string_iterator_new_ptr)(mlx_map_string_to_string map) = NULL; int (*mlx_map_string_to_string_iterator_free_ptr)(mlx_map_string_to_string_iterator it) = NULL; int (*mlx_map_string_to_string_iterator_next_ptr)(const char** key, const char** value, mlx_map_string_to_string_iterator it) = NULL; int (*mlx_clear_cache_ptr)(void) = NULL; int (*mlx_get_active_memory_ptr)(size_t* res) = NULL; int (*mlx_get_cache_memory_ptr)(size_t* res) = NULL; int (*mlx_get_memory_limit_ptr)(size_t* res) = NULL; int (*mlx_get_peak_memory_ptr)(size_t* res) = NULL; int (*mlx_reset_peak_memory_ptr)(void) = NULL; int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit) = NULL; int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit) = NULL; int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit) = NULL; int (*mlx_metal_is_available_ptr)(bool* res) = NULL; int (*mlx_metal_start_capture_ptr)(const char* path) = NULL; int (*mlx_metal_stop_capture_ptr)(void) = NULL; int (*mlx_abs_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_add_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_addmm_ptr)(mlx_array* res, const mlx_array c, const mlx_array a, const mlx_array b, float alpha, float beta, const mlx_stream s) = NULL; int (*mlx_all_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL; int (*mlx_all_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL; int (*mlx_all_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL; int (*mlx_allclose_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, double rtol, double atol, bool equal_nan, const mlx_stream s) = NULL; int (*mlx_any_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL; int (*mlx_any_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL; int (*mlx_any_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL; int (*mlx_arange_ptr)(mlx_array* res, double start, double stop, double step, mlx_dtype dtype, const mlx_stream s) = NULL; int (*mlx_arccos_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_arccosh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_arcsin_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_arcsinh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_arctan_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_arctan2_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_arctanh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_argmax_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL; int (*mlx_argmax_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL; int (*mlx_argmin_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL; int (*mlx_argmin_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL; int (*mlx_argpartition_axis_ptr)(mlx_array* res, const mlx_array a, int kth, int axis, const mlx_stream s) = NULL; int (*mlx_argpartition_ptr)(mlx_array* res, const mlx_array a, int kth, const mlx_stream s) = NULL; int (*mlx_argsort_axis_ptr)(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) = NULL; int (*mlx_argsort_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_array_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, bool equal_nan, const mlx_stream s) = NULL; int (*mlx_as_strided_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const int64_t* strides, size_t strides_num, size_t offset, const mlx_stream s) = NULL; int (*mlx_astype_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype, const mlx_stream s) = NULL; int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_bartlett_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL; int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_blackman_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL; int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) = NULL; int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s) = NULL; int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) = NULL; int (*mlx_ceil_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_clip_ptr)(mlx_array* res, const mlx_array a, const mlx_array a_min , const mlx_array a_max , const mlx_stream s) = NULL; int (*mlx_concatenate_axis_ptr)(mlx_array* res, const mlx_vector_array arrays, int axis, const mlx_stream s) = NULL; int (*mlx_concatenate_ptr)(mlx_array* res, const mlx_vector_array arrays, const mlx_stream s) = NULL; int (*mlx_conjugate_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_contiguous_ptr)(mlx_array* res, const mlx_array a, bool allow_col_major, const mlx_stream s) = NULL; int (*mlx_conv1d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride, int padding, int dilation, int groups, const mlx_stream s) = NULL; int (*mlx_conv2d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int groups, const mlx_stream s) = NULL; int (*mlx_conv3d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int groups, const mlx_stream s) = NULL; int (*mlx_conv_general_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, const int* stride, size_t stride_num, const int* padding_lo, size_t padding_lo_num, const int* padding_hi, size_t padding_hi_num, const int* kernel_dilation, size_t kernel_dilation_num, const int* input_dilation, size_t input_dilation_num, int groups, bool flip, const mlx_stream s) = NULL; int (*mlx_conv_transpose1d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride, int padding, int dilation, int output_padding, int groups, const mlx_stream s) = NULL; int (*mlx_conv_transpose2d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int output_padding_0, int output_padding_1, int groups, const mlx_stream s) = NULL; int (*mlx_conv_transpose3d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int output_padding_0, int output_padding_1, int output_padding_2, int groups, const mlx_stream s) = NULL; int (*mlx_copy_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_cos_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_cosh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_cummax_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL; int (*mlx_cummin_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL; int (*mlx_cumprod_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL; int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL; int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies) = NULL; int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s) = NULL; int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL; int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s) = NULL; int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_divmod_ptr)(mlx_vector_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_einsum_ptr)(mlx_array* res, const char* subscripts, const mlx_vector_array operands, const mlx_stream s) = NULL; int (*mlx_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_erf_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_erfinv_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_exp_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_expand_dims_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) = NULL; int (*mlx_expand_dims_ptr)(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) = NULL; int (*mlx_expm1_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_eye_ptr)(mlx_array* res, int n, int m, int k, mlx_dtype dtype, const mlx_stream s) = NULL; int (*mlx_flatten_ptr)(mlx_array* res, const mlx_array a, int start_axis, int end_axis, const mlx_stream s) = NULL; int (*mlx_floor_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_floor_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_from_fp8_ptr)(mlx_array* res, const mlx_array x, mlx_dtype dtype, const mlx_stream s) = NULL; int (*mlx_full_ptr)(mlx_array* res, const int* shape, size_t shape_num, const mlx_array vals, mlx_dtype dtype, const mlx_stream s) = NULL; int (*mlx_full_like_ptr)(mlx_array* res, const mlx_array a, const mlx_array vals, mlx_dtype dtype, const mlx_stream s) = NULL; int (*mlx_gather_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const int* axes, size_t axes_num, const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s) = NULL; int (*mlx_gather_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s) = NULL; int (*mlx_gather_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_array lhs_indices , const mlx_array rhs_indices , bool sorted_indices, const mlx_stream s) = NULL; int (*mlx_gather_qmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , const mlx_array lhs_indices , const mlx_array rhs_indices , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, bool sorted_indices, const mlx_stream s) = NULL; int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s) = NULL; int (*mlx_hamming_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL; int (*mlx_hanning_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL; int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL; int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_isclose_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, double rtol, double atol, bool equal_nan, const mlx_stream s) = NULL; int (*mlx_isfinite_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_isinf_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_isnan_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_isneginf_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_isposinf_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_kron_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_left_shift_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_less_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_less_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_linspace_ptr)(mlx_array* res, double start, double stop, int num, mlx_dtype dtype, const mlx_stream s) = NULL; int (*mlx_log_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_log10_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_log1p_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_log2_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_logaddexp_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_logcumsumexp_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL; int (*mlx_logical_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_logical_not_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_logical_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_logsumexp_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL; int (*mlx_logsumexp_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL; int (*mlx_logsumexp_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL; int (*mlx_masked_scatter_ptr)(mlx_array* res, const mlx_array a, const mlx_array mask, const mlx_array src, const mlx_stream s) = NULL; int (*mlx_matmul_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_max_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL; int (*mlx_max_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL; int (*mlx_max_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL; int (*mlx_maximum_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_mean_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL; int (*mlx_mean_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL; int (*mlx_mean_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL; int (*mlx_median_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL; int (*mlx_meshgrid_ptr)(mlx_vector_array* res, const mlx_vector_array arrays, bool sparse, const char* indexing, const mlx_stream s) = NULL; int (*mlx_min_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL; int (*mlx_min_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL; int (*mlx_min_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL; int (*mlx_minimum_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_moveaxis_ptr)(mlx_array* res, const mlx_array a, int source, int destination, const mlx_stream s) = NULL; int (*mlx_multiply_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_nan_to_num_ptr)(mlx_array* res, const mlx_array a, float nan, mlx_optional_float posinf, mlx_optional_float neginf, const mlx_stream s) = NULL; int (*mlx_negative_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_not_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_number_of_elements_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool inverted, mlx_dtype dtype, const mlx_stream s) = NULL; int (*mlx_ones_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_stream s) = NULL; int (*mlx_ones_like_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_outer_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_pad_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const int* low_pad_size, size_t low_pad_size_num, const int* high_pad_size, size_t high_pad_size_num, const mlx_array pad_value, const char* mode, const mlx_stream s) = NULL; int (*mlx_pad_symmetric_ptr)(mlx_array* res, const mlx_array a, int pad_width, const mlx_array pad_value, const char* mode, const mlx_stream s) = NULL; int (*mlx_partition_axis_ptr)(mlx_array* res, const mlx_array a, int kth, int axis, const mlx_stream s) = NULL; int (*mlx_partition_ptr)(mlx_array* res, const mlx_array a, int kth, const mlx_stream s) = NULL; int (*mlx_power_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_prod_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL; int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL; int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL; int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) = NULL; int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s) = NULL; int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s) = NULL; int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL; int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_reciprocal_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_remainder_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_repeat_axis_ptr)(mlx_array* res, const mlx_array arr, int repeats, int axis, const mlx_stream s) = NULL; int (*mlx_repeat_ptr)(mlx_array* res, const mlx_array arr, int repeats, const mlx_stream s) = NULL; int (*mlx_reshape_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) = NULL; int (*mlx_right_shift_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_roll_axis_ptr)(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, int axis, const mlx_stream s) = NULL; int (*mlx_roll_axes_ptr)(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL; int (*mlx_roll_ptr)(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, const mlx_stream s) = NULL; int (*mlx_round_ptr)(mlx_array* res, const mlx_array a, int decimals, const mlx_stream s) = NULL; int (*mlx_rsqrt_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_scatter_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) = NULL; int (*mlx_scatter_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) = NULL; int (*mlx_scatter_add_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) = NULL; int (*mlx_scatter_add_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) = NULL; int (*mlx_scatter_add_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) = NULL; int (*mlx_scatter_max_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) = NULL; int (*mlx_scatter_max_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) = NULL; int (*mlx_scatter_min_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) = NULL; int (*mlx_scatter_min_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) = NULL; int (*mlx_scatter_prod_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) = NULL; int (*mlx_scatter_prod_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) = NULL; int (*mlx_segmented_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_array segments, const mlx_stream s) = NULL; int (*mlx_sigmoid_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_sign_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_sin_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_sinh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_slice_ptr)(mlx_array* res, const mlx_array a, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) = NULL; int (*mlx_slice_dynamic_ptr)(mlx_array* res, const mlx_array a, const mlx_array start, const int* axes, size_t axes_num, const int* slice_size, size_t slice_size_num, const mlx_stream s) = NULL; int (*mlx_slice_update_ptr)(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) = NULL; int (*mlx_slice_update_dynamic_ptr)(mlx_array* res, const mlx_array src, const mlx_array update, const mlx_array start, const int* axes, size_t axes_num, const mlx_stream s) = NULL; int (*mlx_softmax_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool precise, const mlx_stream s) = NULL; int (*mlx_softmax_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool precise, const mlx_stream s) = NULL; int (*mlx_softmax_ptr)(mlx_array* res, const mlx_array a, bool precise, const mlx_stream s) = NULL; int (*mlx_sort_axis_ptr)(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) = NULL; int (*mlx_sort_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_split_ptr)(mlx_vector_array* res, const mlx_array a, int num_splits, int axis, const mlx_stream s) = NULL; int (*mlx_split_sections_ptr)(mlx_vector_array* res, const mlx_array a, const int* indices, size_t indices_num, int axis, const mlx_stream s) = NULL; int (*mlx_sqrt_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_square_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_squeeze_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) = NULL; int (*mlx_squeeze_axis_ptr)(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) = NULL; int (*mlx_squeeze_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_stack_axis_ptr)(mlx_array* res, const mlx_vector_array arrays, int axis, const mlx_stream s) = NULL; int (*mlx_stack_ptr)(mlx_array* res, const mlx_vector_array arrays, const mlx_stream s) = NULL; int (*mlx_std_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, int ddof, const mlx_stream s) = NULL; int (*mlx_std_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, int ddof, const mlx_stream s) = NULL; int (*mlx_std_ptr)(mlx_array* res, const mlx_array a, bool keepdims, int ddof, const mlx_stream s) = NULL; int (*mlx_stop_gradient_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_subtract_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; int (*mlx_sum_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL; int (*mlx_sum_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL; int (*mlx_sum_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL; int (*mlx_swapaxes_ptr)(mlx_array* res, const mlx_array a, int axis1, int axis2, const mlx_stream s) = NULL; int (*mlx_take_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const mlx_stream s) = NULL; int (*mlx_take_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_stream s) = NULL; int (*mlx_take_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const mlx_stream s) = NULL; int (*mlx_tan_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_tanh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_tensordot_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const int* axes_a, size_t axes_a_num, const int* axes_b, size_t axes_b_num, const mlx_stream s) = NULL; int (*mlx_tensordot_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s) = NULL; int (*mlx_tile_ptr)(mlx_array* res, const mlx_array arr, const int* reps, size_t reps_num, const mlx_stream s) = NULL; int (*mlx_to_fp8_ptr)(mlx_array* res, const mlx_array x, const mlx_stream s) = NULL; int (*mlx_topk_axis_ptr)(mlx_array* res, const mlx_array a, int k, int axis, const mlx_stream s) = NULL; int (*mlx_topk_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL; int (*mlx_trace_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, mlx_dtype dtype, const mlx_stream s) = NULL; int (*mlx_transpose_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) = NULL; int (*mlx_transpose_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_tri_ptr)(mlx_array* res, int n, int m, int k, mlx_dtype type, const mlx_stream s) = NULL; int (*mlx_tril_ptr)(mlx_array* res, const mlx_array x, int k, const mlx_stream s) = NULL; int (*mlx_triu_ptr)(mlx_array* res, const mlx_array x, int k, const mlx_stream s) = NULL; int (*mlx_unflatten_ptr)(mlx_array* res, const mlx_array a, int axis, const int* shape, size_t shape_num, const mlx_stream s) = NULL; int (*mlx_var_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, int ddof, const mlx_stream s) = NULL; int (*mlx_var_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, int ddof, const mlx_stream s) = NULL; int (*mlx_var_ptr)(mlx_array* res, const mlx_array a, bool keepdims, int ddof, const mlx_stream s) = NULL; int (*mlx_view_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype, const mlx_stream s) = NULL; int (*mlx_where_ptr)(mlx_array* res, const mlx_array condition, const mlx_array x, const mlx_array y, const mlx_stream s) = NULL; int (*mlx_zeros_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_stream s) = NULL; int (*mlx_zeros_like_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; int (*mlx_random_bernoulli_ptr)(mlx_array* res, const mlx_array p, const int* shape, size_t shape_num, const mlx_array key , const mlx_stream s) = NULL; int (*mlx_random_bits_ptr)(mlx_array* res, const int* shape, size_t shape_num, int width, const mlx_array key , const mlx_stream s) = NULL; int (*mlx_random_categorical_shape_ptr)(mlx_array* res, const mlx_array logits, int axis, const int* shape, size_t shape_num, const mlx_array key , const mlx_stream s) = NULL; int (*mlx_random_categorical_num_samples_ptr)(mlx_array* res, const mlx_array logits_, int axis, int num_samples, const mlx_array key , const mlx_stream s) = NULL; int (*mlx_random_categorical_ptr)(mlx_array* res, const mlx_array logits, int axis, const mlx_array key , const mlx_stream s) = NULL; int (*mlx_random_gumbel_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) = NULL; int (*mlx_random_key_ptr)(mlx_array* res, uint64_t seed) = NULL; int (*mlx_random_laplace_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, float loc, float scale, const mlx_array key , const mlx_stream s) = NULL; int (*mlx_random_multivariate_normal_ptr)(mlx_array* res, const mlx_array mean, const mlx_array cov, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) = NULL; int (*mlx_random_normal_broadcast_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array loc , const mlx_array scale , const mlx_array key , const mlx_stream s) = NULL; int (*mlx_random_normal_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, float loc, float scale, const mlx_array key , const mlx_stream s) = NULL; int (*mlx_random_permutation_ptr)(mlx_array* res, const mlx_array x, int axis, const mlx_array key , const mlx_stream s) = NULL; int (*mlx_random_permutation_arange_ptr)(mlx_array* res, int x, const mlx_array key , const mlx_stream s) = NULL; int (*mlx_random_randint_ptr)(mlx_array* res, const mlx_array low, const mlx_array high, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) = NULL; int (*mlx_random_seed_ptr)(uint64_t seed) = NULL; int (*mlx_random_split_num_ptr)(mlx_array* res, const mlx_array key, int num, const mlx_stream s) = NULL; int (*mlx_random_split_ptr)(mlx_array* res_0, mlx_array* res_1, const mlx_array key, const mlx_stream s) = NULL; int (*mlx_random_truncated_normal_ptr)(mlx_array* res, const mlx_array lower, const mlx_array upper, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) = NULL; int (*mlx_random_uniform_ptr)(mlx_array* res, const mlx_array low, const mlx_array high, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) = NULL; mlx_stream (*mlx_stream_new_ptr)(void) = NULL; mlx_stream (*mlx_stream_new_device_ptr)(mlx_device dev) = NULL; int (*mlx_stream_set_ptr)(mlx_stream* stream, const mlx_stream src) = NULL; int (*mlx_stream_free_ptr)(mlx_stream stream) = NULL; int (*mlx_stream_tostring_ptr)(mlx_string* str, mlx_stream stream) = NULL; bool (*mlx_stream_equal_ptr)(mlx_stream lhs, mlx_stream rhs) = NULL; int (*mlx_stream_get_device_ptr)(mlx_device* dev, mlx_stream stream) = NULL; int (*mlx_stream_get_index_ptr)(int* index, mlx_stream stream) = NULL; int (*mlx_synchronize_ptr)(mlx_stream stream) = NULL; int (*mlx_get_default_stream_ptr)(mlx_stream* stream, mlx_device dev) = NULL; int (*mlx_set_default_stream_ptr)(mlx_stream stream) = NULL; mlx_stream (*mlx_default_cpu_stream_new_ptr)(void) = NULL; mlx_stream (*mlx_default_gpu_stream_new_ptr)(void) = NULL; mlx_string (*mlx_string_new_ptr)(void) = NULL; mlx_string (*mlx_string_new_data_ptr)(const char* str) = NULL; int (*mlx_string_set_ptr)(mlx_string* str, const mlx_string src) = NULL; const char* (*mlx_string_data_ptr)(mlx_string str) = NULL; int (*mlx_string_free_ptr)(mlx_string str) = NULL; int (*mlx_async_eval_ptr)(const mlx_vector_array outputs) = NULL; int (*mlx_checkpoint_ptr)(mlx_closure* res, const mlx_closure fun) = NULL; int (*mlx_custom_function_ptr)(mlx_closure* res, const mlx_closure fun, const mlx_closure_custom fun_vjp , const mlx_closure_custom_jvp fun_jvp , const mlx_closure_custom_vmap fun_vmap) = NULL; int (*mlx_custom_vjp_ptr)(mlx_closure* res, const mlx_closure fun, const mlx_closure_custom fun_vjp) = NULL; int (*mlx_eval_ptr)(const mlx_vector_array outputs) = NULL; int (*mlx_jvp_ptr)(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array tangents) = NULL; int (*mlx_value_and_grad_ptr)(mlx_closure_value_and_grad* res, const mlx_closure fun, const int* argnums, size_t argnums_num) = NULL; int (*mlx_vjp_ptr)(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array cotangents) = NULL; int (*mlx_detail_vmap_replace_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array s_inputs, const mlx_vector_array s_outputs, const int* in_axes, size_t in_axes_num, const int* out_axes, size_t out_axes_num) = NULL; int (*mlx_detail_vmap_trace_ptr)(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array inputs, const int* in_axes, size_t in_axes_num) = NULL; mlx_vector_array (*mlx_vector_array_new_ptr)(void) = NULL; int (*mlx_vector_array_set_ptr)(mlx_vector_array* vec, const mlx_vector_array src) = NULL; int (*mlx_vector_array_free_ptr)(mlx_vector_array vec) = NULL; mlx_vector_array (*mlx_vector_array_new_data_ptr)(const mlx_array* data, size_t size) = NULL; mlx_vector_array (*mlx_vector_array_new_value_ptr)(const mlx_array val) = NULL; int (*mlx_vector_array_set_data_ptr)(mlx_vector_array* vec, const mlx_array* data, size_t size) = NULL; int (*mlx_vector_array_set_value_ptr)(mlx_vector_array* vec, const mlx_array val) = NULL; int (*mlx_vector_array_append_data_ptr)(mlx_vector_array vec, const mlx_array* data, size_t size) = NULL; int (*mlx_vector_array_append_value_ptr)(mlx_vector_array vec, const mlx_array val) = NULL; size_t (*mlx_vector_array_size_ptr)(mlx_vector_array vec) = NULL; int (*mlx_vector_array_get_ptr)(mlx_array* res, const mlx_vector_array vec, size_t idx) = NULL; mlx_vector_vector_array (*mlx_vector_vector_array_new_ptr)(void) = NULL; int (*mlx_vector_vector_array_set_ptr)(mlx_vector_vector_array* vec, const mlx_vector_vector_array src) = NULL; int (*mlx_vector_vector_array_free_ptr)(mlx_vector_vector_array vec) = NULL; mlx_vector_vector_array (*mlx_vector_vector_array_new_data_ptr)(const mlx_vector_array* data, size_t size) = NULL; mlx_vector_vector_array (*mlx_vector_vector_array_new_value_ptr)(const mlx_vector_array val) = NULL; int (*mlx_vector_vector_array_set_data_ptr)(mlx_vector_vector_array* vec, const mlx_vector_array* data, size_t size) = NULL; int (*mlx_vector_vector_array_set_value_ptr)(mlx_vector_vector_array* vec, const mlx_vector_array val) = NULL; int (*mlx_vector_vector_array_append_data_ptr)(mlx_vector_vector_array vec, const mlx_vector_array* data, size_t size) = NULL; int (*mlx_vector_vector_array_append_value_ptr)(mlx_vector_vector_array vec, const mlx_vector_array val) = NULL; size_t (*mlx_vector_vector_array_size_ptr)(mlx_vector_vector_array vec) = NULL; int (*mlx_vector_vector_array_get_ptr)(mlx_vector_array* res, const mlx_vector_vector_array vec, size_t idx) = NULL; mlx_vector_int (*mlx_vector_int_new_ptr)(void) = NULL; int (*mlx_vector_int_set_ptr)(mlx_vector_int* vec, const mlx_vector_int src) = NULL; int (*mlx_vector_int_free_ptr)(mlx_vector_int vec) = NULL; mlx_vector_int (*mlx_vector_int_new_data_ptr)(int* data, size_t size) = NULL; mlx_vector_int (*mlx_vector_int_new_value_ptr)(int val) = NULL; int (*mlx_vector_int_set_data_ptr)(mlx_vector_int* vec, int* data, size_t size) = NULL; int (*mlx_vector_int_set_value_ptr)(mlx_vector_int* vec, int val) = NULL; int (*mlx_vector_int_append_data_ptr)(mlx_vector_int vec, int* data, size_t size) = NULL; int (*mlx_vector_int_append_value_ptr)(mlx_vector_int vec, int val) = NULL; size_t (*mlx_vector_int_size_ptr)(mlx_vector_int vec) = NULL; int (*mlx_vector_int_get_ptr)(int* res, const mlx_vector_int vec, size_t idx) = NULL; mlx_vector_string (*mlx_vector_string_new_ptr)(void) = NULL; int (*mlx_vector_string_set_ptr)(mlx_vector_string* vec, const mlx_vector_string src) = NULL; int (*mlx_vector_string_free_ptr)(mlx_vector_string vec) = NULL; mlx_vector_string (*mlx_vector_string_new_data_ptr)(const char** data, size_t size) = NULL; mlx_vector_string (*mlx_vector_string_new_value_ptr)(const char* val) = NULL; int (*mlx_vector_string_set_data_ptr)(mlx_vector_string* vec, const char** data, size_t size) = NULL; int (*mlx_vector_string_set_value_ptr)(mlx_vector_string* vec, const char* val) = NULL; int (*mlx_vector_string_append_data_ptr)(mlx_vector_string vec, const char** data, size_t size) = NULL; int (*mlx_vector_string_append_value_ptr)(mlx_vector_string vec, const char* val) = NULL; size_t (*mlx_vector_string_size_ptr)(mlx_vector_string vec) = NULL; int (*mlx_vector_string_get_ptr)(char** res, const mlx_vector_string vec, size_t idx) = NULL; int (*mlx_version_ptr)(mlx_string* str_) = NULL; // Initialize all function pointers int mlx_load_functions(void* handle) { if (handle == NULL) { fprintf(stderr, "MLX: Invalid library handle\n"); return -1; } mlx_dtype_size_ptr = GET_SYM(handle, "mlx_dtype_size"); if (mlx_dtype_size_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_dtype_size\n"); return -1; } mlx_array_tostring_ptr = GET_SYM(handle, "mlx_array_tostring"); if (mlx_array_tostring_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_tostring\n"); return -1; } mlx_array_new_ptr = GET_SYM(handle, "mlx_array_new"); if (mlx_array_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new\n"); return -1; } mlx_array_free_ptr = GET_SYM(handle, "mlx_array_free"); if (mlx_array_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_free\n"); return -1; } mlx_array_new_bool_ptr = GET_SYM(handle, "mlx_array_new_bool"); if (mlx_array_new_bool_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_bool\n"); return -1; } mlx_array_new_int_ptr = GET_SYM(handle, "mlx_array_new_int"); if (mlx_array_new_int_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_int\n"); return -1; } mlx_array_new_float32_ptr = GET_SYM(handle, "mlx_array_new_float32"); if (mlx_array_new_float32_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_float32\n"); return -1; } mlx_array_new_float_ptr = GET_SYM(handle, "mlx_array_new_float"); if (mlx_array_new_float_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_float\n"); return -1; } mlx_array_new_float64_ptr = GET_SYM(handle, "mlx_array_new_float64"); if (mlx_array_new_float64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_float64\n"); return -1; } mlx_array_new_double_ptr = GET_SYM(handle, "mlx_array_new_double"); if (mlx_array_new_double_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_double\n"); return -1; } mlx_array_new_complex_ptr = GET_SYM(handle, "mlx_array_new_complex"); if (mlx_array_new_complex_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_complex\n"); return -1; } mlx_array_new_data_ptr = GET_SYM(handle, "mlx_array_new_data"); if (mlx_array_new_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data\n"); return -1; } mlx_array_new_data_managed_ptr = GET_SYM(handle, "mlx_array_new_data_managed"); if (mlx_array_new_data_managed_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed\n"); return -1; } mlx_array_new_data_managed_payload_ptr = GET_SYM(handle, "mlx_array_new_data_managed_payload"); if (mlx_array_new_data_managed_payload_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed_payload\n"); return -1; } mlx_array_set_ptr = GET_SYM(handle, "mlx_array_set"); if (mlx_array_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set\n"); return -1; } mlx_array_set_bool_ptr = GET_SYM(handle, "mlx_array_set_bool"); if (mlx_array_set_bool_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_bool\n"); return -1; } mlx_array_set_int_ptr = GET_SYM(handle, "mlx_array_set_int"); if (mlx_array_set_int_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_int\n"); return -1; } mlx_array_set_float32_ptr = GET_SYM(handle, "mlx_array_set_float32"); if (mlx_array_set_float32_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_float32\n"); return -1; } mlx_array_set_float_ptr = GET_SYM(handle, "mlx_array_set_float"); if (mlx_array_set_float_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_float\n"); return -1; } mlx_array_set_float64_ptr = GET_SYM(handle, "mlx_array_set_float64"); if (mlx_array_set_float64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_float64\n"); return -1; } mlx_array_set_double_ptr = GET_SYM(handle, "mlx_array_set_double"); if (mlx_array_set_double_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_double\n"); return -1; } mlx_array_set_complex_ptr = GET_SYM(handle, "mlx_array_set_complex"); if (mlx_array_set_complex_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_complex\n"); return -1; } mlx_array_set_data_ptr = GET_SYM(handle, "mlx_array_set_data"); if (mlx_array_set_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_data\n"); return -1; } mlx_array_itemsize_ptr = GET_SYM(handle, "mlx_array_itemsize"); if (mlx_array_itemsize_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_itemsize\n"); return -1; } mlx_array_size_ptr = GET_SYM(handle, "mlx_array_size"); if (mlx_array_size_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_size\n"); return -1; } mlx_array_nbytes_ptr = GET_SYM(handle, "mlx_array_nbytes"); if (mlx_array_nbytes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_nbytes\n"); return -1; } mlx_array_ndim_ptr = GET_SYM(handle, "mlx_array_ndim"); if (mlx_array_ndim_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_ndim\n"); return -1; } mlx_array_shape_ptr = GET_SYM(handle, "mlx_array_shape"); if (mlx_array_shape_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_shape\n"); return -1; } mlx_array_strides_ptr = GET_SYM(handle, "mlx_array_strides"); if (mlx_array_strides_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_strides\n"); return -1; } mlx_array_dim_ptr = GET_SYM(handle, "mlx_array_dim"); if (mlx_array_dim_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_dim\n"); return -1; } mlx_array_dtype_ptr = GET_SYM(handle, "mlx_array_dtype"); if (mlx_array_dtype_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_dtype\n"); return -1; } mlx_array_eval_ptr = GET_SYM(handle, "mlx_array_eval"); if (mlx_array_eval_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_eval\n"); return -1; } mlx_array_item_bool_ptr = GET_SYM(handle, "mlx_array_item_bool"); if (mlx_array_item_bool_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_bool\n"); return -1; } mlx_array_item_uint8_ptr = GET_SYM(handle, "mlx_array_item_uint8"); if (mlx_array_item_uint8_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint8\n"); return -1; } mlx_array_item_uint16_ptr = GET_SYM(handle, "mlx_array_item_uint16"); if (mlx_array_item_uint16_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint16\n"); return -1; } mlx_array_item_uint32_ptr = GET_SYM(handle, "mlx_array_item_uint32"); if (mlx_array_item_uint32_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint32\n"); return -1; } mlx_array_item_uint64_ptr = GET_SYM(handle, "mlx_array_item_uint64"); if (mlx_array_item_uint64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint64\n"); return -1; } mlx_array_item_int8_ptr = GET_SYM(handle, "mlx_array_item_int8"); if (mlx_array_item_int8_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int8\n"); return -1; } mlx_array_item_int16_ptr = GET_SYM(handle, "mlx_array_item_int16"); if (mlx_array_item_int16_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int16\n"); return -1; } mlx_array_item_int32_ptr = GET_SYM(handle, "mlx_array_item_int32"); if (mlx_array_item_int32_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int32\n"); return -1; } mlx_array_item_int64_ptr = GET_SYM(handle, "mlx_array_item_int64"); if (mlx_array_item_int64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int64\n"); return -1; } mlx_array_item_float32_ptr = GET_SYM(handle, "mlx_array_item_float32"); if (mlx_array_item_float32_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_float32\n"); return -1; } mlx_array_item_float64_ptr = GET_SYM(handle, "mlx_array_item_float64"); if (mlx_array_item_float64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_float64\n"); return -1; } mlx_array_item_complex64_ptr = GET_SYM(handle, "mlx_array_item_complex64"); if (mlx_array_item_complex64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_complex64\n"); return -1; } #if defined(__aarch64__) || defined(_M_ARM64) mlx_array_item_float16_ptr = GET_SYM(handle, "mlx_array_item_float16"); if (mlx_array_item_float16_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_float16\n"); return -1; } #endif #if defined(__aarch64__) || defined(_M_ARM64) mlx_array_item_bfloat16_ptr = GET_SYM(handle, "mlx_array_item_bfloat16"); if (mlx_array_item_bfloat16_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_bfloat16\n"); return -1; } #endif mlx_array_data_bool_ptr = GET_SYM(handle, "mlx_array_data_bool"); if (mlx_array_data_bool_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_bool\n"); return -1; } mlx_array_data_uint8_ptr = GET_SYM(handle, "mlx_array_data_uint8"); if (mlx_array_data_uint8_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint8\n"); return -1; } mlx_array_data_uint16_ptr = GET_SYM(handle, "mlx_array_data_uint16"); if (mlx_array_data_uint16_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint16\n"); return -1; } mlx_array_data_uint32_ptr = GET_SYM(handle, "mlx_array_data_uint32"); if (mlx_array_data_uint32_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint32\n"); return -1; } mlx_array_data_uint64_ptr = GET_SYM(handle, "mlx_array_data_uint64"); if (mlx_array_data_uint64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint64\n"); return -1; } mlx_array_data_int8_ptr = GET_SYM(handle, "mlx_array_data_int8"); if (mlx_array_data_int8_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int8\n"); return -1; } mlx_array_data_int16_ptr = GET_SYM(handle, "mlx_array_data_int16"); if (mlx_array_data_int16_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int16\n"); return -1; } mlx_array_data_int32_ptr = GET_SYM(handle, "mlx_array_data_int32"); if (mlx_array_data_int32_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int32\n"); return -1; } mlx_array_data_int64_ptr = GET_SYM(handle, "mlx_array_data_int64"); if (mlx_array_data_int64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int64\n"); return -1; } mlx_array_data_float32_ptr = GET_SYM(handle, "mlx_array_data_float32"); if (mlx_array_data_float32_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_float32\n"); return -1; } mlx_array_data_float64_ptr = GET_SYM(handle, "mlx_array_data_float64"); if (mlx_array_data_float64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_float64\n"); return -1; } mlx_array_data_complex64_ptr = GET_SYM(handle, "mlx_array_data_complex64"); if (mlx_array_data_complex64_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_complex64\n"); return -1; } #if defined(__aarch64__) || defined(_M_ARM64) mlx_array_data_float16_ptr = GET_SYM(handle, "mlx_array_data_float16"); if (mlx_array_data_float16_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_float16\n"); return -1; } #endif #if defined(__aarch64__) || defined(_M_ARM64) mlx_array_data_bfloat16_ptr = GET_SYM(handle, "mlx_array_data_bfloat16"); if (mlx_array_data_bfloat16_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_bfloat16\n"); return -1; } #endif _mlx_array_is_available_ptr = GET_SYM(handle, "_mlx_array_is_available"); if (_mlx_array_is_available_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_available\n"); return -1; } _mlx_array_wait_ptr = GET_SYM(handle, "_mlx_array_wait"); if (_mlx_array_wait_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_wait\n"); return -1; } _mlx_array_is_contiguous_ptr = GET_SYM(handle, "_mlx_array_is_contiguous"); if (_mlx_array_is_contiguous_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_contiguous\n"); return -1; } _mlx_array_is_row_contiguous_ptr = GET_SYM(handle, "_mlx_array_is_row_contiguous"); if (_mlx_array_is_row_contiguous_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_row_contiguous\n"); return -1; } _mlx_array_is_col_contiguous_ptr = GET_SYM(handle, "_mlx_array_is_col_contiguous"); if (_mlx_array_is_col_contiguous_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_col_contiguous\n"); return -1; } mlx_closure_new_ptr = GET_SYM(handle, "mlx_closure_new"); if (mlx_closure_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new\n"); return -1; } mlx_closure_free_ptr = GET_SYM(handle, "mlx_closure_free"); if (mlx_closure_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_free\n"); return -1; } mlx_closure_new_func_ptr = GET_SYM(handle, "mlx_closure_new_func"); if (mlx_closure_new_func_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new_func\n"); return -1; } mlx_closure_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_new_func_payload"); if (mlx_closure_new_func_payload_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new_func_payload\n"); return -1; } mlx_closure_set_ptr = GET_SYM(handle, "mlx_closure_set"); if (mlx_closure_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_set\n"); return -1; } mlx_closure_apply_ptr = GET_SYM(handle, "mlx_closure_apply"); if (mlx_closure_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_apply\n"); return -1; } mlx_closure_new_unary_ptr = GET_SYM(handle, "mlx_closure_new_unary"); if (mlx_closure_new_unary_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new_unary\n"); return -1; } mlx_closure_kwargs_new_ptr = GET_SYM(handle, "mlx_closure_kwargs_new"); if (mlx_closure_kwargs_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_new\n"); return -1; } mlx_closure_kwargs_free_ptr = GET_SYM(handle, "mlx_closure_kwargs_free"); if (mlx_closure_kwargs_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_free\n"); return -1; } mlx_closure_kwargs_new_func_ptr = GET_SYM(handle, "mlx_closure_kwargs_new_func"); if (mlx_closure_kwargs_new_func_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_new_func\n"); return -1; } mlx_closure_kwargs_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_kwargs_new_func_payload"); if (mlx_closure_kwargs_new_func_payload_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_new_func_payload\n"); return -1; } mlx_closure_kwargs_set_ptr = GET_SYM(handle, "mlx_closure_kwargs_set"); if (mlx_closure_kwargs_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_set\n"); return -1; } mlx_closure_kwargs_apply_ptr = GET_SYM(handle, "mlx_closure_kwargs_apply"); if (mlx_closure_kwargs_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_apply\n"); return -1; } mlx_closure_value_and_grad_new_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_new"); if (mlx_closure_value_and_grad_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_new\n"); return -1; } mlx_closure_value_and_grad_free_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_free"); if (mlx_closure_value_and_grad_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_free\n"); return -1; } mlx_closure_value_and_grad_new_func_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_new_func"); if (mlx_closure_value_and_grad_new_func_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_new_func\n"); return -1; } mlx_closure_value_and_grad_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_new_func_payload"); if (mlx_closure_value_and_grad_new_func_payload_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_new_func_payload\n"); return -1; } mlx_closure_value_and_grad_set_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_set"); if (mlx_closure_value_and_grad_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_set\n"); return -1; } mlx_closure_value_and_grad_apply_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_apply"); if (mlx_closure_value_and_grad_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_apply\n"); return -1; } mlx_closure_custom_new_ptr = GET_SYM(handle, "mlx_closure_custom_new"); if (mlx_closure_custom_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_new\n"); return -1; } mlx_closure_custom_free_ptr = GET_SYM(handle, "mlx_closure_custom_free"); if (mlx_closure_custom_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_free\n"); return -1; } mlx_closure_custom_new_func_ptr = GET_SYM(handle, "mlx_closure_custom_new_func"); if (mlx_closure_custom_new_func_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_new_func\n"); return -1; } mlx_closure_custom_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_custom_new_func_payload"); if (mlx_closure_custom_new_func_payload_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_new_func_payload\n"); return -1; } mlx_closure_custom_set_ptr = GET_SYM(handle, "mlx_closure_custom_set"); if (mlx_closure_custom_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_set\n"); return -1; } mlx_closure_custom_apply_ptr = GET_SYM(handle, "mlx_closure_custom_apply"); if (mlx_closure_custom_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_apply\n"); return -1; } mlx_closure_custom_jvp_new_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_new"); if (mlx_closure_custom_jvp_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_new\n"); return -1; } mlx_closure_custom_jvp_free_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_free"); if (mlx_closure_custom_jvp_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_free\n"); return -1; } mlx_closure_custom_jvp_new_func_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_new_func"); if (mlx_closure_custom_jvp_new_func_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_new_func\n"); return -1; } mlx_closure_custom_jvp_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_new_func_payload"); if (mlx_closure_custom_jvp_new_func_payload_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_new_func_payload\n"); return -1; } mlx_closure_custom_jvp_set_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_set"); if (mlx_closure_custom_jvp_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_set\n"); return -1; } mlx_closure_custom_jvp_apply_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_apply"); if (mlx_closure_custom_jvp_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_apply\n"); return -1; } mlx_closure_custom_vmap_new_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_new"); if (mlx_closure_custom_vmap_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_new\n"); return -1; } mlx_closure_custom_vmap_free_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_free"); if (mlx_closure_custom_vmap_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_free\n"); return -1; } mlx_closure_custom_vmap_new_func_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_new_func"); if (mlx_closure_custom_vmap_new_func_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_new_func\n"); return -1; } mlx_closure_custom_vmap_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_new_func_payload"); if (mlx_closure_custom_vmap_new_func_payload_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_new_func_payload\n"); return -1; } mlx_closure_custom_vmap_set_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_set"); if (mlx_closure_custom_vmap_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_set\n"); return -1; } mlx_closure_custom_vmap_apply_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_apply"); if (mlx_closure_custom_vmap_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_apply\n"); return -1; } mlx_compile_ptr = GET_SYM(handle, "mlx_compile"); if (mlx_compile_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_compile\n"); return -1; } mlx_detail_compile_ptr = GET_SYM(handle, "mlx_detail_compile"); if (mlx_detail_compile_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_compile\n"); return -1; } mlx_detail_compile_clear_cache_ptr = GET_SYM(handle, "mlx_detail_compile_clear_cache"); if (mlx_detail_compile_clear_cache_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_compile_clear_cache\n"); return -1; } mlx_detail_compile_erase_ptr = GET_SYM(handle, "mlx_detail_compile_erase"); if (mlx_detail_compile_erase_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_compile_erase\n"); return -1; } mlx_disable_compile_ptr = GET_SYM(handle, "mlx_disable_compile"); if (mlx_disable_compile_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_disable_compile\n"); return -1; } mlx_enable_compile_ptr = GET_SYM(handle, "mlx_enable_compile"); if (mlx_enable_compile_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_enable_compile\n"); return -1; } mlx_set_compile_mode_ptr = GET_SYM(handle, "mlx_set_compile_mode"); if (mlx_set_compile_mode_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_compile_mode\n"); return -1; } mlx_cuda_is_available_ptr = GET_SYM(handle, "mlx_cuda_is_available"); if (mlx_cuda_is_available_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_cuda_is_available\n"); return -1; } mlx_device_new_ptr = GET_SYM(handle, "mlx_device_new"); if (mlx_device_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_new\n"); return -1; } mlx_device_new_type_ptr = GET_SYM(handle, "mlx_device_new_type"); if (mlx_device_new_type_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_new_type\n"); return -1; } mlx_device_free_ptr = GET_SYM(handle, "mlx_device_free"); if (mlx_device_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_free\n"); return -1; } mlx_device_set_ptr = GET_SYM(handle, "mlx_device_set"); if (mlx_device_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_set\n"); return -1; } mlx_device_tostring_ptr = GET_SYM(handle, "mlx_device_tostring"); if (mlx_device_tostring_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_tostring\n"); return -1; } mlx_device_equal_ptr = GET_SYM(handle, "mlx_device_equal"); if (mlx_device_equal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_equal\n"); return -1; } mlx_device_get_index_ptr = GET_SYM(handle, "mlx_device_get_index"); if (mlx_device_get_index_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_get_index\n"); return -1; } mlx_device_get_type_ptr = GET_SYM(handle, "mlx_device_get_type"); if (mlx_device_get_type_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_get_type\n"); return -1; } mlx_get_default_device_ptr = GET_SYM(handle, "mlx_get_default_device"); if (mlx_get_default_device_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_get_default_device\n"); return -1; } mlx_set_default_device_ptr = GET_SYM(handle, "mlx_set_default_device"); if (mlx_set_default_device_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_default_device\n"); return -1; } mlx_device_is_available_ptr = GET_SYM(handle, "mlx_device_is_available"); if (mlx_device_is_available_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_is_available\n"); return -1; } mlx_device_count_ptr = GET_SYM(handle, "mlx_device_count"); if (mlx_device_count_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_count\n"); return -1; } mlx_device_info_new_ptr = GET_SYM(handle, "mlx_device_info_new"); if (mlx_device_info_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_new\n"); return -1; } mlx_device_info_get_ptr = GET_SYM(handle, "mlx_device_info_get"); if (mlx_device_info_get_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get\n"); return -1; } mlx_device_info_free_ptr = GET_SYM(handle, "mlx_device_info_free"); if (mlx_device_info_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_free\n"); return -1; } mlx_device_info_has_key_ptr = GET_SYM(handle, "mlx_device_info_has_key"); if (mlx_device_info_has_key_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_has_key\n"); return -1; } mlx_device_info_is_string_ptr = GET_SYM(handle, "mlx_device_info_is_string"); if (mlx_device_info_is_string_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_is_string\n"); return -1; } mlx_device_info_get_string_ptr = GET_SYM(handle, "mlx_device_info_get_string"); if (mlx_device_info_get_string_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_string\n"); return -1; } mlx_device_info_get_size_ptr = GET_SYM(handle, "mlx_device_info_get_size"); if (mlx_device_info_get_size_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_size\n"); return -1; } mlx_device_info_get_keys_ptr = GET_SYM(handle, "mlx_device_info_get_keys"); if (mlx_device_info_get_keys_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_keys\n"); return -1; } mlx_distributed_all_gather_ptr = GET_SYM(handle, "mlx_distributed_all_gather"); if (mlx_distributed_all_gather_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_gather\n"); return -1; } mlx_distributed_all_max_ptr = GET_SYM(handle, "mlx_distributed_all_max"); if (mlx_distributed_all_max_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_max\n"); return -1; } mlx_distributed_all_min_ptr = GET_SYM(handle, "mlx_distributed_all_min"); if (mlx_distributed_all_min_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_min\n"); return -1; } mlx_distributed_all_sum_ptr = GET_SYM(handle, "mlx_distributed_all_sum"); if (mlx_distributed_all_sum_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_sum\n"); return -1; } mlx_distributed_recv_ptr = GET_SYM(handle, "mlx_distributed_recv"); if (mlx_distributed_recv_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_recv\n"); return -1; } mlx_distributed_recv_like_ptr = GET_SYM(handle, "mlx_distributed_recv_like"); if (mlx_distributed_recv_like_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_recv_like\n"); return -1; } mlx_distributed_send_ptr = GET_SYM(handle, "mlx_distributed_send"); if (mlx_distributed_send_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_send\n"); return -1; } mlx_distributed_sum_scatter_ptr = GET_SYM(handle, "mlx_distributed_sum_scatter"); if (mlx_distributed_sum_scatter_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_sum_scatter\n"); return -1; } mlx_distributed_group_rank_ptr = GET_SYM(handle, "mlx_distributed_group_rank"); if (mlx_distributed_group_rank_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_group_rank\n"); return -1; } mlx_distributed_group_size_ptr = GET_SYM(handle, "mlx_distributed_group_size"); if (mlx_distributed_group_size_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_group_size\n"); return -1; } mlx_distributed_group_split_ptr = GET_SYM(handle, "mlx_distributed_group_split"); if (mlx_distributed_group_split_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_group_split\n"); return -1; } mlx_distributed_is_available_ptr = GET_SYM(handle, "mlx_distributed_is_available"); if (mlx_distributed_is_available_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_is_available\n"); return -1; } mlx_distributed_init_ptr = GET_SYM(handle, "mlx_distributed_init"); if (mlx_distributed_init_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_init\n"); return -1; } mlx_set_error_handler_ptr = GET_SYM(handle, "mlx_set_error_handler"); if (mlx_set_error_handler_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_error_handler\n"); return -1; } _mlx_error_ptr = GET_SYM(handle, "_mlx_error"); if (_mlx_error_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: _mlx_error\n"); return -1; } mlx_export_function_ptr = GET_SYM(handle, "mlx_export_function"); if (mlx_export_function_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_export_function\n"); return -1; } mlx_export_function_kwargs_ptr = GET_SYM(handle, "mlx_export_function_kwargs"); if (mlx_export_function_kwargs_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_export_function_kwargs\n"); return -1; } mlx_function_exporter_new_ptr = GET_SYM(handle, "mlx_function_exporter_new"); if (mlx_function_exporter_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_new\n"); return -1; } mlx_function_exporter_free_ptr = GET_SYM(handle, "mlx_function_exporter_free"); if (mlx_function_exporter_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_free\n"); return -1; } mlx_function_exporter_apply_ptr = GET_SYM(handle, "mlx_function_exporter_apply"); if (mlx_function_exporter_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_apply\n"); return -1; } mlx_function_exporter_apply_kwargs_ptr = GET_SYM(handle, "mlx_function_exporter_apply_kwargs"); if (mlx_function_exporter_apply_kwargs_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_apply_kwargs\n"); return -1; } mlx_imported_function_new_ptr = GET_SYM(handle, "mlx_imported_function_new"); if (mlx_imported_function_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_new\n"); return -1; } mlx_imported_function_free_ptr = GET_SYM(handle, "mlx_imported_function_free"); if (mlx_imported_function_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_free\n"); return -1; } mlx_imported_function_apply_ptr = GET_SYM(handle, "mlx_imported_function_apply"); if (mlx_imported_function_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_apply\n"); return -1; } mlx_imported_function_apply_kwargs_ptr = GET_SYM(handle, "mlx_imported_function_apply_kwargs"); if (mlx_imported_function_apply_kwargs_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_apply_kwargs\n"); return -1; } mlx_fast_cuda_kernel_config_new_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_new"); if (mlx_fast_cuda_kernel_config_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_new\n"); return -1; } mlx_fast_cuda_kernel_config_free_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_free"); if (mlx_fast_cuda_kernel_config_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_free\n"); return -1; } mlx_fast_cuda_kernel_config_add_output_arg_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_add_output_arg"); if (mlx_fast_cuda_kernel_config_add_output_arg_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_output_arg\n"); return -1; } mlx_fast_cuda_kernel_config_set_grid_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_set_grid"); if (mlx_fast_cuda_kernel_config_set_grid_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_grid\n"); return -1; } mlx_fast_cuda_kernel_config_set_thread_group_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_set_thread_group"); if (mlx_fast_cuda_kernel_config_set_thread_group_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_thread_group\n"); return -1; } mlx_fast_cuda_kernel_config_set_init_value_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_set_init_value"); if (mlx_fast_cuda_kernel_config_set_init_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_init_value\n"); return -1; } mlx_fast_cuda_kernel_config_set_verbose_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_set_verbose"); if (mlx_fast_cuda_kernel_config_set_verbose_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_verbose\n"); return -1; } mlx_fast_cuda_kernel_config_add_template_arg_dtype_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_add_template_arg_dtype"); if (mlx_fast_cuda_kernel_config_add_template_arg_dtype_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_template_arg_dtype\n"); return -1; } mlx_fast_cuda_kernel_config_add_template_arg_int_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_add_template_arg_int"); if (mlx_fast_cuda_kernel_config_add_template_arg_int_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_template_arg_int\n"); return -1; } mlx_fast_cuda_kernel_config_add_template_arg_bool_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_add_template_arg_bool"); if (mlx_fast_cuda_kernel_config_add_template_arg_bool_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_template_arg_bool\n"); return -1; } mlx_fast_cuda_kernel_new_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_new"); if (mlx_fast_cuda_kernel_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_new\n"); return -1; } mlx_fast_cuda_kernel_free_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_free"); if (mlx_fast_cuda_kernel_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_free\n"); return -1; } mlx_fast_cuda_kernel_apply_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_apply"); if (mlx_fast_cuda_kernel_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_apply\n"); return -1; } mlx_fast_layer_norm_ptr = GET_SYM(handle, "mlx_fast_layer_norm"); if (mlx_fast_layer_norm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_layer_norm\n"); return -1; } mlx_fast_metal_kernel_config_new_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_new"); if (mlx_fast_metal_kernel_config_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_new\n"); return -1; } mlx_fast_metal_kernel_config_free_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_free"); if (mlx_fast_metal_kernel_config_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_free\n"); return -1; } mlx_fast_metal_kernel_config_add_output_arg_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_add_output_arg"); if (mlx_fast_metal_kernel_config_add_output_arg_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_output_arg\n"); return -1; } mlx_fast_metal_kernel_config_set_grid_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_set_grid"); if (mlx_fast_metal_kernel_config_set_grid_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_grid\n"); return -1; } mlx_fast_metal_kernel_config_set_thread_group_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_set_thread_group"); if (mlx_fast_metal_kernel_config_set_thread_group_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_thread_group\n"); return -1; } mlx_fast_metal_kernel_config_set_init_value_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_set_init_value"); if (mlx_fast_metal_kernel_config_set_init_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_init_value\n"); return -1; } mlx_fast_metal_kernel_config_set_verbose_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_set_verbose"); if (mlx_fast_metal_kernel_config_set_verbose_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_verbose\n"); return -1; } mlx_fast_metal_kernel_config_add_template_arg_dtype_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_add_template_arg_dtype"); if (mlx_fast_metal_kernel_config_add_template_arg_dtype_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_template_arg_dtype\n"); return -1; } mlx_fast_metal_kernel_config_add_template_arg_int_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_add_template_arg_int"); if (mlx_fast_metal_kernel_config_add_template_arg_int_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_template_arg_int\n"); return -1; } mlx_fast_metal_kernel_config_add_template_arg_bool_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_add_template_arg_bool"); if (mlx_fast_metal_kernel_config_add_template_arg_bool_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_template_arg_bool\n"); return -1; } mlx_fast_metal_kernel_new_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_new"); if (mlx_fast_metal_kernel_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_new\n"); return -1; } mlx_fast_metal_kernel_free_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_free"); if (mlx_fast_metal_kernel_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_free\n"); return -1; } mlx_fast_metal_kernel_apply_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_apply"); if (mlx_fast_metal_kernel_apply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_apply\n"); return -1; } mlx_fast_rms_norm_ptr = GET_SYM(handle, "mlx_fast_rms_norm"); if (mlx_fast_rms_norm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_rms_norm\n"); return -1; } mlx_fast_rope_ptr = GET_SYM(handle, "mlx_fast_rope"); if (mlx_fast_rope_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_rope\n"); return -1; } mlx_fast_rope_dynamic_ptr = GET_SYM(handle, "mlx_fast_rope_dynamic"); if (mlx_fast_rope_dynamic_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_rope_dynamic\n"); return -1; } mlx_fast_scaled_dot_product_attention_ptr = GET_SYM(handle, "mlx_fast_scaled_dot_product_attention"); if (mlx_fast_scaled_dot_product_attention_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_scaled_dot_product_attention\n"); return -1; } mlx_fft_fft_ptr = GET_SYM(handle, "mlx_fft_fft"); if (mlx_fft_fft_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fft\n"); return -1; } mlx_fft_fft2_ptr = GET_SYM(handle, "mlx_fft_fft2"); if (mlx_fft_fft2_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fft2\n"); return -1; } mlx_fft_fftn_ptr = GET_SYM(handle, "mlx_fft_fftn"); if (mlx_fft_fftn_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fftn\n"); return -1; } mlx_fft_fftshift_ptr = GET_SYM(handle, "mlx_fft_fftshift"); if (mlx_fft_fftshift_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fftshift\n"); return -1; } mlx_fft_ifft_ptr = GET_SYM(handle, "mlx_fft_ifft"); if (mlx_fft_ifft_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifft\n"); return -1; } mlx_fft_ifft2_ptr = GET_SYM(handle, "mlx_fft_ifft2"); if (mlx_fft_ifft2_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifft2\n"); return -1; } mlx_fft_ifftn_ptr = GET_SYM(handle, "mlx_fft_ifftn"); if (mlx_fft_ifftn_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifftn\n"); return -1; } mlx_fft_ifftshift_ptr = GET_SYM(handle, "mlx_fft_ifftshift"); if (mlx_fft_ifftshift_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifftshift\n"); return -1; } mlx_fft_irfft_ptr = GET_SYM(handle, "mlx_fft_irfft"); if (mlx_fft_irfft_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_irfft\n"); return -1; } mlx_fft_irfft2_ptr = GET_SYM(handle, "mlx_fft_irfft2"); if (mlx_fft_irfft2_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_irfft2\n"); return -1; } mlx_fft_irfftn_ptr = GET_SYM(handle, "mlx_fft_irfftn"); if (mlx_fft_irfftn_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_irfftn\n"); return -1; } mlx_fft_rfft_ptr = GET_SYM(handle, "mlx_fft_rfft"); if (mlx_fft_rfft_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_rfft\n"); return -1; } mlx_fft_rfft2_ptr = GET_SYM(handle, "mlx_fft_rfft2"); if (mlx_fft_rfft2_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_rfft2\n"); return -1; } mlx_fft_rfftn_ptr = GET_SYM(handle, "mlx_fft_rfftn"); if (mlx_fft_rfftn_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_rfftn\n"); return -1; } mlx_load_reader_ptr = GET_SYM(handle, "mlx_load_reader"); if (mlx_load_reader_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_load_reader\n"); return -1; } mlx_load_ptr = GET_SYM(handle, "mlx_load"); if (mlx_load_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_load\n"); return -1; } mlx_load_safetensors_reader_ptr = GET_SYM(handle, "mlx_load_safetensors_reader"); if (mlx_load_safetensors_reader_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_load_safetensors_reader\n"); return -1; } mlx_load_safetensors_ptr = GET_SYM(handle, "mlx_load_safetensors"); if (mlx_load_safetensors_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_load_safetensors\n"); return -1; } mlx_save_writer_ptr = GET_SYM(handle, "mlx_save_writer"); if (mlx_save_writer_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_save_writer\n"); return -1; } mlx_save_ptr = GET_SYM(handle, "mlx_save"); if (mlx_save_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_save\n"); return -1; } mlx_save_safetensors_writer_ptr = GET_SYM(handle, "mlx_save_safetensors_writer"); if (mlx_save_safetensors_writer_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_save_safetensors_writer\n"); return -1; } mlx_save_safetensors_ptr = GET_SYM(handle, "mlx_save_safetensors"); if (mlx_save_safetensors_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_save_safetensors\n"); return -1; } mlx_io_reader_new_ptr = GET_SYM(handle, "mlx_io_reader_new"); if (mlx_io_reader_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_new\n"); return -1; } mlx_io_reader_descriptor_ptr = GET_SYM(handle, "mlx_io_reader_descriptor"); if (mlx_io_reader_descriptor_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_descriptor\n"); return -1; } mlx_io_reader_tostring_ptr = GET_SYM(handle, "mlx_io_reader_tostring"); if (mlx_io_reader_tostring_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_tostring\n"); return -1; } mlx_io_reader_free_ptr = GET_SYM(handle, "mlx_io_reader_free"); if (mlx_io_reader_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_free\n"); return -1; } mlx_io_writer_new_ptr = GET_SYM(handle, "mlx_io_writer_new"); if (mlx_io_writer_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_new\n"); return -1; } mlx_io_writer_descriptor_ptr = GET_SYM(handle, "mlx_io_writer_descriptor"); if (mlx_io_writer_descriptor_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_descriptor\n"); return -1; } mlx_io_writer_tostring_ptr = GET_SYM(handle, "mlx_io_writer_tostring"); if (mlx_io_writer_tostring_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_tostring\n"); return -1; } mlx_io_writer_free_ptr = GET_SYM(handle, "mlx_io_writer_free"); if (mlx_io_writer_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_free\n"); return -1; } mlx_linalg_cholesky_ptr = GET_SYM(handle, "mlx_linalg_cholesky"); if (mlx_linalg_cholesky_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_cholesky\n"); return -1; } mlx_linalg_cholesky_inv_ptr = GET_SYM(handle, "mlx_linalg_cholesky_inv"); if (mlx_linalg_cholesky_inv_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_cholesky_inv\n"); return -1; } mlx_linalg_cross_ptr = GET_SYM(handle, "mlx_linalg_cross"); if (mlx_linalg_cross_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_cross\n"); return -1; } mlx_linalg_eig_ptr = GET_SYM(handle, "mlx_linalg_eig"); if (mlx_linalg_eig_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eig\n"); return -1; } mlx_linalg_eigh_ptr = GET_SYM(handle, "mlx_linalg_eigh"); if (mlx_linalg_eigh_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eigh\n"); return -1; } mlx_linalg_eigvals_ptr = GET_SYM(handle, "mlx_linalg_eigvals"); if (mlx_linalg_eigvals_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eigvals\n"); return -1; } mlx_linalg_eigvalsh_ptr = GET_SYM(handle, "mlx_linalg_eigvalsh"); if (mlx_linalg_eigvalsh_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eigvalsh\n"); return -1; } mlx_linalg_inv_ptr = GET_SYM(handle, "mlx_linalg_inv"); if (mlx_linalg_inv_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_inv\n"); return -1; } mlx_linalg_lu_ptr = GET_SYM(handle, "mlx_linalg_lu"); if (mlx_linalg_lu_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_lu\n"); return -1; } mlx_linalg_lu_factor_ptr = GET_SYM(handle, "mlx_linalg_lu_factor"); if (mlx_linalg_lu_factor_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_lu_factor\n"); return -1; } mlx_linalg_norm_ptr = GET_SYM(handle, "mlx_linalg_norm"); if (mlx_linalg_norm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_norm\n"); return -1; } mlx_linalg_norm_matrix_ptr = GET_SYM(handle, "mlx_linalg_norm_matrix"); if (mlx_linalg_norm_matrix_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_norm_matrix\n"); return -1; } mlx_linalg_norm_l2_ptr = GET_SYM(handle, "mlx_linalg_norm_l2"); if (mlx_linalg_norm_l2_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_norm_l2\n"); return -1; } mlx_linalg_pinv_ptr = GET_SYM(handle, "mlx_linalg_pinv"); if (mlx_linalg_pinv_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_pinv\n"); return -1; } mlx_linalg_qr_ptr = GET_SYM(handle, "mlx_linalg_qr"); if (mlx_linalg_qr_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_qr\n"); return -1; } mlx_linalg_solve_ptr = GET_SYM(handle, "mlx_linalg_solve"); if (mlx_linalg_solve_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_solve\n"); return -1; } mlx_linalg_solve_triangular_ptr = GET_SYM(handle, "mlx_linalg_solve_triangular"); if (mlx_linalg_solve_triangular_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_solve_triangular\n"); return -1; } mlx_linalg_svd_ptr = GET_SYM(handle, "mlx_linalg_svd"); if (mlx_linalg_svd_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_svd\n"); return -1; } mlx_linalg_tri_inv_ptr = GET_SYM(handle, "mlx_linalg_tri_inv"); if (mlx_linalg_tri_inv_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_tri_inv\n"); return -1; } mlx_map_string_to_array_new_ptr = GET_SYM(handle, "mlx_map_string_to_array_new"); if (mlx_map_string_to_array_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_new\n"); return -1; } mlx_map_string_to_array_set_ptr = GET_SYM(handle, "mlx_map_string_to_array_set"); if (mlx_map_string_to_array_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_set\n"); return -1; } mlx_map_string_to_array_free_ptr = GET_SYM(handle, "mlx_map_string_to_array_free"); if (mlx_map_string_to_array_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_free\n"); return -1; } mlx_map_string_to_array_insert_ptr = GET_SYM(handle, "mlx_map_string_to_array_insert"); if (mlx_map_string_to_array_insert_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_insert\n"); return -1; } mlx_map_string_to_array_get_ptr = GET_SYM(handle, "mlx_map_string_to_array_get"); if (mlx_map_string_to_array_get_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_get\n"); return -1; } mlx_map_string_to_array_iterator_new_ptr = GET_SYM(handle, "mlx_map_string_to_array_iterator_new"); if (mlx_map_string_to_array_iterator_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_iterator_new\n"); return -1; } mlx_map_string_to_array_iterator_free_ptr = GET_SYM(handle, "mlx_map_string_to_array_iterator_free"); if (mlx_map_string_to_array_iterator_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_iterator_free\n"); return -1; } mlx_map_string_to_array_iterator_next_ptr = GET_SYM(handle, "mlx_map_string_to_array_iterator_next"); if (mlx_map_string_to_array_iterator_next_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_iterator_next\n"); return -1; } mlx_map_string_to_string_new_ptr = GET_SYM(handle, "mlx_map_string_to_string_new"); if (mlx_map_string_to_string_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_new\n"); return -1; } mlx_map_string_to_string_set_ptr = GET_SYM(handle, "mlx_map_string_to_string_set"); if (mlx_map_string_to_string_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_set\n"); return -1; } mlx_map_string_to_string_free_ptr = GET_SYM(handle, "mlx_map_string_to_string_free"); if (mlx_map_string_to_string_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_free\n"); return -1; } mlx_map_string_to_string_insert_ptr = GET_SYM(handle, "mlx_map_string_to_string_insert"); if (mlx_map_string_to_string_insert_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_insert\n"); return -1; } mlx_map_string_to_string_get_ptr = GET_SYM(handle, "mlx_map_string_to_string_get"); if (mlx_map_string_to_string_get_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_get\n"); return -1; } mlx_map_string_to_string_iterator_new_ptr = GET_SYM(handle, "mlx_map_string_to_string_iterator_new"); if (mlx_map_string_to_string_iterator_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_iterator_new\n"); return -1; } mlx_map_string_to_string_iterator_free_ptr = GET_SYM(handle, "mlx_map_string_to_string_iterator_free"); if (mlx_map_string_to_string_iterator_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_iterator_free\n"); return -1; } mlx_map_string_to_string_iterator_next_ptr = GET_SYM(handle, "mlx_map_string_to_string_iterator_next"); if (mlx_map_string_to_string_iterator_next_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_iterator_next\n"); return -1; } mlx_clear_cache_ptr = GET_SYM(handle, "mlx_clear_cache"); if (mlx_clear_cache_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_clear_cache\n"); return -1; } mlx_get_active_memory_ptr = GET_SYM(handle, "mlx_get_active_memory"); if (mlx_get_active_memory_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_get_active_memory\n"); return -1; } mlx_get_cache_memory_ptr = GET_SYM(handle, "mlx_get_cache_memory"); if (mlx_get_cache_memory_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_get_cache_memory\n"); return -1; } mlx_get_memory_limit_ptr = GET_SYM(handle, "mlx_get_memory_limit"); if (mlx_get_memory_limit_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_get_memory_limit\n"); return -1; } mlx_get_peak_memory_ptr = GET_SYM(handle, "mlx_get_peak_memory"); if (mlx_get_peak_memory_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_get_peak_memory\n"); return -1; } mlx_reset_peak_memory_ptr = GET_SYM(handle, "mlx_reset_peak_memory"); if (mlx_reset_peak_memory_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_reset_peak_memory\n"); return -1; } mlx_set_cache_limit_ptr = GET_SYM(handle, "mlx_set_cache_limit"); if (mlx_set_cache_limit_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_cache_limit\n"); return -1; } mlx_set_memory_limit_ptr = GET_SYM(handle, "mlx_set_memory_limit"); if (mlx_set_memory_limit_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_memory_limit\n"); return -1; } mlx_set_wired_limit_ptr = GET_SYM(handle, "mlx_set_wired_limit"); if (mlx_set_wired_limit_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_wired_limit\n"); return -1; } mlx_metal_is_available_ptr = GET_SYM(handle, "mlx_metal_is_available"); if (mlx_metal_is_available_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_is_available\n"); return -1; } mlx_metal_start_capture_ptr = GET_SYM(handle, "mlx_metal_start_capture"); if (mlx_metal_start_capture_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_start_capture\n"); return -1; } mlx_metal_stop_capture_ptr = GET_SYM(handle, "mlx_metal_stop_capture"); if (mlx_metal_stop_capture_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_stop_capture\n"); return -1; } mlx_abs_ptr = GET_SYM(handle, "mlx_abs"); if (mlx_abs_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_abs\n"); return -1; } mlx_add_ptr = GET_SYM(handle, "mlx_add"); if (mlx_add_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_add\n"); return -1; } mlx_addmm_ptr = GET_SYM(handle, "mlx_addmm"); if (mlx_addmm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_addmm\n"); return -1; } mlx_all_axes_ptr = GET_SYM(handle, "mlx_all_axes"); if (mlx_all_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_all_axes\n"); return -1; } mlx_all_axis_ptr = GET_SYM(handle, "mlx_all_axis"); if (mlx_all_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_all_axis\n"); return -1; } mlx_all_ptr = GET_SYM(handle, "mlx_all"); if (mlx_all_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_all\n"); return -1; } mlx_allclose_ptr = GET_SYM(handle, "mlx_allclose"); if (mlx_allclose_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_allclose\n"); return -1; } mlx_any_axes_ptr = GET_SYM(handle, "mlx_any_axes"); if (mlx_any_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_any_axes\n"); return -1; } mlx_any_axis_ptr = GET_SYM(handle, "mlx_any_axis"); if (mlx_any_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_any_axis\n"); return -1; } mlx_any_ptr = GET_SYM(handle, "mlx_any"); if (mlx_any_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_any\n"); return -1; } mlx_arange_ptr = GET_SYM(handle, "mlx_arange"); if (mlx_arange_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_arange\n"); return -1; } mlx_arccos_ptr = GET_SYM(handle, "mlx_arccos"); if (mlx_arccos_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_arccos\n"); return -1; } mlx_arccosh_ptr = GET_SYM(handle, "mlx_arccosh"); if (mlx_arccosh_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_arccosh\n"); return -1; } mlx_arcsin_ptr = GET_SYM(handle, "mlx_arcsin"); if (mlx_arcsin_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_arcsin\n"); return -1; } mlx_arcsinh_ptr = GET_SYM(handle, "mlx_arcsinh"); if (mlx_arcsinh_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_arcsinh\n"); return -1; } mlx_arctan_ptr = GET_SYM(handle, "mlx_arctan"); if (mlx_arctan_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_arctan\n"); return -1; } mlx_arctan2_ptr = GET_SYM(handle, "mlx_arctan2"); if (mlx_arctan2_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_arctan2\n"); return -1; } mlx_arctanh_ptr = GET_SYM(handle, "mlx_arctanh"); if (mlx_arctanh_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_arctanh\n"); return -1; } mlx_argmax_axis_ptr = GET_SYM(handle, "mlx_argmax_axis"); if (mlx_argmax_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_argmax_axis\n"); return -1; } mlx_argmax_ptr = GET_SYM(handle, "mlx_argmax"); if (mlx_argmax_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_argmax\n"); return -1; } mlx_argmin_axis_ptr = GET_SYM(handle, "mlx_argmin_axis"); if (mlx_argmin_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_argmin_axis\n"); return -1; } mlx_argmin_ptr = GET_SYM(handle, "mlx_argmin"); if (mlx_argmin_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_argmin\n"); return -1; } mlx_argpartition_axis_ptr = GET_SYM(handle, "mlx_argpartition_axis"); if (mlx_argpartition_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_argpartition_axis\n"); return -1; } mlx_argpartition_ptr = GET_SYM(handle, "mlx_argpartition"); if (mlx_argpartition_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_argpartition\n"); return -1; } mlx_argsort_axis_ptr = GET_SYM(handle, "mlx_argsort_axis"); if (mlx_argsort_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_argsort_axis\n"); return -1; } mlx_argsort_ptr = GET_SYM(handle, "mlx_argsort"); if (mlx_argsort_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_argsort\n"); return -1; } mlx_array_equal_ptr = GET_SYM(handle, "mlx_array_equal"); if (mlx_array_equal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_array_equal\n"); return -1; } mlx_as_strided_ptr = GET_SYM(handle, "mlx_as_strided"); if (mlx_as_strided_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_as_strided\n"); return -1; } mlx_astype_ptr = GET_SYM(handle, "mlx_astype"); if (mlx_astype_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_astype\n"); return -1; } mlx_atleast_1d_ptr = GET_SYM(handle, "mlx_atleast_1d"); if (mlx_atleast_1d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_1d\n"); return -1; } mlx_atleast_2d_ptr = GET_SYM(handle, "mlx_atleast_2d"); if (mlx_atleast_2d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_2d\n"); return -1; } mlx_atleast_3d_ptr = GET_SYM(handle, "mlx_atleast_3d"); if (mlx_atleast_3d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_3d\n"); return -1; } mlx_bartlett_ptr = GET_SYM(handle, "mlx_bartlett"); if (mlx_bartlett_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_bartlett\n"); return -1; } mlx_bitwise_and_ptr = GET_SYM(handle, "mlx_bitwise_and"); if (mlx_bitwise_and_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_and\n"); return -1; } mlx_bitwise_invert_ptr = GET_SYM(handle, "mlx_bitwise_invert"); if (mlx_bitwise_invert_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_invert\n"); return -1; } mlx_bitwise_or_ptr = GET_SYM(handle, "mlx_bitwise_or"); if (mlx_bitwise_or_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_or\n"); return -1; } mlx_bitwise_xor_ptr = GET_SYM(handle, "mlx_bitwise_xor"); if (mlx_bitwise_xor_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_xor\n"); return -1; } mlx_blackman_ptr = GET_SYM(handle, "mlx_blackman"); if (mlx_blackman_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_blackman\n"); return -1; } mlx_block_masked_mm_ptr = GET_SYM(handle, "mlx_block_masked_mm"); if (mlx_block_masked_mm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_block_masked_mm\n"); return -1; } mlx_broadcast_arrays_ptr = GET_SYM(handle, "mlx_broadcast_arrays"); if (mlx_broadcast_arrays_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_broadcast_arrays\n"); return -1; } mlx_broadcast_to_ptr = GET_SYM(handle, "mlx_broadcast_to"); if (mlx_broadcast_to_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_broadcast_to\n"); return -1; } mlx_ceil_ptr = GET_SYM(handle, "mlx_ceil"); if (mlx_ceil_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_ceil\n"); return -1; } mlx_clip_ptr = GET_SYM(handle, "mlx_clip"); if (mlx_clip_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_clip\n"); return -1; } mlx_concatenate_axis_ptr = GET_SYM(handle, "mlx_concatenate_axis"); if (mlx_concatenate_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_concatenate_axis\n"); return -1; } mlx_concatenate_ptr = GET_SYM(handle, "mlx_concatenate"); if (mlx_concatenate_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_concatenate\n"); return -1; } mlx_conjugate_ptr = GET_SYM(handle, "mlx_conjugate"); if (mlx_conjugate_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_conjugate\n"); return -1; } mlx_contiguous_ptr = GET_SYM(handle, "mlx_contiguous"); if (mlx_contiguous_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_contiguous\n"); return -1; } mlx_conv1d_ptr = GET_SYM(handle, "mlx_conv1d"); if (mlx_conv1d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_conv1d\n"); return -1; } mlx_conv2d_ptr = GET_SYM(handle, "mlx_conv2d"); if (mlx_conv2d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_conv2d\n"); return -1; } mlx_conv3d_ptr = GET_SYM(handle, "mlx_conv3d"); if (mlx_conv3d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_conv3d\n"); return -1; } mlx_conv_general_ptr = GET_SYM(handle, "mlx_conv_general"); if (mlx_conv_general_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_general\n"); return -1; } mlx_conv_transpose1d_ptr = GET_SYM(handle, "mlx_conv_transpose1d"); if (mlx_conv_transpose1d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_transpose1d\n"); return -1; } mlx_conv_transpose2d_ptr = GET_SYM(handle, "mlx_conv_transpose2d"); if (mlx_conv_transpose2d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_transpose2d\n"); return -1; } mlx_conv_transpose3d_ptr = GET_SYM(handle, "mlx_conv_transpose3d"); if (mlx_conv_transpose3d_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_transpose3d\n"); return -1; } mlx_copy_ptr = GET_SYM(handle, "mlx_copy"); if (mlx_copy_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_copy\n"); return -1; } mlx_cos_ptr = GET_SYM(handle, "mlx_cos"); if (mlx_cos_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_cos\n"); return -1; } mlx_cosh_ptr = GET_SYM(handle, "mlx_cosh"); if (mlx_cosh_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_cosh\n"); return -1; } mlx_cummax_ptr = GET_SYM(handle, "mlx_cummax"); if (mlx_cummax_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_cummax\n"); return -1; } mlx_cummin_ptr = GET_SYM(handle, "mlx_cummin"); if (mlx_cummin_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_cummin\n"); return -1; } mlx_cumprod_ptr = GET_SYM(handle, "mlx_cumprod"); if (mlx_cumprod_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_cumprod\n"); return -1; } mlx_cumsum_ptr = GET_SYM(handle, "mlx_cumsum"); if (mlx_cumsum_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_cumsum\n"); return -1; } mlx_degrees_ptr = GET_SYM(handle, "mlx_degrees"); if (mlx_degrees_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_degrees\n"); return -1; } mlx_depends_ptr = GET_SYM(handle, "mlx_depends"); if (mlx_depends_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_depends\n"); return -1; } mlx_dequantize_ptr = GET_SYM(handle, "mlx_dequantize"); if (mlx_dequantize_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_dequantize\n"); return -1; } mlx_diag_ptr = GET_SYM(handle, "mlx_diag"); if (mlx_diag_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_diag\n"); return -1; } mlx_diagonal_ptr = GET_SYM(handle, "mlx_diagonal"); if (mlx_diagonal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_diagonal\n"); return -1; } mlx_divide_ptr = GET_SYM(handle, "mlx_divide"); if (mlx_divide_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_divide\n"); return -1; } mlx_divmod_ptr = GET_SYM(handle, "mlx_divmod"); if (mlx_divmod_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_divmod\n"); return -1; } mlx_einsum_ptr = GET_SYM(handle, "mlx_einsum"); if (mlx_einsum_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_einsum\n"); return -1; } mlx_equal_ptr = GET_SYM(handle, "mlx_equal"); if (mlx_equal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_equal\n"); return -1; } mlx_erf_ptr = GET_SYM(handle, "mlx_erf"); if (mlx_erf_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_erf\n"); return -1; } mlx_erfinv_ptr = GET_SYM(handle, "mlx_erfinv"); if (mlx_erfinv_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_erfinv\n"); return -1; } mlx_exp_ptr = GET_SYM(handle, "mlx_exp"); if (mlx_exp_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_exp\n"); return -1; } mlx_expand_dims_axes_ptr = GET_SYM(handle, "mlx_expand_dims_axes"); if (mlx_expand_dims_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_expand_dims_axes\n"); return -1; } mlx_expand_dims_ptr = GET_SYM(handle, "mlx_expand_dims"); if (mlx_expand_dims_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_expand_dims\n"); return -1; } mlx_expm1_ptr = GET_SYM(handle, "mlx_expm1"); if (mlx_expm1_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_expm1\n"); return -1; } mlx_eye_ptr = GET_SYM(handle, "mlx_eye"); if (mlx_eye_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_eye\n"); return -1; } mlx_flatten_ptr = GET_SYM(handle, "mlx_flatten"); if (mlx_flatten_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_flatten\n"); return -1; } mlx_floor_ptr = GET_SYM(handle, "mlx_floor"); if (mlx_floor_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_floor\n"); return -1; } mlx_floor_divide_ptr = GET_SYM(handle, "mlx_floor_divide"); if (mlx_floor_divide_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_floor_divide\n"); return -1; } mlx_from_fp8_ptr = GET_SYM(handle, "mlx_from_fp8"); if (mlx_from_fp8_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_from_fp8\n"); return -1; } mlx_full_ptr = GET_SYM(handle, "mlx_full"); if (mlx_full_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_full\n"); return -1; } mlx_full_like_ptr = GET_SYM(handle, "mlx_full_like"); if (mlx_full_like_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_full_like\n"); return -1; } mlx_gather_ptr = GET_SYM(handle, "mlx_gather"); if (mlx_gather_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_gather\n"); return -1; } mlx_gather_single_ptr = GET_SYM(handle, "mlx_gather_single"); if (mlx_gather_single_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_gather_single\n"); return -1; } mlx_gather_mm_ptr = GET_SYM(handle, "mlx_gather_mm"); if (mlx_gather_mm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_gather_mm\n"); return -1; } mlx_gather_qmm_ptr = GET_SYM(handle, "mlx_gather_qmm"); if (mlx_gather_qmm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_gather_qmm\n"); return -1; } mlx_greater_ptr = GET_SYM(handle, "mlx_greater"); if (mlx_greater_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_greater\n"); return -1; } mlx_greater_equal_ptr = GET_SYM(handle, "mlx_greater_equal"); if (mlx_greater_equal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_greater_equal\n"); return -1; } mlx_hadamard_transform_ptr = GET_SYM(handle, "mlx_hadamard_transform"); if (mlx_hadamard_transform_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_hadamard_transform\n"); return -1; } mlx_hamming_ptr = GET_SYM(handle, "mlx_hamming"); if (mlx_hamming_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_hamming\n"); return -1; } mlx_hanning_ptr = GET_SYM(handle, "mlx_hanning"); if (mlx_hanning_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_hanning\n"); return -1; } mlx_identity_ptr = GET_SYM(handle, "mlx_identity"); if (mlx_identity_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_identity\n"); return -1; } mlx_imag_ptr = GET_SYM(handle, "mlx_imag"); if (mlx_imag_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_imag\n"); return -1; } mlx_inner_ptr = GET_SYM(handle, "mlx_inner"); if (mlx_inner_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_inner\n"); return -1; } mlx_isclose_ptr = GET_SYM(handle, "mlx_isclose"); if (mlx_isclose_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_isclose\n"); return -1; } mlx_isfinite_ptr = GET_SYM(handle, "mlx_isfinite"); if (mlx_isfinite_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_isfinite\n"); return -1; } mlx_isinf_ptr = GET_SYM(handle, "mlx_isinf"); if (mlx_isinf_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_isinf\n"); return -1; } mlx_isnan_ptr = GET_SYM(handle, "mlx_isnan"); if (mlx_isnan_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_isnan\n"); return -1; } mlx_isneginf_ptr = GET_SYM(handle, "mlx_isneginf"); if (mlx_isneginf_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_isneginf\n"); return -1; } mlx_isposinf_ptr = GET_SYM(handle, "mlx_isposinf"); if (mlx_isposinf_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_isposinf\n"); return -1; } mlx_kron_ptr = GET_SYM(handle, "mlx_kron"); if (mlx_kron_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_kron\n"); return -1; } mlx_left_shift_ptr = GET_SYM(handle, "mlx_left_shift"); if (mlx_left_shift_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_left_shift\n"); return -1; } mlx_less_ptr = GET_SYM(handle, "mlx_less"); if (mlx_less_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_less\n"); return -1; } mlx_less_equal_ptr = GET_SYM(handle, "mlx_less_equal"); if (mlx_less_equal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_less_equal\n"); return -1; } mlx_linspace_ptr = GET_SYM(handle, "mlx_linspace"); if (mlx_linspace_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linspace\n"); return -1; } mlx_log_ptr = GET_SYM(handle, "mlx_log"); if (mlx_log_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_log\n"); return -1; } mlx_log10_ptr = GET_SYM(handle, "mlx_log10"); if (mlx_log10_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_log10\n"); return -1; } mlx_log1p_ptr = GET_SYM(handle, "mlx_log1p"); if (mlx_log1p_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_log1p\n"); return -1; } mlx_log2_ptr = GET_SYM(handle, "mlx_log2"); if (mlx_log2_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_log2\n"); return -1; } mlx_logaddexp_ptr = GET_SYM(handle, "mlx_logaddexp"); if (mlx_logaddexp_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_logaddexp\n"); return -1; } mlx_logcumsumexp_ptr = GET_SYM(handle, "mlx_logcumsumexp"); if (mlx_logcumsumexp_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_logcumsumexp\n"); return -1; } mlx_logical_and_ptr = GET_SYM(handle, "mlx_logical_and"); if (mlx_logical_and_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_logical_and\n"); return -1; } mlx_logical_not_ptr = GET_SYM(handle, "mlx_logical_not"); if (mlx_logical_not_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_logical_not\n"); return -1; } mlx_logical_or_ptr = GET_SYM(handle, "mlx_logical_or"); if (mlx_logical_or_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_logical_or\n"); return -1; } mlx_logsumexp_axes_ptr = GET_SYM(handle, "mlx_logsumexp_axes"); if (mlx_logsumexp_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_logsumexp_axes\n"); return -1; } mlx_logsumexp_axis_ptr = GET_SYM(handle, "mlx_logsumexp_axis"); if (mlx_logsumexp_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_logsumexp_axis\n"); return -1; } mlx_logsumexp_ptr = GET_SYM(handle, "mlx_logsumexp"); if (mlx_logsumexp_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_logsumexp\n"); return -1; } mlx_masked_scatter_ptr = GET_SYM(handle, "mlx_masked_scatter"); if (mlx_masked_scatter_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_masked_scatter\n"); return -1; } mlx_matmul_ptr = GET_SYM(handle, "mlx_matmul"); if (mlx_matmul_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_matmul\n"); return -1; } mlx_max_axes_ptr = GET_SYM(handle, "mlx_max_axes"); if (mlx_max_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_max_axes\n"); return -1; } mlx_max_axis_ptr = GET_SYM(handle, "mlx_max_axis"); if (mlx_max_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_max_axis\n"); return -1; } mlx_max_ptr = GET_SYM(handle, "mlx_max"); if (mlx_max_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_max\n"); return -1; } mlx_maximum_ptr = GET_SYM(handle, "mlx_maximum"); if (mlx_maximum_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_maximum\n"); return -1; } mlx_mean_axes_ptr = GET_SYM(handle, "mlx_mean_axes"); if (mlx_mean_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_mean_axes\n"); return -1; } mlx_mean_axis_ptr = GET_SYM(handle, "mlx_mean_axis"); if (mlx_mean_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_mean_axis\n"); return -1; } mlx_mean_ptr = GET_SYM(handle, "mlx_mean"); if (mlx_mean_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_mean\n"); return -1; } mlx_median_ptr = GET_SYM(handle, "mlx_median"); if (mlx_median_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_median\n"); return -1; } mlx_meshgrid_ptr = GET_SYM(handle, "mlx_meshgrid"); if (mlx_meshgrid_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_meshgrid\n"); return -1; } mlx_min_axes_ptr = GET_SYM(handle, "mlx_min_axes"); if (mlx_min_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_min_axes\n"); return -1; } mlx_min_axis_ptr = GET_SYM(handle, "mlx_min_axis"); if (mlx_min_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_min_axis\n"); return -1; } mlx_min_ptr = GET_SYM(handle, "mlx_min"); if (mlx_min_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_min\n"); return -1; } mlx_minimum_ptr = GET_SYM(handle, "mlx_minimum"); if (mlx_minimum_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_minimum\n"); return -1; } mlx_moveaxis_ptr = GET_SYM(handle, "mlx_moveaxis"); if (mlx_moveaxis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_moveaxis\n"); return -1; } mlx_multiply_ptr = GET_SYM(handle, "mlx_multiply"); if (mlx_multiply_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_multiply\n"); return -1; } mlx_nan_to_num_ptr = GET_SYM(handle, "mlx_nan_to_num"); if (mlx_nan_to_num_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_nan_to_num\n"); return -1; } mlx_negative_ptr = GET_SYM(handle, "mlx_negative"); if (mlx_negative_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_negative\n"); return -1; } mlx_not_equal_ptr = GET_SYM(handle, "mlx_not_equal"); if (mlx_not_equal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_not_equal\n"); return -1; } mlx_number_of_elements_ptr = GET_SYM(handle, "mlx_number_of_elements"); if (mlx_number_of_elements_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_number_of_elements\n"); return -1; } mlx_ones_ptr = GET_SYM(handle, "mlx_ones"); if (mlx_ones_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_ones\n"); return -1; } mlx_ones_like_ptr = GET_SYM(handle, "mlx_ones_like"); if (mlx_ones_like_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_ones_like\n"); return -1; } mlx_outer_ptr = GET_SYM(handle, "mlx_outer"); if (mlx_outer_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_outer\n"); return -1; } mlx_pad_ptr = GET_SYM(handle, "mlx_pad"); if (mlx_pad_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_pad\n"); return -1; } mlx_pad_symmetric_ptr = GET_SYM(handle, "mlx_pad_symmetric"); if (mlx_pad_symmetric_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_pad_symmetric\n"); return -1; } mlx_partition_axis_ptr = GET_SYM(handle, "mlx_partition_axis"); if (mlx_partition_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_partition_axis\n"); return -1; } mlx_partition_ptr = GET_SYM(handle, "mlx_partition"); if (mlx_partition_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_partition\n"); return -1; } mlx_power_ptr = GET_SYM(handle, "mlx_power"); if (mlx_power_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_power\n"); return -1; } mlx_prod_axes_ptr = GET_SYM(handle, "mlx_prod_axes"); if (mlx_prod_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_prod_axes\n"); return -1; } mlx_prod_axis_ptr = GET_SYM(handle, "mlx_prod_axis"); if (mlx_prod_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_prod_axis\n"); return -1; } mlx_prod_ptr = GET_SYM(handle, "mlx_prod"); if (mlx_prod_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_prod\n"); return -1; } mlx_put_along_axis_ptr = GET_SYM(handle, "mlx_put_along_axis"); if (mlx_put_along_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_put_along_axis\n"); return -1; } mlx_qqmm_ptr = GET_SYM(handle, "mlx_qqmm"); if (mlx_qqmm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_qqmm\n"); return -1; } mlx_quantize_ptr = GET_SYM(handle, "mlx_quantize"); if (mlx_quantize_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_quantize\n"); return -1; } mlx_quantized_matmul_ptr = GET_SYM(handle, "mlx_quantized_matmul"); if (mlx_quantized_matmul_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_quantized_matmul\n"); return -1; } mlx_radians_ptr = GET_SYM(handle, "mlx_radians"); if (mlx_radians_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_radians\n"); return -1; } mlx_real_ptr = GET_SYM(handle, "mlx_real"); if (mlx_real_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_real\n"); return -1; } mlx_reciprocal_ptr = GET_SYM(handle, "mlx_reciprocal"); if (mlx_reciprocal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_reciprocal\n"); return -1; } mlx_remainder_ptr = GET_SYM(handle, "mlx_remainder"); if (mlx_remainder_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_remainder\n"); return -1; } mlx_repeat_axis_ptr = GET_SYM(handle, "mlx_repeat_axis"); if (mlx_repeat_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_repeat_axis\n"); return -1; } mlx_repeat_ptr = GET_SYM(handle, "mlx_repeat"); if (mlx_repeat_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_repeat\n"); return -1; } mlx_reshape_ptr = GET_SYM(handle, "mlx_reshape"); if (mlx_reshape_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_reshape\n"); return -1; } mlx_right_shift_ptr = GET_SYM(handle, "mlx_right_shift"); if (mlx_right_shift_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_right_shift\n"); return -1; } mlx_roll_axis_ptr = GET_SYM(handle, "mlx_roll_axis"); if (mlx_roll_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_roll_axis\n"); return -1; } mlx_roll_axes_ptr = GET_SYM(handle, "mlx_roll_axes"); if (mlx_roll_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_roll_axes\n"); return -1; } mlx_roll_ptr = GET_SYM(handle, "mlx_roll"); if (mlx_roll_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_roll\n"); return -1; } mlx_round_ptr = GET_SYM(handle, "mlx_round"); if (mlx_round_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_round\n"); return -1; } mlx_rsqrt_ptr = GET_SYM(handle, "mlx_rsqrt"); if (mlx_rsqrt_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_rsqrt\n"); return -1; } mlx_scatter_ptr = GET_SYM(handle, "mlx_scatter"); if (mlx_scatter_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter\n"); return -1; } mlx_scatter_single_ptr = GET_SYM(handle, "mlx_scatter_single"); if (mlx_scatter_single_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_single\n"); return -1; } mlx_scatter_add_ptr = GET_SYM(handle, "mlx_scatter_add"); if (mlx_scatter_add_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_add\n"); return -1; } mlx_scatter_add_single_ptr = GET_SYM(handle, "mlx_scatter_add_single"); if (mlx_scatter_add_single_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_add_single\n"); return -1; } mlx_scatter_add_axis_ptr = GET_SYM(handle, "mlx_scatter_add_axis"); if (mlx_scatter_add_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_add_axis\n"); return -1; } mlx_scatter_max_ptr = GET_SYM(handle, "mlx_scatter_max"); if (mlx_scatter_max_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_max\n"); return -1; } mlx_scatter_max_single_ptr = GET_SYM(handle, "mlx_scatter_max_single"); if (mlx_scatter_max_single_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_max_single\n"); return -1; } mlx_scatter_min_ptr = GET_SYM(handle, "mlx_scatter_min"); if (mlx_scatter_min_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_min\n"); return -1; } mlx_scatter_min_single_ptr = GET_SYM(handle, "mlx_scatter_min_single"); if (mlx_scatter_min_single_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_min_single\n"); return -1; } mlx_scatter_prod_ptr = GET_SYM(handle, "mlx_scatter_prod"); if (mlx_scatter_prod_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_prod\n"); return -1; } mlx_scatter_prod_single_ptr = GET_SYM(handle, "mlx_scatter_prod_single"); if (mlx_scatter_prod_single_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_prod_single\n"); return -1; } mlx_segmented_mm_ptr = GET_SYM(handle, "mlx_segmented_mm"); if (mlx_segmented_mm_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_segmented_mm\n"); return -1; } mlx_sigmoid_ptr = GET_SYM(handle, "mlx_sigmoid"); if (mlx_sigmoid_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sigmoid\n"); return -1; } mlx_sign_ptr = GET_SYM(handle, "mlx_sign"); if (mlx_sign_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sign\n"); return -1; } mlx_sin_ptr = GET_SYM(handle, "mlx_sin"); if (mlx_sin_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sin\n"); return -1; } mlx_sinh_ptr = GET_SYM(handle, "mlx_sinh"); if (mlx_sinh_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sinh\n"); return -1; } mlx_slice_ptr = GET_SYM(handle, "mlx_slice"); if (mlx_slice_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_slice\n"); return -1; } mlx_slice_dynamic_ptr = GET_SYM(handle, "mlx_slice_dynamic"); if (mlx_slice_dynamic_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_slice_dynamic\n"); return -1; } mlx_slice_update_ptr = GET_SYM(handle, "mlx_slice_update"); if (mlx_slice_update_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_slice_update\n"); return -1; } mlx_slice_update_dynamic_ptr = GET_SYM(handle, "mlx_slice_update_dynamic"); if (mlx_slice_update_dynamic_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_slice_update_dynamic\n"); return -1; } mlx_softmax_axes_ptr = GET_SYM(handle, "mlx_softmax_axes"); if (mlx_softmax_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_softmax_axes\n"); return -1; } mlx_softmax_axis_ptr = GET_SYM(handle, "mlx_softmax_axis"); if (mlx_softmax_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_softmax_axis\n"); return -1; } mlx_softmax_ptr = GET_SYM(handle, "mlx_softmax"); if (mlx_softmax_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_softmax\n"); return -1; } mlx_sort_axis_ptr = GET_SYM(handle, "mlx_sort_axis"); if (mlx_sort_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sort_axis\n"); return -1; } mlx_sort_ptr = GET_SYM(handle, "mlx_sort"); if (mlx_sort_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sort\n"); return -1; } mlx_split_ptr = GET_SYM(handle, "mlx_split"); if (mlx_split_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_split\n"); return -1; } mlx_split_sections_ptr = GET_SYM(handle, "mlx_split_sections"); if (mlx_split_sections_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_split_sections\n"); return -1; } mlx_sqrt_ptr = GET_SYM(handle, "mlx_sqrt"); if (mlx_sqrt_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sqrt\n"); return -1; } mlx_square_ptr = GET_SYM(handle, "mlx_square"); if (mlx_square_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_square\n"); return -1; } mlx_squeeze_axes_ptr = GET_SYM(handle, "mlx_squeeze_axes"); if (mlx_squeeze_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_squeeze_axes\n"); return -1; } mlx_squeeze_axis_ptr = GET_SYM(handle, "mlx_squeeze_axis"); if (mlx_squeeze_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_squeeze_axis\n"); return -1; } mlx_squeeze_ptr = GET_SYM(handle, "mlx_squeeze"); if (mlx_squeeze_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_squeeze\n"); return -1; } mlx_stack_axis_ptr = GET_SYM(handle, "mlx_stack_axis"); if (mlx_stack_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stack_axis\n"); return -1; } mlx_stack_ptr = GET_SYM(handle, "mlx_stack"); if (mlx_stack_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stack\n"); return -1; } mlx_std_axes_ptr = GET_SYM(handle, "mlx_std_axes"); if (mlx_std_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_std_axes\n"); return -1; } mlx_std_axis_ptr = GET_SYM(handle, "mlx_std_axis"); if (mlx_std_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_std_axis\n"); return -1; } mlx_std_ptr = GET_SYM(handle, "mlx_std"); if (mlx_std_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_std\n"); return -1; } mlx_stop_gradient_ptr = GET_SYM(handle, "mlx_stop_gradient"); if (mlx_stop_gradient_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stop_gradient\n"); return -1; } mlx_subtract_ptr = GET_SYM(handle, "mlx_subtract"); if (mlx_subtract_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_subtract\n"); return -1; } mlx_sum_axes_ptr = GET_SYM(handle, "mlx_sum_axes"); if (mlx_sum_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sum_axes\n"); return -1; } mlx_sum_axis_ptr = GET_SYM(handle, "mlx_sum_axis"); if (mlx_sum_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sum_axis\n"); return -1; } mlx_sum_ptr = GET_SYM(handle, "mlx_sum"); if (mlx_sum_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_sum\n"); return -1; } mlx_swapaxes_ptr = GET_SYM(handle, "mlx_swapaxes"); if (mlx_swapaxes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_swapaxes\n"); return -1; } mlx_take_axis_ptr = GET_SYM(handle, "mlx_take_axis"); if (mlx_take_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_take_axis\n"); return -1; } mlx_take_ptr = GET_SYM(handle, "mlx_take"); if (mlx_take_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_take\n"); return -1; } mlx_take_along_axis_ptr = GET_SYM(handle, "mlx_take_along_axis"); if (mlx_take_along_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_take_along_axis\n"); return -1; } mlx_tan_ptr = GET_SYM(handle, "mlx_tan"); if (mlx_tan_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_tan\n"); return -1; } mlx_tanh_ptr = GET_SYM(handle, "mlx_tanh"); if (mlx_tanh_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_tanh\n"); return -1; } mlx_tensordot_ptr = GET_SYM(handle, "mlx_tensordot"); if (mlx_tensordot_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_tensordot\n"); return -1; } mlx_tensordot_axis_ptr = GET_SYM(handle, "mlx_tensordot_axis"); if (mlx_tensordot_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_tensordot_axis\n"); return -1; } mlx_tile_ptr = GET_SYM(handle, "mlx_tile"); if (mlx_tile_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_tile\n"); return -1; } mlx_to_fp8_ptr = GET_SYM(handle, "mlx_to_fp8"); if (mlx_to_fp8_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_to_fp8\n"); return -1; } mlx_topk_axis_ptr = GET_SYM(handle, "mlx_topk_axis"); if (mlx_topk_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_topk_axis\n"); return -1; } mlx_topk_ptr = GET_SYM(handle, "mlx_topk"); if (mlx_topk_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_topk\n"); return -1; } mlx_trace_ptr = GET_SYM(handle, "mlx_trace"); if (mlx_trace_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_trace\n"); return -1; } mlx_transpose_axes_ptr = GET_SYM(handle, "mlx_transpose_axes"); if (mlx_transpose_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_transpose_axes\n"); return -1; } mlx_transpose_ptr = GET_SYM(handle, "mlx_transpose"); if (mlx_transpose_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_transpose\n"); return -1; } mlx_tri_ptr = GET_SYM(handle, "mlx_tri"); if (mlx_tri_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_tri\n"); return -1; } mlx_tril_ptr = GET_SYM(handle, "mlx_tril"); if (mlx_tril_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_tril\n"); return -1; } mlx_triu_ptr = GET_SYM(handle, "mlx_triu"); if (mlx_triu_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_triu\n"); return -1; } mlx_unflatten_ptr = GET_SYM(handle, "mlx_unflatten"); if (mlx_unflatten_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_unflatten\n"); return -1; } mlx_var_axes_ptr = GET_SYM(handle, "mlx_var_axes"); if (mlx_var_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_var_axes\n"); return -1; } mlx_var_axis_ptr = GET_SYM(handle, "mlx_var_axis"); if (mlx_var_axis_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_var_axis\n"); return -1; } mlx_var_ptr = GET_SYM(handle, "mlx_var"); if (mlx_var_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_var\n"); return -1; } mlx_view_ptr = GET_SYM(handle, "mlx_view"); if (mlx_view_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_view\n"); return -1; } mlx_where_ptr = GET_SYM(handle, "mlx_where"); if (mlx_where_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_where\n"); return -1; } mlx_zeros_ptr = GET_SYM(handle, "mlx_zeros"); if (mlx_zeros_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_zeros\n"); return -1; } mlx_zeros_like_ptr = GET_SYM(handle, "mlx_zeros_like"); if (mlx_zeros_like_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_zeros_like\n"); return -1; } mlx_random_bernoulli_ptr = GET_SYM(handle, "mlx_random_bernoulli"); if (mlx_random_bernoulli_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_bernoulli\n"); return -1; } mlx_random_bits_ptr = GET_SYM(handle, "mlx_random_bits"); if (mlx_random_bits_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_bits\n"); return -1; } mlx_random_categorical_shape_ptr = GET_SYM(handle, "mlx_random_categorical_shape"); if (mlx_random_categorical_shape_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_categorical_shape\n"); return -1; } mlx_random_categorical_num_samples_ptr = GET_SYM(handle, "mlx_random_categorical_num_samples"); if (mlx_random_categorical_num_samples_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_categorical_num_samples\n"); return -1; } mlx_random_categorical_ptr = GET_SYM(handle, "mlx_random_categorical"); if (mlx_random_categorical_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_categorical\n"); return -1; } mlx_random_gumbel_ptr = GET_SYM(handle, "mlx_random_gumbel"); if (mlx_random_gumbel_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_gumbel\n"); return -1; } mlx_random_key_ptr = GET_SYM(handle, "mlx_random_key"); if (mlx_random_key_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_key\n"); return -1; } mlx_random_laplace_ptr = GET_SYM(handle, "mlx_random_laplace"); if (mlx_random_laplace_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_laplace\n"); return -1; } mlx_random_multivariate_normal_ptr = GET_SYM(handle, "mlx_random_multivariate_normal"); if (mlx_random_multivariate_normal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_multivariate_normal\n"); return -1; } mlx_random_normal_broadcast_ptr = GET_SYM(handle, "mlx_random_normal_broadcast"); if (mlx_random_normal_broadcast_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_normal_broadcast\n"); return -1; } mlx_random_normal_ptr = GET_SYM(handle, "mlx_random_normal"); if (mlx_random_normal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_normal\n"); return -1; } mlx_random_permutation_ptr = GET_SYM(handle, "mlx_random_permutation"); if (mlx_random_permutation_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_permutation\n"); return -1; } mlx_random_permutation_arange_ptr = GET_SYM(handle, "mlx_random_permutation_arange"); if (mlx_random_permutation_arange_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_permutation_arange\n"); return -1; } mlx_random_randint_ptr = GET_SYM(handle, "mlx_random_randint"); if (mlx_random_randint_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_randint\n"); return -1; } mlx_random_seed_ptr = GET_SYM(handle, "mlx_random_seed"); if (mlx_random_seed_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_seed\n"); return -1; } mlx_random_split_num_ptr = GET_SYM(handle, "mlx_random_split_num"); if (mlx_random_split_num_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_split_num\n"); return -1; } mlx_random_split_ptr = GET_SYM(handle, "mlx_random_split"); if (mlx_random_split_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_split\n"); return -1; } mlx_random_truncated_normal_ptr = GET_SYM(handle, "mlx_random_truncated_normal"); if (mlx_random_truncated_normal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_truncated_normal\n"); return -1; } mlx_random_uniform_ptr = GET_SYM(handle, "mlx_random_uniform"); if (mlx_random_uniform_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_random_uniform\n"); return -1; } mlx_stream_new_ptr = GET_SYM(handle, "mlx_stream_new"); if (mlx_stream_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_new\n"); return -1; } mlx_stream_new_device_ptr = GET_SYM(handle, "mlx_stream_new_device"); if (mlx_stream_new_device_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_new_device\n"); return -1; } mlx_stream_set_ptr = GET_SYM(handle, "mlx_stream_set"); if (mlx_stream_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_set\n"); return -1; } mlx_stream_free_ptr = GET_SYM(handle, "mlx_stream_free"); if (mlx_stream_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_free\n"); return -1; } mlx_stream_tostring_ptr = GET_SYM(handle, "mlx_stream_tostring"); if (mlx_stream_tostring_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_tostring\n"); return -1; } mlx_stream_equal_ptr = GET_SYM(handle, "mlx_stream_equal"); if (mlx_stream_equal_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_equal\n"); return -1; } mlx_stream_get_device_ptr = GET_SYM(handle, "mlx_stream_get_device"); if (mlx_stream_get_device_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_get_device\n"); return -1; } mlx_stream_get_index_ptr = GET_SYM(handle, "mlx_stream_get_index"); if (mlx_stream_get_index_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_get_index\n"); return -1; } mlx_synchronize_ptr = GET_SYM(handle, "mlx_synchronize"); if (mlx_synchronize_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_synchronize\n"); return -1; } mlx_get_default_stream_ptr = GET_SYM(handle, "mlx_get_default_stream"); if (mlx_get_default_stream_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_get_default_stream\n"); return -1; } mlx_set_default_stream_ptr = GET_SYM(handle, "mlx_set_default_stream"); if (mlx_set_default_stream_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_default_stream\n"); return -1; } mlx_default_cpu_stream_new_ptr = GET_SYM(handle, "mlx_default_cpu_stream_new"); if (mlx_default_cpu_stream_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_default_cpu_stream_new\n"); return -1; } mlx_default_gpu_stream_new_ptr = GET_SYM(handle, "mlx_default_gpu_stream_new"); if (mlx_default_gpu_stream_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_default_gpu_stream_new\n"); return -1; } mlx_string_new_ptr = GET_SYM(handle, "mlx_string_new"); if (mlx_string_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_string_new\n"); return -1; } mlx_string_new_data_ptr = GET_SYM(handle, "mlx_string_new_data"); if (mlx_string_new_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_string_new_data\n"); return -1; } mlx_string_set_ptr = GET_SYM(handle, "mlx_string_set"); if (mlx_string_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_string_set\n"); return -1; } mlx_string_data_ptr = GET_SYM(handle, "mlx_string_data"); if (mlx_string_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_string_data\n"); return -1; } mlx_string_free_ptr = GET_SYM(handle, "mlx_string_free"); if (mlx_string_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_string_free\n"); return -1; } mlx_async_eval_ptr = GET_SYM(handle, "mlx_async_eval"); if (mlx_async_eval_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_async_eval\n"); return -1; } mlx_checkpoint_ptr = GET_SYM(handle, "mlx_checkpoint"); if (mlx_checkpoint_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_checkpoint\n"); return -1; } mlx_custom_function_ptr = GET_SYM(handle, "mlx_custom_function"); if (mlx_custom_function_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_custom_function\n"); return -1; } mlx_custom_vjp_ptr = GET_SYM(handle, "mlx_custom_vjp"); if (mlx_custom_vjp_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_custom_vjp\n"); return -1; } mlx_eval_ptr = GET_SYM(handle, "mlx_eval"); if (mlx_eval_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_eval\n"); return -1; } mlx_jvp_ptr = GET_SYM(handle, "mlx_jvp"); if (mlx_jvp_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_jvp\n"); return -1; } mlx_value_and_grad_ptr = GET_SYM(handle, "mlx_value_and_grad"); if (mlx_value_and_grad_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_value_and_grad\n"); return -1; } mlx_vjp_ptr = GET_SYM(handle, "mlx_vjp"); if (mlx_vjp_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vjp\n"); return -1; } mlx_detail_vmap_replace_ptr = GET_SYM(handle, "mlx_detail_vmap_replace"); if (mlx_detail_vmap_replace_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_vmap_replace\n"); return -1; } mlx_detail_vmap_trace_ptr = GET_SYM(handle, "mlx_detail_vmap_trace"); if (mlx_detail_vmap_trace_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_vmap_trace\n"); return -1; } mlx_vector_array_new_ptr = GET_SYM(handle, "mlx_vector_array_new"); if (mlx_vector_array_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_new\n"); return -1; } mlx_vector_array_set_ptr = GET_SYM(handle, "mlx_vector_array_set"); if (mlx_vector_array_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_set\n"); return -1; } mlx_vector_array_free_ptr = GET_SYM(handle, "mlx_vector_array_free"); if (mlx_vector_array_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_free\n"); return -1; } mlx_vector_array_new_data_ptr = GET_SYM(handle, "mlx_vector_array_new_data"); if (mlx_vector_array_new_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_new_data\n"); return -1; } mlx_vector_array_new_value_ptr = GET_SYM(handle, "mlx_vector_array_new_value"); if (mlx_vector_array_new_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_new_value\n"); return -1; } mlx_vector_array_set_data_ptr = GET_SYM(handle, "mlx_vector_array_set_data"); if (mlx_vector_array_set_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_set_data\n"); return -1; } mlx_vector_array_set_value_ptr = GET_SYM(handle, "mlx_vector_array_set_value"); if (mlx_vector_array_set_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_set_value\n"); return -1; } mlx_vector_array_append_data_ptr = GET_SYM(handle, "mlx_vector_array_append_data"); if (mlx_vector_array_append_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_append_data\n"); return -1; } mlx_vector_array_append_value_ptr = GET_SYM(handle, "mlx_vector_array_append_value"); if (mlx_vector_array_append_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_append_value\n"); return -1; } mlx_vector_array_size_ptr = GET_SYM(handle, "mlx_vector_array_size"); if (mlx_vector_array_size_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_size\n"); return -1; } mlx_vector_array_get_ptr = GET_SYM(handle, "mlx_vector_array_get"); if (mlx_vector_array_get_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_get\n"); return -1; } mlx_vector_vector_array_new_ptr = GET_SYM(handle, "mlx_vector_vector_array_new"); if (mlx_vector_vector_array_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_new\n"); return -1; } mlx_vector_vector_array_set_ptr = GET_SYM(handle, "mlx_vector_vector_array_set"); if (mlx_vector_vector_array_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_set\n"); return -1; } mlx_vector_vector_array_free_ptr = GET_SYM(handle, "mlx_vector_vector_array_free"); if (mlx_vector_vector_array_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_free\n"); return -1; } mlx_vector_vector_array_new_data_ptr = GET_SYM(handle, "mlx_vector_vector_array_new_data"); if (mlx_vector_vector_array_new_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_new_data\n"); return -1; } mlx_vector_vector_array_new_value_ptr = GET_SYM(handle, "mlx_vector_vector_array_new_value"); if (mlx_vector_vector_array_new_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_new_value\n"); return -1; } mlx_vector_vector_array_set_data_ptr = GET_SYM(handle, "mlx_vector_vector_array_set_data"); if (mlx_vector_vector_array_set_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_set_data\n"); return -1; } mlx_vector_vector_array_set_value_ptr = GET_SYM(handle, "mlx_vector_vector_array_set_value"); if (mlx_vector_vector_array_set_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_set_value\n"); return -1; } mlx_vector_vector_array_append_data_ptr = GET_SYM(handle, "mlx_vector_vector_array_append_data"); if (mlx_vector_vector_array_append_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_append_data\n"); return -1; } mlx_vector_vector_array_append_value_ptr = GET_SYM(handle, "mlx_vector_vector_array_append_value"); if (mlx_vector_vector_array_append_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_append_value\n"); return -1; } mlx_vector_vector_array_size_ptr = GET_SYM(handle, "mlx_vector_vector_array_size"); if (mlx_vector_vector_array_size_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_size\n"); return -1; } mlx_vector_vector_array_get_ptr = GET_SYM(handle, "mlx_vector_vector_array_get"); if (mlx_vector_vector_array_get_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_get\n"); return -1; } mlx_vector_int_new_ptr = GET_SYM(handle, "mlx_vector_int_new"); if (mlx_vector_int_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_new\n"); return -1; } mlx_vector_int_set_ptr = GET_SYM(handle, "mlx_vector_int_set"); if (mlx_vector_int_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_set\n"); return -1; } mlx_vector_int_free_ptr = GET_SYM(handle, "mlx_vector_int_free"); if (mlx_vector_int_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_free\n"); return -1; } mlx_vector_int_new_data_ptr = GET_SYM(handle, "mlx_vector_int_new_data"); if (mlx_vector_int_new_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_new_data\n"); return -1; } mlx_vector_int_new_value_ptr = GET_SYM(handle, "mlx_vector_int_new_value"); if (mlx_vector_int_new_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_new_value\n"); return -1; } mlx_vector_int_set_data_ptr = GET_SYM(handle, "mlx_vector_int_set_data"); if (mlx_vector_int_set_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_set_data\n"); return -1; } mlx_vector_int_set_value_ptr = GET_SYM(handle, "mlx_vector_int_set_value"); if (mlx_vector_int_set_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_set_value\n"); return -1; } mlx_vector_int_append_data_ptr = GET_SYM(handle, "mlx_vector_int_append_data"); if (mlx_vector_int_append_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_append_data\n"); return -1; } mlx_vector_int_append_value_ptr = GET_SYM(handle, "mlx_vector_int_append_value"); if (mlx_vector_int_append_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_append_value\n"); return -1; } mlx_vector_int_size_ptr = GET_SYM(handle, "mlx_vector_int_size"); if (mlx_vector_int_size_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_size\n"); return -1; } mlx_vector_int_get_ptr = GET_SYM(handle, "mlx_vector_int_get"); if (mlx_vector_int_get_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_get\n"); return -1; } mlx_vector_string_new_ptr = GET_SYM(handle, "mlx_vector_string_new"); if (mlx_vector_string_new_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_new\n"); return -1; } mlx_vector_string_set_ptr = GET_SYM(handle, "mlx_vector_string_set"); if (mlx_vector_string_set_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_set\n"); return -1; } mlx_vector_string_free_ptr = GET_SYM(handle, "mlx_vector_string_free"); if (mlx_vector_string_free_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_free\n"); return -1; } mlx_vector_string_new_data_ptr = GET_SYM(handle, "mlx_vector_string_new_data"); if (mlx_vector_string_new_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_new_data\n"); return -1; } mlx_vector_string_new_value_ptr = GET_SYM(handle, "mlx_vector_string_new_value"); if (mlx_vector_string_new_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_new_value\n"); return -1; } mlx_vector_string_set_data_ptr = GET_SYM(handle, "mlx_vector_string_set_data"); if (mlx_vector_string_set_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_set_data\n"); return -1; } mlx_vector_string_set_value_ptr = GET_SYM(handle, "mlx_vector_string_set_value"); if (mlx_vector_string_set_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_set_value\n"); return -1; } mlx_vector_string_append_data_ptr = GET_SYM(handle, "mlx_vector_string_append_data"); if (mlx_vector_string_append_data_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_append_data\n"); return -1; } mlx_vector_string_append_value_ptr = GET_SYM(handle, "mlx_vector_string_append_value"); if (mlx_vector_string_append_value_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_append_value\n"); return -1; } mlx_vector_string_size_ptr = GET_SYM(handle, "mlx_vector_string_size"); if (mlx_vector_string_size_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_size\n"); return -1; } mlx_vector_string_get_ptr = GET_SYM(handle, "mlx_vector_string_get"); if (mlx_vector_string_get_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_get\n"); return -1; } mlx_version_ptr = GET_SYM(handle, "mlx_version"); if (mlx_version_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_version\n"); return -1; } return 0; } // Wrapper function implementations that call through function pointers size_t mlx_dtype_size(mlx_dtype dtype) { return mlx_dtype_size_ptr(dtype); } int mlx_array_tostring(mlx_string* str, const mlx_array arr) { return mlx_array_tostring_ptr(str, arr); } mlx_array mlx_array_new(void) { return mlx_array_new_ptr(); } int mlx_array_free(mlx_array arr) { return mlx_array_free_ptr(arr); } mlx_array mlx_array_new_bool(bool val) { return mlx_array_new_bool_ptr(val); } mlx_array mlx_array_new_int(int val) { return mlx_array_new_int_ptr(val); } mlx_array mlx_array_new_float32(float val) { return mlx_array_new_float32_ptr(val); } mlx_array mlx_array_new_float(float val) { return mlx_array_new_float_ptr(val); } mlx_array mlx_array_new_float64(double val) { return mlx_array_new_float64_ptr(val); } mlx_array mlx_array_new_double(double val) { return mlx_array_new_double_ptr(val); } mlx_array mlx_array_new_complex(float real_val, float imag_val) { return mlx_array_new_complex_ptr(real_val, imag_val); } mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dtype dtype) { return mlx_array_new_data_ptr(data, shape, dim, dtype); } mlx_array mlx_array_new_data_managed(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) { return mlx_array_new_data_managed_ptr(data, shape, dim, dtype, dtor); } mlx_array mlx_array_new_data_managed_payload(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) { return mlx_array_new_data_managed_payload_ptr(data, shape, dim, dtype, payload, dtor); } int mlx_array_set(mlx_array* arr, const mlx_array src) { return mlx_array_set_ptr(arr, src); } int mlx_array_set_bool(mlx_array* arr, bool val) { return mlx_array_set_bool_ptr(arr, val); } int mlx_array_set_int(mlx_array* arr, int val) { return mlx_array_set_int_ptr(arr, val); } int mlx_array_set_float32(mlx_array* arr, float val) { return mlx_array_set_float32_ptr(arr, val); } int mlx_array_set_float(mlx_array* arr, float val) { return mlx_array_set_float_ptr(arr, val); } int mlx_array_set_float64(mlx_array* arr, double val) { return mlx_array_set_float64_ptr(arr, val); } int mlx_array_set_double(mlx_array* arr, double val) { return mlx_array_set_double_ptr(arr, val); } int mlx_array_set_complex(mlx_array* arr, float real_val, float imag_val) { return mlx_array_set_complex_ptr(arr, real_val, imag_val); } int mlx_array_set_data(mlx_array* arr, const void* data, const int* shape, int dim, mlx_dtype dtype) { return mlx_array_set_data_ptr(arr, data, shape, dim, dtype); } size_t mlx_array_itemsize(const mlx_array arr) { return mlx_array_itemsize_ptr(arr); } size_t mlx_array_size(const mlx_array arr) { return mlx_array_size_ptr(arr); } size_t mlx_array_nbytes(const mlx_array arr) { return mlx_array_nbytes_ptr(arr); } size_t mlx_array_ndim(const mlx_array arr) { return mlx_array_ndim_ptr(arr); } const int* mlx_array_shape(const mlx_array arr) { return mlx_array_shape_ptr(arr); } const size_t* mlx_array_strides(const mlx_array arr) { return mlx_array_strides_ptr(arr); } int mlx_array_dim(const mlx_array arr, int dim) { return mlx_array_dim_ptr(arr, dim); } mlx_dtype mlx_array_dtype(const mlx_array arr) { return mlx_array_dtype_ptr(arr); } int mlx_array_eval(mlx_array arr) { return mlx_array_eval_ptr(arr); } int mlx_array_item_bool(bool* res, const mlx_array arr) { return mlx_array_item_bool_ptr(res, arr); } int mlx_array_item_uint8(uint8_t* res, const mlx_array arr) { return mlx_array_item_uint8_ptr(res, arr); } int mlx_array_item_uint16(uint16_t* res, const mlx_array arr) { return mlx_array_item_uint16_ptr(res, arr); } int mlx_array_item_uint32(uint32_t* res, const mlx_array arr) { return mlx_array_item_uint32_ptr(res, arr); } int mlx_array_item_uint64(uint64_t* res, const mlx_array arr) { return mlx_array_item_uint64_ptr(res, arr); } int mlx_array_item_int8(int8_t* res, const mlx_array arr) { return mlx_array_item_int8_ptr(res, arr); } int mlx_array_item_int16(int16_t* res, const mlx_array arr) { return mlx_array_item_int16_ptr(res, arr); } int mlx_array_item_int32(int32_t* res, const mlx_array arr) { return mlx_array_item_int32_ptr(res, arr); } int mlx_array_item_int64(int64_t* res, const mlx_array arr) { return mlx_array_item_int64_ptr(res, arr); } int mlx_array_item_float32(float* res, const mlx_array arr) { return mlx_array_item_float32_ptr(res, arr); } int mlx_array_item_float64(double* res, const mlx_array arr) { return mlx_array_item_float64_ptr(res, arr); } int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr) { return mlx_array_item_complex64_ptr(res, arr); } #if defined(__aarch64__) || defined(_M_ARM64) int mlx_array_item_float16(float16_t* res, const mlx_array arr) { return mlx_array_item_float16_ptr(res, arr); } #endif #if defined(__aarch64__) || defined(_M_ARM64) int mlx_array_item_bfloat16(bfloat16_t* res, const mlx_array arr) { return mlx_array_item_bfloat16_ptr(res, arr); } #endif const bool* mlx_array_data_bool(const mlx_array arr) { return mlx_array_data_bool_ptr(arr); } const uint8_t* mlx_array_data_uint8(const mlx_array arr) { return mlx_array_data_uint8_ptr(arr); } const uint16_t* mlx_array_data_uint16(const mlx_array arr) { return mlx_array_data_uint16_ptr(arr); } const uint32_t* mlx_array_data_uint32(const mlx_array arr) { return mlx_array_data_uint32_ptr(arr); } const uint64_t* mlx_array_data_uint64(const mlx_array arr) { return mlx_array_data_uint64_ptr(arr); } const int8_t* mlx_array_data_int8(const mlx_array arr) { return mlx_array_data_int8_ptr(arr); } const int16_t* mlx_array_data_int16(const mlx_array arr) { return mlx_array_data_int16_ptr(arr); } const int32_t* mlx_array_data_int32(const mlx_array arr) { return mlx_array_data_int32_ptr(arr); } const int64_t* mlx_array_data_int64(const mlx_array arr) { return mlx_array_data_int64_ptr(arr); } const float* mlx_array_data_float32(const mlx_array arr) { return mlx_array_data_float32_ptr(arr); } const double* mlx_array_data_float64(const mlx_array arr) { return mlx_array_data_float64_ptr(arr); } const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr) { return mlx_array_data_complex64_ptr(arr); } #if defined(__aarch64__) || defined(_M_ARM64) const float16_t* mlx_array_data_float16(const mlx_array arr) { return mlx_array_data_float16_ptr(arr); } #endif #if defined(__aarch64__) || defined(_M_ARM64) const bfloat16_t* mlx_array_data_bfloat16(const mlx_array arr) { return mlx_array_data_bfloat16_ptr(arr); } #endif int _mlx_array_is_available(bool* res, const mlx_array arr) { return _mlx_array_is_available_ptr(res, arr); } int _mlx_array_wait(const mlx_array arr) { return _mlx_array_wait_ptr(arr); } int _mlx_array_is_contiguous(bool* res, const mlx_array arr) { return _mlx_array_is_contiguous_ptr(res, arr); } int _mlx_array_is_row_contiguous(bool* res, const mlx_array arr) { return _mlx_array_is_row_contiguous_ptr(res, arr); } int _mlx_array_is_col_contiguous(bool* res, const mlx_array arr) { return _mlx_array_is_col_contiguous_ptr(res, arr); } mlx_closure mlx_closure_new(void) { return mlx_closure_new_ptr(); } int mlx_closure_free(mlx_closure cls) { return mlx_closure_free_ptr(cls); } mlx_closure mlx_closure_new_func(int (*fun)(mlx_vector_array*, const mlx_vector_array)) { return mlx_closure_new_func_ptr(fun); } mlx_closure mlx_closure_new_func_payload(int (*fun)(mlx_vector_array*, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) { return mlx_closure_new_func_payload_ptr(fun, payload, dtor); } int mlx_closure_set(mlx_closure* cls, const mlx_closure src) { return mlx_closure_set_ptr(cls, src); } int mlx_closure_apply(mlx_vector_array* res, mlx_closure cls, const mlx_vector_array input) { return mlx_closure_apply_ptr(res, cls, input); } mlx_closure mlx_closure_new_unary(int (*fun)(mlx_array*, const mlx_array)) { return mlx_closure_new_unary_ptr(fun); } mlx_closure_kwargs mlx_closure_kwargs_new(void) { return mlx_closure_kwargs_new_ptr(); } int mlx_closure_kwargs_free(mlx_closure_kwargs cls) { return mlx_closure_kwargs_free_ptr(cls); } mlx_closure_kwargs mlx_closure_kwargs_new_func(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_map_string_to_array)) { return mlx_closure_kwargs_new_func_ptr(fun); } mlx_closure_kwargs mlx_closure_kwargs_new_func_payload(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_map_string_to_array, void*), void* payload, void (*dtor)(void*)) { return mlx_closure_kwargs_new_func_payload_ptr(fun, payload, dtor); } int mlx_closure_kwargs_set(mlx_closure_kwargs* cls, const mlx_closure_kwargs src) { return mlx_closure_kwargs_set_ptr(cls, src); } int mlx_closure_kwargs_apply(mlx_vector_array* res, mlx_closure_kwargs cls, const mlx_vector_array input_0, const mlx_map_string_to_array input_1) { return mlx_closure_kwargs_apply_ptr(res, cls, input_0, input_1); } mlx_closure_value_and_grad mlx_closure_value_and_grad_new(void) { return mlx_closure_value_and_grad_new_ptr(); } int mlx_closure_value_and_grad_free(mlx_closure_value_and_grad cls) { return mlx_closure_value_and_grad_free_ptr(cls); } mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func(int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array)) { return mlx_closure_value_and_grad_new_func_ptr(fun); } mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func_payload(int (*fun)( mlx_vector_array*, mlx_vector_array*, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) { return mlx_closure_value_and_grad_new_func_payload_ptr(fun, payload, dtor); } int mlx_closure_value_and_grad_set(mlx_closure_value_and_grad* cls, const mlx_closure_value_and_grad src) { return mlx_closure_value_and_grad_set_ptr(cls, src); } int mlx_closure_value_and_grad_apply(mlx_vector_array* res_0, mlx_vector_array* res_1, mlx_closure_value_and_grad cls, const mlx_vector_array input) { return mlx_closure_value_and_grad_apply_ptr(res_0, res_1, cls, input); } mlx_closure_custom mlx_closure_custom_new(void) { return mlx_closure_custom_new_ptr(); } int mlx_closure_custom_free(mlx_closure_custom cls) { return mlx_closure_custom_free_ptr(cls); } mlx_closure_custom mlx_closure_custom_new_func(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const mlx_vector_array)) { return mlx_closure_custom_new_func_ptr(fun); } mlx_closure_custom mlx_closure_custom_new_func_payload(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) { return mlx_closure_custom_new_func_payload_ptr(fun, payload, dtor); } int mlx_closure_custom_set(mlx_closure_custom* cls, const mlx_closure_custom src) { return mlx_closure_custom_set_ptr(cls, src); } int mlx_closure_custom_apply(mlx_vector_array* res, mlx_closure_custom cls, const mlx_vector_array input_0, const mlx_vector_array input_1, const mlx_vector_array input_2) { return mlx_closure_custom_apply_ptr(res, cls, input_0, input_1, input_2); } mlx_closure_custom_jvp mlx_closure_custom_jvp_new(void) { return mlx_closure_custom_jvp_new_ptr(); } int mlx_closure_custom_jvp_free(mlx_closure_custom_jvp cls) { return mlx_closure_custom_jvp_free_ptr(cls); } mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const int*, size_t _num)) { return mlx_closure_custom_jvp_new_func_ptr(fun); } mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func_payload(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const int*, size_t _num, void*), void* payload, void (*dtor)(void*)) { return mlx_closure_custom_jvp_new_func_payload_ptr(fun, payload, dtor); } int mlx_closure_custom_jvp_set(mlx_closure_custom_jvp* cls, const mlx_closure_custom_jvp src) { return mlx_closure_custom_jvp_set_ptr(cls, src); } int mlx_closure_custom_jvp_apply(mlx_vector_array* res, mlx_closure_custom_jvp cls, const mlx_vector_array input_0, const mlx_vector_array input_1, const int* input_2, size_t input_2_num) { return mlx_closure_custom_jvp_apply_ptr(res, cls, input_0, input_1, input_2, input_2_num); } mlx_closure_custom_vmap mlx_closure_custom_vmap_new(void) { return mlx_closure_custom_vmap_new_ptr(); } int mlx_closure_custom_vmap_free(mlx_closure_custom_vmap cls) { return mlx_closure_custom_vmap_free_ptr(cls); } mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func(int (*fun)( mlx_vector_array*, mlx_vector_int*, const mlx_vector_array, const int*, size_t _num)) { return mlx_closure_custom_vmap_new_func_ptr(fun); } mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func_payload(int (*fun)( mlx_vector_array*, mlx_vector_int*, const mlx_vector_array, const int*, size_t _num, void*), void* payload, void (*dtor)(void*)) { return mlx_closure_custom_vmap_new_func_payload_ptr(fun, payload, dtor); } int mlx_closure_custom_vmap_set(mlx_closure_custom_vmap* cls, const mlx_closure_custom_vmap src) { return mlx_closure_custom_vmap_set_ptr(cls, src); } int mlx_closure_custom_vmap_apply(mlx_vector_array* res_0, mlx_vector_int* res_1, mlx_closure_custom_vmap cls, const mlx_vector_array input_0, const int* input_1, size_t input_1_num) { return mlx_closure_custom_vmap_apply_ptr(res_0, res_1, cls, input_0, input_1, input_1_num); } int mlx_compile(mlx_closure* res, const mlx_closure fun, bool shapeless) { return mlx_compile_ptr(res, fun, shapeless); } int mlx_detail_compile(mlx_closure* res, const mlx_closure fun, uintptr_t fun_id, bool shapeless, const uint64_t* constants, size_t constants_num) { return mlx_detail_compile_ptr(res, fun, fun_id, shapeless, constants, constants_num); } int mlx_detail_compile_clear_cache(void) { return mlx_detail_compile_clear_cache_ptr(); } int mlx_detail_compile_erase(uintptr_t fun_id) { return mlx_detail_compile_erase_ptr(fun_id); } int mlx_disable_compile(void) { return mlx_disable_compile_ptr(); } int mlx_enable_compile(void) { return mlx_enable_compile_ptr(); } int mlx_set_compile_mode(mlx_compile_mode mode) { return mlx_set_compile_mode_ptr(mode); } int mlx_cuda_is_available(bool* res) { return mlx_cuda_is_available_ptr(res); } mlx_device mlx_device_new(void) { return mlx_device_new_ptr(); } mlx_device mlx_device_new_type(mlx_device_type type, int index) { return mlx_device_new_type_ptr(type, index); } int mlx_device_free(mlx_device dev) { return mlx_device_free_ptr(dev); } int mlx_device_set(mlx_device* dev, const mlx_device src) { return mlx_device_set_ptr(dev, src); } int mlx_device_tostring(mlx_string* str, mlx_device dev) { return mlx_device_tostring_ptr(str, dev); } bool mlx_device_equal(mlx_device lhs, mlx_device rhs) { return mlx_device_equal_ptr(lhs, rhs); } int mlx_device_get_index(int* index, mlx_device dev) { return mlx_device_get_index_ptr(index, dev); } int mlx_device_get_type(mlx_device_type* type, mlx_device dev) { return mlx_device_get_type_ptr(type, dev); } int mlx_get_default_device(mlx_device* dev) { return mlx_get_default_device_ptr(dev); } int mlx_set_default_device(mlx_device dev) { return mlx_set_default_device_ptr(dev); } int mlx_device_is_available(bool* avail, mlx_device dev) { return mlx_device_is_available_ptr(avail, dev); } int mlx_device_count(int* count, mlx_device_type type) { return mlx_device_count_ptr(count, type); } mlx_device_info mlx_device_info_new(void) { return mlx_device_info_new_ptr(); } int mlx_device_info_get(mlx_device_info* info, mlx_device dev) { return mlx_device_info_get_ptr(info, dev); } int mlx_device_info_free(mlx_device_info info) { return mlx_device_info_free_ptr(info); } int mlx_device_info_has_key(bool* exists, mlx_device_info info, const char* key) { return mlx_device_info_has_key_ptr(exists, info, key); } int mlx_device_info_is_string(bool* is_string, mlx_device_info info, const char* key) { return mlx_device_info_is_string_ptr(is_string, info, key); } int mlx_device_info_get_string(const char** value, mlx_device_info info, const char* key) { return mlx_device_info_get_string_ptr(value, info, key); } int mlx_device_info_get_size(size_t* value, mlx_device_info info, const char* key) { return mlx_device_info_get_size_ptr(value, info, key); } int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info) { return mlx_device_info_get_keys_ptr(keys, info); } int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) { return mlx_distributed_all_gather_ptr(res, x, group, S); } int mlx_distributed_all_max(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) { return mlx_distributed_all_max_ptr(res, x, group, s); } int mlx_distributed_all_min(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) { return mlx_distributed_all_min_ptr(res, x, group, s); } int mlx_distributed_all_sum(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) { return mlx_distributed_all_sum_ptr(res, x, group, s); } int mlx_distributed_recv(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, int src, const mlx_distributed_group group , const mlx_stream s) { return mlx_distributed_recv_ptr(res, shape, shape_num, dtype, src, group, s); } int mlx_distributed_recv_like(mlx_array* res, const mlx_array x, int src, const mlx_distributed_group group , const mlx_stream s) { return mlx_distributed_recv_like_ptr(res, x, src, group, s); } int mlx_distributed_send(mlx_array* res, const mlx_array x, int dst, const mlx_distributed_group group , const mlx_stream s) { return mlx_distributed_send_ptr(res, x, dst, group, s); } int mlx_distributed_sum_scatter(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) { return mlx_distributed_sum_scatter_ptr(res, x, group, s); } int mlx_distributed_group_rank(mlx_distributed_group group) { return mlx_distributed_group_rank_ptr(group); } int mlx_distributed_group_size(mlx_distributed_group group) { return mlx_distributed_group_size_ptr(group); } mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) { return mlx_distributed_group_split_ptr(group, color, key); } bool mlx_distributed_is_available(const char* bk) { return mlx_distributed_is_available_ptr(bk); } mlx_distributed_group mlx_distributed_init(bool strict, const char* bk) { return mlx_distributed_init_ptr(strict, bk); } void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) { mlx_set_error_handler_ptr(handler, data, dtor); } void _mlx_error(const char* file, const int line, const char* fmt, ...) { _mlx_error_ptr(file, line, fmt); } int mlx_export_function(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless) { return mlx_export_function_ptr(file, fun, args, shapeless); } int mlx_export_function_kwargs(const char* file, const mlx_closure_kwargs fun, const mlx_vector_array args, const mlx_map_string_to_array kwargs, bool shapeless) { return mlx_export_function_kwargs_ptr(file, fun, args, kwargs, shapeless); } mlx_function_exporter mlx_function_exporter_new(const char* file, const mlx_closure fun, bool shapeless) { return mlx_function_exporter_new_ptr(file, fun, shapeless); } int mlx_function_exporter_free(mlx_function_exporter xfunc) { return mlx_function_exporter_free_ptr(xfunc); } int mlx_function_exporter_apply(const mlx_function_exporter xfunc, const mlx_vector_array args) { return mlx_function_exporter_apply_ptr(xfunc, args); } int mlx_function_exporter_apply_kwargs(const mlx_function_exporter xfunc, const mlx_vector_array args, const mlx_map_string_to_array kwargs) { return mlx_function_exporter_apply_kwargs_ptr(xfunc, args, kwargs); } mlx_imported_function mlx_imported_function_new(const char* file) { return mlx_imported_function_new_ptr(file); } int mlx_imported_function_free(mlx_imported_function xfunc) { return mlx_imported_function_free_ptr(xfunc); } int mlx_imported_function_apply(mlx_vector_array* res, const mlx_imported_function xfunc, const mlx_vector_array args) { return mlx_imported_function_apply_ptr(res, xfunc, args); } int mlx_imported_function_apply_kwargs(mlx_vector_array* res, const mlx_imported_function xfunc, const mlx_vector_array args, const mlx_map_string_to_array kwargs) { return mlx_imported_function_apply_kwargs_ptr(res, xfunc, args, kwargs); } mlx_fast_cuda_kernel_config mlx_fast_cuda_kernel_config_new(void) { return mlx_fast_cuda_kernel_config_new_ptr(); } void mlx_fast_cuda_kernel_config_free(mlx_fast_cuda_kernel_config cls) { mlx_fast_cuda_kernel_config_free_ptr(cls); } int mlx_fast_cuda_kernel_config_add_output_arg(mlx_fast_cuda_kernel_config cls, const int* shape, size_t size, mlx_dtype dtype) { return mlx_fast_cuda_kernel_config_add_output_arg_ptr(cls, shape, size, dtype); } int mlx_fast_cuda_kernel_config_set_grid(mlx_fast_cuda_kernel_config cls, int grid1, int grid2, int grid3) { return mlx_fast_cuda_kernel_config_set_grid_ptr(cls, grid1, grid2, grid3); } int mlx_fast_cuda_kernel_config_set_thread_group(mlx_fast_cuda_kernel_config cls, int thread1, int thread2, int thread3) { return mlx_fast_cuda_kernel_config_set_thread_group_ptr(cls, thread1, thread2, thread3); } int mlx_fast_cuda_kernel_config_set_init_value(mlx_fast_cuda_kernel_config cls, float value) { return mlx_fast_cuda_kernel_config_set_init_value_ptr(cls, value); } int mlx_fast_cuda_kernel_config_set_verbose(mlx_fast_cuda_kernel_config cls, bool verbose) { return mlx_fast_cuda_kernel_config_set_verbose_ptr(cls, verbose); } int mlx_fast_cuda_kernel_config_add_template_arg_dtype(mlx_fast_cuda_kernel_config cls, const char* name, mlx_dtype dtype) { return mlx_fast_cuda_kernel_config_add_template_arg_dtype_ptr(cls, name, dtype); } int mlx_fast_cuda_kernel_config_add_template_arg_int(mlx_fast_cuda_kernel_config cls, const char* name, int value) { return mlx_fast_cuda_kernel_config_add_template_arg_int_ptr(cls, name, value); } int mlx_fast_cuda_kernel_config_add_template_arg_bool(mlx_fast_cuda_kernel_config cls, const char* name, bool value) { return mlx_fast_cuda_kernel_config_add_template_arg_bool_ptr(cls, name, value); } mlx_fast_cuda_kernel mlx_fast_cuda_kernel_new(const char* name, const mlx_vector_string input_names, const mlx_vector_string output_names, const char* source, const char* header, bool ensure_row_contiguous, int shared_memory) { return mlx_fast_cuda_kernel_new_ptr(name, input_names, output_names, source, header, ensure_row_contiguous, shared_memory); } void mlx_fast_cuda_kernel_free(mlx_fast_cuda_kernel cls) { mlx_fast_cuda_kernel_free_ptr(cls); } int mlx_fast_cuda_kernel_apply(mlx_vector_array* outputs, mlx_fast_cuda_kernel cls, const mlx_vector_array inputs, const mlx_fast_cuda_kernel_config config, const mlx_stream stream) { return mlx_fast_cuda_kernel_apply_ptr(outputs, cls, inputs, config, stream); } int mlx_fast_layer_norm(mlx_array* res, const mlx_array x, const mlx_array weight , const mlx_array bias , float eps, const mlx_stream s) { return mlx_fast_layer_norm_ptr(res, x, weight, bias, eps, s); } mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new(void) { return mlx_fast_metal_kernel_config_new_ptr(); } void mlx_fast_metal_kernel_config_free(mlx_fast_metal_kernel_config cls) { mlx_fast_metal_kernel_config_free_ptr(cls); } int mlx_fast_metal_kernel_config_add_output_arg(mlx_fast_metal_kernel_config cls, const int* shape, size_t size, mlx_dtype dtype) { return mlx_fast_metal_kernel_config_add_output_arg_ptr(cls, shape, size, dtype); } int mlx_fast_metal_kernel_config_set_grid(mlx_fast_metal_kernel_config cls, int grid1, int grid2, int grid3) { return mlx_fast_metal_kernel_config_set_grid_ptr(cls, grid1, grid2, grid3); } int mlx_fast_metal_kernel_config_set_thread_group(mlx_fast_metal_kernel_config cls, int thread1, int thread2, int thread3) { return mlx_fast_metal_kernel_config_set_thread_group_ptr(cls, thread1, thread2, thread3); } int mlx_fast_metal_kernel_config_set_init_value(mlx_fast_metal_kernel_config cls, float value) { return mlx_fast_metal_kernel_config_set_init_value_ptr(cls, value); } int mlx_fast_metal_kernel_config_set_verbose(mlx_fast_metal_kernel_config cls, bool verbose) { return mlx_fast_metal_kernel_config_set_verbose_ptr(cls, verbose); } int mlx_fast_metal_kernel_config_add_template_arg_dtype(mlx_fast_metal_kernel_config cls, const char* name, mlx_dtype dtype) { return mlx_fast_metal_kernel_config_add_template_arg_dtype_ptr(cls, name, dtype); } int mlx_fast_metal_kernel_config_add_template_arg_int(mlx_fast_metal_kernel_config cls, const char* name, int value) { return mlx_fast_metal_kernel_config_add_template_arg_int_ptr(cls, name, value); } int mlx_fast_metal_kernel_config_add_template_arg_bool(mlx_fast_metal_kernel_config cls, const char* name, bool value) { return mlx_fast_metal_kernel_config_add_template_arg_bool_ptr(cls, name, value); } mlx_fast_metal_kernel mlx_fast_metal_kernel_new(const char* name, const mlx_vector_string input_names, const mlx_vector_string output_names, const char* source, const char* header, bool ensure_row_contiguous, bool atomic_outputs) { return mlx_fast_metal_kernel_new_ptr(name, input_names, output_names, source, header, ensure_row_contiguous, atomic_outputs); } void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls) { mlx_fast_metal_kernel_free_ptr(cls); } int mlx_fast_metal_kernel_apply(mlx_vector_array* outputs, mlx_fast_metal_kernel cls, const mlx_vector_array inputs, const mlx_fast_metal_kernel_config config, const mlx_stream stream) { return mlx_fast_metal_kernel_apply_ptr(outputs, cls, inputs, config, stream); } int mlx_fast_rms_norm(mlx_array* res, const mlx_array x, const mlx_array weight , float eps, const mlx_stream s) { return mlx_fast_rms_norm_ptr(res, x, weight, eps, s); } int mlx_fast_rope(mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, int offset, const mlx_array freqs , const mlx_stream s) { return mlx_fast_rope_ptr(res, x, dims, traditional, base, scale, offset, freqs, s); } int mlx_fast_rope_dynamic(mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, const mlx_array offset, const mlx_array freqs , const mlx_stream s) { return mlx_fast_rope_dynamic_ptr(res, x, dims, traditional, base, scale, offset, freqs, s); } int mlx_fast_scaled_dot_product_attention(mlx_array* res, const mlx_array queries, const mlx_array keys, const mlx_array values, float scale, const char* mask_mode, const mlx_array mask_arr , const mlx_array sinks , const mlx_stream s) { return mlx_fast_scaled_dot_product_attention_ptr(res, queries, keys, values, scale, mask_mode, mask_arr, sinks, s); } int mlx_fft_fft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) { return mlx_fft_fft_ptr(res, a, n, axis, s); } int mlx_fft_fft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_fft_fft2_ptr(res, a, n, n_num, axes, axes_num, s); } int mlx_fft_fftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_fft_fftn_ptr(res, a, n, n_num, axes, axes_num, s); } int mlx_fft_fftshift(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_fft_fftshift_ptr(res, a, axes, axes_num, s); } int mlx_fft_ifft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) { return mlx_fft_ifft_ptr(res, a, n, axis, s); } int mlx_fft_ifft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_fft_ifft2_ptr(res, a, n, n_num, axes, axes_num, s); } int mlx_fft_ifftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_fft_ifftn_ptr(res, a, n, n_num, axes, axes_num, s); } int mlx_fft_ifftshift(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_fft_ifftshift_ptr(res, a, axes, axes_num, s); } int mlx_fft_irfft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) { return mlx_fft_irfft_ptr(res, a, n, axis, s); } int mlx_fft_irfft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_fft_irfft2_ptr(res, a, n, n_num, axes, axes_num, s); } int mlx_fft_irfftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_fft_irfftn_ptr(res, a, n, n_num, axes, axes_num, s); } int mlx_fft_rfft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) { return mlx_fft_rfft_ptr(res, a, n, axis, s); } int mlx_fft_rfft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_fft_rfft2_ptr(res, a, n, n_num, axes, axes_num, s); } int mlx_fft_rfftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_fft_rfftn_ptr(res, a, n, n_num, axes, axes_num, s); } int mlx_load_reader(mlx_array* res, mlx_io_reader in_stream, const mlx_stream s) { return mlx_load_reader_ptr(res, in_stream, s); } int mlx_load(mlx_array* res, const char* file, const mlx_stream s) { return mlx_load_ptr(res, file, s); } int mlx_load_safetensors_reader(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, mlx_io_reader in_stream, const mlx_stream s) { return mlx_load_safetensors_reader_ptr(res_0, res_1, in_stream, s); } int mlx_load_safetensors(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, const char* file, const mlx_stream s) { return mlx_load_safetensors_ptr(res_0, res_1, file, s); } int mlx_save_writer(mlx_io_writer out_stream, const mlx_array a) { return mlx_save_writer_ptr(out_stream, a); } int mlx_save(const char* file, const mlx_array a) { return mlx_save_ptr(file, a); } int mlx_save_safetensors_writer(mlx_io_writer in_stream, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) { return mlx_save_safetensors_writer_ptr(in_stream, param, metadata); } int mlx_save_safetensors(const char* file, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) { return mlx_save_safetensors_ptr(file, param, metadata); } mlx_io_reader mlx_io_reader_new(void* desc, mlx_io_vtable vtable) { return mlx_io_reader_new_ptr(desc, vtable); } int mlx_io_reader_descriptor(void** desc_, mlx_io_reader io) { return mlx_io_reader_descriptor_ptr(desc_, io); } int mlx_io_reader_tostring(mlx_string* str_, mlx_io_reader io) { return mlx_io_reader_tostring_ptr(str_, io); } int mlx_io_reader_free(mlx_io_reader io) { return mlx_io_reader_free_ptr(io); } mlx_io_writer mlx_io_writer_new(void* desc, mlx_io_vtable vtable) { return mlx_io_writer_new_ptr(desc, vtable); } int mlx_io_writer_descriptor(void** desc_, mlx_io_writer io) { return mlx_io_writer_descriptor_ptr(desc_, io); } int mlx_io_writer_tostring(mlx_string* str_, mlx_io_writer io) { return mlx_io_writer_tostring_ptr(str_, io); } int mlx_io_writer_free(mlx_io_writer io) { return mlx_io_writer_free_ptr(io); } int mlx_linalg_cholesky(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) { return mlx_linalg_cholesky_ptr(res, a, upper, s); } int mlx_linalg_cholesky_inv(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) { return mlx_linalg_cholesky_inv_ptr(res, a, upper, s); } int mlx_linalg_cross(mlx_array* res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s) { return mlx_linalg_cross_ptr(res, a, b, axis, s); } int mlx_linalg_eig(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) { return mlx_linalg_eig_ptr(res_0, res_1, a, s); } int mlx_linalg_eigh(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const char* UPLO, const mlx_stream s) { return mlx_linalg_eigh_ptr(res_0, res_1, a, UPLO, s); } int mlx_linalg_eigvals(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_linalg_eigvals_ptr(res, a, s); } int mlx_linalg_eigvalsh(mlx_array* res, const mlx_array a, const char* UPLO, const mlx_stream s) { return mlx_linalg_eigvalsh_ptr(res, a, UPLO, s); } int mlx_linalg_inv(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_linalg_inv_ptr(res, a, s); } int mlx_linalg_lu(mlx_vector_array* res, const mlx_array a, const mlx_stream s) { return mlx_linalg_lu_ptr(res, a, s); } int mlx_linalg_lu_factor(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) { return mlx_linalg_lu_factor_ptr(res_0, res_1, a, s); } int mlx_linalg_norm(mlx_array* res, const mlx_array a, double ord, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s) { return mlx_linalg_norm_ptr(res, a, ord, axis, axis_num, keepdims, s); } int mlx_linalg_norm_matrix(mlx_array* res, const mlx_array a, const char* ord, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s) { return mlx_linalg_norm_matrix_ptr(res, a, ord, axis, axis_num, keepdims, s); } int mlx_linalg_norm_l2(mlx_array* res, const mlx_array a, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s) { return mlx_linalg_norm_l2_ptr(res, a, axis, axis_num, keepdims, s); } int mlx_linalg_pinv(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_linalg_pinv_ptr(res, a, s); } int mlx_linalg_qr(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) { return mlx_linalg_qr_ptr(res_0, res_1, a, s); } int mlx_linalg_solve(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_linalg_solve_ptr(res, a, b, s); } int mlx_linalg_solve_triangular(mlx_array* res, const mlx_array a, const mlx_array b, bool upper, const mlx_stream s) { return mlx_linalg_solve_triangular_ptr(res, a, b, upper, s); } int mlx_linalg_svd(mlx_vector_array* res, const mlx_array a, bool compute_uv, const mlx_stream s) { return mlx_linalg_svd_ptr(res, a, compute_uv, s); } int mlx_linalg_tri_inv(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) { return mlx_linalg_tri_inv_ptr(res, a, upper, s); } mlx_map_string_to_array mlx_map_string_to_array_new(void) { return mlx_map_string_to_array_new_ptr(); } int mlx_map_string_to_array_set(mlx_map_string_to_array* map, const mlx_map_string_to_array src) { return mlx_map_string_to_array_set_ptr(map, src); } int mlx_map_string_to_array_free(mlx_map_string_to_array map) { return mlx_map_string_to_array_free_ptr(map); } int mlx_map_string_to_array_insert(mlx_map_string_to_array map, const char* key, const mlx_array value) { return mlx_map_string_to_array_insert_ptr(map, key, value); } int mlx_map_string_to_array_get(mlx_array* value, const mlx_map_string_to_array map, const char* key) { return mlx_map_string_to_array_get_ptr(value, map, key); } mlx_map_string_to_array_iterator mlx_map_string_to_array_iterator_new(mlx_map_string_to_array map) { return mlx_map_string_to_array_iterator_new_ptr(map); } int mlx_map_string_to_array_iterator_free(mlx_map_string_to_array_iterator it) { return mlx_map_string_to_array_iterator_free_ptr(it); } int mlx_map_string_to_array_iterator_next(const char** key, mlx_array* value, mlx_map_string_to_array_iterator it) { return mlx_map_string_to_array_iterator_next_ptr(key, value, it); } mlx_map_string_to_string mlx_map_string_to_string_new(void) { return mlx_map_string_to_string_new_ptr(); } int mlx_map_string_to_string_set(mlx_map_string_to_string* map, const mlx_map_string_to_string src) { return mlx_map_string_to_string_set_ptr(map, src); } int mlx_map_string_to_string_free(mlx_map_string_to_string map) { return mlx_map_string_to_string_free_ptr(map); } int mlx_map_string_to_string_insert(mlx_map_string_to_string map, const char* key, const char* value) { return mlx_map_string_to_string_insert_ptr(map, key, value); } int mlx_map_string_to_string_get(const char** value, const mlx_map_string_to_string map, const char* key) { return mlx_map_string_to_string_get_ptr(value, map, key); } mlx_map_string_to_string_iterator mlx_map_string_to_string_iterator_new(mlx_map_string_to_string map) { return mlx_map_string_to_string_iterator_new_ptr(map); } int mlx_map_string_to_string_iterator_free(mlx_map_string_to_string_iterator it) { return mlx_map_string_to_string_iterator_free_ptr(it); } int mlx_map_string_to_string_iterator_next(const char** key, const char** value, mlx_map_string_to_string_iterator it) { return mlx_map_string_to_string_iterator_next_ptr(key, value, it); } int mlx_clear_cache(void) { return mlx_clear_cache_ptr(); } int mlx_get_active_memory(size_t* res) { return mlx_get_active_memory_ptr(res); } int mlx_get_cache_memory(size_t* res) { return mlx_get_cache_memory_ptr(res); } int mlx_get_memory_limit(size_t* res) { return mlx_get_memory_limit_ptr(res); } int mlx_get_peak_memory(size_t* res) { return mlx_get_peak_memory_ptr(res); } int mlx_reset_peak_memory(void) { return mlx_reset_peak_memory_ptr(); } int mlx_set_cache_limit(size_t* res, size_t limit) { return mlx_set_cache_limit_ptr(res, limit); } int mlx_set_memory_limit(size_t* res, size_t limit) { return mlx_set_memory_limit_ptr(res, limit); } int mlx_set_wired_limit(size_t* res, size_t limit) { return mlx_set_wired_limit_ptr(res, limit); } int mlx_metal_is_available(bool* res) { return mlx_metal_is_available_ptr(res); } int mlx_metal_start_capture(const char* path) { return mlx_metal_start_capture_ptr(path); } int mlx_metal_stop_capture(void) { return mlx_metal_stop_capture_ptr(); } int mlx_abs(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_abs_ptr(res, a, s); } int mlx_add(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_add_ptr(res, a, b, s); } int mlx_addmm(mlx_array* res, const mlx_array c, const mlx_array a, const mlx_array b, float alpha, float beta, const mlx_stream s) { return mlx_addmm_ptr(res, c, a, b, alpha, beta, s); } int mlx_all_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { return mlx_all_axes_ptr(res, a, axes, axes_num, keepdims, s); } int mlx_all_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { return mlx_all_axis_ptr(res, a, axis, keepdims, s); } int mlx_all(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { return mlx_all_ptr(res, a, keepdims, s); } int mlx_allclose(mlx_array* res, const mlx_array a, const mlx_array b, double rtol, double atol, bool equal_nan, const mlx_stream s) { return mlx_allclose_ptr(res, a, b, rtol, atol, equal_nan, s); } int mlx_any_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { return mlx_any_axes_ptr(res, a, axes, axes_num, keepdims, s); } int mlx_any_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { return mlx_any_axis_ptr(res, a, axis, keepdims, s); } int mlx_any(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { return mlx_any_ptr(res, a, keepdims, s); } int mlx_arange(mlx_array* res, double start, double stop, double step, mlx_dtype dtype, const mlx_stream s) { return mlx_arange_ptr(res, start, stop, step, dtype, s); } int mlx_arccos(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_arccos_ptr(res, a, s); } int mlx_arccosh(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_arccosh_ptr(res, a, s); } int mlx_arcsin(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_arcsin_ptr(res, a, s); } int mlx_arcsinh(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_arcsinh_ptr(res, a, s); } int mlx_arctan(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_arctan_ptr(res, a, s); } int mlx_arctan2(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_arctan2_ptr(res, a, b, s); } int mlx_arctanh(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_arctanh_ptr(res, a, s); } int mlx_argmax_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { return mlx_argmax_axis_ptr(res, a, axis, keepdims, s); } int mlx_argmax(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { return mlx_argmax_ptr(res, a, keepdims, s); } int mlx_argmin_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { return mlx_argmin_axis_ptr(res, a, axis, keepdims, s); } int mlx_argmin(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { return mlx_argmin_ptr(res, a, keepdims, s); } int mlx_argpartition_axis(mlx_array* res, const mlx_array a, int kth, int axis, const mlx_stream s) { return mlx_argpartition_axis_ptr(res, a, kth, axis, s); } int mlx_argpartition(mlx_array* res, const mlx_array a, int kth, const mlx_stream s) { return mlx_argpartition_ptr(res, a, kth, s); } int mlx_argsort_axis(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) { return mlx_argsort_axis_ptr(res, a, axis, s); } int mlx_argsort(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_argsort_ptr(res, a, s); } int mlx_array_equal(mlx_array* res, const mlx_array a, const mlx_array b, bool equal_nan, const mlx_stream s) { return mlx_array_equal_ptr(res, a, b, equal_nan, s); } int mlx_as_strided(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const int64_t* strides, size_t strides_num, size_t offset, const mlx_stream s) { return mlx_as_strided_ptr(res, a, shape, shape_num, strides, strides_num, offset, s); } int mlx_astype(mlx_array* res, const mlx_array a, mlx_dtype dtype, const mlx_stream s) { return mlx_astype_ptr(res, a, dtype, s); } int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_atleast_1d_ptr(res, a, s); } int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_atleast_2d_ptr(res, a, s); } int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_atleast_3d_ptr(res, a, s); } int mlx_bartlett(mlx_array* res, int M, const mlx_stream s) { return mlx_bartlett_ptr(res, M, s); } int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_bitwise_and_ptr(res, a, b, s); } int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_bitwise_invert_ptr(res, a, s); } int mlx_bitwise_or(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_bitwise_or_ptr(res, a, b, s); } int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_bitwise_xor_ptr(res, a, b, s); } int mlx_blackman(mlx_array* res, int M, const mlx_stream s) { return mlx_blackman_ptr(res, M, s); } int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) { return mlx_block_masked_mm_ptr(res, a, b, block_size, mask_out, mask_lhs, mask_rhs, s); } int mlx_broadcast_arrays(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s) { return mlx_broadcast_arrays_ptr(res, inputs, s); } int mlx_broadcast_to(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) { return mlx_broadcast_to_ptr(res, a, shape, shape_num, s); } int mlx_ceil(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_ceil_ptr(res, a, s); } int mlx_clip(mlx_array* res, const mlx_array a, const mlx_array a_min , const mlx_array a_max , const mlx_stream s) { return mlx_clip_ptr(res, a, a_min, a_max, s); } int mlx_concatenate_axis(mlx_array* res, const mlx_vector_array arrays, int axis, const mlx_stream s) { return mlx_concatenate_axis_ptr(res, arrays, axis, s); } int mlx_concatenate(mlx_array* res, const mlx_vector_array arrays, const mlx_stream s) { return mlx_concatenate_ptr(res, arrays, s); } int mlx_conjugate(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_conjugate_ptr(res, a, s); } int mlx_contiguous(mlx_array* res, const mlx_array a, bool allow_col_major, const mlx_stream s) { return mlx_contiguous_ptr(res, a, allow_col_major, s); } int mlx_conv1d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride, int padding, int dilation, int groups, const mlx_stream s) { return mlx_conv1d_ptr(res, input, weight, stride, padding, dilation, groups, s); } int mlx_conv2d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int groups, const mlx_stream s) { return mlx_conv2d_ptr(res, input, weight, stride_0, stride_1, padding_0, padding_1, dilation_0, dilation_1, groups, s); } int mlx_conv3d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int groups, const mlx_stream s) { return mlx_conv3d_ptr(res, input, weight, stride_0, stride_1, stride_2, padding_0, padding_1, padding_2, dilation_0, dilation_1, dilation_2, groups, s); } int mlx_conv_general(mlx_array* res, const mlx_array input, const mlx_array weight, const int* stride, size_t stride_num, const int* padding_lo, size_t padding_lo_num, const int* padding_hi, size_t padding_hi_num, const int* kernel_dilation, size_t kernel_dilation_num, const int* input_dilation, size_t input_dilation_num, int groups, bool flip, const mlx_stream s) { return mlx_conv_general_ptr(res, input, weight, stride, stride_num, padding_lo, padding_lo_num, padding_hi, padding_hi_num, kernel_dilation, kernel_dilation_num, input_dilation, input_dilation_num, groups, flip, s); } int mlx_conv_transpose1d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride, int padding, int dilation, int output_padding, int groups, const mlx_stream s) { return mlx_conv_transpose1d_ptr(res, input, weight, stride, padding, dilation, output_padding, groups, s); } int mlx_conv_transpose2d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int output_padding_0, int output_padding_1, int groups, const mlx_stream s) { return mlx_conv_transpose2d_ptr(res, input, weight, stride_0, stride_1, padding_0, padding_1, dilation_0, dilation_1, output_padding_0, output_padding_1, groups, s); } int mlx_conv_transpose3d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int output_padding_0, int output_padding_1, int output_padding_2, int groups, const mlx_stream s) { return mlx_conv_transpose3d_ptr(res, input, weight, stride_0, stride_1, stride_2, padding_0, padding_1, padding_2, dilation_0, dilation_1, dilation_2, output_padding_0, output_padding_1, output_padding_2, groups, s); } int mlx_copy(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_copy_ptr(res, a, s); } int mlx_cos(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_cos_ptr(res, a, s); } int mlx_cosh(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_cosh_ptr(res, a, s); } int mlx_cummax(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) { return mlx_cummax_ptr(res, a, axis, reverse, inclusive, s); } int mlx_cummin(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) { return mlx_cummin_ptr(res, a, axis, reverse, inclusive, s); } int mlx_cumprod(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) { return mlx_cumprod_ptr(res, a, axis, reverse, inclusive, s); } int mlx_cumsum(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) { return mlx_cumsum_ptr(res, a, axis, reverse, inclusive, s); } int mlx_degrees(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_degrees_ptr(res, a, s); } int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies) { return mlx_depends_ptr(res, inputs, dependencies); } int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s) { return mlx_dequantize_ptr(res, w, scales, biases, group_size, bits, mode, global_scale, dtype, s); } int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) { return mlx_diag_ptr(res, a, k, s); } int mlx_diagonal(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s) { return mlx_diagonal_ptr(res, a, offset, axis1, axis2, s); } int mlx_divide(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_divide_ptr(res, a, b, s); } int mlx_divmod(mlx_vector_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_divmod_ptr(res, a, b, s); } int mlx_einsum(mlx_array* res, const char* subscripts, const mlx_vector_array operands, const mlx_stream s) { return mlx_einsum_ptr(res, subscripts, operands, s); } int mlx_equal(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_equal_ptr(res, a, b, s); } int mlx_erf(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_erf_ptr(res, a, s); } int mlx_erfinv(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_erfinv_ptr(res, a, s); } int mlx_exp(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_exp_ptr(res, a, s); } int mlx_expand_dims_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_expand_dims_axes_ptr(res, a, axes, axes_num, s); } int mlx_expand_dims(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) { return mlx_expand_dims_ptr(res, a, axis, s); } int mlx_expm1(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_expm1_ptr(res, a, s); } int mlx_eye(mlx_array* res, int n, int m, int k, mlx_dtype dtype, const mlx_stream s) { return mlx_eye_ptr(res, n, m, k, dtype, s); } int mlx_flatten(mlx_array* res, const mlx_array a, int start_axis, int end_axis, const mlx_stream s) { return mlx_flatten_ptr(res, a, start_axis, end_axis, s); } int mlx_floor(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_floor_ptr(res, a, s); } int mlx_floor_divide(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_floor_divide_ptr(res, a, b, s); } int mlx_from_fp8(mlx_array* res, const mlx_array x, mlx_dtype dtype, const mlx_stream s) { return mlx_from_fp8_ptr(res, x, dtype, s); } int mlx_full(mlx_array* res, const int* shape, size_t shape_num, const mlx_array vals, mlx_dtype dtype, const mlx_stream s) { return mlx_full_ptr(res, shape, shape_num, vals, dtype, s); } int mlx_full_like(mlx_array* res, const mlx_array a, const mlx_array vals, mlx_dtype dtype, const mlx_stream s) { return mlx_full_like_ptr(res, a, vals, dtype, s); } int mlx_gather(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const int* axes, size_t axes_num, const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s) { return mlx_gather_ptr(res, a, indices, axes, axes_num, slice_sizes, slice_sizes_num, s); } int mlx_gather_single(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s) { return mlx_gather_single_ptr(res, a, indices, axis, slice_sizes, slice_sizes_num, s); } int mlx_gather_mm(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_array lhs_indices , const mlx_array rhs_indices , bool sorted_indices, const mlx_stream s) { return mlx_gather_mm_ptr(res, a, b, lhs_indices, rhs_indices, sorted_indices, s); } int mlx_gather_qmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , const mlx_array lhs_indices , const mlx_array rhs_indices , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, bool sorted_indices, const mlx_stream s) { return mlx_gather_qmm_ptr(res, x, w, scales, biases, lhs_indices, rhs_indices, transpose, group_size, bits, mode, sorted_indices, s); } int mlx_greater(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_greater_ptr(res, a, b, s); } int mlx_greater_equal(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_greater_equal_ptr(res, a, b, s); } int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s) { return mlx_hadamard_transform_ptr(res, a, scale, s); } int mlx_hamming(mlx_array* res, int M, const mlx_stream s) { return mlx_hamming_ptr(res, M, s); } int mlx_hanning(mlx_array* res, int M, const mlx_stream s) { return mlx_hanning_ptr(res, M, s); } int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) { return mlx_identity_ptr(res, n, dtype, s); } int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_imag_ptr(res, a, s); } int mlx_inner(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_inner_ptr(res, a, b, s); } int mlx_isclose(mlx_array* res, const mlx_array a, const mlx_array b, double rtol, double atol, bool equal_nan, const mlx_stream s) { return mlx_isclose_ptr(res, a, b, rtol, atol, equal_nan, s); } int mlx_isfinite(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_isfinite_ptr(res, a, s); } int mlx_isinf(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_isinf_ptr(res, a, s); } int mlx_isnan(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_isnan_ptr(res, a, s); } int mlx_isneginf(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_isneginf_ptr(res, a, s); } int mlx_isposinf(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_isposinf_ptr(res, a, s); } int mlx_kron(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_kron_ptr(res, a, b, s); } int mlx_left_shift(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_left_shift_ptr(res, a, b, s); } int mlx_less(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_less_ptr(res, a, b, s); } int mlx_less_equal(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_less_equal_ptr(res, a, b, s); } int mlx_linspace(mlx_array* res, double start, double stop, int num, mlx_dtype dtype, const mlx_stream s) { return mlx_linspace_ptr(res, start, stop, num, dtype, s); } int mlx_log(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_log_ptr(res, a, s); } int mlx_log10(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_log10_ptr(res, a, s); } int mlx_log1p(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_log1p_ptr(res, a, s); } int mlx_log2(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_log2_ptr(res, a, s); } int mlx_logaddexp(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_logaddexp_ptr(res, a, b, s); } int mlx_logcumsumexp(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) { return mlx_logcumsumexp_ptr(res, a, axis, reverse, inclusive, s); } int mlx_logical_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_logical_and_ptr(res, a, b, s); } int mlx_logical_not(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_logical_not_ptr(res, a, s); } int mlx_logical_or(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_logical_or_ptr(res, a, b, s); } int mlx_logsumexp_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { return mlx_logsumexp_axes_ptr(res, a, axes, axes_num, keepdims, s); } int mlx_logsumexp_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { return mlx_logsumexp_axis_ptr(res, a, axis, keepdims, s); } int mlx_logsumexp(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { return mlx_logsumexp_ptr(res, a, keepdims, s); } int mlx_masked_scatter(mlx_array* res, const mlx_array a, const mlx_array mask, const mlx_array src, const mlx_stream s) { return mlx_masked_scatter_ptr(res, a, mask, src, s); } int mlx_matmul(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_matmul_ptr(res, a, b, s); } int mlx_max_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { return mlx_max_axes_ptr(res, a, axes, axes_num, keepdims, s); } int mlx_max_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { return mlx_max_axis_ptr(res, a, axis, keepdims, s); } int mlx_max(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { return mlx_max_ptr(res, a, keepdims, s); } int mlx_maximum(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_maximum_ptr(res, a, b, s); } int mlx_mean_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { return mlx_mean_axes_ptr(res, a, axes, axes_num, keepdims, s); } int mlx_mean_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { return mlx_mean_axis_ptr(res, a, axis, keepdims, s); } int mlx_mean(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { return mlx_mean_ptr(res, a, keepdims, s); } int mlx_median(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { return mlx_median_ptr(res, a, axes, axes_num, keepdims, s); } int mlx_meshgrid(mlx_vector_array* res, const mlx_vector_array arrays, bool sparse, const char* indexing, const mlx_stream s) { return mlx_meshgrid_ptr(res, arrays, sparse, indexing, s); } int mlx_min_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { return mlx_min_axes_ptr(res, a, axes, axes_num, keepdims, s); } int mlx_min_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { return mlx_min_axis_ptr(res, a, axis, keepdims, s); } int mlx_min(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { return mlx_min_ptr(res, a, keepdims, s); } int mlx_minimum(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_minimum_ptr(res, a, b, s); } int mlx_moveaxis(mlx_array* res, const mlx_array a, int source, int destination, const mlx_stream s) { return mlx_moveaxis_ptr(res, a, source, destination, s); } int mlx_multiply(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_multiply_ptr(res, a, b, s); } int mlx_nan_to_num(mlx_array* res, const mlx_array a, float nan, mlx_optional_float posinf, mlx_optional_float neginf, const mlx_stream s) { return mlx_nan_to_num_ptr(res, a, nan, posinf, neginf, s); } int mlx_negative(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_negative_ptr(res, a, s); } int mlx_not_equal(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_not_equal_ptr(res, a, b, s); } int mlx_number_of_elements(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool inverted, mlx_dtype dtype, const mlx_stream s) { return mlx_number_of_elements_ptr(res, a, axes, axes_num, inverted, dtype, s); } int mlx_ones(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_stream s) { return mlx_ones_ptr(res, shape, shape_num, dtype, s); } int mlx_ones_like(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_ones_like_ptr(res, a, s); } int mlx_outer(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_outer_ptr(res, a, b, s); } int mlx_pad(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const int* low_pad_size, size_t low_pad_size_num, const int* high_pad_size, size_t high_pad_size_num, const mlx_array pad_value, const char* mode, const mlx_stream s) { return mlx_pad_ptr(res, a, axes, axes_num, low_pad_size, low_pad_size_num, high_pad_size, high_pad_size_num, pad_value, mode, s); } int mlx_pad_symmetric(mlx_array* res, const mlx_array a, int pad_width, const mlx_array pad_value, const char* mode, const mlx_stream s) { return mlx_pad_symmetric_ptr(res, a, pad_width, pad_value, mode, s); } int mlx_partition_axis(mlx_array* res, const mlx_array a, int kth, int axis, const mlx_stream s) { return mlx_partition_axis_ptr(res, a, kth, axis, s); } int mlx_partition(mlx_array* res, const mlx_array a, int kth, const mlx_stream s) { return mlx_partition_ptr(res, a, kth, s); } int mlx_power(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_power_ptr(res, a, b, s); } int mlx_prod_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { return mlx_prod_axes_ptr(res, a, axes, axes_num, keepdims, s); } int mlx_prod_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { return mlx_prod_axis_ptr(res, a, axis, keepdims, s); } int mlx_prod(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { return mlx_prod_ptr(res, a, keepdims, s); } int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) { return mlx_put_along_axis_ptr(res, a, indices, values, axis, s); } int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s) { return mlx_qqmm_ptr(res, x, w, w_scales, group_size, bits, mode, global_scale_x, global_scale_w, s); } int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s) { return mlx_quantize_ptr(res, w, group_size, bits, mode, global_scale, s); } int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) { return mlx_quantized_matmul_ptr(res, x, w, scales, biases, transpose, group_size, bits, mode, s); } int mlx_radians(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_radians_ptr(res, a, s); } int mlx_real(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_real_ptr(res, a, s); } int mlx_reciprocal(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_reciprocal_ptr(res, a, s); } int mlx_remainder(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_remainder_ptr(res, a, b, s); } int mlx_repeat_axis(mlx_array* res, const mlx_array arr, int repeats, int axis, const mlx_stream s) { return mlx_repeat_axis_ptr(res, arr, repeats, axis, s); } int mlx_repeat(mlx_array* res, const mlx_array arr, int repeats, const mlx_stream s) { return mlx_repeat_ptr(res, arr, repeats, s); } int mlx_reshape(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) { return mlx_reshape_ptr(res, a, shape, shape_num, s); } int mlx_right_shift(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_right_shift_ptr(res, a, b, s); } int mlx_roll_axis(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, int axis, const mlx_stream s) { return mlx_roll_axis_ptr(res, a, shift, shift_num, axis, s); } int mlx_roll_axes(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_roll_axes_ptr(res, a, shift, shift_num, axes, axes_num, s); } int mlx_roll(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, const mlx_stream s) { return mlx_roll_ptr(res, a, shift, shift_num, s); } int mlx_round(mlx_array* res, const mlx_array a, int decimals, const mlx_stream s) { return mlx_round_ptr(res, a, decimals, s); } int mlx_rsqrt(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_rsqrt_ptr(res, a, s); } int mlx_scatter(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_scatter_ptr(res, a, indices, updates, axes, axes_num, s); } int mlx_scatter_single(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) { return mlx_scatter_single_ptr(res, a, indices, updates, axis, s); } int mlx_scatter_add(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_scatter_add_ptr(res, a, indices, updates, axes, axes_num, s); } int mlx_scatter_add_single(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) { return mlx_scatter_add_single_ptr(res, a, indices, updates, axis, s); } int mlx_scatter_add_axis(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) { return mlx_scatter_add_axis_ptr(res, a, indices, values, axis, s); } int mlx_scatter_max(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_scatter_max_ptr(res, a, indices, updates, axes, axes_num, s); } int mlx_scatter_max_single(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) { return mlx_scatter_max_single_ptr(res, a, indices, updates, axis, s); } int mlx_scatter_min(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_scatter_min_ptr(res, a, indices, updates, axes, axes_num, s); } int mlx_scatter_min_single(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) { return mlx_scatter_min_single_ptr(res, a, indices, updates, axis, s); } int mlx_scatter_prod(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_scatter_prod_ptr(res, a, indices, updates, axes, axes_num, s); } int mlx_scatter_prod_single(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) { return mlx_scatter_prod_single_ptr(res, a, indices, updates, axis, s); } int mlx_segmented_mm(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_array segments, const mlx_stream s) { return mlx_segmented_mm_ptr(res, a, b, segments, s); } int mlx_sigmoid(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_sigmoid_ptr(res, a, s); } int mlx_sign(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_sign_ptr(res, a, s); } int mlx_sin(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_sin_ptr(res, a, s); } int mlx_sinh(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_sinh_ptr(res, a, s); } int mlx_slice(mlx_array* res, const mlx_array a, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) { return mlx_slice_ptr(res, a, start, start_num, stop, stop_num, strides, strides_num, s); } int mlx_slice_dynamic(mlx_array* res, const mlx_array a, const mlx_array start, const int* axes, size_t axes_num, const int* slice_size, size_t slice_size_num, const mlx_stream s) { return mlx_slice_dynamic_ptr(res, a, start, axes, axes_num, slice_size, slice_size_num, s); } int mlx_slice_update(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) { return mlx_slice_update_ptr(res, src, update, start, start_num, stop, stop_num, strides, strides_num, s); } int mlx_slice_update_dynamic(mlx_array* res, const mlx_array src, const mlx_array update, const mlx_array start, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_slice_update_dynamic_ptr(res, src, update, start, axes, axes_num, s); } int mlx_softmax_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool precise, const mlx_stream s) { return mlx_softmax_axes_ptr(res, a, axes, axes_num, precise, s); } int mlx_softmax_axis(mlx_array* res, const mlx_array a, int axis, bool precise, const mlx_stream s) { return mlx_softmax_axis_ptr(res, a, axis, precise, s); } int mlx_softmax(mlx_array* res, const mlx_array a, bool precise, const mlx_stream s) { return mlx_softmax_ptr(res, a, precise, s); } int mlx_sort_axis(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) { return mlx_sort_axis_ptr(res, a, axis, s); } int mlx_sort(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_sort_ptr(res, a, s); } int mlx_split(mlx_vector_array* res, const mlx_array a, int num_splits, int axis, const mlx_stream s) { return mlx_split_ptr(res, a, num_splits, axis, s); } int mlx_split_sections(mlx_vector_array* res, const mlx_array a, const int* indices, size_t indices_num, int axis, const mlx_stream s) { return mlx_split_sections_ptr(res, a, indices, indices_num, axis, s); } int mlx_sqrt(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_sqrt_ptr(res, a, s); } int mlx_square(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_square_ptr(res, a, s); } int mlx_squeeze_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_squeeze_axes_ptr(res, a, axes, axes_num, s); } int mlx_squeeze_axis(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) { return mlx_squeeze_axis_ptr(res, a, axis, s); } int mlx_squeeze(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_squeeze_ptr(res, a, s); } int mlx_stack_axis(mlx_array* res, const mlx_vector_array arrays, int axis, const mlx_stream s) { return mlx_stack_axis_ptr(res, arrays, axis, s); } int mlx_stack(mlx_array* res, const mlx_vector_array arrays, const mlx_stream s) { return mlx_stack_ptr(res, arrays, s); } int mlx_std_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, int ddof, const mlx_stream s) { return mlx_std_axes_ptr(res, a, axes, axes_num, keepdims, ddof, s); } int mlx_std_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, int ddof, const mlx_stream s) { return mlx_std_axis_ptr(res, a, axis, keepdims, ddof, s); } int mlx_std(mlx_array* res, const mlx_array a, bool keepdims, int ddof, const mlx_stream s) { return mlx_std_ptr(res, a, keepdims, ddof, s); } int mlx_stop_gradient(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_stop_gradient_ptr(res, a, s); } int mlx_subtract(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_subtract_ptr(res, a, b, s); } int mlx_sum_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { return mlx_sum_axes_ptr(res, a, axes, axes_num, keepdims, s); } int mlx_sum_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { return mlx_sum_axis_ptr(res, a, axis, keepdims, s); } int mlx_sum(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { return mlx_sum_ptr(res, a, keepdims, s); } int mlx_swapaxes(mlx_array* res, const mlx_array a, int axis1, int axis2, const mlx_stream s) { return mlx_swapaxes_ptr(res, a, axis1, axis2, s); } int mlx_take_axis(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const mlx_stream s) { return mlx_take_axis_ptr(res, a, indices, axis, s); } int mlx_take(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_stream s) { return mlx_take_ptr(res, a, indices, s); } int mlx_take_along_axis(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const mlx_stream s) { return mlx_take_along_axis_ptr(res, a, indices, axis, s); } int mlx_tan(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_tan_ptr(res, a, s); } int mlx_tanh(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_tanh_ptr(res, a, s); } int mlx_tensordot(mlx_array* res, const mlx_array a, const mlx_array b, const int* axes_a, size_t axes_a_num, const int* axes_b, size_t axes_b_num, const mlx_stream s) { return mlx_tensordot_ptr(res, a, b, axes_a, axes_a_num, axes_b, axes_b_num, s); } int mlx_tensordot_axis(mlx_array* res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s) { return mlx_tensordot_axis_ptr(res, a, b, axis, s); } int mlx_tile(mlx_array* res, const mlx_array arr, const int* reps, size_t reps_num, const mlx_stream s) { return mlx_tile_ptr(res, arr, reps, reps_num, s); } int mlx_to_fp8(mlx_array* res, const mlx_array x, const mlx_stream s) { return mlx_to_fp8_ptr(res, x, s); } int mlx_topk_axis(mlx_array* res, const mlx_array a, int k, int axis, const mlx_stream s) { return mlx_topk_axis_ptr(res, a, k, axis, s); } int mlx_topk(mlx_array* res, const mlx_array a, int k, const mlx_stream s) { return mlx_topk_ptr(res, a, k, s); } int mlx_trace(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, mlx_dtype dtype, const mlx_stream s) { return mlx_trace_ptr(res, a, offset, axis1, axis2, dtype, s); } int mlx_transpose_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_transpose_axes_ptr(res, a, axes, axes_num, s); } int mlx_transpose(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_transpose_ptr(res, a, s); } int mlx_tri(mlx_array* res, int n, int m, int k, mlx_dtype type, const mlx_stream s) { return mlx_tri_ptr(res, n, m, k, type, s); } int mlx_tril(mlx_array* res, const mlx_array x, int k, const mlx_stream s) { return mlx_tril_ptr(res, x, k, s); } int mlx_triu(mlx_array* res, const mlx_array x, int k, const mlx_stream s) { return mlx_triu_ptr(res, x, k, s); } int mlx_unflatten(mlx_array* res, const mlx_array a, int axis, const int* shape, size_t shape_num, const mlx_stream s) { return mlx_unflatten_ptr(res, a, axis, shape, shape_num, s); } int mlx_var_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, int ddof, const mlx_stream s) { return mlx_var_axes_ptr(res, a, axes, axes_num, keepdims, ddof, s); } int mlx_var_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, int ddof, const mlx_stream s) { return mlx_var_axis_ptr(res, a, axis, keepdims, ddof, s); } int mlx_var(mlx_array* res, const mlx_array a, bool keepdims, int ddof, const mlx_stream s) { return mlx_var_ptr(res, a, keepdims, ddof, s); } int mlx_view(mlx_array* res, const mlx_array a, mlx_dtype dtype, const mlx_stream s) { return mlx_view_ptr(res, a, dtype, s); } int mlx_where(mlx_array* res, const mlx_array condition, const mlx_array x, const mlx_array y, const mlx_stream s) { return mlx_where_ptr(res, condition, x, y, s); } int mlx_zeros(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_stream s) { return mlx_zeros_ptr(res, shape, shape_num, dtype, s); } int mlx_zeros_like(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_zeros_like_ptr(res, a, s); } int mlx_random_bernoulli(mlx_array* res, const mlx_array p, const int* shape, size_t shape_num, const mlx_array key , const mlx_stream s) { return mlx_random_bernoulli_ptr(res, p, shape, shape_num, key, s); } int mlx_random_bits(mlx_array* res, const int* shape, size_t shape_num, int width, const mlx_array key , const mlx_stream s) { return mlx_random_bits_ptr(res, shape, shape_num, width, key, s); } int mlx_random_categorical_shape(mlx_array* res, const mlx_array logits, int axis, const int* shape, size_t shape_num, const mlx_array key , const mlx_stream s) { return mlx_random_categorical_shape_ptr(res, logits, axis, shape, shape_num, key, s); } int mlx_random_categorical_num_samples(mlx_array* res, const mlx_array logits_, int axis, int num_samples, const mlx_array key , const mlx_stream s) { return mlx_random_categorical_num_samples_ptr(res, logits_, axis, num_samples, key, s); } int mlx_random_categorical(mlx_array* res, const mlx_array logits, int axis, const mlx_array key , const mlx_stream s) { return mlx_random_categorical_ptr(res, logits, axis, key, s); } int mlx_random_gumbel(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) { return mlx_random_gumbel_ptr(res, shape, shape_num, dtype, key, s); } int mlx_random_key(mlx_array* res, uint64_t seed) { return mlx_random_key_ptr(res, seed); } int mlx_random_laplace(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, float loc, float scale, const mlx_array key , const mlx_stream s) { return mlx_random_laplace_ptr(res, shape, shape_num, dtype, loc, scale, key, s); } int mlx_random_multivariate_normal(mlx_array* res, const mlx_array mean, const mlx_array cov, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) { return mlx_random_multivariate_normal_ptr(res, mean, cov, shape, shape_num, dtype, key, s); } int mlx_random_normal_broadcast(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array loc , const mlx_array scale , const mlx_array key , const mlx_stream s) { return mlx_random_normal_broadcast_ptr(res, shape, shape_num, dtype, loc, scale, key, s); } int mlx_random_normal(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, float loc, float scale, const mlx_array key , const mlx_stream s) { return mlx_random_normal_ptr(res, shape, shape_num, dtype, loc, scale, key, s); } int mlx_random_permutation(mlx_array* res, const mlx_array x, int axis, const mlx_array key , const mlx_stream s) { return mlx_random_permutation_ptr(res, x, axis, key, s); } int mlx_random_permutation_arange(mlx_array* res, int x, const mlx_array key , const mlx_stream s) { return mlx_random_permutation_arange_ptr(res, x, key, s); } int mlx_random_randint(mlx_array* res, const mlx_array low, const mlx_array high, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) { return mlx_random_randint_ptr(res, low, high, shape, shape_num, dtype, key, s); } int mlx_random_seed(uint64_t seed) { return mlx_random_seed_ptr(seed); } int mlx_random_split_num(mlx_array* res, const mlx_array key, int num, const mlx_stream s) { return mlx_random_split_num_ptr(res, key, num, s); } int mlx_random_split(mlx_array* res_0, mlx_array* res_1, const mlx_array key, const mlx_stream s) { return mlx_random_split_ptr(res_0, res_1, key, s); } int mlx_random_truncated_normal(mlx_array* res, const mlx_array lower, const mlx_array upper, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) { return mlx_random_truncated_normal_ptr(res, lower, upper, shape, shape_num, dtype, key, s); } int mlx_random_uniform(mlx_array* res, const mlx_array low, const mlx_array high, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) { return mlx_random_uniform_ptr(res, low, high, shape, shape_num, dtype, key, s); } mlx_stream mlx_stream_new(void) { return mlx_stream_new_ptr(); } mlx_stream mlx_stream_new_device(mlx_device dev) { return mlx_stream_new_device_ptr(dev); } int mlx_stream_set(mlx_stream* stream, const mlx_stream src) { return mlx_stream_set_ptr(stream, src); } int mlx_stream_free(mlx_stream stream) { return mlx_stream_free_ptr(stream); } int mlx_stream_tostring(mlx_string* str, mlx_stream stream) { return mlx_stream_tostring_ptr(str, stream); } bool mlx_stream_equal(mlx_stream lhs, mlx_stream rhs) { return mlx_stream_equal_ptr(lhs, rhs); } int mlx_stream_get_device(mlx_device* dev, mlx_stream stream) { return mlx_stream_get_device_ptr(dev, stream); } int mlx_stream_get_index(int* index, mlx_stream stream) { return mlx_stream_get_index_ptr(index, stream); } int mlx_synchronize(mlx_stream stream) { return mlx_synchronize_ptr(stream); } int mlx_get_default_stream(mlx_stream* stream, mlx_device dev) { return mlx_get_default_stream_ptr(stream, dev); } int mlx_set_default_stream(mlx_stream stream) { return mlx_set_default_stream_ptr(stream); } mlx_stream mlx_default_cpu_stream_new(void) { return mlx_default_cpu_stream_new_ptr(); } mlx_stream mlx_default_gpu_stream_new(void) { return mlx_default_gpu_stream_new_ptr(); } mlx_string mlx_string_new(void) { return mlx_string_new_ptr(); } mlx_string mlx_string_new_data(const char* str) { return mlx_string_new_data_ptr(str); } int mlx_string_set(mlx_string* str, const mlx_string src) { return mlx_string_set_ptr(str, src); } const char* mlx_string_data(mlx_string str) { return mlx_string_data_ptr(str); } int mlx_string_free(mlx_string str) { return mlx_string_free_ptr(str); } int mlx_async_eval(const mlx_vector_array outputs) { return mlx_async_eval_ptr(outputs); } int mlx_checkpoint(mlx_closure* res, const mlx_closure fun) { return mlx_checkpoint_ptr(res, fun); } int mlx_custom_function(mlx_closure* res, const mlx_closure fun, const mlx_closure_custom fun_vjp , const mlx_closure_custom_jvp fun_jvp , const mlx_closure_custom_vmap fun_vmap) { return mlx_custom_function_ptr(res, fun, fun_vjp, fun_jvp, fun_vmap); } int mlx_custom_vjp(mlx_closure* res, const mlx_closure fun, const mlx_closure_custom fun_vjp) { return mlx_custom_vjp_ptr(res, fun, fun_vjp); } int mlx_eval(const mlx_vector_array outputs) { return mlx_eval_ptr(outputs); } int mlx_jvp(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array tangents) { return mlx_jvp_ptr(res_0, res_1, fun, primals, tangents); } int mlx_value_and_grad(mlx_closure_value_and_grad* res, const mlx_closure fun, const int* argnums, size_t argnums_num) { return mlx_value_and_grad_ptr(res, fun, argnums, argnums_num); } int mlx_vjp(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array cotangents) { return mlx_vjp_ptr(res_0, res_1, fun, primals, cotangents); } int mlx_detail_vmap_replace(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array s_inputs, const mlx_vector_array s_outputs, const int* in_axes, size_t in_axes_num, const int* out_axes, size_t out_axes_num) { return mlx_detail_vmap_replace_ptr(res, inputs, s_inputs, s_outputs, in_axes, in_axes_num, out_axes, out_axes_num); } int mlx_detail_vmap_trace(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array inputs, const int* in_axes, size_t in_axes_num) { return mlx_detail_vmap_trace_ptr(res_0, res_1, fun, inputs, in_axes, in_axes_num); } mlx_vector_array mlx_vector_array_new(void) { return mlx_vector_array_new_ptr(); } int mlx_vector_array_set(mlx_vector_array* vec, const mlx_vector_array src) { return mlx_vector_array_set_ptr(vec, src); } int mlx_vector_array_free(mlx_vector_array vec) { return mlx_vector_array_free_ptr(vec); } mlx_vector_array mlx_vector_array_new_data(const mlx_array* data, size_t size) { return mlx_vector_array_new_data_ptr(data, size); } mlx_vector_array mlx_vector_array_new_value(const mlx_array val) { return mlx_vector_array_new_value_ptr(val); } int mlx_vector_array_set_data(mlx_vector_array* vec, const mlx_array* data, size_t size) { return mlx_vector_array_set_data_ptr(vec, data, size); } int mlx_vector_array_set_value(mlx_vector_array* vec, const mlx_array val) { return mlx_vector_array_set_value_ptr(vec, val); } int mlx_vector_array_append_data(mlx_vector_array vec, const mlx_array* data, size_t size) { return mlx_vector_array_append_data_ptr(vec, data, size); } int mlx_vector_array_append_value(mlx_vector_array vec, const mlx_array val) { return mlx_vector_array_append_value_ptr(vec, val); } size_t mlx_vector_array_size(mlx_vector_array vec) { return mlx_vector_array_size_ptr(vec); } int mlx_vector_array_get(mlx_array* res, const mlx_vector_array vec, size_t idx) { return mlx_vector_array_get_ptr(res, vec, idx); } mlx_vector_vector_array mlx_vector_vector_array_new(void) { return mlx_vector_vector_array_new_ptr(); } int mlx_vector_vector_array_set(mlx_vector_vector_array* vec, const mlx_vector_vector_array src) { return mlx_vector_vector_array_set_ptr(vec, src); } int mlx_vector_vector_array_free(mlx_vector_vector_array vec) { return mlx_vector_vector_array_free_ptr(vec); } mlx_vector_vector_array mlx_vector_vector_array_new_data(const mlx_vector_array* data, size_t size) { return mlx_vector_vector_array_new_data_ptr(data, size); } mlx_vector_vector_array mlx_vector_vector_array_new_value(const mlx_vector_array val) { return mlx_vector_vector_array_new_value_ptr(val); } int mlx_vector_vector_array_set_data(mlx_vector_vector_array* vec, const mlx_vector_array* data, size_t size) { return mlx_vector_vector_array_set_data_ptr(vec, data, size); } int mlx_vector_vector_array_set_value(mlx_vector_vector_array* vec, const mlx_vector_array val) { return mlx_vector_vector_array_set_value_ptr(vec, val); } int mlx_vector_vector_array_append_data(mlx_vector_vector_array vec, const mlx_vector_array* data, size_t size) { return mlx_vector_vector_array_append_data_ptr(vec, data, size); } int mlx_vector_vector_array_append_value(mlx_vector_vector_array vec, const mlx_vector_array val) { return mlx_vector_vector_array_append_value_ptr(vec, val); } size_t mlx_vector_vector_array_size(mlx_vector_vector_array vec) { return mlx_vector_vector_array_size_ptr(vec); } int mlx_vector_vector_array_get(mlx_vector_array* res, const mlx_vector_vector_array vec, size_t idx) { return mlx_vector_vector_array_get_ptr(res, vec, idx); } mlx_vector_int mlx_vector_int_new(void) { return mlx_vector_int_new_ptr(); } int mlx_vector_int_set(mlx_vector_int* vec, const mlx_vector_int src) { return mlx_vector_int_set_ptr(vec, src); } int mlx_vector_int_free(mlx_vector_int vec) { return mlx_vector_int_free_ptr(vec); } mlx_vector_int mlx_vector_int_new_data(int* data, size_t size) { return mlx_vector_int_new_data_ptr(data, size); } mlx_vector_int mlx_vector_int_new_value(int val) { return mlx_vector_int_new_value_ptr(val); } int mlx_vector_int_set_data(mlx_vector_int* vec, int* data, size_t size) { return mlx_vector_int_set_data_ptr(vec, data, size); } int mlx_vector_int_set_value(mlx_vector_int* vec, int val) { return mlx_vector_int_set_value_ptr(vec, val); } int mlx_vector_int_append_data(mlx_vector_int vec, int* data, size_t size) { return mlx_vector_int_append_data_ptr(vec, data, size); } int mlx_vector_int_append_value(mlx_vector_int vec, int val) { return mlx_vector_int_append_value_ptr(vec, val); } size_t mlx_vector_int_size(mlx_vector_int vec) { return mlx_vector_int_size_ptr(vec); } int mlx_vector_int_get(int* res, const mlx_vector_int vec, size_t idx) { return mlx_vector_int_get_ptr(res, vec, idx); } mlx_vector_string mlx_vector_string_new(void) { return mlx_vector_string_new_ptr(); } int mlx_vector_string_set(mlx_vector_string* vec, const mlx_vector_string src) { return mlx_vector_string_set_ptr(vec, src); } int mlx_vector_string_free(mlx_vector_string vec) { return mlx_vector_string_free_ptr(vec); } mlx_vector_string mlx_vector_string_new_data(const char** data, size_t size) { return mlx_vector_string_new_data_ptr(data, size); } mlx_vector_string mlx_vector_string_new_value(const char* val) { return mlx_vector_string_new_value_ptr(val); } int mlx_vector_string_set_data(mlx_vector_string* vec, const char** data, size_t size) { return mlx_vector_string_set_data_ptr(vec, data, size); } int mlx_vector_string_set_value(mlx_vector_string* vec, const char* val) { return mlx_vector_string_set_value_ptr(vec, val); } int mlx_vector_string_append_data(mlx_vector_string vec, const char** data, size_t size) { return mlx_vector_string_append_data_ptr(vec, data, size); } int mlx_vector_string_append_value(mlx_vector_string vec, const char* val) { return mlx_vector_string_append_value_ptr(vec, val); } size_t mlx_vector_string_size(mlx_vector_string vec) { return mlx_vector_string_size_ptr(vec); } int mlx_vector_string_get(char** res, const mlx_vector_string vec, size_t idx) { return mlx_vector_string_get_ptr(res, vec, idx); } int mlx_version(mlx_string* str_) { return mlx_version_ptr(str_); }